您的位置:  首页 > 技术杂谈 > 正文

【大规模训练】transformer 中的张量模型并行

2022-07-19 17:00 https://my.oschina.net/u/5682856/blog/5555783 SenseParrots 次阅读 条评论

文 @ 不愿透露姓名的小 P 同学


0 前言

随着模型和算力的规模越来越大,训练也不再局限于单卡,而是采用多机多卡并行的方式来进行模型训练。同时,目前也存在很多不同的并行训练优化方案,模型并行就是其中的一种,下文将从三个方面来展开介绍。

  • 模型并行的动机和现状
  • 模型并行的原理
  • 数据并行和模型并行的结合

1 模型并行的动机和现状

1.1 动机

在训练大模型时,受限于显卡的内存大小,我们需要应用一系列技术方法来减小计算内存。主要从两方面展开:

  1. 优化单张卡上的计算内存;
  2. 利用多卡来分担。

Optimizer state sharding (ZeRO) 中,我们已经介绍了许多在实际中使用到的方法。模型并行,或者称之为张量并行、层内并行就是其中之一,它利用了分块矩阵的原理,专注于解决单层参数过大的问题。

一个超大模型除了会在深度上做拓展,往往也会拓展宽度。这其中的比例需要模型设计者去把握,当下也有一些工作对这类问题进行了研究分析。对于模型深度拓展带来的超大显存,可以采用流水线并行(层间并行)或者是一些计算序列优化的手段(比如 checkpointing )来解决;而对于模型宽度拓展带来的超大参数,模型并行无疑是当前较为成熟的方案。

transformer 中 linear 层的广泛使用,能充分发挥模型并行的优势。

1.2 现状

早期在人脸识别任务中使用较多的方法是InsightFace 。它对最后的 FC 层和 loss 层进行参数分割和并行计算。

而目前在 transformer 中广泛使用的方法是Megatron ,它使用数据并行+模型并行+流水并行的方法,实现了在 3072 张卡上支持 1 万亿的超大模型训练,其中模型并行的设计充分利用了 transformer 中的 AttentionMLP 等模型特征,组合并行的 linear 以减少通信量。

近期开源了一个基于 Pytorch 的工具包 torchshard ,它重写了 Megatron 里面的接口,同时兼容了AMPZeRO 等技术的混合使用,为用户提供了一种轻量级使用模型并行的方式。

此外,还有 Parallelformersfairscale 等开源库,原理和设计大同小异,在此不再一一赘述了。

2 模型并行的原理

模型并行的主要思想是:将网络层的输入、参数与运算分到不同的卡上,落点在于参数。

2.1 数学原理

对于linear层,最朴素的思想就是利用分块矩阵计算法则,获得结果的一致性。

如下面公式 (1) 和公式(2) 所示,将AB分别做矩阵按块的拆分。

\( \left[ \begin{matrix} X \end{matrix} \right] \times \left[ \begin{matrix} A_1 & A_2 \end{matrix} \right] =\left[ \begin{matrix} XA_1 & XA_2 \end{matrix} \right] \tag{1} \)

\( \left[ \begin{matrix} Y_1 & Y_2 \end{matrix} \right] \times \left[ \begin{matrix} B_1 \\\\ B_2 \end{matrix} \right] =\left[ \begin{matrix} Y_1B_1 + Y_2B_2 \end{matrix} \right] \tag{2} \)

对于非 linear 层,不做额外设计。当接收到的输入是并行式的,可以考虑将数据的通信延后或者想办法降低通信量。比如 softmax 层,可以依赖数学上的传递性,即,局部最大之最大为全局最大;以及结合律,即,局部和之和即为全局之和,来达到降低通信量的效果。

2.1 在 transformer 模型中的做法

2.1.1 linear的并行

在下图中,我们将 XYZ 当作网络中的激活值,将 AB 当作网络中的 linear 层的参数。

第一行显示的并行方式就是将参数按照列来切分,这样得到的输出按照列来拼接就可以得到与不使用并行等价的结果。

