歡迎您光臨本站 註冊首頁

使用pytorch實現論文中的unet網絡

←手機掃碼閱讀     e36605 @ 2020-06-26 , reply:0

設計神經網絡的一般步驟:

1. 設計框架

2. 設計骨幹網絡

Unet網絡設計的步驟:

1. 設計Unet網絡工廠模式

2. 設計編解碼結構

3. 設計卷積模塊

4. unet實例模塊

Unet網絡最重要的特徵:

1. 編解碼結構。

2. 解碼結構,比FCN更加完善,採用連接方式。

3. 本質是一個框架,編碼部分可以使用很多圖像分類網絡。

示例代碼:

  import torch  import torch.nn as nn    class Unet(nn.Module):   #初始化參數:Encoder,Decoder,bridge   #bridge默認值為無,如果有參數傳入,則用該參數替換None   def __init__(self,Encoder,Decoder,bridge = None):    super(Unet,self).__init__()    self.encoder = Encoder(encoder_blocks)    self.decoder = Decoder(decoder_blocks)    self.bridge = bridge   def forward(self,x):    res = self.encoder(x)    out,skip = res[0],res[1,:]    if bridge is not None:     out = bridge(out)    out = self.decoder(out,skip)    return out  #設計編碼模塊  class Encoder(nn.Module):   def __init__(self,blocks):    super(Encoder,self).__init__()    #assert:斷言函數,避免出現參數錯誤    assert len(blocks) > 0    #nn.Modulelist():模型列表,所有的參數可以納入網絡,但是沒有forward函數    self.blocks = nn.Modulelist(blocks)   def forward(self,x):    skip = []    for i in range(len(self.blocks) - 1):     x = self.blocks[i](x)     skip.append(x)    res = [self.block[i+1](x)]    #列表之間可以通過+號拼接    res += skip    return res  #設計Decoder模塊  class Decoder(nn.Module):   def __init__(self,blocks):    super(Decoder, self).__init__()    assert len(blocks) > 0    self.blocks = nn.Modulelist(blocks)   def ceter_crop(self,skips,x):    _,_,height1,width1 = skips.shape()    _,_,height2,width2 = x.shape()    #對圖像進行剪切處理,拼接的時候保持對應size參數一致    ht,wt = min(height1,height2),min(width1,width2)    dh1 = (height1 - height2)//2 if height1 > height2 else 0    dw1 = (width1 - width2)//2 if width1 > width2 else 0    dh2 = (height2 - height1)//2 if height2 > height1 else 0    dw2 = (width2 - width1)//2 if width2 > width1 else 0    return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],      x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]     def forward(self, skips,x,reverse_skips = True):    assert len(skips) == len(blocks) - 1    if reverse_skips is True:     skips = skips[: : -1]    x = self.blocks[0](x)    for i in range(1, len(self.blocks)):     skip = skips[i-1]     x = torch.cat(skip,x,1)     x = self.blocks[i](x)    return x  #定義了一個卷積block  def unet_convs(in_channels,out_channels,padding = 0):   #nn.Sequential:與Modulelist相比,包含了forward函數   return nn.Sequential(    nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),    nn.BatchNorm2d(outchannels),    nn.ReLU(inplace = True),    nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),    nn.BatchNorm2d(outchannels),    nn.ReLU(inplace=True),   )  #實例化Unet模型  def unet(in_channels,out_channels):   encoder_blocks = [unet_convs(in_channels, 64),        nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),           unet_convs(64,128)),         nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True),            unet_convs(128, 256)),        nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True),            unet_convs(256, 512)),        ]   bridge = nn.Sequential(unet_convs(512, 1024))   decoder_blocks = [nn.conTranpose2d(1024, 512),         nn.Sequential(unet_convs(1024, 512),           nn.conTranpose2d(512, 256)),        nn.Sequential(unet_convs(512, 256),           nn.conTranpose2d(256, 128)),         nn.Sequential(unet_convs(512, 256),           nn.conTranpose2d(256, 128)),         nn.Sequential(unet_convs(256, 128),           nn.conTranpose2d(128, 64))        ]   return Unet(encoder_blocks,decoder_blocks,bridge)

 

