TL;DR 本讲将分布式通信概念落地为可执行代码,通过 PyTorch 和 NCCL 展示了集合通信的实际实现,并手写了三种并行策略的简化版本。从 All-Reduce 的基准测试到 MLP 的切分实现,揭示了通信开销与计算模式的核心逻辑。

通信原语:从接口到底层

集合通信在 PyTorch 中的接口简洁直白。初始化进程组后,每个进程执行相同的通信操作,NCCL 自动处理底层数据流动。

1
2
3
4
5
6
7
8
9
# 初始化进程组(每个进程都要执行)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15623"
dist.init_process_group("nccl", rank=rank, world_size=world_size)

# All-Reduce 示例
tensor = torch.tensor([0., 1, 2, 3], device=device) + rank
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)  # 原地修改
print(f"Rank {rank}: {tensor}")  # 所有进程输出相同的求和结果

基准测试揭示了硬件带宽的实际限制。测量 All-Reduce 的有效带宽时,需要计算传输数据总量和总耗时:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# 基准测试核心逻辑
start_time = time.time()
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
end_time = time.time()

duration = end_time - start_time
size_bytes = tensor.element_size() * tensor.numel()
sent_bytes = size_bytes * 2 * (world_size - 1)  # 发送和接收
bandwidth = sent_bytes / (world_size * duration)
print(f"带宽:{round(bandwidth / 1024**3)} GB/s")

数据并行:梯度同步的集体操作

数据并行的核心是在反向传播后同步梯度。每个进程处理部分数据,计算局部梯度,然后通过 All-Reduce 得到全局平均梯度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def data_parallelism_main(rank, world_size, data, num_layers):
    setup(rank, world_size)
    
    # 数据切片:每个进程获取部分批次
    batch_size = data.size(0)
    local_batch_size = batch_size // world_size
    local_data = data[rank*local_batch_size:(rank+1)*local_batch_size].to(device)
    
    # 完整模型副本
    params = [get_init_params(dim, dim, rank) for _ in range(num_layers)]
    optimizer = torch.optim.AdamW(params, lr=1e-3)
    
    # 前向传播(使用本地数据)
    x = local_data
    for param in params:
        x = x @ param
        x = F.gelu(x)
    loss = x.square().mean()
    
    # 反向传播
    loss.backward()
    
    # 关键:梯度同步
    for param in params:
        dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False)
    
    # 参数更新(各进程同步更新相同参数)
    optimizer.step()

张量并行:层内矩阵的切分与聚合

张量并行将权重矩阵按列切分,每个进程计算部分输出,然后通过 All-Gather 聚合完整结果。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def tensor_parallelism_main(rank, world_size, data, num_layers):
    setup(rank, world_size)
    
    data = data.to(device)
    num_dim = data.size(1)
    local_num_dim = num_dim // world_size  # 特征维度切分
    
    # 每个进程只持有部分参数
    params = [get_init_params(num_dim, local_num_dim, rank) for _ in range(num_layers)]
    
    # 前向传播
    x = data
    for i in range(num_layers):
        # 局部矩阵乘法
        x = x @ params[i]  # 输出形状:[batch_size, local_num_dim]
        x = F.gelu(x)
        
        # 聚合所有进程的部分结果
        activations = [torch.empty_like(x) for _ in range(world_size)]
        dist.all_gather(tensor_list=activations, tensor=x, async_op=False)
        
        # 拼接得到完整特征
        x = torch.cat(activations, dim=1)  # 形状:[batch_size, num_dim]

流水线并行:层间激活值的流动

流水线并行将模型按深度切分,激活值在进程间传递。使用微批次可以减少流水线气泡。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def pipeline_parallelism_main(rank, world_size, data, num_layers, num_micro_batches):
    setup(rank, world_size)
    
    # 层分配
    local_num_layers = num_layers // world_size
    local_params = [get_init_params(dim, dim, rank) for _ in range(local_num_layers)]
    
    # 微批次划分
    micro_batch_size = data.size(0) // num_micro_batches
    if rank == 0:
        micro_batches = data.chunk(num_micro_batches, dim=0)
    else:
        micro_batches = [torch.empty(micro_batch_size, dim, device=device) 
                        for _ in range(num_micro_batches)]
    
    # 前向流水线
    for micro_batch in micro_batches:
        # 从上一进程接收(如果不是第一个进程)
        if rank > 0:
            dist.recv(tensor=micro_batch, src=rank-1)
        
        # 本地计算
        for param in local_params:
            micro_batch = micro_batch @ param
            micro_batch = F.gelu(micro_batch)
        
        # 发送到下一进程(如果不是最后一个进程)
        if rank < world_size - 1:
            dist.send(tensor=micro_batch, dst=rank+1)

总结:代码揭示的模式

混合并行策略的选择在代码层面体现为通信原语的组合。数据并行的 All-Reduce、张量并行的 All-Gather、流水线并行的 Send/Recv,这些操作共同构成了分布式训练的通信骨架。

实际系统会在这些基础模式上添加优化:通信与计算的重叠、梯度检查点、异步执行等。但理解这些基础实现是分析和优化分布式训练性能的起点。

注:以上代码为教学用简化实现,省略了错误处理、设备管理、性能优化等工程细节。实际项目应使用 PyTorch 的 DistributedDataParallel、FullyShardedDataParallel 或第三方框架如 DeepSpeed。