学習不要で何でもセグメンテーションできる segment-anything-1-【Python】

AI
スポンサーリンク
スポンサーリンク

はじめに

今までは、学習が必要なセグメンテーション (Detectron2, yolact-edge) について説明してきました。

今回は学習不要でセグメンテーション可能な segment-anything について説明します。

前提条件

前提条件は以下の通りです。

  • Python == 3.9.13
  • torch == 1.12.1
  • torchvision == 0.13.1

GPU は無くても問題ありません。使用する場合は8GB 以上必要となります。

cpu 推論:0.13s

segment-anything の github はこちらにあります。

環境構築

以下のコマンドで必要なライブラリをインストールできます。

pip install git+https://github.com/facebookresearch/segment-anything.git

モデルのダウンロード

モデルをダウンロードします。

ページ中腹の以下の青丸のリンクからモデルをダウンロードしてください。約 2.4 GB あります。

画像の準備

今回は以下の画像を使用します。名称は sample_image.png とします。

MVTec-AD の cable 画像を使用します。

MVTec-AD のページはこちらです。

推論プログラム

inference.py を用意します。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import time

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


if __name__ == "__main__":
    image = cv2.imread('images/sample_image.png')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    import sys
    sys.path.append("..")
    from segment_anything import sam_model_registry, SamPredictor

    sam_checkpoint = "./weights/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    predictor = SamPredictor(sam)

    predictor.set_image(image)

    input_point = np.array([[600, 600]])
    input_label = np.array([1])

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_points(input_point, input_label, plt.gca())
    plt.axis('on')
    plt.show()

    start = time.time()
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    print(time.time() - start)

    print("mask shape: ", masks.shape)

    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10,10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()  

    input_point = np.array([[600, 600], [700, 600]])
    input_label = np.array([1, 1])
    mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )

    print("mask shape: ", masks.shape)

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

上記を実行すると、以下の出力が得られます。

セグメンテーションは上手くできています。

プログラムの説明

プログラムの説明をしていきます。

if __name__ == "__main__":
    image = cv2.imread('images/sample_image.png')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    import sys
    sys.path.append("..")
    from segment_anything import sam_model_registry, SamPredictor

    sam_checkpoint = "./weights/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

AI モデルと 画像を読み込みます。device = “cuda” とすると、GPU で推論が可能です。

predictor = SamPredictor(sam)

predictor.set_image(image)

SamPredictor インスタンスを作成し、インスタンスに画像を渡します。

input_point = np.array([[600, 600]])
input_label = np.array([1])

input_point はセグメントしてほしい箇所を指定します。今回は x=600, y=600 です。

input_label は セグメントしてほしい箇所ごとに設定します。背景=0, 前景=1 です。

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

画像を表示します。

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

推論部分となります。multimask_output=True で 3つのマスク、False で 1つのマスクを返します。

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

3つのマスクに関してそれぞれ画像とスコアを表示します。

input_point = np.array([[600, 600], [700, 600]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

今回は (600,600) と (700,600) の2点を前景として指定します。
また、補助用に先ほどの推論結果を mask_input として流用します。

print("mask shape: ", masks.shape)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

思い通りにセグメンテーションできました。

別の画像でセグメンテーションしてみる

次に、こちらの画像でテストしてみます。

データセットの公式サイトはこちらです。

推論プログラム

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import time

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


if __name__ == "__main__":
    image = cv2.imread('images/sample_image2.png')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    import sys
    sys.path.append("..")
    from segment_anything import sam_model_registry, SamPredictor

    sam_checkpoint = "./weights/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    predictor = SamPredictor(sam)

    predictor.set_image(image)


    input_point = np.array([[200, 300]])
    input_label = np.array([1])

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

    print("mask shape: ", masks.shape)

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()


    input_point = np.array([[48, 65], [75, 95]])
    input_label = np.array([1, 1])

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

    print("mask shape: ", masks.shape)

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

上記を実行すると、以下のような画像が得られます。

綺麗にセグメンテーションできています!

おわりに

今回は segment-anything について、使い方を説明しました。

このライブラリを使用すれば、学習無しで実践投入できそうです。特に製造現場ではその恩恵が大きいかと思われます。

次回は、segment-anything のもう一つの使い方について説明します。

コメント

タイトルとURLをコピーしました