YOLOX をカスタムデータで学習【Python】

AI
スポンサーリンク
スポンサーリンク

はじめに

前回までは、教師なし学習のAIライブラリについて説明しました。

今回は、今更ですが YOLOX に関してカスタムデータの学習手順をまとめておきます。

YOLOX の github はこちらです。

前提条件

前提条件は以下の通りです。

  • Windows11
  • Python3.9
  • torch==1.12.1+cu113

データの準備

カスタムデータは、以前の FastFlow の記事を参考に用意しました。
以前と同様に、Class1, Class1_def のみ使用します。
学習データ、テストデータは適当に分けておいてください。

今回は Class1_def の 1 – 10 を使用します。

images/train フォルダに配置しておいてください。

labelImg の準備

labelImg を入手しておきます。

git clone https://github.com/HumanSignal/labelImg.git
cd labelImg
pyrcc5 -o libs/resources.py resources.qrc
rm data/predefined_classes.txt
cd ..

アノテーション

アノテーションは Yolov5, Yolov8 と同じようにアノテーションを行います。

python labelImg/labelImg.py

上記を実行すると、labelImg が立ち上がります。

左上の ディレクトリを開く から、先ほどの画像データを保存したフォルダを選択してください。

キーボードの w を押して、マウスドラッグで矩形を作成してアノテーションしてください。

今回は 10枚 アノテーションしたら、labelImg を閉じます。

データ拡張とCOCO形式への変換

変換プログラムの作成はこちらを参考にしました。

change_and_aug_yoloXsystem.py

# coding: utf-8
import glob, os, pathlib, shutil, yaml, cv2
import tkinter
from tkinter import filedialog
import sys
import argparse
import json
import shutil
import time
import warnings
from pathlib import Path
from tqdm import tqdm

