ResNet18の構造をやんわりと理解する
最近ResNet18を使用する機会があり、かつ構造を理解することを迫られたので、色々調べていると以下の表にたどり着きました。
- Deep Residual Learning for Image Recognitionより引用
なるほどなるほどこうなっているのね...
入力がこれで出力がこれで...???いっちょんわからん...
7×7や3×3の隣にある数字は何?計算してもoutput sizeにならない...そもそも各層がどのようにつながっているのかわからない...
と多くの疑問を持ったので色々調べたことからこの表を読み解いていこうと思います。
目次
そもそもResNetとは?
ResNetがCNNの一つであるというのはconvやらpoolやらが前出の表に出てきていることからもお分かりかと思います。まずCNNをよくわかっていないという方はこちら の記事がわかりやすかったので読むことをお勧めします。
本来CNNは理論的には層が深いほどより高次の特徴を捉えることができ精度が高まるとされています。しかし層を深くすると学習が進まないという問題 (学習が進まないといえば勾配消失問題が思い浮かびますが、これはdegradation problemらしい) があり、それを解決したのが残差(Residual)を足し合わせる処理があるResidual Network通称ResNetです。具体的には
input→
↓
畳み込み層
↓
畳み込み層
↓
output + input ←
のように畳み込み層を通った値と通る前の値を足し合わせます。この外側のinputの流れをshortcutと呼び、表の
のようなブロックごとにshortcutが作られています。
表を読み解く
表をよく見ると基本的にはデータの流れが分かりにくいために構造が理解できなかったと思います。
そのため表を入出力のデータサイズに着目して上から読み解いていこうと思いますが、
- relu,batchnormなどの直接データサイズに関係ないものについては言及しない
- shortcutについても触れない
ので留意してください。
なおResNetはImageNetを前提にしているので入力サイズは224*224です。(余談ですがpytorchではnn.AdaptiveAvgPool2d((1, 1))を全結合層の前に入れることによって画像のサイズによらず学習ができるようになっているらしいです。)
conv1
224*224の入力が入ると7*7のカーネル(フィルタ)、stride 2で畳み込まれます。この時入力値を上下左右3ずつpaddingするので230*230のデータを畳み込んでると考えて問題ないです。すると1~7マス目,3~9マス目...223~229マス目が畳み込まれ結果として112*112の出力が得られます。この操作を64(channel)枚のカーネルで行うため結果112*112*64のサイズの出力が得られます。
簡潔にサイズの遷移をまとめると
224*224*1
↓
(230*230*1)
↓
112*112*64
となりました。
conv2_x
この層では最初にpooling層があります。3*3のカーネルサイズ、stride 2,padding 1なので実質114*114から縦横1~3マス目,3~9マス目...111~113マス目の範囲の最大値が抽出され結果56*56*64の出力が得られます。
次に、畳み込み層があり3*3のカーネルサイズ、stride 1,padding 1で畳み込まれます。今までのように実質58*58から縦横1~3マス目...55~57マス目の範囲が畳み込まれ、結局56*56*64の出力が得られます。
このように3*3のカーネルサイズ、stride 1,padding 1で畳み込むと入力サイズと同様な出力が得られるので今後説明は割愛します。そして、ここではこのような層が後三層続きます。
この層でも同じようにサイズの遷移をまとめると
112*112*64
↓
(114*114*64)
↓
56*56*64
↓
...
↓
56*56*64
となりました。
conv3_x
以下畳み込み層が続きますが、表の説明をよく読むと各層の先頭の畳み込み層だけstrideが2であることがわかります。
- Deep Residual Learning for Image Recognitionより引用
また、64から128に横の数字が増えていますが、これがカーネルの枚数だと辻褄が合わないことから出力のチャネル数だということが推測でき、pytorchの実装コードを読むことで確信に変わると思います。
入力のチャネル数と出力のチャネル数が変わってもいいの??
とイメージがわかない方はこちらを読むことでイメージの助けになると思います。
この層の先頭畳み込み層では3*3のカーネルサイズ、stride 2,padding 1出力チャネル128なので実質58*58から縦横1~3マス目,3~9マス目...55~57マス目の範囲の最大値が抽出され結果28*26*128の出力が得られます。
以下三つの畳み込み層はconv2_xの最後三層と出力チャネル以外同じなので省略してサイズの遷移をまとめると
56*56*64
↓
(58*58*64)
↓
28*28*128
↓
...
↓
28*28*128
conv4_x
conv2_xと同じなので省略してまとめると
28*28*128
↓
...
14*14*256
conv5_x
同上
14*14*256
↓
...
7*7*512
avg pool及びcf(全結合層)
avg pool(Average pooling)層で7*7のそれぞれの要素の平均を取って1*1*512にしています。また、全結合層で1*1*1000にしています。1000はImageNetのクラス数です。
まとめると
7*7*512
↓
1*1*512
↓
1*1*1000
よってResNet18のすべての工程が読み解けました!!
終わりに
もともとは表のすべてを解説しようと思っていたのですが面d...時間がないので辞めました。ResNet34はResNet18とほぼ同じ構成、ResNet50以上はソースコードのBottleneckクラスあたりを読めば解読できると思います。
参考文献
Author And Source
この問題について(ResNet18の構造をやんわりと理解する), 我々は、より多くの情報をここで見つけました https://qiita.com/teacat/items/d6b24fb5353872f6b3a3著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .