ニューラルネットワークの話を前の記事で話ました。皆さん上司なり、先生なり、取引先なりにディープラーニング(ニューラルネットワーク)の精度を当然求められると思います。じつはニューラルネットワークにとって、“学習に使ったデータ”に対する精度を出すことは難しいことではありません。問題は未知のデータに対して同じような精度が出せるか?という話です。これは機械学習の世界では“汎化性能”と言われ、汎化性能を持つモデルをつくるためには交差検証(クロスバリデーション)が重要になってきます。
いろいろと用語がでてきましたが、もう少し用語がでます。。。
ニューラルネットワークの汎化性能を落としてしまう主な原因として”過学習“があげられます。クロスバリデーションはこの過学習を検出するのに役立つ技術ということになります。っということで、クロスバリデーションを説明する前に、過学習を説明していきます。
過学習とは
過学習を会社で説明するとき、わかりやすいと好評な例があります。「テスト」をイメージしてください。皆さん経験あると思いますが、テスト前の対策として「過去問」を解きますよね。過去問しか解かず山を張るなんて人も少なくないかと思います。過去問しか解かない状態が、いわゆる過学習の状態です。もう少し細かく説明すると、過去問の「問題」と「答え」の組み合わせしか覚えない状態のことです。
「問題」と「答え」の組み合わせしか覚えていない状態で試験に挑んだ経験ありますか?ちなみに私はあります。この状態で、いつも過去問から出題する先生が突然血迷って新規問題を出して来たらどうでしょうか。。。青ざめますよね。これこそが過学習です。過学習に陥ったニューラルネットワークモデルは新規問題で的外れな回答をする危険性を秘めています。
クロスバリデーションとは【過学習を回避せよ】
クロスバリデーションのやり方は、名前の割には難しくありません。ニューラルネットワークなどの学習前に、手持ちのデータをある一定の比率で分割しておくことで準備OKです。自分で分割しておいてもOKですが、scikit-learnライブラリをインストールしておけば一発で分けることができます。この辺のやり方は別の機会に紹介しようと思います。
データは3つに分解します。1つを訓練データ、もう1つのデータを検証データ、もう1つのデータをテストデータと言います。注)書籍によって言い回しは様々ですが、ここではこのような言い方をします。ニューラルネットワークは繰り返し学習を実施していくのですが、その過程で訓練データで学習したモデルを検証データで評価し、最終的なモデルをテストデータで評価する。これがクロスバリデーションの流れです。
ちょっと具体的に説明します。今とある学習をモデルをつくった際の学習曲線※を下図に示します。※↓のような縦軸Loss、横軸Epochsで表したグラフのこと。
訓練データ(実線)を見るとEpochsが進むにつれてLossが下がっていることがわかります。一方で、検証データ(破線)を見るとEpochsが進むにつれ最初はLossが低下しますが、ある時を境に増加へ転じており、最終的に訓練データと大きく乖離していることが分かります。最終的にこのモデルは訓練データでは高精度であるが、検証データでは精度がでていない状態:過学習の状態となっていることが、分かります。クロスバリデーション評価によって、このようなことが分かります。分割したテストデータが出てきませんが、次で説明します。
ちなみに過学習が起こってしまったときの対処法としては、
- データ数を増やす!👈まずはこれ
- 学習曲線で検証データのLossが増加する前に学習を切り上げる
などがあります。
ホールドアウト検証、k-分割交差検証
クロスバリデーションはホールドアウト検証と、k-分割交差検証の2つの方法があるので説明していきます。これらはイメージで開設するのが一番なので、↓を見てください。
ホールドアウト検証は学習前に予め訓練データ、検証データ、テストデータに分割しておき、そのまま固定しておく方法です。シンプルで分かりやすい反面、特にデータ数が少ない場合はデータ分割の偏りがでる可能性があります。そんな時はk-分割交差検証を実施します。k-分割交差検証は訓練データ、検証データの分割を複数回実施し、それぞれで学習・評価を実施する方法です。これにより分割の仕方による偏りの影響を軽減することが可能です。
最後にテストデータは学習に一切用いらないデータです。これを用いてモデルの最終的な評価を実施していきます。テストデータを用いた最終評価が一番と思いますが、実際にはテストデータを用いない論文(検証データで評価している)も多くあります。テストデータまで分割するか、あるいは検証データで評価するかはケースバイケースというのが実態です。特にデータが少ない場合はテストデータを省いた方が良いケースもあると思います。
まとめ
- クロスバリデーションとは、未知のデータに対してどのくらい精度が出せるか評価する手法であり、言い換えると過学習を検出する手法
- 過学習とは訓練データでは高精度であるが、検証データでは精度がでていない状態のこと
- クロスバリデーションは予めデータを訓練データ、検証データ、テストデータへ分割しておき、訓練データと検証データでモデルをつくっていき、テストデータで評価する手法
- クロスバリデーションはホールドアウト検証と、k-分割交差検証の2つの方法がある
コメント