class YOLOV5ToCOCO():
    def __init__(self, data_dir):
        self.raw_data_dir = Path(data_dir)

        self.verify_exists(self.raw_data_dir / 'images')
        self.verify_exists(self.raw_data_dir / 'labels')

        save_dir_name = 'train_COCO_format'
        self.output_dir = Path("./annotation",save_dir_name)
        self.mkdir(self.output_dir)

        self._init_json()

    def __call__(self, mode_list: list):
        if not mode_list:
            raise ValueError('mode_list is empty!!')

        for mode in mode_list:
            # Read the image txt.
            txt_path = self.raw_data_dir / f'{mode}.txt'
            self.verify_exists(txt_path)
            img_list = self.read_txt(txt_path)
            if mode == 'train':
                img_list = self.append_bg_img(img_list)

            # Create the directory of saving the new image.
            save_img_dir = self.output_dir / f'{mode}2017'
            self.mkdir(save_img_dir)

            # Generate json file.
            anno_dir = self.output_dir / "annotations"
            self.mkdir(anno_dir)

            save_json_path = anno_dir / f'{mode}_annotations.json'
            json_data = self.convert(img_list, save_img_dir, mode)

            self.write_json(save_json_path, json_data)
        print(f'Successfully convert, detail in {self.output_dir}')

    def _init_json(self):
        classes_path = self.raw_data_dir / 'classes.txt'
        self.verify_exists(classes_path)
        self.categories = self._get_category(classes_path)
        self.type = 'instances'
        self.annotation_id = 1

        self.cur_year = time.strftime('%Y', time.localtime(time.time()))
        self.info = {
            'year': int(self.cur_year),
            'version': '1.0',
            'description': 'For object detection',
            'date_created': self.cur_year,
        }

        self.licenses = [{
            'id': 1,
            'name': 'Apache License v2.0',
            'url': 'https://github.com/RapidAI/YOLO2COCO/LICENSE',
        }]

    def append_bg_img(self, img_list):
        bg_dir = self.raw_data_dir / 'background_images'
        if bg_dir.exists():
            bg_img_list = list(bg_dir.iterdir())
            for bg_img_path in bg_img_list:
                img_list.append(str(bg_img_path))
        return img_list

    def _get_category(self, classes_path):
        class_list = self.read_txt(classes_path)
        categories = []
        for i, category in enumerate(class_list, 1):
            categories.append({
                'supercategory': category,
                'id': i,
                'name': category,
            })
        return categories

    def convert(self, img_list, save_img_dir, mode):
        images, annotations = [], []
        for img_id, img_path in enumerate(tqdm(img_list, desc=mode), 1):
            image_dict = self.get_image_info(img_path, img_id, save_img_dir)
            images.append(image_dict)

            label_path = self.raw_data_dir / 'labels' / f'{Path(img_path).stem}.txt'
            annotation = self.get_annotation(label_path,
                                             img_id,
                                             image_dict['height'],
                                             image_dict['width'])
            annotations.extend(annotation)

        json_data = {
            'info': self.info,
            'images': images,
            'licenses': self.licenses,
            'type': self.type,
            'annotations': annotations,
            'categories': self.categories,
        }
        return json_data

    def get_image_info(self, img_path, img_id, save_img_dir):
        img_path = Path(img_path)
        if self.raw_data_dir.as_posix() not in img_path.as_posix():
            # relative path (relative to the raw_data_dir)
            # e.g. images/images(3).jpg
            img_path = self.raw_data_dir / img_path

        self.verify_exists(img_path)

        new_img_name = f'{img_id:012d}.jpg'
        save_img_path = save_img_dir / new_img_name
        img_src = cv2.imread(str(img_path))
        if img_path.suffix.lower() == ".jpg":
            shutil.copyfile(img_path, save_img_path)
        else:
            cv2.imwrite(str(save_img_path), img_src)

        height, width = img_src.shape[:2]
        image_info = {
            'date_captured': self.cur_year,
            'file_name': new_img_name,
            'id': img_id,
            'height': height,
            'width': width,
        }
        return image_info

    def get_annotation(self, label_path: Path, img_id, height, width):
        def get_box_info(vertex_info, height, width):
            cx, cy, w, h = [float(i) for i in vertex_info]

            cx = cx * width
            cy = cy * height
            box_w = w * width
            box_h = h * height

            # left top
            x0 = max(cx - box_w / 2, 0)
            y0 = max(cy - box_h / 2, 0)

            # right bottom
            x1 = min(x0 + box_w, width)
            y1 = min(y0 + box_h, height)

            segmentation = [[x0, y0, x1, y0, x1, y1, x0, y1]]
            bbox = [x0, y0, box_w, box_h]
            area = box_w * box_h
            return segmentation, bbox, area

        if not label_path.exists():
            annotation = [{
                'segmentation': [],
                'area': 0,
                'iscrowd': 0,
                'image_id': img_id,
                'bbox': [],
                'category_id': -1,
                'id': self.annotation_id,
            }]
            self.annotation_id += 1
            return annotation

        annotation = []
        label_list = self.read_txt(str(label_path))
        for i, one_line in enumerate(label_list):
            label_info = one_line.split(' ')
            if len(label_info) < 5:
                warnings.warn(
                    f'The {i+1} line of the {label_path} has been corrupted.')
                continue

            category_id, vertex_info = label_info[0], label_info[1:]
            segmentation, bbox, area = get_box_info(vertex_info, height, width)
            annotation.append({
                'segmentation': segmentation,
                'area': area,
                'iscrowd': 0,
                'image_id': img_id,
                'bbox': bbox,
                'category_id': int(category_id)+1,
                'id': self.annotation_id,
            })
            self.annotation_id += 1
        return annotation

    @staticmethod
    def read_txt(txt_path):
        with open(str(txt_path), 'r', encoding='utf-8') as f:
            data = list(map(lambda x: x.rstrip('\n'), f))
        return data

    @staticmethod
    def mkdir(dir_path):
        Path(dir_path).mkdir(parents=True, exist_ok=True)

    @staticmethod
    def verify_exists(file_path):
        file_path = Path(file_path)
        if not file_path.exists():
            raise FileNotFoundError(f'The {file_path} is not exists!!!')

    @staticmethod
    def write_json(json_path, content: dict):
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(content, f, ensure_ascii=False)

