本文是 CS336 系列笔记的第二讲,PyTorch 与资源计算。本讲比较简单,从两个问题出发,讨论如何计算计算量,并顺带讲一些 PyTorch 的入门概念。本讲的重点不在于 PyTorch 基础,而在于培养“资源计算”的思维模式。
本节课的目标就是回答两个问题:
- 使用 1024 张 H100 在 15 T token 的数据集上训练 70B 的模型需要多久?
- 在 8 张 H100 上使用 AdamW 优化器,最大可以训练多大的模型?
内存计算
Tensor 入门
深度学习的一切都是使用 Tensor 进行存储的,包括参数、梯度、优化器状态、数据、激活层。
入门结束。(逃)
Tensor 内存
几乎所有 Tensor 都以浮点数的形式存储,通常默认的数据格式为 FP32。
Tensor 占据的内存取决于数据类型和 Tensor 内元素的个数。
为了容纳更多的参数,业界还引入了 FP16 、 BF16 和 FP8 数据格式。
总而言之:
- FP32 对于训练已经够够的了,但是需要大量内存
- FP8 / FP16 / BF16 训练存在风险,可能会导致不稳定
- 解决方案:使用混合精度训练
Tensor 操作
课程介绍了一些 Tensor 的基本操作,此处不赘述。值得一提的是 stride 机制,若不熟悉可以查阅一下。
Tensor 操作 FLOPs 计算
一次浮点运算 Float-point Operation FLOP 指的是一次浮点数加法或者浮点数乘法。需要区分的两个概念:
- FLOPs:float-point operations,浮点计算量,用来表达需要进行多少次浮点计算
- FLOP/s:FLOP per second,每秒浮点计算量,用来衡量硬件的计算能力
一些直观感受:
- 训练 GPT-3(2020)使用了 3.14e23 FLOPs
- 训练 GPT-3(2023)被推测使用了 2e25 FLOPs
- A100 峰值算力 3.12e14 FLOP/s
- H100 系数
假定输入数据为 [B, D]
,线性层将 D
映射到 K
即权重矩阵形状为 [D, K]
。对于输出张量的每个元素,都需要通过将两个长度为 D 的向量逐元素相乘后相加,即 2D 个 FLOPs,一共有 B x K
个输出元素,所以一次矩乘的 FLOPs 为 2 x B x D x K
。
mfu = 模型实际每秒FLOP / 硬件理论最大每秒FLOP
,一般来说 MFU 大于 0.5 已经相当不错了。
模型
参数初始化
简单来说,参数初始化也是一门学问,否则会导致梯度爆炸或者消失,这门课对此没有详细介绍。CMU 10414 对此有从数学上的推导,见我当时的学习笔记:《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客。
其它内容
这里还介绍了如何训练一个模型、优化器、检查点等内容,过于基础,此处不赘述。
混合精度训练
低精度可以减少显存占用并加快计算速度,同时避免低精度带来的不稳定,我们的策略是:默认使用 FP32,并尽可能使用 BF16、FP8,具体来说,在前向时使用低精度,在反向时使用高精度。
小结
本讲作为 PyTorch 的入门课,介绍的都是一些很基本的概念。涉及的 PyTorch 的用法和训练过程都是一个引子,重点还是与培养资源计算的思考模式,毕竟大模型时代,效率就是金钱。