高速画像分類 Vision Transformer をカスタムデータで学習させる

AI
スポンサーリンク

スポンサーリンク

はじめに

前回は、Vision Transformer の example について説明しました。

今回は、Vision Transformer をカスタムデータで学習する方法について説明します。

前提条件

前提条件は、以下の通りです。

  • vit-pytorch == 0.40.2
  • pytorch == 1.11.0+cu113
  • numpy == 1.23.3
  • timm == 0.6.7

カスタムデータについて

今回使用するデータは工業製品にの不良のデータセットを用います。ダウンロードはこちらから。
ページ中盤の赤丸内の Class1, Class1_def, Class2, Class2_def をダウンロードしてください。

ダウンロードしたデータを解凍してから使用します。

mkdir -p ~/vit-dev/vit-pytorch/examples/custom_data/train
cp Class1 ~/vit-dev/vit-pytorch/examples/custom_data/train
cp Class1_def ~/vit-dev/vit-pytorch/examples/custom_data/train
cp Class2 ~/vit-dev/vit-pytorch/examples/custom_data/train
cp Class2_def ~/vit-dev/vit-pytorch/examples/custom_data/train

mkdir -p ~/vit-dev/vit-pytorch/examples/custom_data/val
mv Class1 ~/vit-dev/vit-pytorch/examples/custom_data/val
mv Class1_def ~/vit-dev/vit-pytorch/examples/custom_data/val
mv Class2 ~/vit-dev/vit-pytorch/examples/custom_data/val
mv Class2_def ~/vit-dev/vit-pytorch/examples/custom_data/val

これでデータの準備は完了です。

作成した trainフォルダと valフォルダに適当に画像を振り分けておいてください。

学習用のコード

先に学習用のコードを掲載しておきます。
実際は、Jupyter Notebook で動かしていますが、.py に変更しました。


import glob
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
import timm
import cv2, glob

print(f"Torch: {torch.__version__}")

# Training settings
batch_size = 8
epochs = 20
lr = 1e-5
gamma = 0.7
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)
device = 'cuda'

train_transforms = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(0.3),
        transforms.RandomVerticalFlip(0.3),
        transforms.RandomAffine(degrees=[-10, 10], translate=(0.1, 0.1), scale=(0.5, 1.5)),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

train_dataset_dir = Path('./custom_data/train')
val_dataset_dir = Path('./custom_data/val')

train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=False)

model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=4)
model.to("cuda:0")

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

model.eval()

torch.save(model.state_dict(), "custom-data.pth")

def cv2pil(image):
    new_image = image.copy()
    new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
    new_image = Image.fromarray(new_image)
    return new_image

data_list = glob.glob("./custom_data/val/*/*.png")
for im in data_list:
    frame = cv2.imread(im)
    frame = cv2pil(frame)
    img = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
    )(frame).unsqueeze(0)
    img = img.cuda()
    with torch.no_grad():
        outputs = model(img)
    result = outputs.cpu().numpy()
    result_ = np.argmax(result)
    print(result, result_, im)

こちらを実行すると、custom-data.pth が出力されます。
また、20epoch の acc:0.9825, val_acc:1.000 と非常に高い精度を確認できました。

コード説明

まずは、画像の前処理について説明します。

train_transforms = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(0.3),
        transforms.RandomVerticalFlip(0.3),
        transforms.RandomAffine(degrees=[-10, 10], translate=(0.1, 0.1), scale=(0.5, 1.5)),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

今回は工業製品の不良データということで、画面内の隅に不良が映っている可能性も考えられます。
そこで、transforms.Pad で画面端をedgeの色で延長します。
その後、224 x 224 にリサイズします。

train_dataset_dir = Path('./custom_data/train')
val_dataset_dir = Path('./custom_data/val')

train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=False)

datasets.ImageFolder は、フォルダ内の画像全てをフォルダ名でラベリングしてくれる便利な関数です。前回の cats_and_dogs では一つ一つラベルを付けていましたが、こちらの方法は非常に簡単です。

model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=4)
model.to("cuda:0")

事前学習済みのモデルを フォルダの数と同じ数の num_classes = 4 として読込みます。


def cv2pil(image):
    new_image = image.copy()
    new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
    new_image = Image.fromarray(new_image)
    return new_image

data_list = glob.glob("./custom_data/val/*/*.png")
for im in data_list:
    frame = cv2.imread(im)
    frame = cv2pil(frame)
    img = transforms.Compose(
    [
        transforms.Pad(16, padding_mode='edge'),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
    )(frame).unsqueeze(0)
    img = img.cuda()
    with torch.no_grad():
        outputs = model(img)
    result = outputs.cpu().numpy()
    result_ = np.argmax(result)
    print(result, result_, im)

こちらは推論用のコードとなります。
opencv で読み込んだ画像を cv2pil で Pillow の画像フォーマットに変換します。
outputs.cpu().numpy() で cuda で計算した結果を cuda から cpu へ取り込みます。

outputs は、1 x num_classes となりますので、np.argmax で最大値の index を取得します。
result_ が予測したクラスラベルとなります。

おわりに

今回は、Vision Transformer をカスタムデータで学習する方法について説明しました。

非常に簡単かつ高速に学習することができたと思います。

学習コストがほとんどかからないので、PoC が捗ると思います。

次回は、工業製品画像で教師なし異常検知をする方法について説明していきます。

コメント

タイトルとURLをコピーしました