GAN(Generative Adversarial Network)を使ったアルゴリズムの一つに「Pix2Pix」があります。Pix2PixはDCGANとは異なり画像から画像へ変換する仕組みになっており、学習時に入力と出力画像のペアが必要になるモノです。ペア画像を用意するのは面倒ですが、ペアをつくることでCycleGANなどのアルゴリズムに比べて鮮明に画像を生成できます。
今回はWindowsを用いてgitクローンしたPix2Pixの実装をやってみようと思います。
Pix2Pixについて
Pix2Pixは↑で説明した通り画像から画像へ変換するアルゴリズムです。アルゴリズムの中身はGANをベースとしており、通常のCAN
は入力に乱数を用いるのに対し、Pix2Pixは入力に画像データを用いることが異なります。詳細は論文を読んでいただくと良いと思いますが、↓のようなことができます。
セマンティックセグメンテーションモデル画像を実際の画像へ変換したり、白黒写真をカラーにしたり、線画に着色したり、アイデア次第で様々なことができるのがPix2Pixです。今回は、Pix2Pixでは定番の建築物画像の生成をやってみようと思います。
学習環境は↓
- Windows 10
- CPU : AMD Ryzen 7 5800
- GPU : RTX3070 8GB
GPUがないと学習時間がかなりかかります。その場合はgoogle colaboratoryの使用を検討してください。今回はローカル環境での実装方法について説明していきます。
Pix2Pixを実装していく
前準備
まず環境設定ですが↓の組み合わせで進めていきます。CUDA、cuDNNのインストールはクセがあります。このサイトに詳しく掲載されている通りにインストールしています。インストールしたバージョンは下記です。
- NVIDIA CUDA Toolkit 11.4
- NVIDIA cuDNN v8.2.0
コードは↓サイトのモノを使わせていただきました。Gitクローンもしくはダウンロードして準備してください。この辺りの操作はここのソースコードの準備に記載してあります。
TensorFlow1.x系(旧バージョン)を使ったモノが一般的とおもいますが、RTX3000系のグラフィックボードとCUDA10.0の相性が悪く、起動に時間がかかるため、今回はPyTorchで実装していきます。
学習データセット
今回は、建物の実際の写真と、窓やドアなどをセグメンテーションした画像のデータセットを学習していきます。画像をダウンロードし、学習可能な状態へ変換していきます。具体的には↓のように入力画像と出力画像を横に結合した状態の画像を準備します。
ここで簡単にデータセットを準備するためにLinux環境で下記のコードを入力すると一発で変換してくれます。Windows10でLinuxを動かす場合はwsl2(Windows Subsystem for Linux 2)をインストールし、Ubuntu上で実行することができます。ここを参考に設定することができます。※Linux環境を使うのはここだけです。Linux環境を使わなくてもできます。詳細は下記。
Linuxの準備が整いましたらpytorch-CycleGAN-and-pix2pixディレクトリまで移動し↓のコードを実行してください。
cd desktop/pytorch-CycleGAN-and-pix2pix #デスクトップへpytorch-CycleGAN-and-pix2pixをダウンロードした場合
bash ./datasets/download_pix2pix_dataset.sh facades
実行するとpytorch-CycleGAN-and-pix2pix中のdatasetsディレクトリにfacadesディレクトリが生成されると思います。facadesディレクトリの中にtest, train, valディレクトリが生成され、その中に画像データ(.jpg)が格納されていることを確認してください。ディレクトリ構成は↓
/datasets
-> facades
->test
->train
1.jpg
2.jpg
3.jpg
4.jpg
....
->val
- train : 400画像
- test : 106画像
- val : 100画像
※Linux環境を用いない場合でも入力と出力画像を並べて結合した画像を準備すればOKです。ここからデータをダウンロードし、OpenCVなどを使って画像を準備してください。画像は↑のディレクトリを作成し保存してください。
学習実行
AnacondaPromptやWindowsPowerShellで学習を実行していきます。Linuxでも学習できますが、Linux環境上でGPUを使う設定が難しいのでここからはWindowsを使います。
AnacondaPromptへ必要なライブラリをインストールしていきます。
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install dominate
pip install visdom
PyTorchはv1.8を使っていきます。CUDAのバージョンが違う場合や、CPUで実行(非推奨)する場合は、ここから必要なPyTorchのインストールコマンドを取得してください。
↓のコードを実行すると学習が始まりますので待ちです。
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
私の環境で3時間弱かかりました。。。最終epochのLossはこんな感じです。
(epoch: 200, iters: 400, time: 8.547, data: 0.000) G_GAN: 2.715 G_L1: 14.973 D_real: 0.037 D_fake: 0.109
Pix2Pix学習結果
結果を見ていきましょう。./checkpoints/facades_pix2pix/web/index.htmlをブラウザで開くことで各epochでの出力を確認することができます。
epochが進むにつれ画像が鮮明になっていくのが確認できると思います。epoch200でかなり鮮明な画像が得られることが確認できると思います。
モデルの重みが得られたので学習に未使用のデータを用い、モデルのテストをしてみます。コードは↓です。
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
良い感じに学習できていると思います。Pix2Pixは少数のデータでも本当に鮮明な画像が生成できるので面白いです。
調子に乗って↓の画像もテストしてみました。
結果はこちら↓
顔怖いです😅なぜか鼻が生成されているところが更に怖いです。ロゴはレンガ調の壁画っぽい感じが見られますが、やはり学習で似たようなデータが無い場合、うまく表現するのは難しいようです。
Pix2Pixは今回紹介したセグメンテーション画像だけでなく、色々な画像から目的の画像を生成できる技術なので本当に面白いです。学習に大変時間がかかるのがネックですが、お時間あれば是非挑戦してみてください👍
コメント
機械学習というものに今回初めて触れたのですが、この記事にしたがってpix2pixを動かしたら、facadesのデータセットで無事に学習ができました!
(私のところにはGPU環境がなかったので、Google Colaboratoryを使いました)
わかりやすい記事をありがとうございます!