第二行显示的则是将参数按照行来切分,为了做矩阵乘法的时候形状能匹配,需要将输入按照列切分,这样得到的输出 Z1Z2 再做加法就可以得到最终的结果 Z

为了简化操作,这里的切分都是指均分。

每一次切分都会带来额外的两次通信(图中显示的是 forward 阶段,相应地,backward 阶段也会有一次对应规则的通信)。为了减小通信的代价,一个比较直接的想法就是,如果能把两次切分组合在一起,那么上面一行最后的 all_gather 和下面一行最开始的 split 就可以省去。也就是图中中间那一条折线所示,参数矩阵列切分和行切分的组合。

transformer 的结构天然支持这种组合,下面两张图来自于 Megatron 的论文 :

先来看 MLP 模块,图中参数 A 按列分成 A1A2,得到的结果进入 GeLU 模块,GeLU 的计算仅与当前元素有关,满足 GeLU([XA1, XA2])=[GeLU(XA1, XA2)],所以可以继续往后传播,直到遇到下一次矩阵分块乘法。

图中的 f 和 g 表示和通信切块相关的操作,在 MLP 中,f 的 forward 是切片(split),backward 是求和(all_reduce);g 的 forward 是求和(all_reduce),backward 是切片(split)。

如上面下图所示,在 Attention 中也有类似的组合,我们常用的是 Multiheads Attention,QKV只在 head 内部做交互,而我们并行拆分的恰好是 heads,所以在两次矩阵分块乘法之间,计算是等价的,前后两次通信操作和 MLP 类似。

一个典型的模型并行版本的 Multiheads Attention 如下代码所示:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.world_size = get_tensor_model_parallel_world_size() # 模型并行路数
        self.dim_per_partition = divide(dim, self.world_size)
        self.dim_per_attention_head = divide(dim, num_heads)
        self.num_heads_per_partition = divide(num_heads, self.world_size)

        self.qkv = ColumnParallelLinear(dim, dim * 3, gather_output=False) # gather_output=False,延迟 gather 的时间,和后面的 RowParallelLinear 组合起来
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = RowParallelLinear(dim, dim, input_is_parallel=True) # input_is_parallel=True,输入已经是按列切分的了
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, self.num_heads_per_partition, 3, self.dim_per_attention_head).permute(3, 0, 2, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dim_per_partition)
        x = self.proj_drop(self.proj(x))
        return x

2.2.2 cross_entropy_loss 的并行

如果在 cross_entropy_loss 前的输出是一个按列并行的 linear,那么马上进行通信的代价是比较大的。实际上,有一种比较节省通信的方式:接入并行版本的 cross_entropy_loss。首先计算的是 softmax 的值,如下公式所示,其中 p 代表模型并行的组内卡号

\( softmax(x_i) = \frac{e^{x_i}}{\sum_{j}{e^{x_j}}} = \frac{e^{x_i -x_{\max}}}{\sum_{j}{e^{x_j-x_{\max}}}} = \frac{e^{x_i-x_{\max}}}{\sum_p{\sum_{k}{e^{x_k-x_{\max}}}}} \tag{3} \)

\( x_{\max} = \max_p(\max_k(x_k)) \tag{4} \)

为了防止浮点溢出,一般在计算时,会将 \(x_{i}\) 减去 \(x_{\max}\) 的值,所以在这个过程中,会有两次通信,一次是通信得到 \(x_{\max}\),其中 \(\underset{k}{\max}(x_k)\) 指模型并行组内的一张卡上的最大 x,\(\underset{p}{\max}(\underset{k}{\max}(x_k))\) 是指通过通信求得模型并行组内的整体最大 x。另外一次通信是通过求部分和的和得到整体的和。

求完 softmax 的值之后,还需要将 target 做切分,得到部分类别的 loss 之后,最后再做一次求和,得到所有类别的 loss,这样通过一共三次总量极小的通信,就可以完成 cross_entropy_loss 的计算。

2.2.3 额外的一些处理

在多个 linear 的并行组合在一起时,如果中间有一些 dropout 之类的操作,需要尽可能保证随机性的一致。

此外,在一些需要全局参数信息的地方,也需要做对应的参数通信,比如 clip_grad_norm

