画像の異常検知 ind_knn_ad を学習と推論で分割する 【SPADE学習編】

AI
スポンサーリンク

スポンサーリンク

はじめに

前回は ind_knn_ad の推論用のコードを作成しました。

しかしながら、このままでは起動するたびに学習をしないと推論をすることができません。
そこで今回は、学習と推論でコードを切り離したいと思います。

前提条件

前提条件は以下の通りです。

  • python3.9
  • torch == 1.12.1+cu113

ind_knn_ad の github はこちらです。

フォルダの準備

フォルダの準備をしていきます。

cd ind_knn_ad の作業ディレクトリ
mkdir npy_data
mkdir weights
touch train.py
touch inference.py

また、indad/models.py の 12行目を以下のように変更してください。

from utils import GaussianBlur, get_coreset_idx_randomp, get_tqdm_params
↓
from indad.utils import GaussianBlur, get_coreset_idx_randomp, get_tqdm_params

学習用のコード

まずは、学習用のコードの作成をしていきます。

from indad.models import SPADE, PaDiM, PatchCore
from indad.data import MVTecDataset
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import tensor

IMAGENET_MEAN = tensor([.485, .456, .406])
IMAGENET_STD = tensor([.229, .224, .225])
SIZE = 224
filename = "./weights/SPADE.pth"

model = SPADE(k=25, backbone_name="wide_resnet50_2")
# model = PaDiM(d_reduced=350, backbone_name="wide_resnet50_2")
# model = PatchCore(f_coreset=.10, backbone_name="wide_resnet50_2")

# model.to("cuda")
train_ds, test_ds = MVTecDataset("custom", SIZE).get_dataloaders()

# feed healthy dataset
model.fit(train_ds)

# torch.save(model, filename)
torch.save(model.state_dict(), filename)

print("model saved to: ", filename)

上記を実行すると、weights フォルダ内に SPADE.pth が作成されます。

メモリを非常に消費するので、model.to()

しかしながら、このモデルだけでは動きません。
SPADE は、モデルの学習もありますが推論には特徴量マップと z_lib(マハラノビス距離?)も使用します。

なので、これらを保存しておく必要があります。

models.py の変更

indad/models.py の 109 行目付近の SPADEクラスの fit 関数を変更します。

def fit(self, train_dl):
		for sample, _ in tqdm(train_dl, **get_tqdm_params()):
			feature_maps, z = self(sample)

			# z vector
			self.z_lib.append(z)

			# feature maps
			if len(self.feature_maps) == 0:
				for fmap in feature_maps:
					self.feature_maps.append([fmap])
			else:
				for idx, fmap in enumerate(feature_maps):
					self.feature_maps[idx].append(fmap)

		self.z_lib = torch.vstack(self.z_lib)
		x = self.z_lib.to('cpu').detach().numpy().copy()
		np.save("./npy_data/spade.npy", x)
		print("numpy saved: ")
		for idx, fmap in enumerate(self.feature_maps):
			self.feature_maps[idx] = torch.vstack(fmap)
		
		x = np.array(self.feature_maps[0], dtype=float)
		np.save("./npy_data/featuremaps1.npy", x)
		x = np.array(self.feature_maps[1], dtype=float)
		np.save("./npy_data/featuremaps2.npy", x)
		x = np.array(self.feature_maps[2], dtype=float)
		np.save("./npy_data/featuremaps3.npy", x)
		print("numpy saved: ")

z_lib の保存

self.z_lib = torch.vstack(self.z_lib)
x = self.z_lib.to('cpu').detach().numpy().copy()
np.save("./npy_data/spade.npy", x)

z_lib を numpy 配列に変更し、npy 拡張子で保存します。

特徴量マップの保存

x = np.array(self.feature_maps[0], dtype=float)
np.save("./npy_data/featuremaps1.npy", x)
x = np.array(self.feature_maps[1], dtype=float)
np.save("./npy_data/featuremaps2.npy", x)
x = np.array(self.feature_maps[2], dtype=float)
np.save("./npy_data/featuremaps3.npy", x)

特徴量マップは全部で三つあり、それぞれサイズが異なります。
まとめて保存しようとすると dtype=object にしてください といった警告が発生します。

面倒ですが、一つずつ numpy.ndarray で保存してください。

おわりに

推論部分もかなり長くなる予定なので今回はここまでとします。

npy 形式の保存は、scikit-learn 等を使用するうえでかなり重宝すると思うので、覚えておいて損はないです。

次回は、推論部分のコードの説明と、結果の見方について説明します。

コメント

タイトルとURLをコピーしました