class YOLOAugment():
    def __init__(self, file_path):
        self.dir_ = file_path + "/"
        self.path_xml = glob.glob(self.dir_+"*.xml.txt")
        self.path_txt = glob.glob(self.dir_+"*.txt")
        self.path_png = glob.glob(self.dir_+"*.jpg")

    def doAugment(self):
        if len(self.path_png)==0:
            return False
        if len(self.path_txt)==0:
            return False
        #ファイル処理
        for png_ in self.path_png:
            png_p = png_[:-4]
            #.txtに変更
            png_t = png_p + ".txt"
            png_x = png_p + ".xml.txt"

            #path_txt内に.txtがあるか確認
            for txt_ in self.path_txt:
                #.txtがある場合
                if png_t == txt_:
                    #.xml.txtと.txtが共存する場合、.xml.txtを削除
                    if os.path.exists(png_x):
                        os.remove(png_x)
                #.xml.txtのみの場合
                elif png_x == txt_:
                    #名前変更
                    if os.path.exists(png_x):
                        os.rename(png_x, png_t)
            #ラベリングしていないデータは削除
            if os.path.exists(png_t) == False:
                if os.path.exists(png_x) == False:
                    os.remove(png_)

        #メイン
        for file in os.listdir(self.dir_):
            name, ext = os.path.splitext(file)
            print('base:' + name + '  ext:' + ext)
            if ext == '.jpg':
                img = cv2.imread(self.dir_ + file)
                label_info = self.load_labeldata(name)
                print(label_info, len(label_info))
                if len(label_info)==0:
                    print("move to: ", self.dir_ + "background/" + file)
                else:
                    self.flipped_y(img, name, label_info)
                    self.flipped_x(img, name, label_info)
                    self.flipped_xy(img, name, label_info)

        #yamlファイル作成
        if not os.path.exists(self.dir_+"classes.txt"):
            return False
        with open(self.dir_+"classes.txt", 'r', encoding='UTF-8') as f:
            data = f.read().split("\n")
        data = [a for a in data if a != '']
        print(data, len(data))

        with open(self.dir_+"data.yaml", "w") as yf:
            yaml.dump({
                "train": "datasets/train/images",
                "val": "datasets/valid/images",
                "nc": len(data),
                "names":data,
            }, yf, default_flow_style=False)

        #フォルダをimagesとlabelsに分割
        os.makedirs(self.dir_+"images", exist_ok=True)
        os.makedirs(self.dir_+"labels", exist_ok=True)
        images_dir = self.dir_ + "images/"
        labels_dir = self.dir_ + "labels/"
        path_labels = pathlib.Path(self.dir_).glob('*.txt')
        path_images = pathlib.Path(self.dir_).glob('*.jpg')

        for labels in path_labels:
            if labels.name == "classes.txt":
                continue
            print(labels.name)
            shutil.move(self.dir_+str(labels.name), labels_dir+str(labels.name))
        for images in path_images:
            with open(self.dir_+"train.txt", "a") as f:
                f.write(images_dir+str(images.name))
                f.write("\n")
            if not "_flipped_" in images.name:
                with open(self.dir_+"val.txt", "a") as f:
                    f.write(images_dir+str(images.name))
                    f.write("\n")
            shutil.move(self.dir_+str(images.name), images_dir+str(images.name))
        return True

    #データ拡張
    def flipped_y(self, img, filenm_base, label_info):
        img_flip_ud = cv2.flip(img, 0)
        cv2.imwrite(self.dir_ + filenm_base + '_flipped_y.jpg', img_flip_ud)

        f = open(self.dir_ + filenm_base + '_flipped_y.txt', 'a')
        for data in label_info:
            label, x_coordinate, y_coordinate, x_size, y_size = data.split()
            f.write(label + ' ' + x_coordinate + ' ' + self.turn_over(y_coordinate) + ' ' + x_size + ' ' + y_size + '\n')
        f.close()

    def flipped_x(self, img, filenm_base, label_info):
        img_flip_lr = cv2.flip(img, 1)
        cv2.imwrite(self.dir_ + filenm_base + '_flipped_x.jpg', img_flip_lr)

        f = open(self.dir_ + filenm_base + '_flipped_x.txt', 'a')
        for data in label_info:
            label, x_coordinate, y_coordinate, x_size, y_size = data.split()
            f.write(label + ' ' + self.turn_over(x_coordinate) + ' ' + y_coordinate + ' ' + x_size + ' ' + y_size + '\n')
        f.close()

    def flipped_xy(self, img, filenm_base, label_info):
        img_flip_ud_lr = cv2.flip(img, -1)
        cv2.imwrite(self.dir_ + filenm_base + '_flipped_xy.jpg', img_flip_ud_lr)

        f = open(self.dir_ + filenm_base + '_flipped_xy.txt', 'a')
        for data in label_info:
            label, x_coordinate, y_coordinate, x_size, y_size = data.split()
            f.write(label + ' ' + self.turn_over(x_coordinate) + ' ' + self.turn_over(y_coordinate) + ' ' + x_size + ' ' + y_size + '\n')
        f.close()

    def turn_over(self, coordinate):
        val = 1 - float(coordinate)
        return '{:.6f}'.format(val)


    def load_labeldata(self, filenm):
        try:
            f = open(self.dir_ + filenm + '.txt', 'r')
            return f.readlines()
        except Exception as e:
            print(filenm)
            print(e)
        finally:
            f.close()


