はじめに
前回は、ind_knn_ad の SPADE について、学習のみ・推論のみに分ける方法の、学習部分について説明しました。
今回は、残りの推論部分と、結果の見方に関して説明していきます。
前提条件
- python3.9
- torch == 1.12.1+cu113
ind_knn_ad の github はこちらです。
推論の準備
まずは、indad/models.py の SPADEクラスに以下の関数を追加してください。
def standby(self):
self.z_lib = np.load("./npy_data/spade.npy")
self.z_lib = torch.from_numpy(self.z_lib.astype(np.float32)).clone()
self.feature_maps = []
maps = np.load("./npy_data/featuremaps1.npy")
self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())
maps = np.load("./npy_data/featuremaps2.npy")
self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())
maps = np.load("./npy_data/featuremaps3.npy")
self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())
前回作成した npy ファイルを読み込むための関数となります。
推論コード
推論のみを実行するプログラムを、inference.py としました。
from indad.models import SPADE, PaDiM, PatchCore
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import tensor
IMAGENET_MEAN = tensor([.485, .456, .406])
IMAGENET_STD = tensor([.229, .224, .225])
SIZE = 224
filename = "./weights/SPADE.pth"
load_model = SPADE(k=25, backbone_name="wide_resnet50_2")
load_model.load_state_dict(torch.load(filename))
load_model.standby()
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
# get predictions
good_frame = cv2.imread("./good.png")
good_frame = cv2.cvtColor(good_frame,cv2.COLOR_BGR2RGB)
good_x = transform(good_frame)
good_x = good_x.unsqueeze(0)
defect_frame = cv2.imread("./defect.png")
defect_frame = cv2.cvtColor(defect_frame,cv2.COLOR_BGR2RGB)
defect_x = transform(defect_frame)
defect_x = defect_x.unsqueeze(0)
load_model.eval()
with torch.no_grad():
img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(good_x)
print("good frame score is: ", img_lvl_anom_score)
img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(defect_x)
print("defect frame score is: ", img_lvl_anom_score)
print(pxl_lvl_anom_score.shape)
上記を実行すると、以下のような出力が得られます。
good frame score is: tensor(6.7276)
defect frame score is: tensor(6.8548)
torch.Size([1, 224, 224])
good, defect ともに推論結果に影響はなさそうです。
([1, 224, 224]) は、アノマリーマップのサイズとなります。
コードの説明
モデルの読込
load_model = SPADE(k=25, backbone_name="wide_resnet50_2")
load_model.load_state_dict(torch.load(filename))
load_model.standby()
前回訓練したモデルは state_sict で保存したので、SPADEモデルを作成してから読み込みます。
モデルを丸ごと保存すると、5GBを超えてしまいます。
standby() は、先ほど作成した関数です。これによって npy ファイルを読み込みます。
画像を推論用に変換
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
transforms.Compose で複数の変換をまとめることができます。
- transforms.ToPILImage() … OpenCV 画像(numpy) から、PIL形式の画像へと変換します。
- transforms.Resize … SIZE x SIZE に画像をリサイズします。
- transforms.CenterCrop … 画像の中心から SIZE x SIZE を切り出します。
- transforms.ToTensor() … PIL 形式から Tensor 形式へ変換します。
- transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) … IMAGENET の形式で画像を標準化します。
画像の読込
# get predictions
good_frame = cv2.imread("./good.png")
good_frame = cv2.cvtColor(good_frame,cv2.COLOR_BGR2RGB)
good_x = transform(good_frame)
good_x = good_x.unsqueeze(0)
transform で画像を推論用に変換した後、unsqueeze でバッチから画像を取り出します。
推論の実行
load_model.eval()
with torch.no_grad():
img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(good_x)
print("good frame score is: ", img_lvl_anom_score)
img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(defect_x)
print("defect frame score is: ", img_lvl_anom_score)
print(pxl_lvl_anom_score.shape)
model.eval() でモデルを推論に切り替え、torch.no_grad() で重みが変わらないようにします。
predict() で推論を実行します。
返り値は異常度と異常マップとなります。
異常マップの可視化
inference.py の末尾に以下を追加してください。
anom_frame = pxl_lvl_anom_score.numpy().reshape(224,224,1).astype("uint8")
print(anom_frame.shape)
cv2.imshow("frame", anom_frame)
cv2.waitKey(0)
cv2.destroyAllWindows()
異常マップを numpy.ndarray に変換した後に、PIL 形式から opencv で扱えるようにリサイズします。
異常マップを uint8 形式に変換すると、cv2.imshow() で表示できるようになります。
おわりに
今回は異常検知AIの SPADE に関して、推論のみ実行する方法について説明しました。
少し手間ですが、教師あり学習のアノテーションに比べたらマシだと思います。
次回は、PaDiM か PatchCore で同様に学習と推論に分けて実行する方法を説明します。
コメント