IE盒子

搜索
查看: 176|回复: 16

U-Net网络原理分析与pytorch实现

[复制链接]

2

主题

7

帖子

10

积分

新手上路

Rank: 1

积分
10
发表于 2022-12-5 15:43:01 | 显示全部楼层 |阅读模式
笔记:U-Net论文阅读与pytorch实现
原文地址:https://arxiv.org/pdf/1505.04597v1.pdf
U-Net最初应用于医学图像分割上,其后在其它的领域得到了广泛的应用。由于自己对医学图像分割并不了解,所以本次笔记主要是记录关于如何把它应用到图像去雾、去噪、去模糊、超分辨率或者去雨方向上。

笔包含以下3个部分:
(1)U-Net网络结构与提出背景
(2)优点与创新性
(3)pytorch实现U-Net
1、U-Net网络结构与提出背景
U-Net之前图像分割还有一篇经典的FCN网络(全卷积网络,Fully convolutional networks for semantic segmentation),U-Net扩展了FCN使其效果更好并仅仅需要少量的标注数据。其改进包含:通过添加更多的通道数使得网络的上下文信息能流动到更高分辨率的层;pooling操作被上采样操作替代(作者是这么说的,但是自己不确定是不是理解错了)。
按照论文中插图的脚注:蓝色的矩形条表示特征图,矩形图上面的数字是通道数,矩形图侧面的数字是x-y(特征图长和宽)。
从左至右分析整张图:


(1)首先输入的image tile为572*572的单通道图片,随后通过连续的2次卷积(蓝色箭头)变为568*568的64通道特征图(个人猜测,那个时候DL框架还没这么成熟,Padding还不流行,所以作者没有进行边界补0,因此卷积以后长和宽减小)。
(2)左上角第一个红色的箭头表示max pooling操作,会将特征图长宽降低为原来的一半。这里的尺寸变换为从(568,568,64)到(284,284,64),随后的两个卷积层将特征图通道数增加为128。
(3)左半部分的其它操作和(1)(2)分析方法相同
(4)中间处的特征图通道数为1024,随后通过的up-conv(反卷积或者上采样)增加特征图的长和宽。
(5)灰色箭头表示将左边的特征图“复制”到右边的特征图,对其方式为通道。例如,图左半部分的尺寸为(280, 280, 128)的特征图通过第二条灰色箭头,连接到右边(200, 200, 128)的特征图,得到的尺寸为(200, 200, 256)。(猜测:由于没用到padding造成卷积过程中尺寸减小,所以需要把280*280的特征图裁剪为200*200的,论文中说“ The cropping is necessary due to the loss of border pixels in every convolution. ”应该就是这个意思)
(6)重复(5)的分析,可以得出最终输出的图像尺寸为(388, 388,64)
(7)最后的浅绿色箭头表示1x1的卷积,见通道数减少到2。

2、网络优点与创新性
下面关于U-Net的优点不完全是论文中指出的,有一部分是自己认为的。
(1)适用于小规模的数据集。这一点主要还是针对于医学数据来说的,对于图像分类任务或者去噪之类的任务数据集还是很充足的。
(2)不使用全连接层。搭建网络时,全连接层的应用始终受限,主要是由于其参数过多。
假设输入是一张尺寸为(224,224,3)的彩色图片,并假设期望输出的特征图尺寸为(224, 224, 64)。如果采用全连接Linear,那么输入特征数量为224*224*3=150528,输出特征尺寸为224*224*64=3211264,参数的数量为150528*3211264=483,385,147,392,这甚至比很多大型网络参数都多;而如果使用卷积Conv(假设用3x3的卷积核),那么需要的卷积核为64个3x3x3的卷积核,总参数数量为64*3*3*3=1728,所以相比于全连接,卷积层大幅度减少了网络的参数数量。

3、pytorch实现U-Net
由于对论文中采用的crop并不理解,原始论文的caffe代码读不懂,而且网上复现的代码又都不太一样,所以主要还是结合自己的理解和所学对其进行复现。当然,知道了卷积之类的具体计算方式,复现起来就很容易了。
总的来说,复现中有以下几个疑惑:
(1)crop是否由于没有padding造成左边浅层特征图复制到深层时尺寸较大,需要剪切掉多出的边界
(2)特征图复用是否为目前普遍采用的通道跳跃连接,比如DenseNet提供的跳跃连接
(3)是不是需要进行补0操作,不知道对边界进行补零会不会造成负面影响
(4)原论文输出图像尺寸为(388,388,2),自己的输出和输入尺寸相同(通道数也都是3)
(5)原文中没用BN,是不是由于数据很少
3.1 连续的两个卷积块


两个蓝色箭头(Conv+ReLU两次)在网络图中出现了很多次,通过单独的封装可以减少很多后期的代码量:
class ConvBlock(nn.Module):
    """ implement conv+ReLU two times """
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        conv_relu = []
        conv_relu.append(nn.Conv2d(in_channels=in_channels, out_channels=middle_channels,
                                   kernel_size=3, padding=1, stride=1))
        conv_relu.append(nn.ReLU())
        conv_relu.append(nn.Conv2d(in_channels=middle_channels, out_channels=out_channels,
                                   kernel_size=3, padding=1, stride=1))
        conv_relu.append(nn.ReLU())
        self.conv_ReLU = nn.Sequential(*conv_relu)
    def forward(self, x):
        out = self.conv_ReLU(x)
        return out3.2 完整的网络结构
