CNNモデル”ResNet”の解説

AI基礎

旬ではありませんが、畳み込みニューラルネットワーク(CNN)モデルのResNetについて説明をしていこうと思います。経緯はPCを買い替えたんですが、GPU環境を設定がてらResNetモデルコードを書いたので備忘録として書き残しておこうと思います。プログラミング素人の方でも、各コードで何が動いているか?できるだけ細かく書いていこうと思います。

  • AI・機械学習を勉強したいけど、何からやればよいか分からない
  • 書籍が難しすぎて挫折した
  • これからデータサイエンティストを目指したい
  • CNNの中の動きを知りたい

今回はモデルの説明になります。ResNetを速攻でやってみたい人は↓を参照ください。

ResNetモデル論文の解釈

ResNetは言わずもがな超有名なCNNを代表するモデルの一つですが、あえて説明させてください。

  • 2015年のILSVRCの優勝モデル
  • Skip Connection構造を用い100層を超えるモデルでの高精度化を実現した
  • ResNetをベースとした派生型が現在でも多く研究されており高性能・高速で話題

まとめるとこんな感じですが、なんといってもSkip Connection構造(Residualモジュール)が特徴的です。ResNetの論文Deep Residual Learning for Image Recognitionを基に中身をみていきましょう。

CNNの層を深くすると表現力が向上し、精度が上がるといわれていましたが、単純に層を増やしただけではある一定の層の深さを超えると、それ以降悪化していくという実験結果があります。さらにこれは勾配消失や、過学習の影響ではなくdegradation問題とされています。↓のグラフを参照してください。

出典:Deep Residual Learning for Image Recognition, Kaiming Heら

ResNetの論文↑図より、20層モデルより56層モデルの方が誤差が大きいことが分かります。層を深くすると発生するこのようなdegradation問題を解決するためにSkip Connection構造(論文中ではshortcut connections)を使います。Skip Connection構造を簡単に表すと↓のようになります。

出典:Deep Residual Learning for Image Recognition, Kaiming Heら

図はSkip Connection構造を伴う1つのブロック(Residual Block)をしめしており、weight layer(畳み込み層)の出力値に、2層前の出力を足し合わせるという構造を取っています。この効果としては、より浅い層の出力を取り入れることで、浅い層より誤差が大きくならないように学習が進むという理論です。このようなResidual Blockを積み重ねていくモデルがResNetです。

ResNetコードの解説

ResNetを実装する場合、特にモデルの層数などをいじる必要がなければTensorFlowなどにあらかじめ学習されているモデルを使用する方法が簡単で良いですが、それでは中身を理解できないので、あえてコードを書いていきます。今回はTensorFlowでResNetモデルをつくっていきます。ここのkoshianさんのコード書き方が分かりやすかったので参考にさせていただきました。コメントアウトで説明付けますので参考にしてください。CNNの基礎用語など分からない方、不安な方は別の記事で説明していますので、そちらも見ながらだと理解しやすいと思います。

コードの構成は、def residual_blockで畳み込み2層+Skip Connectionを定義しており、def create_resnetでresidual_blockを積み重ねて全体のモデルを構築していく構成となっています。

#とりあえずTensorFlowの呼出し
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.layers as layers


#Residual Blockの定義
def residual_block(inputs, ch, strides):
    # main path
    #(Batch Normalization ⇒ ReLU ⇒ 畳み込み)×2セット
    x = layers.BatchNormalization()(inputs) 
    x = layers.ReLU()(x)
    x = layers.Conv2D(ch, 3, strides=strides, padding="same", 
                      kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x) # Conv2Dに1e-4のL2正則化を入れます(論文より)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(ch, 3, padding="same", kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    
    # shortcut path
    # inputデータ(2層畳み込む前のデータ)を呼び出す
    if inputs.shape[-1] != ch or strides > 1:
        s = layers.Conv2D(ch, 3, strides=strides, padding="same", 
                          kernel_regularizer=tf.keras.regularizers.l2(1e-4))(inputs)
  else:
        s = inputs

    # add
    # inputデータと2層畳み込んだデータを足す
    x = layers.Add()([x, s])
    return x


#ResNetモデルの作成
def create_resnet():
    # inputは(縦pixel, 横pixel, チャンネル数)なので入力画像へあった形へ調整
    inputs = layers.Input((32, 32, 3))
    # ResNet論文に合わせ畳み込みカーネル(フィルタ)数を16, 32, 64に設定。最初の層で16に設定
    x = layers.Conv2D(16, 3, padding="same")(inputs)
    # フィルタ数を増やしながらモデルを作成する。リストの中身がフィルタ数。
    for ch in [16, 32, 64]:
    # SkipConnectionを何回入れるか設定。今回は7回に設定。
        for i in range(7):
            # stridesの設定
            strides = 2 if i == 0 else 1
            if ch == 16:
                strides = 1
            # 上で定義したresidual blockの呼出し
            x = residual_block(x, ch, strides)
    # 最後にBatch Normalization ⇒ ReLU ⇒ プーリング層 ⇒ 全結合層を設定しモデル完成
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    # pooling層を追加。
    x = layers.GlobalAveragePooling2D()(x)
    # 全結合層を追加。活性化関数は分類問題の場合softmax
    x = layers.Dense(10, activation="softmax")(x)
    
    return tf.keras.models.Model(inputs, x)

モデル作成の定義は以上です。↓のコードを入れるとモデルを確認することができます。

model = create_resnet()
model.summary()

↑コードの結果はかなり長いので割愛します。

まとめ

  • ResNetはSkip Connection構造を用い100層を超えるモデルでの高精度化を実現した
  • Skip Connectionは入力データを2層畳み込んだ出力へ、畳み込む前の出力値を加える動きをしている
  • Skip Connection構造をとることで浅い層より誤差が大きくならないように学習が進む
  • Skip Connection構造を重ね合わせたモデルがResNet

コメント

タイトルとURLをコピーしました