物体の三次元姿勢推定 CenterSnap -CenterSnapを学習させる- 【Python】

AI
スポンサーリンク
スポンサーリンク

はじめに

前回は CenterSnap を学習させるための中間ファイルを作成する方法について説明しました。

今回は中間ファイルを使用して CenterSnap を学習させていきます。

前提条件

前提条件は以下の通りです。

学習設定ファイル

学習の設定は、makeNOCS/CenterSnap/configs/net_config.txt を調整します。

net_config.txt

--max_steps=76000                                ----------|
--model_file=models/panoptic_net.py                        |
--model_name=res_fpn                                       |
--output=results/train/CenterSnap_TrainSynthetic           |
--train_path=file://data/NOCS_Data/CAMERA/train            |
--train_batch_size=6                                       |------------この部分を変更
--train_num_workers=2                                      |
--val_path=file://data/NOCS_Data/CAMERA/val                |
--val_batch_size=6                                         |
--val_num_workers=2                              ----------|
--optim_learning_rate=0.006
--optim_momentum=0.9
--optim_weight_decay=1e-4
--optim_poly_exp=0.9
--optim_warmup_epochs=1
--loss_seg_mult=1.0
--loss_depth_mult=1.0
--loss_vertex_mult=0.1
--loss_rotation_mult=0.1
--loss_heatmap_mult=100.0
--loss_latent_emb_mult=0.1
--loss_abs_pose_mult=0.1
--loss_z_centroid_mult=0.1
--wandb_name=NOCS_Train_Synthetic

上記の部分を変更して GPU エラーとならないようにしてください。
上記の設定は、NVIDIA T1000 4GB の設定項目です。

学習プログラム

cd makeNOCS/CenterSnap
touch net_train.py

net_train.py を以下のようにしてください。

import os

os.environ['PYTHONHASHSEED'] = str(1)
import argparse
from importlib.machinery import SourceFileLoader
import sys

import random

random.seed(12345)
import numpy as np

np.random.seed(12345)
import torch

torch.manual_seed(12345)

# import wandb

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers

from simnet.lib.net import common
from simnet.lib import datapoint
from simnet.lib import camera
from simnet.lib.net.panoptic_trainer import PanopticModel

_GPU_TO_USE = 0


class EvalMethod():

  def __init__(self):

    self.eval_3d = None
    self.camera_model = camera.NOCS_Camera()

  def process_sample(self, pose_outputs, box_outputs, seg_outputs, detections_gt, scene_name):
    return True

  def process_all_dataset(self, log):
    return True
    # log['all 3Dmap'] = self.eval_3d.process_all_3D_dataset()

  def draw_detections(
      self,seg_outputs,left_image_np, llog, prefix
  ):
    seg_vis = seg_outputs.get_visualization_img(np.copy(left_image_np))
    # llog[f'{prefix}/seg'] = wandb.Image(seg_vis, caption=prefix)

if __name__ == "__main__":
  parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
  common.add_train_args(parser)
  hparams = parser.parse_args()
  train_ds = datapoint.make_dataset(hparams.train_path)
  samples_per_epoch = len(train_ds.list())
  samples_per_step = hparams.train_batch_size
  steps = hparams.max_steps
  steps_per_epoch = samples_per_epoch // samples_per_step
  epochs = int(np.ceil(steps / steps_per_epoch))
  actual_steps = epochs * steps_per_epoch
  print('Samples per epoch', samples_per_epoch)
  print('Steps per epoch', steps_per_epoch)
  print('Target steps:', steps)
  print('Actual steps:', actual_steps)
  print('Epochs:', epochs)

  model = PanopticModel(hparams, epochs, train_ds, EvalMethod())
  model_checkpoint = ModelCheckpoint(filepath=hparams.output, save_top_k=-1, period=1, mode='max')
  # wandb_logger = loggers.WandbLogger(name=hparams.wandb_name, project='CenterSnap')

  if hparams.finetune_real:
    print("############# Finetune ################")
    trainer = pl.Trainer(
        max_nb_epochs=epochs,
        early_stop_callback=None,
        gpus=[_GPU_TO_USE],
        checkpoint_callback=model_checkpoint,
        val_check_interval=1.0,
        # logger=wandb_logger,
        # logger=None,
        default_save_path=hparams.output,
        use_amp=False,
        print_nan_grads=True,
        resume_from_checkpoint=hparams.checkpoint
    )
  else:
    print("############# Train ################")
    trainer = pl.Trainer(
        max_nb_epochs=epochs,
        early_stop_callback=None,
        gpus=[_GPU_TO_USE],
        checkpoint_callback=model_checkpoint,
        val_check_interval=1.0,
        # logger=wandb_logger,
        # logger=None,
        default_save_path=hparams.output,
        use_amp=False,
        print_nan_grads=True,
        #num_sanity_val_steps=4,
    )

  trainer.fit(model)

学習プログラムは以下のように実行してください。

python3 net_train.py @configs/net_config.txt

weight ファイルがとにかく大容量なので注意してください。ストレージの空きは300GB以上が望ましいです。

results/train/CenterSnap_TrainSynthetic に weight ファイル(.ckpt)が保存されています。

Real データセットがある場合は Finetune が必要ですが、まだその段階ではないので、無視します。

おわりに

今回は CenterSnap を学習させる方法について説明しました。

次回は学習させたモデルで推論する方法について説明します。

コメント

タイトルとURLをコピーしました