您的位置:  首页 > 技术杂谈 > 正文

Swin-Unet最强分割网络

2022-06-09 12:00 https://my.oschina.net/u/3768341/blog/5535941 算法之名 次阅读 条评论

Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络

它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为[2,2,2,1],而不是Swin Transformer的[2,2,6,2]。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。

我们来看一下Patch Expanding的代码实现

from einops import rearrange
class PatchExpand(nn.Module):
    """
    块状扩充,尺寸翻倍,通道数减半
    """
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        """
        Args:
            input_resolution: 解码过程的feature map的宽高
            dim: frature map通道数
            dim_scale: 通道数扩充的倍数
            norm_layer: 通道方向归一化
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        # 通过全连接层来扩大通道数
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        # 先把通道数翻倍
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        # 将各个通道分开,再将所有通道拼成一个feature map
        # 增大了feature map的尺寸
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
        # 通道翻倍后再除以4,实际相当于通道数减半
        x = x.view(B, -1, C // 4)
        x = self.norm(x)

        return x

在编码器这边基本上跟Swin Transformer是一样的,我们重点来看解码器这边。它是使用BasicLayer_up类来对SwinTransformerBlock和Patch Expanding来进行搭配的。

class BasicLayer_up(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    一个BasicLayer_up包含偶数个SwinTransformerBlock和一个upsamele层(即Patch Expanding层)
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
        """
        Args:
            dim: feature map通道数
            input_resolution: feature map的宽高
            depth: 各个Stage中,Swin Transformer Block的数量
            num_heads: 多头注意力各个Stage中的头数
            window_size: 窗口自注意力机制的窗口中的patch数
            mlp_ratio: 层感知机模块中第一个全连接层输出的通道倍数
            qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
            qk_scale: 窗口自注意力公式常数
            drop: dropout rate,默认为0
            attn_drop: 用于自注意力机制中的dropout rate,默认为0
            drop_path: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
                       LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0
            norm_layer: 通道方向归一化
            upsample: 使用Patch Expanding来升采样
            use_checkpoint: 是否使用Pytorch中间数据保存机制
        """

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build SwinTransformerBlock
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 # 用于区分是使用W-MSA还是SW-MSA,0为W-MSA,1为SW-MSA
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        # 当stage=4的时候为None
        if upsample is not None:
            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
        else:
            self.upsample = None

    def forward(self, x):
        # 通过每一个SwinTransformerBlock
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        # 进行块状扩充(PatchExpanding)上采样
        if self.upsample is not None:
            x = self.upsample(x)
        return x

SwinTransformerBlock跟SwinTransformer中的代码也是一样的,这里就不重复了。

然后还有一个从编码器到解码器之间的跳连。这里需要看一下Swin-Unet的主类代码

class SwinTransformerSys(nn.Module):
    """ Swin-UNet网络模型
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, final_upsample="expand_first", **kwargs):
        """
        Args:
            img_size: 原始图像尺寸
            patch_size: 一个patch中的像素点数
            in_chans: 进入网络的图片通道数
            num_classes: 分类数量
            embed_dim: feature map通道数
            depths: 编码器各个Stage中,Swin Transformer Block的数量
            depths_decoder: 解码器各个Stage中,Swin Transformer Block的数量
            num_heads: 多头注意力各个Stage中的头数
            window_size: 窗口自注意力机制的窗口中的patch数
            mlp_ratio: 多层感知机模块中第一个全连接层输出的通道倍数
            qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
            qk_scale: 自注意力公式中的常量
            drop_rate: dropout rate,默认为0
            attn_drop_rate: 用于自注意力机制中的dropout rate,默认为0
            drop_path_rate: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
                            LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1
            norm_layer: 通道方向归一化
            ape: 是否进行绝对位置嵌入,默认False
            patch_norm: 如果是True的话,在patch embedding之后加上归一化
            use_checkpoint: 是否使用Pytorch中间数据保存机制
            final_upsample: 解码器stage4后的Patch Expanding
            **kwargs:
        """
        super().__init__()

        print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
              depths_decoder, drop_path_rate, num_classes))

        self.num_classes = num_classes
        # stage的数量
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        # 编码器stage4输出特征的通道数(Swin-Tiny:768)
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        # 解码器stage4输出特征的通道数(192)
        self.num_features_up = int(embed_dim * 2)
        self.mlp_ratio = mlp_ratio
        self.final_upsample = final_upsample

        # 把图像分割成不重叠的patch
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        # 获取feature map的高宽
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # 绝对位置嵌入
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # 不同的stage,舍弃整个直连分支的概率不同,从小到大,最小为0,最大为0.1
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # 创建编码器layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):  # layer相当于stage
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               # 只有前3个stage有patchmerging,最后一个没有
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)
        
        # 创建解码器layers
        self.layers_up = nn.ModuleList()
        self.concat_back_dim = nn.ModuleList()
        for i_layer in range(self.num_layers):  # layer相当于stage
            # 每一个stage结束后,通道数减半的全连接层
            concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)),
                                      int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
            if i_layer == 0:  # 第一个stage只进行上采样
                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
                                       patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
            else:
                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
                                         input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
                                                           patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
                                         depth=depths[(self.num_layers-1-i_layer)],
                                         num_heads=num_heads[(self.num_layers-1-i_layer)],
                                         window_size=window_size,
                                         mlp_ratio=self.mlp_ratio,
                                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                                         drop=drop_rate, attn_drop=attn_drop_rate,
                                         drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])],
                                         norm_layer=norm_layer,
                                         # 只有前3个stage有PatchExpand,最后一个没有
                                         upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
                                         use_checkpoint=use_checkpoint)
            self.layers_up.append(layer_up)
            self.concat_back_dim.append(concat_linear)

        self.norm = norm_layer(self.num_features)
        self.norm_up = norm_layer(self.embed_dim)
        # 解码器最后一个stage进行FinalPatchExpand处理
        if self.final_upsample == "expand_first":
            print("---final upsample expand_first---")
            self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim)
            self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)

        self.apply(self._init_weights)

