Albumentationsで拡張した独自画像データを使ってみる【PyTorch】

AI実装

Data Augmentation(画像データの水増し)※は画像認識系のディープラーニング学習で必須の技術となっています。今回はData Augmentation用のライブラリであるAlbumentationsについてPyTorchでの使い方を説明します。

※Data Augmentationは画像を拡大・縮小、回転したり、明るさ・コントラスト変えたり、画像にバリエーションを持たせディープラーニングにおける精度を向上させたり、過学習を抑制する技術です。

画像タスクのディープラーニングは日々新しい手法が生み出されていますが、新しい手法は精度が高い反面、過学習に注意が必要です。今回は過去に実装したViT↓へ適用してみようと思います。

Albumentationsについて

Albumentationsは多様なData Augmentationを手軽に実施するライブラリです。実装は直感でできるような仕様になっています。よくあるData Augmentationの方法以外にも面白い方法があるのでオススメです。

GitHub

GitHub - albumentations-team/albumentations: Fast image augmentation library and an easy-to-use wrapper around other libraries. Documentation: https://albumentations.ai/docs/ Paper about the library: https://www.mdpi.com/2078-2489/11/2/125
Fast image augmentation library and an easy-to-use wrapper around other libraries. Documentation: Paper about the library: -...
出典:https://github.com/albumentations-team/albumentations

↑図はAlbumentationsによる画像拡張一例ですが、カメラのボケを意識した「Blur」であったり、画像のjpeg圧縮を意識した「jpegCompression」など面白いモノがたくさんあります。なんと70種類を超える拡張機能があるようです。

特筆すべきは画像分類以外にも、セマンティックセグメンテーション、インスタンスセグメンテーション、物体検出、姿勢推定にもData Augmentationできることです。これらはData Augmentationすると正解データのマスク画像だったり、バウンディングボックスの位置も修正する必要があるのですが、それをやってくれるのは便利です↓

出典:https://github.com/albumentations-team/albumentations

今回は分類タスクなので試しませんが、機会があればセマンティックセグメンテーションでも試してみようと思います。

Albumentationsを試してみる

以下のコードはAnacondaのJupyterNotebookを想定して記載します。必要あればインストールしてください。Anacondaって!?という方はここを参考にしてください。

まずはAlbumentationsライブラリをインストールしていきます。↓のコードをJupyterNotebookで実行してください。

! pip install -U albumentations

Albumentationsを試してみましょう!まずは必要なライブラリをインポートします。

from pathlib import Path
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import copy
import albumentations as A

次に画像データを読み込んでいきます。データは過去の記事と同じモノをつかっていきます。

train_dataset_dir = Path('./data/train')
val_dataset_dir = Path('./data/val')

files = glob.glob('./data/*/*/*.jpg')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(files[idx])
    ax.imshow(img)
    ax.axis("off")

AlbumentationsのData Augmentationを試してみます。このサイトで機能説明のコードが公開されていましたので、使わせていただきました。

test_image = cv2.imread(files[0])
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)

def compose_augmentation():
    transform = [
        # リサイズ
        A.Resize(224,224),
        # ぼかし
        A.Blur(blur_limit=15, p=1.0),
        # 明るさ、コントラスト
        A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, brightness_by_max=True, p=1.0),
        # 回転
        A.RandomRotate90(p=0.5),
        #Random Erasing
        A.CoarseDropout(max_holes=4, max_height=100, max_width=100, min_holes=1, min_height=50, min_width=50, fill_value=0, p=1.0)
    ]
    return A.Compose(transform)

augment_function = compose_augmentation()
test_image = cv2.imread(files[102])
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
augmented_result = augment_function(image=test_image)
augmented_image = augmented_result['image']
debug_image = copy.deepcopy(augmented_image)
debug_image = np.ascontiguousarray(debug_image)
plt.imshow(debug_image)
plt.axis("off")

transform = 以下を変えてあげることで、いろいろなData Augmentationを試すことができます。

こんな感じでいろいろ試せると思います。画像を確認してData Augmentationを実装していくと良いかと思います👍