if __name__=="__main__":

    idir = './'
    file_path = tkinter.filedialog.askdirectory(initialdir = idir)
    print(file_path)
    if file_path == "":
        sys.exit()

    aug = YOLOAugment(file_path)
    aug.doAugment()

    parser = argparse.ArgumentParser('Datasets converter from YOLOV5 to COCO')
    parser.add_argument('--mode_list', type=str, default='train,val',
                        help='generate which mode')
    args = parser.parse_args()

    converter = YOLOV5ToCOCO(file_path)
    converter(mode_list=args.mode_list.split(','))
    save_dir_name = 'train_COCO_format'
    output_dir = "./annotation/" + save_dir_name
    shutil.copy(file_path+"/classes.txt", output_dir+"/classes.txt")
    # shutil.rmtree(file_path)
    

上記を実行すると、ファイルダイアログが表示されるので、先ほどアノテーションした結果が格納されているフォルダを選択します。

annotation/train_COCO_format が作成されているはずです。

学習実行

今回は YOLOX-s を使用します。
こちらのページからダウンロードしておいてください。

git clone https://github.com/Megvii-BaseDetection/YOLOX.git
cd YOLOX
mkdir weights
touch yolox_s.py

今作成した weights フォルダに、yolox_s.pth を保存しておいてください。

また、train_COCO_format フォルダを datasets フォルダに保存しておいてください。

custom_yolox_s.py

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.


import os

from yolox.exp import Exp as MyExp


class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        self.depth = 0.33
        self.width = 0.50
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

        # データセットの場所を定義
        self.data_dir = "./datasets/train_COCO_format"
        self.train_ann = "train_annotations.json"
        self.val_ann = "val_annotations.json"
	
        # クラス数の変更
        self.num_classes = 2
        
        # 評価間隔を変更
        self.eval_interval = 50

        # epoch
        self.max_epoch = 100

        # not multi scale training
        self.multiscale_range = 0

        # apply mixup aug or not
        self.enable_mixup = False

ここまでできたら、以下のコマンドで学習を実行できます。

python tools/train.py -f .\custom_yolox_s.py -d 1 -b 8 --fp16 -o -c .\weights\yolox_s.pth

ひとまず、100epoch 経過するまで待ちます。

最後のログを見ると、以下のように出力されています。

per class AP:
| class   | AP     |
|:--------|:-------|
| finger  | 68.787 |
per class AR:
| class   | AR     |
|:--------|:-------|
| finger  | 75.000 |

2023-10-29 06:31:15 | INFO     | yolox.core.trainer:364 - Save weights to ./YOLOX_outputs\custom_yolox_s
2023-10-29 06:31:15 | INFO     | yolox.core.trainer:364 - Save weights to ./YOLOX_outputs\custom_yolox_s
2023-10-29 06:31:15 | INFO     | yolox.core.trainer:195 - Training of experiment is done and the best AP is 78.33

best AP 78.33 となりました。
学習した重みは YOLOX_outputs/custom_yolox_s に保存されています。
このフォルダの中の best_ckpt.pth を使用してください。

おわりに

今回はここまでとします。

次回は、カスタムデータで学習した重みで YOLOX を推論する方法について説明します。

コメント

タイトルとURLをコピーしました