Skip to content

分布式训练框架

📅 发表于 2025/07/18
🔄 更新于 2025/07/18
👁️ -- 次访问
📝 0 字
0 分钟
llm-infra
#浮点数
#有偏指数
#FP32
#FP16
#BF16
#梯度溢出
#舍入误差
#动态损失缩放
#混合精度优化
#FP32权重备份
#模型显存分析
#参数
#梯度
#优化器
#激活值
#DeepSpeed
#ZeRO-DP
#Zero1
#Zero2
#Zero3
#Zero-Offload

背景知识

分布式基础概念

基础概念

1. 机器概念

  • 主节点 master ip/port:协调其他节点、分配任务、结果汇总等。
  • 节点编号 node_rank:每个节点的唯一标识,不同计算机通信
  • 局部进程编号 local_rank节点内部的进程编号
  • 全局进程编号 rank:整个系统全局进程编号,唯一标识
  • 全局进程总数 word_size:整个系统所有进程总数

2. 通信策略

  • mpi:跨节点通信库,常用于CPU集群
  • gloo: 高性能分布式训练框架,支持CPU和GPU集群
  • nccl:英伟达的GPU专有通信库,适用于GPU

浮点数

浮点数(三要素+FP64+FP32)

1. 浮点数三要素🐦

  • 符号为S:0正,1负
  • 指数位E:实际有偏指数=存储指数值-偏移量 ⭐,2位底
    • 为何需要有偏置指数:指数需要有正有负,减去偏置值就能变有符号整数,做范围平衡
    • 偏置值如何计算:2n11,IEEE754规定该位置值位0,大于其是正数,小于其是负数
      • FP32:281=127
      • FP16:2511=15
      • BF16:281=127
    • 如,FP32是127,指数部分是01111100,实际指数:124-127=-3,即23
    • 指数位全1:特数值;尾数位全0 - 无穷尾数位不全0 - NaN 熟悉的NaN ‼️
    • 指数位全00,+0, -0,非规范化浮点数,
  • 尾数位M:有效数字,精度;实际尾数=1+尾数部分
    • 因为标准设计从首位从1开始,不足的需要小数点右移。
    • 比如尾数部分是11000..,实际尾数=1+21+22=1+0.5+0.25=1.75

2. 单精度+双精度FP32/FP64🐔

  • 双精度,FP64:共64位,1 + 11 + 52。8个字节。
  • 单精度,FP32:共32位,1 + 8 + 23,指数位8位。4个字节。
FP16+BF16

3. 半精度FP16🐱 🚀

  • 共16位,1 + 5 + 10,2个字节 。
  • 优点
    • 减少显存占用⭐:只占原始的一半;能训更大模型、更大bs;通信成本降低;
    • 计算更快:半精度吞吐量是单精度的2-8倍
  • 缺点
    • 数值下溢或上溢问题:数值仅5位范围有限
      • 训练后期梯度很小,乘以学习率后会更小
    • 精度较低+舍入误差问题:尾数位仅10
      • 🔑FP16最小间隔为213权重+梯度=23+214仍然等于23表达不了214
      • 梯度消失了‼️
  • 解法
    • 动态损失缩放(Dynamic Loss Scaling)
    • 混合精度优化(Mixed Precision Optimizer)等

4. BF16 (Brain float point)🐻

  • 共16位,1 + 8 + 7
  • 核心
    • 扩大指数位数到8位数值范围对齐FP32,对精度做折中
    • 避免梯度溢出、弥补了数值稳定
  • 优点
    • 更宽的数值范围;适合深度学习,精度足够,数值更广;硬件家属。
    • 适用部分场景,如梯度累加权重更新等操作。
  • 缺点
    • 精度较低;不适用所有场景。

Fp64&fp32:

非常重要,指数位全0或全1的特殊情况💡 :

最大值:

浮点数的示意图:

混合精度训练

混合精度 (同时使用FP16+FP32)

FP16缺点

  • FP16优点及缺点(上文),以及依赖的相应解决办法

解法一:FP32权重备份

  • 目的:解fp16舍入误差问题
  • 思路:
    • 参数+激活+梯度 在训练时用FP16 🎉 ‼️
    • 更新时拷贝使用FP32 参数 ‼️
      • 权重=旧权重+lr*梯度lr*梯度很小,使用FP32避免FP16更新无效
  • 拷贝FP32权重,是否会导致更高内存?
    • fp32权重确实额外增加
    • 但内存中大部分是激活的值(bs大时更明显),只要激活是fp16,内存就能减半

解法二:Loss Scale 🎓

  • 背景
    • 解决fp16带来的梯度下溢出问题,保留loss关键信息
    • 训练后期梯度特别小:67%梯度 < 224如果用fp16,则全为0
  • 方法
    • 对loss进行scale链式法则原理,会传导到每个梯度,梯度平移到fp16区间
    • Scale UPloss增大2k
    • Scale Down权重梯度缩小2k
  • Dynamic Loss Scale
    • 前期:使用大缩放因子224,判断是否梯度溢出
      • 无梯度溢出:不调整缩放因子,继续迭代
      • 有梯度溢出缩放因子减半,重新确认缩放因子情况,直到不溢出
    • 后期:loss稳定、梯度幅度变小,可以使用更高缩放因子防止数据再次下溢

解法三:提高算数精度

  • 乘法:fp16;加法:fp32

解法四:BF16

  • 扩大了数值范围,解决了溢出问题;但精度不足。可以替代一些关键操作,如梯度累加、权重更新等。

其他方法

  • 梯度裁切:防止梯度爆炸
  • 学习率调整:帮助模型更好收敛