3 数据并行和模型并行的结合

在实际应用中,我们常常将模型并行与多种并行方法结合在一起。下面将详细介绍与数据并行混合使用时的具体流程,流水并行在这里暂不作讨论。

3.1 Megatron 的方式

在 megatron 中,一般在单机内使用模型并行,在机器之间使用数据并行。

以下图为例,一共 8 张卡,2 路数据并行,4 路模型并行,则在数据并行组(data parallel group)内部卡的序号从 0 到 3,一共两组,在模型并行组(model parallel group)内部卡的序号从 0 到 1,一共四组。

在通信初始化的时候,这些卡号和组别就需要被定义清楚。在训练中,一开始的参数对齐和每轮迭代时的梯度对齐都需要指定对应组别。

在同一数据并行组内,与单纯的数据并行方式相比,几乎没有发生变化。以 PyTorch 为例,数据并行模型一般采用 DistributedDataParallel,有一种方式是将其参数 process_group 替换成数据并行组。

# model 的 ddp 要传入 data_parallel_group 作为它的 process_group:在初始化 broadcast 阶段和 average_gradient 阶段起作用

model = DistributedDataParallel(model, process_group=data_parallel_group)

下面解释一下同一模型并行组的行为。

  • 首先,为了保证在分布式训练中不同卡上的参数在初始化时保持同步,除了上面隐含在 DDP 中的数据并行组里的参数 broadcast,需要额外在模型并行组内做一次非并行 layer 的参数 broadcast;
  • 其次,在同一个模型并行组内,处理的数据是同一份,这不仅要求在 data sampler 里保证,为了防止输入到网络中的数据有不同的变形,一般也在 forward 之前,在组内对数据做 broadcast。
  • 这样,在非并行的 layer,比如 patch embedding,组内执行的是完全一致的行为,在并行的 layer,比如包含 parallel attention 和 parallel mlp 的 parallel transformer layer,组内处理的仍然是同一份数据,但是参数已经切片,所以得到的是部分结果,再通过通信组合到一起。

Megatron 将上面所提到的功能合在一起,写在了一个类似于 DistributedDataParallel 的 class 里面。

3.2 InsightFace 的方式

模型的 layer 大部分都可以被模型并行拆分时,上面这种方式实现起来比较简单,拓展性也强。但是当模型中大部分 layer 都不能被拆分时,这种方式会带来比较大的问题:冗余计算。

在 InsightFace 的人脸识别任务中,前面大部分都是无法拆分的层,仅仅在最后的全连接层会有参数量过大的问题,在这种情况下只对全连接层及其后面的 cross_entropy 做模型并行的计算,前面仍然保持数据并行的方式。

为了适应这种并行方式的切分,需要在数据并行转入模型并行的时候,对于激活值做一次 all_gather,得到全量的输入。

这种方式使用起来比较灵活,但是对使用者的要求比较高。对于参数的通信组的划分要做比较精巧的设计,同时对层与层的衔接也要有一个准确的判断,从而适当地插入通信。

References

  1. Optimizer state sharding (ZeRO): https://zhuanlan.zhihu.com/p/394064174
  2. InsightFace: https://github.com/deepinsight/insightface/tree/master/recognition
  3. Megatron: https://github.com/NVIDIA/Megatron-LM
  4. torchshard: https://github.com/KaiyuYue/torchshard
  5. Parallelformers https://github.com/tunib-ai/parallelformers
  6. fairscale: https://github.com/facebookresearch/fairscale
  7. Megatron 的论文: https://arxiv.org/abs/1909.08053

感谢阅读,欢迎在评论区留言讨论哦~

P.S. 如果喜欢本篇文章,请多多 点赞,让更多的人看见我们 :D

关注 公众号「SenseParrots」,获取人工智能框架前沿业界动态与技术思考。

展开阅读全文
  • 0
    感动
  • 0
    路过
  • 0
    高兴
  • 0
    难过
  • 0
    搞笑
  • 0
    无聊
  • 0
    愤怒
  • 0
    同情
热度排行
友情链接