画像の異常検知 ind_knn_ad を学習と推論で分割する 【SPADE推論編】

AI
スポンサーリンク

スポンサーリンク

はじめに

前回は、ind_knn_ad の SPADE について、学習のみ・推論のみに分ける方法の、学習部分について説明しました。

今回は、残りの推論部分と、結果の見方に関して説明していきます。

前提条件

  • python3.9
  • torch == 1.12.1+cu113

ind_knn_ad の github はこちらです。

推論の準備

まずは、indad/models.py の SPADEクラスに以下の関数を追加してください。

def standby(self):
		self.z_lib = np.load("./npy_data/spade.npy")
		self.z_lib = torch.from_numpy(self.z_lib.astype(np.float32)).clone()
		self.feature_maps = []
		maps = np.load("./npy_data/featuremaps1.npy")
		self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())
		maps = np.load("./npy_data/featuremaps2.npy")
		self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())
		maps = np.load("./npy_data/featuremaps3.npy")
		self.feature_maps.append(torch.from_numpy(maps.astype(np.float32)).clone())

前回作成した npy ファイルを読み込むための関数となります。

推論コード

推論のみを実行するプログラムを、inference.py としました。

from indad.models import SPADE, PaDiM, PatchCore
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import tensor

IMAGENET_MEAN = tensor([.485, .456, .406])
IMAGENET_STD = tensor([.229, .224, .225])
SIZE = 224
filename = "./weights/SPADE.pth"

load_model = SPADE(k=25, backbone_name="wide_resnet50_2")
load_model.load_state_dict(torch.load(filename))
load_model.standby()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

# get predictions
good_frame = cv2.imread("./good.png")
good_frame = cv2.cvtColor(good_frame,cv2.COLOR_BGR2RGB)
good_x = transform(good_frame)
good_x = good_x.unsqueeze(0)

defect_frame = cv2.imread("./defect.png")
defect_frame = cv2.cvtColor(defect_frame,cv2.COLOR_BGR2RGB)
defect_x = transform(defect_frame)
defect_x = defect_x.unsqueeze(0)

load_model.eval()
with torch.no_grad():
    img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(good_x)
    print("good frame score is: ", img_lvl_anom_score)
    img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(defect_x)
    print("defect frame score is: ", img_lvl_anom_score)

    print(pxl_lvl_anom_score.shape)

上記を実行すると、以下のような出力が得られます。

good frame score is:  tensor(6.7276)
defect frame score is:  tensor(6.8548)
torch.Size([1, 224, 224])

good, defect ともに推論結果に影響はなさそうです。
([1, 224, 224]) は、アノマリーマップのサイズとなります。

コードの説明

モデルの読込

load_model = SPADE(k=25, backbone_name="wide_resnet50_2")
load_model.load_state_dict(torch.load(filename))
load_model.standby()

前回訓練したモデルは state_sict で保存したので、SPADEモデルを作成してから読み込みます。
モデルを丸ごと保存すると、5GBを超えてしまいます。

standby() は、先ほど作成した関数です。これによって npy ファイルを読み込みます。

画像を推論用に変換

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

transforms.Compose で複数の変換をまとめることができます。

  • transforms.ToPILImage() … OpenCV 画像(numpy) から、PIL形式の画像へと変換します。
  • transforms.Resize … SIZE x SIZE に画像をリサイズします。
  • transforms.CenterCrop … 画像の中心から SIZE x SIZE を切り出します。
  • transforms.ToTensor() … PIL 形式から Tensor 形式へ変換します。
  • transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) … IMAGENET の形式で画像を標準化します。

画像の読込

# get predictions
good_frame = cv2.imread("./good.png")
good_frame = cv2.cvtColor(good_frame,cv2.COLOR_BGR2RGB)
good_x = transform(good_frame)
good_x = good_x.unsqueeze(0)

transform で画像を推論用に変換した後、unsqueeze でバッチから画像を取り出します。

推論の実行

load_model.eval()
with torch.no_grad():
    img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(good_x)
    print("good frame score is: ", img_lvl_anom_score)
    img_lvl_anom_score, pxl_lvl_anom_score = load_model.predict(defect_x)
    print("defect frame score is: ", img_lvl_anom_score)

    print(pxl_lvl_anom_score.shape)

model.eval() でモデルを推論に切り替え、torch.no_grad() で重みが変わらないようにします。
predict() で推論を実行します。

返り値は異常度と異常マップとなります。

異常マップの可視化

inference.py の末尾に以下を追加してください。

anom_frame = pxl_lvl_anom_score.numpy().reshape(224,224,1).astype("uint8")
print(anom_frame.shape)

cv2.imshow("frame", anom_frame)
cv2.waitKey(0)
cv2.destroyAllWindows()

異常マップを numpy.ndarray に変換した後に、PIL 形式から opencv で扱えるようにリサイズします。
異常マップを uint8 形式に変換すると、cv2.imshow() で表示できるようになります。

おわりに

今回は異常検知AIの SPADE に関して、推論のみ実行する方法について説明しました。

少し手間ですが、教師あり学習のアノテーションに比べたらマシだと思います。

次回は、PaDiM か PatchCore で同様に学習と推論に分けて実行する方法を説明します。

コメント

タイトルとURLをコピーしました