混合精度训练流程

主要流程

  • 把FP32模型参数转成FP16
    • 减少内存占用,计算更快
  • FP32 计算loss乘以loss scale平移到fp16空间
    • 避免反向时梯度下溢
  • FP16 计算梯度,再把梯度转为FP32
    • 保持数值稳定性,避免低精度导致的梯度消失或爆炸问题
  • FP32梯度除以loss scale做还原,梯度再乘以学习率,更新FP32权重
  • FP32权重转为FP16
  • 注意:最终存储到磁盘的模型权重是FP32精度

模型显存占用分析

[LLM]大模型显存计算公式与优化估算大模型显存

模型显存占用(混合梯度)

具体显存占用项 (模型参数为m或Ψ)

  • 模型参数fp162m,1参数2Byte

    • 1参数2字节;所以1B参数就是2GB
    • 1 billon=10亿参数,G=10亿字节,除以1024*1024*1024 单位就是G
  • 模型梯度fp162m,1参数2Byte,和模型参数一致

  • 优化器(Adam)12m,1参数12Byte

    • 模型参数副本fp32:4m,1参数4Byte
    • Adam m fp32:4m,1参数4Byte, 梯度平均移动值
    • Adam v fp32:4m,1参数4Byte,梯度平方的移动平均值
      • 每个参数都有一个变量momentum和variance
    • 如果是8bit优化器:则是4+1+1=6Byte。
  • 激活值(Activation)

    • 激活在训练中会占用大量显存‼️
      • 可通过激活checkpoint重计算来降低,但仍然很大。
    • s序列长度,b是micro_batchsize,h是hidden_size,a是attention头数,L是Transformer层数,γ是系数,
    • Megatron及阿里云计算公式
activation 显存计算=sbh(34+5as/h)Lγ激活缓存后显存计算=sbhL(2)
  • 临时缓冲区:如allreduce梯度、fp32模型参数buffer等
  • 无法使用的显存碎片

总结

  • 显存总占用
    • 模型+梯度+优化器+激活值+其他;前3者统称为模型状态,是ZeRO主要优化对象。
    • 单卡:模型 + 梯度 + 优化器 + 激活值 + 其他
    • 并行 (PP+TP+Zero1):
      • 模型/(PP*TP) + 梯度/(PP) + 优化器/N + 激活/TP
  • 模型状态
    • 模型+梯度是4Byte,优化器是12Byte; 优化目标:k从12 -> x?
    • 即:1个参数16Byte,4m+12m=16m,也等于4m+km
  • 1.5B 模型显存要求 (混合训练)
    • 模型状态:1.5*16=24GB
    • 激活值:s=1k, bs=32,需要60GB,激活重计算后,降至8GB
    • 缓冲区:1.5*4=6GB

DeepSeed

GPU存储太多内容,模型、梯度、优化器、激活、buffer等内容。DeepSpeed 采用Zero Redundancy Optimizer,减内存占用

ZeRO-DP 去除冗余的数据并行方案

ZeRO-DP

核心思想

  • 普通DDP模型,但每张卡只存一部分优化器状态os、梯度g、参数p

  • ZeRO-DP: Pos+Pg+Pp 3种分片,Pos对应ZeRO-1,Pos+g对应ZeRO-2,Pos+g+p对应ZeRO3。

ZeRO1:+优化器状态划分 Pos

  • 核心思想
    • 对优化器状态os做划分每卡只持有1N的os,g和p存完整。
    • 每卡仅更新1N参数,并广播告知新参数
    • 各卡单独做前向、反向、参数更新
      • 反向:在梯度AllReduce后
      • 参数更新:各step末尾执行all-gather更新整个参数
    • 适用场景
      • 适合Adam,因为有额外参数m和v。
      • 不适合SGD,因为其只有较少参数内存。
  • 显存占用
    • 4m+12m -> 4m+12mN4m,为标准的1/4

ZeRO2:+梯度划分 Pg

  • 核心思想
    • 优化器os已划分;对梯度g做划分,每卡持有1N梯度
    • 各卡单独更新1N参数,并广播告知。
    • 梯度reduce到各rank,无需allredcue,节省开销
  • 显存占用
    • 4m+12m -> 2m+2mN+12mN2m,为标准的1/8

ZeRO3:+参数划分 Pp

  • 核心思想
    • 优化器os已切分;梯度g已切分;对参数p做切分,每卡持有1N参数
    • 确保能被单个GPU存放下
  • 显存占用
    • 4m+12m -> 2mN+2mN+12mN=16mN,N比较大时,趋近于0

Zero-Offload 策略

Zero-Offload

背景

  • GPU贵少,使用便宜的CPU内存
  • 把部分GPU计算下放到CPU和内存,需注意
    • 不能让CPU和GPU通信成为瓶颈
    • 不能让CPU参与过多计算

核心思想

  • 四类节点:前向、反向、参数更新、float2half
  • GPU和CPU
    • 前向和反向,合成一个节点,放在GPU上,FWD-BWD Super Node
    • 参数更新和float2half,合成一个节点,放在CPU上,Update Super Node
  • 计算流程
    • 在GPU上进行前后向计算,
    • 梯度传给CPU进行参数更新,
    • 再把新参数回传给GPU
  • 多卡场景
    • 利用ZeRO-2,把1N的优化器状态os和梯度p都offload到内存,在CPU上做参数更新
    • 每张卡至少对应1个CPU进程,每个CPU进程只负责1N的计算。
总访客数:   ·   总访问量:
PLM's Blog @ 2016 - 2025