Data Augmentation(画像データの水増し)※は画像認識系のディープラーニング学習で必須の技術となっています。今回はData Augmentation用のライブラリであるAlbumentationsについてPyTorchでの使い方を説明します。
※Data Augmentationは画像を拡大・縮小、回転したり、明るさ・コントラスト変えたり、画像にバリエーションを持たせディープラーニングにおける精度を向上させたり、過学習を抑制する技術です。
画像タスクのディープラーニングは日々新しい手法が生み出されていますが、新しい手法は精度が高い反面、過学習に注意が必要です。今回は過去に実装したViT↓へ適用してみようと思います。
Albumentationsについて
Albumentationsは多様なData Augmentationを手軽に実施するライブラリです。実装は直感でできるような仕様になっています。よくあるData Augmentationの方法以外にも面白い方法があるのでオススメです。
GitHub
↑図はAlbumentationsによる画像拡張一例ですが、カメラのボケを意識した「Blur」であったり、画像のjpeg圧縮を意識した「jpegCompression」など面白いモノがたくさんあります。なんと70種類を超える拡張機能があるようです。
特筆すべきは画像分類以外にも、セマンティックセグメンテーション、インスタンスセグメンテーション、物体検出、姿勢推定にもData Augmentationできることです。これらはData Augmentationすると正解データのマスク画像だったり、バウンディングボックスの位置も修正する必要があるのですが、それをやってくれるのは便利です↓
今回は分類タスクなので試しませんが、機会があればセマンティックセグメンテーションでも試してみようと思います。
Albumentationsを試してみる
以下のコードはAnacondaのJupyterNotebookを想定して記載します。必要あればインストールしてください。Anacondaって!?という方はここを参考にしてください。
まずはAlbumentationsライブラリをインストールしていきます。↓のコードをJupyterNotebookで実行してください。
! pip install -U albumentations
Albumentationsを試してみましょう!まずは必要なライブラリをインポートします。
from pathlib import Path
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import copy
import albumentations as A
次に画像データを読み込んでいきます。データは過去の記事と同じモノをつかっていきます。
train_dataset_dir = Path('./data/train')
val_dataset_dir = Path('./data/val')
files = glob.glob('./data/*/*/*.jpg')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))
for idx, ax in enumerate(axes.ravel()):
img = Image.open(files[idx])
ax.imshow(img)
ax.axis("off")
AlbumentationsのData Augmentationを試してみます。このサイトで機能説明のコードが公開されていましたので、使わせていただきました。
test_image = cv2.imread(files[0])
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
def compose_augmentation():
transform = [
# リサイズ
A.Resize(224,224),
# ぼかし
A.Blur(blur_limit=15, p=1.0),
# 明るさ、コントラスト
A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, brightness_by_max=True, p=1.0),
# 回転
A.RandomRotate90(p=0.5),
#Random Erasing
A.CoarseDropout(max_holes=4, max_height=100, max_width=100, min_holes=1, min_height=50, min_width=50, fill_value=0, p=1.0)
]
return A.Compose(transform)
augment_function = compose_augmentation()
test_image = cv2.imread(files[102])
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
augmented_result = augment_function(image=test_image)
augmented_image = augmented_result['image']
debug_image = copy.deepcopy(augmented_image)
debug_image = np.ascontiguousarray(debug_image)
plt.imshow(debug_image)
plt.axis("off")
transform = 以下を変えてあげることで、いろいろなData Augmentationを試すことができます。
こんな感じでいろいろ試せると思います。画像を確認してData Augmentationを実装していくと良いかと思います👍
Albumentationsを使うといろいろなData Augmentationができるので、何を使ったらよいか迷ってしまいますが、私の場合は「現実にありそうな画像をつくり出していく」というポリシーで拡張しています。実際にありえない画像を使って学習しても気持ち悪いですよね😨
Albumentationsを使ったViT学習の実装
Albumentationsを実装していきます。過去の記事で紹介したViTを用いてPyTorchで実装していきます。
学習環境はここを参照してください。Windows11、ローカルGPU環境で実装していきます。
↓にViTの全実装コードを記載します。ImageFolderを使ってデータセットをつくっているのでデータの配置にご注意ください。
import glob
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
import timm
import albumentations as A
import seaborn as sns
device = 'cuda'
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
seed_everything(seed)
device = 'cuda'
データを読み込んでいきますが、配置は↓の感じでお願いします。
#Load data
train_dataset_dir = Path('./data/train')
val_dataset_dir = Path('./data/val')
↓でAlbumentationsを実装し、ImageFolderでデータセットをつくっていきます。Data Augmentationの種類はここで設定してください。
#Albumentations Augumentation
A_transforms = A.Compose([
A.Resize(224,224),
A.RandomRotate90(p=0.2),
A.RandomGamma(gamma_limit=(85, 150), p=0.2),
A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, brightness_by_max=True, p=0.2),
A.CoarseDropout(max_holes=4, max_height=100, max_width=100, min_holes=1, min_height=50, min_width=50, fill_value=0, p=0.2)
])
def albumentations_transform(image, transform=A_transforms):
if transform:
image_np = np.array(image)
augmented = transform(image=image_np)
image = Image.fromarray(augmented['image'])
return image
data_transforms = transforms.Compose([
transforms.Lambda(albumentations_transform),
transforms.ToTensor()
])
A_transforms_val = A.Compose([
A.Resize(224,224)
])
def albumentations_transform_val(image, transform=A_transforms_val):
if transform:
image_np = np.array(image)
augmented = transform(image=image_np)
image = Image.fromarray(augmented['image'])
return image
val_transforms = transforms.Compose([
transforms.Lambda(albumentations_transform_val),
transforms.ToTensor()
])
train_data = datasets.ImageFolder(root=train_dataset_dir, transform=data_transforms)
valid_data = datasets.ImageFolder(root=val_dataset_dir, transform=val_transforms)
#Load Datasets
train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=True)
#modeling
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True, num_classes=3)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
#Training
train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []
model.to("cuda:0")
for epoch in range(epochs):
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in valid_loader:
data = data.to(device)
label = label.to(device)
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1) == label).float().mean()
epoch_val_accuracy += acc / len(valid_loader)
epoch_val_loss += val_loss / len(valid_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
)
train_acc_list.append(epoch_accuracy)
val_acc_list.append(epoch_val_accuracy)
train_loss_list.append(epoch_loss)
val_loss_list.append(epoch_val_loss)
こんな感じで学習が進むと思います。
#学習結果の可視化
device2 = torch.device('cpu')
train_acc = []
train_loss = []
val_acc = []
val_loss = []
for i in range(epochs):
train_acc2 = train_acc_list[i].to(device2)
train_acc3 = train_acc2.clone().numpy()
train_acc.append(train_acc3)
train_loss2 = train_loss_list[i].to(device2)
train_loss3 = train_loss2.clone().detach().numpy()
train_loss.append(train_loss3)
val_acc2 = val_acc_list[i].to(device2)
val_acc3 = val_acc2.clone().numpy()
val_acc.append(val_acc3)
val_loss2 = val_loss_list[i].to(device2)
val_loss3 = val_loss2.clone().numpy()
val_loss.append(val_loss3)
sns.set()
num_epochs=50
fig = plt.subplots(figsize=(12, 4), dpi=80)
ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')
ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()
こんな感じでAccuracyとLossの学習曲線が出力されると思います。この結果だとまだ過学習気味なので違ったData Augmentationを試してみるというのも手だと思います。ですが、val_lossが下がっていく気配がないので、Data Augmentationだけに頼らず学習データがうまく分類できているか?も同時に見ていった方が良さそうですね😅
Albumentationsを使うとData Augmentationの設定が割と簡単に実装できるので、是非挑戦してみてください👍
今日も良いディープラーニングライフを👍
コメント