这里有一个FinalPatchExpand_X4的方法,我们来看一下它的实现

class FinalPatchExpand_X4(nn.Module):
    """
    stage4之后的PatchExpand
    尺寸翻倍,通道数不变
    """
    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        """
        Args:
            input_resolution: feature map的宽高
            dim: feature map通道数
            dim_scale: 通道数扩充的倍数
            norm_layer: 通道方向归一化
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.dim_scale = dim_scale
        # 通过全连接层来扩大通道数
        self.expand = nn.Linear(dim, 16 * dim, bias=False)
        self.output_dim = dim 
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        # 先把通道数翻倍
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        # 将各个通道分开,再将所有通道拼成一个feature map
        # 增大了feature map的尺寸
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
        # 把扩大的通道数转成原来的通道数
        x = x.view(B, -1, self.output_dim)
        x = self.norm(x)

        return x

回到SwinTransformerSys代码中

def _init_weights(self, m):
    """
    对全连接层或者通道归一化进行权重以及偏置的初始化
    """
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

@torch.jit.ignore
def no_weight_decay(self):
    return {'absolute_pos_embed'}

@torch.jit.ignore
def no_weight_decay_keywords(self):
    return {'relative_position_bias_table'}

#Encoder and Bottleneck
def forward_features(self, x):
    """
    编码器过程
    """
    # 图像分割
    x = self.patch_embed(x)
    # 绝对位置嵌入
    if self.ape:
        x = x + self.absolute_pos_embed
    x = self.pos_drop(x)
    # 跳连点
    x_downsample = []
    # 通过各个编码过程的stage
    for layer in self.layers:
        x_downsample.append(x)
        x = layer(x)

    x = self.norm(x)  # B L C

    return x, x_downsample

#Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
    """
    解码器过程,包含了跳连拼接
    """
    # 通过各个解码过程的stage
    for inx, layer_up in enumerate(self.layers_up):
        if inx == 0:
            x = layer_up(x)
        else:
            # 拼接编码器的跳连部分再进入Swin Transformer Block
            x = torch.cat([x, x_downsample[3-inx]], -1)
            x = self.concat_back_dim[inx](x)
            x = layer_up(x)

    x = self.norm_up(x)  # B L C

    return x

def up_x4(self, x):
    """
    完成解码器的最后一个stage后进入
    """
    H, W = self.patches_resolution
    B, L, C = x.shape
    assert L == H * W, "input features has wrong size"

    if self.final_upsample == "expand_first":
        x = self.up(x)
        x = x.view(B, 4 * H, 4 * W, -1)
        x = x.permute(0, 3, 1, 2) #B,C,H,W
        x = self.output(x)
        
    return x

def forward(self, x):
    """
    前向运算
    """
    x, x_downsample = self.forward_features(x)
    x = self.forward_up_features(x, x_downsample)
    x = self.up_x4(x)

    return x

def flops(self):
    flops = 0
    flops += self.patch_embed.flops()
    for i, layer in enumerate(self.layers):
        flops += layer.flops()
    flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
    flops += self.num_features * self.num_classes
    return flops
展开阅读全文
  • 0
    感动
  • 0
    路过
  • 0
    高兴
  • 0
    难过
  • 0
    搞笑
  • 0
    无聊
  • 0
    愤怒
  • 0
    同情
热度排行
友情链接