本文是 CS336 系列笔记的第二讲,PyTorch 与资源计算。本讲比较简单,从两个问题出发,讨论如何计算计算量,并顺带讲一些 PyTorch 的入门概念。本讲的重点不在于 PyTorch 基础,而在于培养“资源计算”的思维模式。

本节课的目标就是回答两个问题:

  • 使用 1024 张 H100 在 15 T token 的数据集上训练 70B 的模型需要多久?
  • 在 8 张 H100 上使用 AdamW 优化器,最大可以训练多大的模型?

内存计算

Tensor 入门

深度学习的一切都是使用 Tensor 进行存储的,包括参数、梯度、优化器状态、数据、激活层。

入门结束。(逃)

Tensor 内存

几乎所有 Tensor 都以浮点数的形式存储,通常默认的数据格式为 FP32。

Tensor 占据的内存取决于数据类型和 Tensor 内元素的个数。

为了容纳更多的参数,业界还引入了 FP16 、 BF16 和 FP8 数据格式。

总而言之:

  • FP32 对于训练已经够够的了,但是需要大量内存
  • FP8 / FP16 / BF16 训练存在风险,可能会导致不稳定
  • 解决方案:使用混合精度训练

Tensor 操作

课程介绍了一些 Tensor 的基本操作,此处不赘述。值得一提的是 stride 机制,若不熟悉可以查阅一下。

Tensor 操作 FLOPs 计算

一次浮点运算 Float-point Operation FLOP 指的是一次浮点数加法或者浮点数乘法。需要区分的两个概念:

  • FLOPs:float-point operations,浮点计算量,用来表达需要进行多少次浮点计算
  • FLOP/s:FLOP per second,每秒浮点计算量,用来衡量硬件的计算能力

一些直观感受:

  • 训练 GPT-3(2020)使用了 3.14e23 FLOPs
  • 训练 GPT-3(2023)被推测使用了 2e25 FLOPs
  • A100 峰值算力 3.12e14 FLOP/s
  • H100 系数

假定输入数据为 [B, D],线性层将 D 映射到 K 即权重矩阵形状为 [D, K]。对于输出张量的每个元素,都需要通过将两个长度为 D 的向量逐元素相乘后相加,即 2D 个 FLOPs,一共有 B x K 个输出元素,所以一次矩乘的 FLOPs 为 2 x B x D x K

mfu = 模型实际每秒FLOP / 硬件理论最大每秒FLOP,一般来说 MFU 大于 0.5 已经相当不错了。

模型

参数初始化

简单来说,参数初始化也是一门学问,否则会导致梯度爆炸或者消失,这门课对此没有详细介绍。CMU 10414 对此有从数学上的推导,见我当时的学习笔记:《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客

其它内容

这里还介绍了如何训练一个模型、优化器、检查点等内容,过于基础,此处不赘述。

混合精度训练

低精度可以减少显存占用并加快计算速度,同时避免低精度带来的不稳定,我们的策略是:默认使用 FP32,并尽可能使用 BF16、FP8,具体来说,在前向时使用低精度,在反向时使用高精度。

小结

本讲作为 PyTorch 的入门课,介绍的都是一些很基本的概念。涉及的 PyTorch 的用法和训练过程都是一个引子,重点还是与培养资源计算的思考模式,毕竟大模型时代,效率就是金钱。