Transformer でインスタンスセグメンテーション Mask2Former 学習編【Python】※未完

AI
スポンサーリンク

スポンサーリンク

はじめに

前回は、Mask2Former で推論する方法について説明しました。

今回は、Mask2Former を学習させる方法について説明します。

前提条件

前提条件は以下の通りです。

  • Ubuntu == 22.04 (WindowsはDetectron2が使用できないので不可)
  • Python == 3.10.6
  • torch == 1.13.1, torchvision == 0.14.1
  • detectron2 == 0.6
  • 学習する場合は GPU 10GB 以上必要
  • Mask2Former がインストールされている (前回の記事を参照)
  • Detectron2 のライブラリをいじるので、仮想環境が望ましい

今回は(おそらく) Ubuntu でしか使用できないAIとなります。

学習データの準備

学習データは、coco 形式で準備します。

yolact-edge, detectron2 のような形式で準備する必要があります。

準備に関しては、こちらの記事を参考にしてください。

cd Mask2Former/datasets
mkdir coco
cd coco
mkdir annotations
mkdir train2017
mkdir val2017

以下のようなフォルダ構成にしてください。

コードの変更

Mask2Former/train_net.py はそのままでは動かないので、コードを変更していきます。

train_net.py の 149行目付近

@classmethod
def build_train_loader(cls, cfg):
    MetadataCatalog.remove("coco_2017_train")  # 追加

site-packages/detectron2/engine/hooks.py の 551行目付近を以下のようにしてください。

def after_step(self):
next_iter = self.trainer.iter + 1
if self._period > 0 and next_iter % self._period == 0:
    # do the last eval in after_train
    # if next_iter != self.trainer.max_iter:
    #     self._do_eval()
    print("## pass ##")

_do_eval の実行を回避します。

これができていないと、以下のエラーが出ます。

AssertionError: Attribute ‘thing_classes’ in the metadata of ‘coco_2017_train’ cannot be set to a different value!

次に、mask2former/data/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py の 134 行目付近を変更します。

# image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
↓
image = utils.read_image(dataset_dict["file_name"].replace("\\", "/"), format=self.img_format)

これをしないと、画像データが見つかりませんとのエラーが出ます。

学習実行

学習は、以下のコマンドで実行できます。

python train_net.py \
  --config-file ./configs/coco/instance-segmentation/maskformer2_R50_bs16_50ep.yaml \
  --num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0001

SOLVER.IMS_PER_BATCH は 2 で GPU 9GB 程度使用します。
SOLVER.BASE_LR は学習率です。大きすぎると発散してエラーが出ます。
エラー内容は
ValueError: matrix contains invalid numeric entries
となります。

学習はバッチサイズ 2 なので非常に時間がかかります。

私の環境では RTX3060 12GB で 3日ほどです。(画像は12枚)

学習中は以下の文字がコンソールに表示されます。

赤丸の total_loss が減少していくことを確認してください。

eval のタイミングで COCO データセットの重複エラーが出るかもしれませんが、その場合は eval を切ってください。

おわりに

使用できるマシンの関係で今回はここまでとします。申し訳ございません。

改めてマシンが確保できた場合は、続きを記載していこうと思います。

次回は omni3D について説明していきます。その際に三次元姿勢推定のデータセットについても解析していけたらと思います。

コメント

タイトルとURLをコピーしました