自己实现的代码可读性并不强,所以需要仔细备注方便以后使用。
(1)由于U-Net的左右特征图需要进行通道上的拼接,所以下面的代码中需要多次使用torch.cat()函数进行拼接
(2)以left开头的属性表示左边的卷积操作,pool函数表示左边的池化操作
(3)以right开头的属性表示右边的卷积操作(右边的卷积输入为左边对应位置特征图和前一层特征图的cat结果),de开头表示反卷积
代码如下:
class U_Net(nn.Module):
    def __init__(self):
        super().__init__()

        # 首先定义左半部分网络
        # left_conv_1 表示连续的两个(卷积+激活)
        # 随后进行最大池化
        self.left_conv_1 = ConvBlock(in_channels=3, middle_channels=64, out_channels=64)
        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_2 = ConvBlock(in_channels=64, middle_channels=128, out_channels=128)
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_3 = ConvBlock(in_channels=128, middle_channels=256, out_channels=256)
        self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_4 = ConvBlock(in_channels=256, middle_channels=512, out_channels=512)
        self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_5 = ConvBlock(in_channels=512, middle_channels=1024, out_channels=1024)

        # 定义右半部分网络
        self.deconv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.right_conv_1 = ConvBlock(in_channels=1024, middle_channels=512, out_channels=512)

        self.deconv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.right_conv_2 = ConvBlock(in_channels=512, middle_channels=256, out_channels=256)

        self.deconv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=2 ,output_padding=1)
        self.right_conv_3 = ConvBlock(in_channels=256, middle_channels=128, out_channels=128)

        self.deconv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1)
        self.right_conv_4 = ConvBlock(in_channels=128, middle_channels=64, out_channels=64)
        # 最后是1x1的卷积,用于将通道数化为3
        self.right_conv_5 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        # 1:进行编码过程
        feature_1 = self.left_conv_1(x)
        feature_1_pool = self.pool_1(feature_1)

        feature_2 = self.left_conv_2(feature_1_pool)
        feature_2_pool = self.pool_2(feature_2)

        feature_3 = self.left_conv_3(feature_2_pool)
        feature_3_pool = self.pool_3(feature_3)

        feature_4 = self.left_conv_4(feature_3_pool)
        feature_4_pool = self.pool_4(feature_4)

        feature_5 = self.left_conv_5(feature_4_pool)

        # 2:进行解码过程
        de_feature_1 = self.deconv_1(feature_5)
        # 特征拼接
        temp = torch.cat((feature_4, de_feature_1), dim=1)
        de_feature_1_conv = self.right_conv_1(temp)

        de_feature_2 = self.deconv_2(de_feature_1_conv)
        temp = torch.cat((feature_3, de_feature_2), dim=1)
        de_feature_2_conv = self.right_conv_2(temp)

        de_feature_3 = self.deconv_3(de_feature_2_conv)

        temp = torch.cat((feature_2, de_feature_3), dim=1)
        de_feature_3_conv = self.right_conv_3(temp)

        de_feature_4 = self.deconv_4(de_feature_3_conv)
        temp = torch.cat((feature_1, de_feature_4), dim=1)
        de_feature_4_conv = self.right_conv_4(temp)

        out = self.right_conv_5(de_feature_4_conv)

        return out测试网络输入和输出的尺寸是否一致:
if __name__ == "__main__":
    x = torch.rand(size=(8, 3, 224, 224))
    net = U_Net()
    out = net(x)
    print(out.size())
    print("ok")
回复

使用道具 举报

1

主题

7

帖子

9

积分

新手上路

Rank: 1

积分
9
发表于 2022-12-5 15:43:07 | 显示全部楼层
有完整的分割图片吗
回复

使用道具 举报

1

主题

6

帖子

8

积分

新手上路

Rank: 1

积分
8
发表于 2022-12-5 15:43:16 | 显示全部楼层
ConvTranspose2d的kenel_size,stride,padding怎么设置啊?由公式推出来吗
回复

使用道具 举报

1

主题

6

帖子

8

积分

新手上路

Rank: 1

积分
8
发表于 2022-12-5 15:43:54 | 显示全部楼层
是的,就是反卷积计算公式
回复

使用道具 举报

1

主题

7

帖子

4

积分

新手上路

Rank: 1

积分
4
发表于 2022-12-5 15:44:26 | 显示全部楼层
大佬,我想请问下unet里面的特征融合的意义在哪里,
还有就是特征融合时,是直接element-wise plus 效果好还是cat的效果好
回复

使用道具 举报

1

主题

9

帖子

15

积分

新手上路

Rank: 1

积分
15
发表于 2022-12-5 15:45:00 | 显示全部楼层
写的很好!感谢分享代码!
回复

使用道具 举报

2

主题

8

帖子

15

积分

新手上路

Rank: 1

积分
15
发表于 2022-12-5 15:45:34 | 显示全部楼层
不过backbone网络是什么呢?vgg16?
回复

使用道具 举报

4

主题

7

帖子

15

积分

新手上路

Rank: 1

积分
15
发表于 2022-12-5 15:46:28 | 显示全部楼层
小白问题,复现需要原始数据集吗?
回复

使用道具 举报

1

主题

7

帖子

7

积分

新手上路

Rank: 1

积分
7
发表于 2022-12-5 15:46:46 | 显示全部楼层
您好,cat之前不需要所说的crop保持尺寸一致的操作吗
回复

使用道具 举报

1

主题

4

帖子

6

积分

新手上路

Rank: 1

积分
6
发表于 2022-12-5 15:47:33 | 显示全部楼层
您好还有,图片tensor的shape是(n,h,w),两张图片进行cat操作时dim=1,是在h维度cat的吗,U_net中torch.cat后实现的应该是通道的n的叠加吧,我脑子有点乱5555
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表