補充知識:Pytorch搭建U-Net網絡

U-Net: Convolutional Networks for Biomedical Image Segmentation

  import torch.nn as nn  import torch  from torch import autograd  from torchsummary import summary    class DoubleConv(nn.Module):   def __init__(self, in_ch, out_ch):    super(DoubleConv, self).__init__()    self.conv = nn.Sequential(     nn.Conv2d(in_ch, out_ch, 3, padding=0),     nn.BatchNorm2d(out_ch),     nn.ReLU(inplace=True),     nn.Conv2d(out_ch, out_ch, 3, padding=0),     nn.BatchNorm2d(out_ch),     nn.ReLU(inplace=True)    )     def forward(self, input):    return self.conv(input)    class Unet(nn.Module):   def __init__(self, in_ch, out_ch):    super(Unet, self).__init__()    self.conv1 = DoubleConv(in_ch, 64)    self.pool1 = nn.MaxPool2d(2)    self.conv2 = DoubleConv(64, 128)    self.pool2 = nn.MaxPool2d(2)    self.conv3 = DoubleConv(128, 256)    self.pool3 = nn.MaxPool2d(2)    self.conv4 = DoubleConv(256, 512)    self.pool4 = nn.MaxPool2d(2)    self.conv5 = DoubleConv(512, 1024)    # 逆卷積,也可以使用上採樣    self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)    self.conv6 = DoubleConv(1024, 512)    self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)    self.conv7 = DoubleConv(512, 256)    self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)    self.conv8 = DoubleConv(256, 128)    self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)    self.conv9 = DoubleConv(128, 64)    self.conv10 = nn.Conv2d(64, out_ch, 1)     def forward(self, x):    c1 = self.conv1(x)    crop1 = c1[:,:,88:480,88:480]    p1 = self.pool1(c1)    c2 = self.conv2(p1)    crop2 = c2[:,:,40:240,40:240]    p2 = self.pool2(c2)    c3 = self.conv3(p2)    crop3 = c3[:,:,16:120,16:120]    p3 = self.pool3(c3)    c4 = self.conv4(p3)    crop4 = c4[:,:,4:60,4:60]    p4 = self.pool4(c4)    c5 = self.conv5(p4)    up_6 = self.up6(c5)    merge6 = torch.cat([up_6, crop4], dim=1)    c6 = self.conv6(merge6)    up_7 = self.up7(c6)    merge7 = torch.cat([up_7, crop3], dim=1)    c7 = self.conv7(merge7)    up_8 = self.up8(c7)    merge8 = torch.cat([up_8, crop2], dim=1)    c8 = self.conv8(merge8)    up_9 = self.up9(c8)    merge9 = torch.cat([up_9, crop1], dim=1)    c9 = self.conv9(merge9)    c10 = self.conv10(c9)    out = nn.Sigmoid()(c10)    return out    if __name__=="__main__":   test_input=torch.rand(1, 1, 572, 572)   model=Unet(in_ch=1, out_ch=2)   summary(model, (1,572,572))   ouput=model(test_input)   print(ouput.size())

 

  ----------------------------------------------------------------    Layer (type)    Output Shape   Param #  ================================================================     Conv2d-1   [-1, 64, 570, 570]    640    BatchNorm2d-2   [-1, 64, 570, 570]    128      ReLU-3   [-1, 64, 570, 570]    0     Conv2d-4   [-1, 64, 568, 568]   36,928    BatchNorm2d-5   [-1, 64, 568, 568]    128      ReLU-6   [-1, 64, 568, 568]    0    DoubleConv-7   [-1, 64, 568, 568]    0     MaxPool2d-8   [-1, 64, 284, 284]    0     Conv2d-9  [-1, 128, 282, 282]   73,856    BatchNorm2d-10  [-1, 128, 282, 282]    256      ReLU-11  [-1, 128, 282, 282]    0     Conv2d-12  [-1, 128, 280, 280]   147,584    BatchNorm2d-13  [-1, 128, 280, 280]    256      ReLU-14  [-1, 128, 280, 280]    0    DoubleConv-15  [-1, 128, 280, 280]    0    MaxPool2d-16  [-1, 128, 140, 140]    0     Conv2d-17  [-1, 256, 138, 138]   295,168    BatchNorm2d-18  [-1, 256, 138, 138]    512      ReLU-19  [-1, 256, 138, 138]    0     Conv2d-20  [-1, 256, 136, 136]   590,080    BatchNorm2d-21  [-1, 256, 136, 136]    512      ReLU-22  [-1, 256, 136, 136]    0    DoubleConv-23  [-1, 256, 136, 136]    0    MaxPool2d-24   [-1, 256, 68, 68]    0     Conv2d-25   [-1, 512, 66, 66]  1,180,160    BatchNorm2d-26   [-1, 512, 66, 66]   1,024      ReLU-27   [-1, 512, 66, 66]    0     Conv2d-28   [-1, 512, 64, 64]  2,359,808    BatchNorm2d-29   [-1, 512, 64, 64]   1,024      ReLU-30   [-1, 512, 64, 64]    0    DoubleConv-31   [-1, 512, 64, 64]    0    MaxPool2d-32   [-1, 512, 32, 32]    0     Conv2d-33   [-1, 1024, 30, 30]  4,719,616    BatchNorm2d-34   [-1, 1024, 30, 30]   2,048      ReLU-35   [-1, 1024, 30, 30]    0     Conv2d-36   [-1, 1024, 28, 28]  9,438,208    BatchNorm2d-37   [-1, 1024, 28, 28]   2,048      ReLU-38   [-1, 1024, 28, 28]    0    DoubleConv-39   [-1, 1024, 28, 28]    0   ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664     Conv2d-41   [-1, 512, 54, 54]  4,719,104    BatchNorm2d-42   [-1, 512, 54, 54]   1,024      ReLU-43   [-1, 512, 54, 54]    0     Conv2d-44   [-1, 512, 52, 52]  2,359,808    BatchNorm2d-45   [-1, 512, 52, 52]   1,024      ReLU-46   [-1, 512, 52, 52]    0    DoubleConv-47   [-1, 512, 52, 52]    0   ConvTranspose2d-48  [-1, 256, 104, 104]   524,544     Conv2d-49  [-1, 256, 102, 102]  1,179,904    BatchNorm2d-50  [-1, 256, 102, 102]    512      ReLU-51  [-1, 256, 102, 102]    0     Conv2d-52  [-1, 256, 100, 100]   590,080    BatchNorm2d-53  [-1, 256, 100, 100]    512      ReLU-54  [-1, 256, 100, 100]    0    DoubleConv-55  [-1, 256, 100, 100]    0   ConvTranspose2d-56  [-1, 128, 200, 200]   131,200     Conv2d-57  [-1, 128, 198, 198]   295,040    BatchNorm2d-58  [-1, 128, 198, 198]    256      ReLU-59  [-1, 128, 198, 198]    0     Conv2d-60  [-1, 128, 196, 196]   147,584    BatchNorm2d-61  [-1, 128, 196, 196]    256      ReLU-62  [-1, 128, 196, 196]    0    DoubleConv-63  [-1, 128, 196, 196]    0   ConvTranspose2d-64   [-1, 64, 392, 392]   32,832     Conv2d-65   [-1, 64, 390, 390]   73,792    BatchNorm2d-66   [-1, 64, 390, 390]    128      ReLU-67   [-1, 64, 390, 390]    0     Conv2d-68   [-1, 64, 388, 388]   36,928    BatchNorm2d-69   [-1, 64, 388, 388]    128      ReLU-70   [-1, 64, 388, 388]    0    DoubleConv-71   [-1, 64, 388, 388]    0     Conv2d-72   [-1, 2, 388, 388]    130  ================================================================  Total params: 31,042,434  Trainable params: 31,042,434  Non-trainable params: 0  ----------------------------------------------------------------  Input size (MB): 1.25  Forward/backward pass size (MB): 3280.59  Params size (MB): 118.42  Estimated Total Size (MB): 3400.26  ----------------------------------------------------------------  torch.Size([1, 2, 388, 388])

 

 

   


[e36605 ] 使用pytorch實現論文中的unet網絡已經有491次圍觀

http://coctec.com/docs/python/shhow-post-239773.html