TL;DR 本讲将分布式通信概念落地为可执行代码,通过 PyTorch 和 NCCL 展示了集合通信的实际实现,并手写了三种并行策略的简化版本。从 All-Reduce 的基准测试到 MLP 的切分实现,揭示了通信开销与计算模式的核心逻辑。
通信原语:从接口到底层#
集合通信在 PyTorch 中的接口简洁直白。初始化进程组后,每个进程执行相同的通信操作,NCCL 自动处理底层数据流动。
1
2
3
4
5
6
7
8
9
|
# 初始化进程组(每个进程都要执行)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15623"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# All-Reduce 示例
tensor = torch.tensor([0., 1, 2, 3], device=device) + rank
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) # 原地修改
print(f"Rank {rank}: {tensor}") # 所有进程输出相同的求和结果
|
基准测试揭示了硬件带宽的实际限制。测量 All-Reduce 的有效带宽时,需要计算传输数据总量和总耗时:
1
2
3
4
5
6
7
8
9
10
11
|
# 基准测试核心逻辑
start_time = time.time()
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
end_time = time.time()
duration = end_time - start_time
size_bytes = tensor.element_size() * tensor.numel()
sent_bytes = size_bytes * 2 * (world_size - 1) # 发送和接收
bandwidth = sent_bytes / (world_size * duration)
print(f"带宽:{round(bandwidth / 1024**3)} GB/s")
|
数据并行:梯度同步的集体操作#
数据并行的核心是在反向传播后同步梯度。每个进程处理部分数据,计算局部梯度,然后通过 All-Reduce 得到全局平均梯度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
def data_parallelism_main(rank, world_size, data, num_layers):
setup(rank, world_size)
# 数据切片:每个进程获取部分批次
batch_size = data.size(0)
local_batch_size = batch_size // world_size
local_data = data[rank*local_batch_size:(rank+1)*local_batch_size].to(device)
# 完整模型副本
params = [get_init_params(dim, dim, rank) for _ in range(num_layers)]
optimizer = torch.optim.AdamW(params, lr=1e-3)
# 前向传播(使用本地数据)
x = local_data
for param in params:
x = x @ param
x = F.gelu(x)
loss = x.square().mean()
# 反向传播
loss.backward()
# 关键:梯度同步
for param in params:
dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False)
# 参数更新(各进程同步更新相同参数)
optimizer.step()
|
张量并行:层内矩阵的切分与聚合#
张量并行将权重矩阵按列切分,每个进程计算部分输出,然后通过 All-Gather 聚合完整结果。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
def tensor_parallelism_main(rank, world_size, data, num_layers):
setup(rank, world_size)
data = data.to(device)
num_dim = data.size(1)
local_num_dim = num_dim // world_size # 特征维度切分
# 每个进程只持有部分参数
params = [get_init_params(num_dim, local_num_dim, rank) for _ in range(num_layers)]
# 前向传播
x = data
for i in range(num_layers):
# 局部矩阵乘法
x = x @ params[i] # 输出形状:[batch_size, local_num_dim]
x = F.gelu(x)
# 聚合所有进程的部分结果
activations = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(tensor_list=activations, tensor=x, async_op=False)
# 拼接得到完整特征
x = torch.cat(activations, dim=1) # 形状:[batch_size, num_dim]
|
流水线并行:层间激活值的流动#
流水线并行将模型按深度切分,激活值在进程间传递。使用微批次可以减少流水线气泡。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
def pipeline_parallelism_main(rank, world_size, data, num_layers, num_micro_batches):
setup(rank, world_size)
# 层分配
local_num_layers = num_layers // world_size
local_params = [get_init_params(dim, dim, rank) for _ in range(local_num_layers)]
# 微批次划分
micro_batch_size = data.size(0) // num_micro_batches
if rank == 0:
micro_batches = data.chunk(num_micro_batches, dim=0)
else:
micro_batches = [torch.empty(micro_batch_size, dim, device=device)
for _ in range(num_micro_batches)]
# 前向流水线
for micro_batch in micro_batches:
# 从上一进程接收(如果不是第一个进程)
if rank > 0:
dist.recv(tensor=micro_batch, src=rank-1)
# 本地计算
for param in local_params:
micro_batch = micro_batch @ param
micro_batch = F.gelu(micro_batch)
# 发送到下一进程(如果不是最后一个进程)
if rank < world_size - 1:
dist.send(tensor=micro_batch, dst=rank+1)
|
总结:代码揭示的模式#
混合并行策略的选择在代码层面体现为通信原语的组合。数据并行的 All-Reduce、张量并行的 All-Gather、流水线并行的 Send/Recv,这些操作共同构成了分布式训练的通信骨架。
实际系统会在这些基础模式上添加优化:通信与计算的重叠、梯度检查点、异步执行等。但理解这些基础实现是分析和优化分布式训练性能的起点。
注:以上代码为教学用简化实现,省略了错误处理、设备管理、性能优化等工程细节。实际项目应使用 PyTorch 的 DistributedDataParallel、FullyShardedDataParallel 或第三方框架如 DeepSpeed。