学習不要で何でもセグメンテーションできる segment-anything-2-【Python】

AI
スポンサーリンク
スポンサーリンク

はじめに

前回は segment-anything について、背景・前景を指定して推論する方法について説明しました。

今回は 背景・前景を指定せずに推論する方法について説明します。

前提条件

前提条件は以下の通りです。

  • Python == 3.9.13
  • torch == 1.12.1
  • torchvision == 0.13.1
  • 前回の記事で、環境構築・モデルのダウンロードが済んでいる

GPU は無くても問題ありません。使用する場合は8GB 以上必要となります。

cpu 推論:150 – 350s

segment-anything の github はこちらにあります。

使用する画像について

今回は以下の画像を使用します。名称は sample_image.png とします。

MVTec-AD の cable 画像を使用します。

MVTec-AD のページはこちらです。

推論プログラムの作成

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import time

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

if __name__ == "__main__":
    image = cv2.imread('images/sample_image.png')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    import sys
    sys.path.append("..")
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

    sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        points_per_batch=64,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    start = time.time()
    print("inference start")
    masks = mask_generator.generate(image)
    print(time.time() - start)

    print(len(masks))
    print(masks[0].keys())

    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show() 

上記を実行すると、以下の出力が得られます。

ケーブル1本1本を領域検出していますね。

別の画像でテストしてみる

次は、こちらの画像を使用してみます。

結果は、以下のような画像になります。

とても綺麗に領域検出できています。

SamAutomaticMaskGenerator の引数について

精度や処理速度に関する引数は、以下の二つとなります。

  • points_per_side … 画像の一辺に沿ってサンプリングされる点の数。下げると精度が下がるが処理は早くなります。
  • points_per_batch … サンプリングされた点のバッチサイズ。数値が小さければメモリ使用量も低くなる。数値を大きくすると、処理速度が上がります。
  • crop_n_layers … マスクで切り抜かれた画像にもう一度マスク予測をするかどうかの回数。crop_n_layers > 0 で再予測。処理速度は遅くなるが、精度は上がる。

おわりに

今回は segment-anything のもう一つの使い方について説明しました。

GPUが 8GB 以上必要なので、エッジPCでの推論には向いていないかもしれませんが、非常に強力なセグメンテーション手法かと思います。

次回は、Mask2Former について説明できればと思います。

コメント

タイトルとURLをコピーしました