pytoch SENet実現例

1624 ワード

余計なことを言わないで、コードを見てください。

from torch import nn

class SELayer(nn.Module):
 def __init__(self, channel, reduction=16):
  super(SELayer, self).__init__()

  //  1X1      ,     
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
  self.fc = nn.Sequential(
   nn.Linear(channel, channel // reduction, bias=False),
   nn.ReLU(inplace=True),
   nn.Linear(channel // reduction, channel, bias=False),
   nn.Sigmoid()
  )

 def forward(self, x):
  b, c, _, _ = x.size()

  //      ,batch channel         
  y = self.avg_pool(x).view(b, c)

  //    +  
  y = self.fc(y).view(b, c, 1, 1)

  //       
  return x * y.expand_as(x)
補足知識:pytouchはSE Blockを実現します。
論文のモジュール図

コード

import torch.nn as nn
class SE_Block(nn.Module):
 def __init__(self, ch_in, reduction=16):
  super(SE_Block, self).__init__()
  self.avg_pool = nn.AdaptiveAvgPool2d(1)				#        
  self.fc = nn.Sequential(
   nn.Linear(ch_in, ch_in // reduction, bias=False),
   nn.ReLU(inplace=True),
   nn.Linear(ch_in // reduction, ch_in, bias=False),
   nn.Sigmoid()
  )

 def forward(self, x):
  b, c, _, _ = x.size()
  y = self.avg_pool(x).view(b, c)
  y = self.fc(y).view(b, c, 1, 1)
  return x * y.expand_as(x)
今はSEの変形についてたくさんありますが、大体大同小異です。
以上のpytouch SENetの実現例は小編集が皆さんに提供した内容の全部です。参考にしてもらいたいです。どうぞよろしくお願いします。