Albumentationsを使うといろいろなData Augmentationができるので、何を使ったらよいか迷ってしまいますが、私の場合は「現実にありそうな画像をつくり出していく」というポリシーで拡張しています。実際にありえない画像を使って学習しても気持ち悪いですよね😨

Albumentationsを使ったViT学習の実装

Albumentationsを実装していきます。過去の記事で紹介したViTを用いてPyTorchで実装していきます。

学習環境はここを参照してください。Windows11、ローカルGPU環境で実装していきます。

↓にViTの全実装コードを記載します。ImageFolderを使ってデータセットをつくっているのでデータの配置にご注意ください。

import glob
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
import timm
import albumentations as A
import seaborn as sns
device = 'cuda'

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)
device = 'cuda'

データを読み込んでいきますが、配置は↓の感じでお願いします。

#Load data
train_dataset_dir = Path('./data/train')
val_dataset_dir = Path('./data/val')

↓でAlbumentationsを実装し、ImageFolderでデータセットをつくっていきます。Data Augmentationの種類はここで設定してください。

#Albumentations Augumentation

A_transforms = A.Compose([
    A.Resize(224,224),
    A.RandomRotate90(p=0.2),
    A.RandomGamma(gamma_limit=(85, 150), p=0.2),
    A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, brightness_by_max=True, p=0.2),
    A.CoarseDropout(max_holes=4, max_height=100, max_width=100, min_holes=1, min_height=50, min_width=50, fill_value=0, p=0.2)
])

def albumentations_transform(image, transform=A_transforms):
    if transform:
        image_np = np.array(image)
        augmented = transform(image=image_np)
        image = Image.fromarray(augmented['image'])
        return image

data_transforms = transforms.Compose([
    transforms.Lambda(albumentations_transform),
    transforms.ToTensor()
])

A_transforms_val = A.Compose([
    A.Resize(224,224)
])

def albumentations_transform_val(image, transform=A_transforms_val):
    if transform:
        image_np = np.array(image)
        augmented = transform(image=image_np)
        image = Image.fromarray(augmented['image'])
        return image

val_transforms = transforms.Compose([
    transforms.Lambda(albumentations_transform_val),
    transforms.ToTensor()
])

train_data = datasets.ImageFolder(root=train_dataset_dir, transform=data_transforms)
valid_data = datasets.ImageFolder(root=val_dataset_dir, transform=val_transforms)
#Load Datasets
train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=True)
#modeling
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True, num_classes=3)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
#Training
train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []

model.to("cuda:0")

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)              

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    train_acc_list.append(epoch_accuracy)
    val_acc_list.append(epoch_val_accuracy)
    train_loss_list.append(epoch_loss)
    val_loss_list.append(epoch_val_loss)

こんな感じで学習が進むと思います。

#学習結果の可視化
device2 = torch.device('cpu')

train_acc = []
train_loss = []
val_acc = []
val_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)
    
    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)
    
    val_acc2 = val_acc_list[i].to(device2)
    val_acc3 = val_acc2.clone().numpy()
    val_acc.append(val_acc3)
    
    val_loss2 = val_loss_list[i].to(device2)
    val_loss3 = val_loss2.clone().numpy()
    val_loss.append(val_loss3)  

sns.set()
num_epochs=50

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()

こんな感じでAccuracyとLossの学習曲線が出力されると思います。この結果だとまだ過学習気味なので違ったData Augmentationを試してみるというのも手だと思います。ですが、val_lossが下がっていく気配がないので、Data Augmentationだけに頼らず学習データがうまく分類できているか?も同時に見ていった方が良さそうですね😅

Albumentationsを使うとData Augmentationの設定が割と簡単に実装できるので、是非挑戦してみてください👍

今日も良いディープラーニングライフを👍

参考資料

Albumentationsのaugmentationをひたすら動かす - Qiita
Albumentationsとは機械学習用データ拡張用PythonライブラリData …
GitHub - Kazuhito00/albumentations-examples: 画像データ拡張ライブラリAlbumentationsのJupyter上での実行例。
画像データ拡張ライブラリAlbumentationsのJupyter上での実行例。. Contribute to Kazuhito00/albumentations-examples development by creating an account on GitHub.

コメント

タイトルとURLをコピーしました