您的位置:首页 > 其它

YOLOv4实战尝鲜 --- 教你从零开始训练自己的数据集(安全头盔佩戴识别检测)

2020-05-13 21:49 211 查看

目录

  • YOLOv4配置
  • 模型训练
  • 模型测试
  • 本文代码基于:https://github.com/ultralytics/yolov3
    YOLOv4理论部分我的另一篇博客请参考:YOLOv4真的来了!!论文翻译

    数据准备

    首先介绍数据集,来源于AI研习设的一个比赛,见链接:https://god.yanxishe.com/32
    之前我已经用该数据集训练过YOLOv3,感觉效果不是非常好,这次直接把之前训练YOLOv3的数据集放在

    YOLOv4_path/data/
    目录下即可。

    不过,,这样看的可能会一脸懵逼,还是介绍下数据集的准备过程吧。。

    划分数据集

    首先

    clone git
    到本地,解压,得到YOLOv4的文件目录如图:

    然后将下载下来的数据集放到

    data/
    下,文件夹命名为
    原始数据集
    data/原始数据集/
    下如图所示,分别有
    train
    文件夹(训练数据集),
    labels
    文件夹(标签集),
    test
    文件夹(测试数据,用于测试提交。)


    先划分数据集,将train数据集划分为train/valid数据集,运行

    1.split_data.py
    如下代码:

    import os
    import glob
    import json
    import shutil
    import numpy as np
    import xml.etree.ElementTree as ET
    
    img_train = './images/train/'
    img_val = './images/valid/'
    
    label_train = './labels/train/'
    label_val = './labels/valid/'
    
    allimgs = glob.glob('原始数据集/train/' + "/*.jpg")
    allimgs = np.sort(allimgs)
    np.random.seed(100)
    np.random.shuffle(allimgs)
    
    train_ratio = 0.9
    train_num = int(len(allimgs) * train_ratio)
    
    # 得到训练和验证数据集列表
    img_list_train = allimgs[:train_num]
    img_list_val = allimgs[train_num:]
    
    # 创建文件夹
    if os.path.exists(img_train):
    shutil.rmtree(img_train)
    os.mkdir(img_train)
    else:
    os.mkdir(img_train)
    
    if os.path.exists(img_val):
    shutil.rmtree(img_val)
    os.mkdir(img_val)
    else:
    os.mkdir(img_val)
    
    if os.path.exists(label_train):
    shutil.rmtree(label_train)
    os.mkdir(label_train)
    else:
    os.mkdir(label_train)
    
    if os.path.exists(label_val):
    shutil.rmtree(label_val)
    os.mkdir(label_val)
    else:
    os.mkdir(label_val)
    
    # 移动val数据到指定位置
    for i in img_list_val:
    img_id = i.split('.')[0].split('/')[2]
    print(img_id)
    # jpg
    shutil.copy(i, img_val + img_id + '.jpg')
    # xml
    shutil.copy('原始数据集/label/' + 'new_' + img_id + '.xml', label_val + img_id + '.xml')
    
    # 移动train数据到指定位置
    for i in img_list_train:
    img_id = i.split('.')[0].split('/')[2]
    print(img_id)
    # jpg
    shutil.copy(i, img_train + img_id + '.jpg')
    # xml
    shutil.copy('原始数据集/label/' + 'new_' + img_id + '.xml', label_train + img_id + '.xml')

    至此,得到的后面用于训练验证的数据集images和对应labels文件夹。见如下文件夹

    xml2yolo

    yolo的标签格式是txt格式,所以我们还需将xml标签转为txt格式。先新建txt_train、txt_valid文件夹,用于存放转换后的txt标签文件。运行

    2.xml2txt.py
    如下代码:

    import xml.etree.ElementTree as ET
    import pickle
    import os
    from os import listdir, getcwd
    from os.path import join
    import bs4
    from PIL import Image
    
    classes = ["person", "hat"]  #为了获得cls id
    
    def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)
    
    def convert_annotation(image_id):
    global none_counts
    
    # 输入文件xml
    in_file = open('./labels/valid/%s.xml' % (image_id))
    # 输出label txt
    out_file = open('./labels/txt_valid/%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    print(image_id)
    
    if size == None:
    print('{}不存在size字段'.format(image_id))
    # 第一个处理方法
    img = Image.open('./VOCPerson/JPEGImages/' + image_id + '.jpg')
    w, h = img.size  # 大小/尺寸
    print('{}.xml缺失size字段, 读取{}图片得到对应 w:{} h:{}'.format(image_id, image_id, w, h))
    # # 第二种处理方法
    # # 移除xml
    # os.remove('./VOCPerson/Annotations/' + image_id + '.xml')
    # # 移除上面被移除掉xml对应的jpg
    # os.remove('./VOCPerson/JPEGImages/' + image_id + '.jpg')
    
    none_counts += 1
    else:
    
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
    cls = obj.find('name').text
    if cls not in classes:
    continue
    cls_id = classes.index(cls)
    xmlbox = obj.find('bndbox')
    b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
    float(xmlbox.find('ymax').text))
    bb = convert((w, h), b)
    out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    
    if __name__=='__main__':
    xml_count = 0
    none_counts = 0
    list_file = os.listdir('./labels/valid/')
    for file in list_file:
    print(file)
    # image_id = file.replace('.xml', '')
    image_id = file.split('.')[0]
    convert_annotation(image_id)
    xml_count = xml_count + 1
    print('没有size字段的xml文件数目:{}'.format(none_counts))
    print('总xml个数是 {}'.format(xml_count))

    得到yolo训练的标签文件。
    注意1:这里需要修改文件路径list_file,将train,valid数据集的xml全部转换为txt格式放到对应label文件夹下
    注意2:转换完成后将原来的train,valid文件夹删除并将txt_train,txt_valid文件夹重命名看那个为train,valid。

    生成train/valid.txt

    运行如下代码

    3.creattxt.py
    ,生成train.txt文件和valid.txt文件。train.txt文件内容是train数据集的路径,valid.txt文件内容是valid的数据集路径。
    train.txt
    部分内容如下:

    data/images/train/1.jpg
    data/images/train/10.jpg
    data/images/train/1000.jpg
    data/images/train/1001.jpg
    data/images/train/1002.jpg
    data/images/train/1005.jpg
    data/images/train/1006.jpg
    data/images/train/1007.jpg
    data/images/train/1008.jpg
    data/images/train/1009.jpg

    3.creattxt.py
    如下:

    # 根据训练数据集和验证数据集persontrain.txt and personvalid.txt
    import os, random, shutil
    
    trainDir = 'images/train/'
    validDir = 'images/valid/'
    
    train_pathDir = os.listdir(trainDir)  # 取图片的原始路径
    print('训练集图片数目: {}'.format(len(train_pathDir)))
    
    valid_pathDir = os.listdir(validDir)  # 取图片的原始路径
    print('验证集图片数目: {}'.format(len(valid_pathDir)))
    
    # 删除persontrain.txt and personvalid.txt
    if(os.path.exists('train.txt')):
    os.remove('train.txt')
    print('删除train.txt成功')
    
    if(os.path.exists('valid.txt')):
    os.remove('valid.txt')
    print('删除valid.txt成功')
    
    def text_save(root, filename, data):  # filename为写入CSV文件的路径,data为要写入数据列表.
    file = open(filename, 'a')
    for i in range(len(data)):
    s = str(data[i]).replace('[', '').replace(']', '')  # 去除[],这两行按数据不同,可以选择
    s = 'data/' + root + s.replace("'", '').replace(',', '') + '\n'  # 去除单引号,逗号,每行末尾追加换行符
    file.write(s)
    file.close()
    print("保存文件成功")
    
    if __name__ == '__main__':
    text_save(trainDir, './train.txt', train_pathDir)
    text_save(validDir, './valid.txt', valid_pathDir)
    print('train.txt 有 {} 行'.format(len([i for i in open('./train.txt', 'r')])))
    print('valid.txt 有 {} 行'.format(len([i for i in open('./valid.txt', 'r')])))

    至此,数据准备工作完成。

    YOLOv4配置

    需要修改配置的地方主要有三处:

    cfg
    文件,
    data
    文件,
    names
    文件。

    cfg文件修改

    这里推荐使用

    yolov4-relu.cfg
    文件,因为我用了
    yolov4.cfg
    训练太吃显存,2080TI的机子
    batchsize
    设置为2也会爆显存。其主要原因是
    mish
    函数太占显存了。。。
    yolov4-relu.cfg
    mish
    函数替代为
    relu
    函数,大大降低显存使用!

    修改其中一处如下,一共有三处,分别对应三个detect header:

    [convolutional]
    size=1
    stride=1
    pad=1
    filters=21  # 修改为 (class数目 + 4 +1) × 3 = 21
    activation=linear
    
    [yolo]
    mask = 6,7,8
    anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
    classes=2   #修改为类别数目,这里是2
    num=9
    jitter=.3
    ignore_thresh = .7
    truth_thresh = 1
    random=1
    scale_x_y = 1.05
    iou_thresh=0.213
    cls_normalizer=1.0
    iou_normalizer=0.07
    iou_loss=ciou
    nms_kind=greedynms
    beta_nms=0.6

    data文件修改

    在data/下新建hat.data文件,填入以下内容即可:

    classes= 2
    train=data/train.txt
    valid=data/valid.txt
    names=data/hat.names

    names文件修改

    在data/下新建hat.names文件,填入以下类别名字即可:

    person
    hat

    以上,完成配置修改。

    模型训练

    晚上以上就可以正常开始训练了,

    train.py
    文件内容配置修改如下:

    if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=300)  # 500200 batches at bs 16, 117263 COCO images = 273 epochs
    parser.add_argument('--batch-size', type=int, default=6)  # effective bs = batch_size * accumulate = 16 * 4 = 64
    parser.add_argument('--cfg', type=str, default='cfg/yolov4-relu-hat.cfg', help='*.cfg path')
    parser.add_argument('--data', type=str, default='data/hat.data', help='*.data path')
    parser.add_argument('--multi-scale', action='store_true', help='adjust (67%% - 150%%) img_size every 10 batches')
    parser.add_argument('--img-size', nargs='+', type=int, default=[320, 640], help='[min_train, max-train, test]')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    parser.add_argument('--notest', action='store_true', help='only test final epoch')
    parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
    parser.add_argument('--weights', type=str, default='weights/yolov4.weights', help='initial weights path')
    parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
    parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--adam', action='store_true', help='use adam optimizer')
    parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
    opt = parser.parse_args()
    opt.weights = last if opt.resume else opt.weights
    check_git_status()
    print(opt)

    训练日志:

    电费太贵了!我训练100多点epoch就终止了,验证集

    map
    达到
    89%
    (之前用YOLOv3验证集
    map
    不到
    60%
    。。。。),相当可以。

    模型测试

    放两张

    YOLOv4
    的测试图片如下:


    可以看到
    YOLOv4
    效果还是挺可以。此外我计算YOLOv4的实际推断速度发现,
    YOLOv4
    也大大快于
    YOLOv3
    ,相比
    YOLOv3
    又快又准!!真香!!
    不愧是怼了一推佐料的大杀器。。。

    当然,YOLOv4也有错检漏检的情况,如下图,毕竟是一阶段检测器,但是在精度和准确度的权衡上,YOLOv4是相当可以了!

    ————————————————————————————————————————————————
    更新github链接如下:
    https://github.com/cendelian/YOLOv4-Hat-detection

    ————————————————2020.8.6————————————————————————————
    share数据(包括原始数据)和训练权重如下:

    链接:https://pan.baidu.com/s/1QaoyxeI5_95R2pbR9Elhaw
    提取码:fvnj

    OVER!

    内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
    标签: 
    相关文章推荐