画像の教師なし異常検知 FastFlow をカスタムデータセットで学習させる【Python】

AI
スポンサーリンク
スポンサーリンク

はじめに

前回は FastFlow のデモを試しました。

今回は、FastFlow をカスタムデータで学習させる方法について説明していきます。

前提条件

前提条件は以下の通りです。

  • python3.9
  • pytorch == 1.12.1+cu113
  • pytorch-ignite == 0.2.0
  • MVTec-AD に会員登録してある (データセットの準備に必要)

下記コマンドで、FrEIA を pip install します。

python -m pip install FrEIA@git+https://github.com/VLL-HD/FrEIA@1779d1fba1e21000fda1927b59eeac0a6fcaa284

カスタムデータセットの準備

前回と同様に、こちらのデータセットを使用します。Class1, Class1_def を使用します。

datasets フォルダ内に custom という名前のフォルダを作成し、その中に train, test フォルダを作成します。良品データを good 、不良品データを defect としてフォルダを作成します。

cd mvtec-ad
mkdir custom
cd custom
mkdir train
mkdir test
cd train
mkdir good
cd ../test
mkdir good
mkdir defect

また、FastFlow フォルダに good.png と defect.png を用意します。

コードの変更

カスタムデータセット用にコードを変更していきます。

main.py を変更します。train 関数内を以下のように変更してください。

if (epoch + 1) % const.EVAL_INTERVAL == 0:
    eval_once(test_dataloader, model)  
↓
# if (epoch + 1) % const.EVAL_INTERVAL == 0:
#     eval_once(test_dataloader, model)   

次に、parse_args 関数内を以下のように変更してください。

parser.add_argument(
        "-cat",
        "--category",
        type=str,
        choices=const.MVTEC_CATEGORIES,
        required=True,
        help="category name in mvtec",
    )
↓
parser.add_argument(
        "-cat",
        "--category",
        type=str,
        #choices=const.MVTEC_CATEGORIES,
        required=True,
        help="category name in mvtec",
    )

次に、dataset.py を変更します。__getitem__ 関数を変更してください。

def __getitem__(self, index):
        image_file = self.image_files[index]
        image = Image.open(image_file)
        image = image.convert("L").convert("RGB")  # 追加
        image = self.image_transform(image)

今回のデータセットはグレースケールなので、RGB に拡張しておきます。

学習の実行

学習は以下のコマンドで実行します。

python main.py -cfg configs/resnet18.yaml --data ./mvtec-ad -cat custom

loss が減少していくことを確認してください。今回は 20epoch 学習させました。

最終的な出力は以下のようになります。

Epoch 20 - Step 10: loss = -772677.062(-769559.769)
Epoch 20 - Step 20: loss = -798395.750(-779039.241)
Epoch 20 - Step 30: loss = -800987.188(-786451.925)
Epoch 20 - Step 31: loss = -798397.438(-786837.264)

次に、wide_resnet50 で学習させてみます。

python main.py -cfg configs/wide_resnet50_2.yaml --data ./mvtec-ad -cat custom

GPU のメモリエラーが出る場合は、constants.py の BATCH_SIZE を変更してください。
もしくは、データローダーの num_workers を変更します。

20epoch の出力は以下のようになります。

Epoch 20 - Step 10: loss = -7480169.500(-7537587.850)
Epoch 20 - Step 20: loss = -7433349.000(-7535516.900)
Epoch 20 - Step 30: loss = -7473068.500(-7519969.533)
Epoch 20 - Step 40: loss = -7545652.500(-7532244.213)
Epoch 20 - Step 50: loss = -7537684.000(-7535402.120)
Epoch 20 - Step 60: loss = -7602153.000(-7543391.567)
Epoch 20 - Step 70: loss = -7585042.500(-7549508.971)
Epoch 20 - Step 80: loss = -7616941.000(-7552665.013)
Epoch 20 - Step 90: loss = -7525526.500(-7557880.233)
Epoch 20 - Step 100: loss = -7658119.000(-7561838.615)
Epoch 20 - Step 110: loss = -7581773.000(-7565925.764)
Epoch 20 - Step 120: loss = -7621588.500(-7570404.138)
Epoch 20 - Step 125: loss = -7645789.000(-7574360.252)

resnet18 と wide_resnet50_2 では、10倍程度の loss の差が出ました。

推論の実行

推論プログラムを inference.py としました。


import cv2
import torch
import yaml
import fastflow
from torchvision import transforms
from torch import tensor

IMAGENET_MEAN = tensor([.485, .456, .406])
IMAGENET_STD = tensor([.229, .224, .225])

class FastFlowObject():
    def __init__(self):
        config_path = "configs/resnet18.yaml"
        ckpt_path = "_fastflow_experiment_checkpoints/custom_resnet18/19.pt"

        self.config = yaml.safe_load(open(config_path, "r"))
        self.model = self.build_model()
        checkpoint = torch.load(ckpt_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])

        self.target_transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(self.config["input_size"]),
                    transforms.ToTensor(),
                    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
                ]
            )
        self.model.eval()

    def build_model(self):
        model = fastflow.FastFlow(
            backbone_name=self.config["backbone_name"],
            flow_steps=self.config["flow_step"],
            input_size=self.config["input_size"],
            conv3x3_only=self.config["conv3x3_only"],
            hidden_ratio=self.config["hidden_ratio"],
        )
        return model

    def inference(self):
        defect_frame = cv2.imread("./defect.png")
        defect_frame = cv2.cvtColor(defect_frame,cv2.COLOR_BGR2RGB)
        defect_x = self.target_transform(defect_frame)
        defect_x = defect_x.unsqueeze(0)

        with torch.no_grad():
            ret = self.model(defect_x)

        output_image = ret["anomaly_map"].cpu().detach().numpy()
        print(output_image.shape)
        output_image = output_image[0] * -255
        output_image = output_image.astype("uint8").reshape(256,256,1)

        image = cv2.inRange(output_image, 0, 10)
        # 輪郭抽出
        contours, _ = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
        if len(contours) > 0:
            contour = max(contours, key=lambda x: cv2.contourArea(x))
        else:
            contour = contours
        defect_frame = cv2.resize(defect_frame, (self.config["input_size"],self.config["input_size"]))
        defect_frame = cv2.drawContours(defect_frame, contour, -1, (0, 255, 0), 5)

        cv2.imshow("output", output_image)
        # cv2.imshow("frame", image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

        cv2.imwrite("output.png", output_image)

if __name__ == "__main__":
    ff = FastFlowObject()
    ff.inference()

上記を実行すると、good.png と defect.png で以下の出力が得られます。

resnet18

good.png

defect.py

wide_resnet50_2

good.png

defect.png

resnet18 も wide_resnet50_2 も不良の位置を含めて捉えることができました。

このサイズの不良であれば、resnet18 の方が推論が早いので、resnet18 を使ったほうが良さそうです。

おわりに

今回は画像の異常検知AI FastFlow をカスタムデータで学習させ、アプリに組み込めるように推論プログラムも作成しました。

決定版と言ってもいいほどだと思います。

2023/09/17 さらに決定版出ました。

コメント

タイトルとURLをコピーしました