您的位置:首页 > 移动开发 > Objective-C

Tensorflow训练自己的Object Detection模型并进行目标检测

2017-12-08 15:34 3091 查看
准备工作

项目目录概览

准备数据集和相关文件

制作TFRecord

修改配置文件

修改trainpy文件

tensorboard查看运行情况

生成pb文件

摄像头目标检测

0.准备工作

安装TensorFlow: 基于win10,GPU的Tensorflow Object Detection API部署及USB摄像头目标检测

下载TensorFlow/models: https://github.com/tensorflow/models

下载VOC2007数据集: voc2007数据集的下载和解压

下载预训练模型: ssd_inception_v2_coco_11_06_2017.tar.gz

1.项目目录概览



图1 object detection项目目录

2.准备数据集和相关文件

下载VOC2007数据集,解压放到dataset目录下,如图1。

复制
models\research\object_detection\dataset_tools\create_pascal_tf_record.py
文件到dataset目录下,如图1。

复制
models\research\object_detection\data\pascal_label_map.pbtxt
文件到dataset目录下,如图1。

解压预训练模型
ssd_inception_v2_coco_11_06_2017.tar.gz
文件到models目录下,如图1。

复制
models\research\object_detection\samples\configs\ssd_inception_v2_coco.config
到项目根目录下。

复制
models\research\object_detection
目录下的
train.py、eval.py和export_inference_graph.py
文件到项目根目录下。

复制基于win10,GPU的Tensorflow Object Detection API部署及USB摄像头目标检测文档中的
webcamdetect.py
文件到项目根目录下。

复制
models\research\object_detection
文件夹下的utils目录到项目根目录下,create_pascal_tf_record.py会用到。

3.制作TFRecord

create_pascal_tf_record.py第160行 :

examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt')


为:

examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main/' + FLAGS.set + '.txt')


运行如下指令:

python dataset/create_pascal_tf_record.py \
--data_dir=dataset/VOCtrainval_06-Nov-2007/VOCdevkit \
--year=VOC2007 \
--set=train \
--output_path=record/pascal_train.record

python dataset/create_pascal_tf_record.py \
--data_dir=dataset/VOCtrainval_06-Nov-2007/VOCdevkit \
--year=VOC2007 \
--set=val \
--output_path=record/pascal_val.record


在record文件夹下生成
pascal_train.record、pascal_val.record
文件,如图1。

4.修改配置文件<

修改
ssd_inception_v2_coco.config
的关键语句:

...
model {
ssd {
num_classes: 20
...
train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 10000
decay_factor: 0.95
}
...
num_steps: 20000
...
train_input_reader: {
tf_record_input_reader {
input_path: "record/pascal_train.record"
}
label_map_path: "dataset/pascal_label_map.pbtxt"
}

eval_config: {
num_examples: 4952
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}

eval_input_reader: {
tf_record_input_reader {
input_path: "record/pascal_val.record"
}
label_map_path: "dataset/pascal_label_map.pbtxt"
shuffle: false
num_readers: 1
num_epochs: 1
}


5.修改train.py文件

去除警告:
Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


如果GPU内存不够大,务必使用CPU clones

flags.DEFINE_boolean('clone_on_cpu', True,
'Force clones to be deployed on CPU.  Note that even if '
'set to False (allowing ops to run on gpu), some ops may '
'still be run on the CPU if they have no GPU kernel.')


训练模型输出文件夹:

flags.DEFINE_string('train_dir', 'train',
'Directory to save the checkpoints and training summaries.')


设置pipeline_config_path:

flags.DEFINE_string('pipeline_config_path', 'ssd_inception_v2_coco.config',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')


训练:

项目根目录下执行:

python train.py --logtostderr


6.tensorboard查看运行情况

项目根目录下执行:

tensorboard --logdir=train


7.生成pb文件

将train文件夹下的如下文件复制到pb文件夹下,并去除ckpt后面的“-数字”,checkpoint文件内相应也要改:

checkpoint
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta


在项目根目录下执行:

python export_inference_graph.py \
--pipeline_config_path ssd_inception_v2_coco.config \
--trained_checkpoint_prefix pb/model.ckpt \
--output_directory pb


在pb目录下可以找到生成的pb文件:

frozen_inference_graph.pb


8.摄像头目标检测

修改webcamdetect.py文件:

PATH_TO_CKPT = 'pb/frozen_inference_graph.pb'


屏蔽:

# opener = urllib.request.URLopener()
# opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)

# tar_file = tarfile.open(MODEL_FILE)
# for file in tar_file.getmembers():
#   file_name = os.path.basename(file.name)
#   if 'frozen_inference_graph.pb' in file_name:
#     tar_file.extract(file, os.getcwd())


在项目根目录下执行:

python webcamdetect.py






参考文献:

利用TensorFlow Object Detection API 训练自己的数据集

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  TensorFlow 目标检测
相关文章推荐