您的位置:首页 > 其它

在tensorflow上使用FCN训练自己的数据

2018-12-21 17:38 603 查看

在tensorflow上使用FCN训练自己的数据

参考文档:https://blog.csdn.net/m0_37407756/article/details/83379026 tensorflow实现FCN完成训练自己标注的数据

准备数据

  1. 在github上下载fcn的tensorflow版本实现
  2. 下载labelme对自己的图片进行标注,把生成的json文件放入你的json_file中
  3. 使用json2png.py批量将json文件转化为可以进行训练的png格式图片,此时生成的文件夹下每个json文件夹对应img.png,lable.png,label_viz.png三张图片。img.png作为输入,label.png作为标注图像(显示为全黑)。程序中将生成的png图片转化成8位的图片存储,此时label.png中像素实际以0,1,2…来分割图像,可以将灰度值放大来进行验证。
import argparse
import json
import os
import os.path as osp
import warnings

import numpy as np
import PIL.Image

from labelme import utils

def main():
'''
usage: python json2png.py json_file
'''
parser = argparse.ArgumentParser()
parser.add_argument('json_file')
parser.add_argument('-o', '--out', default=None)
args = parser.parse_args()

json_file = args.json_file

list = os.listdir(json_file)
for i in range(0, len(list)):
path = os.path.join(json_file, list[i])
if os.path.isfile(path):
data = json.load(open(path))
img = utils.img_b64_to_arr(data['imageData'])
lbl, lbl_names = utils.labelme_shapes_to_label(img.shape, data['shapes'])

captions = ['%d: %s' % (l, name) for l, name in enumerate(lbl_names)]
lbl_viz = utils.draw_label(lbl, img, captions)
out_dir = osp.basename(list[i]).replace('.', '_')
# out_dir = osp.join(osp.dirname(list[i]), out_dir)
out_dir = osp.join('./png', out_dir)
if not osp.exists(out_dir):
os.mkdir(out_dir)

PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
lbl = PIL.Image.fromarray(np.uint8(lbl))
lbl.save(osp.join(out_dir, 'label.png'))
# PIL.Image.fromarray(lbl).save(osp.join(out_dir, 'label.png'))
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, 'label_viz.png'))
print('Saved to: %s' % out_dir)

if __name__ == '__main__':
main()
  1. 将生成的img和label按原数据集http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip中MIT_SceneParsing/ADEChallengeData2016文件夹中类似存储。下面是批量存入annotations/training和images/training中的代码,validation中的图片可以从training中剪切,也可以修改下面的代码。
import os
import shutil
filedir = './png'
outdir_a = './FCN.tensorflow-master/Data_zoo/power/myData/annotations'
outdir_i = './FCN.tensorflow-master/Data_zoo/power/myData/images'
filelists = os.listdir(filedir)
filelists.sort()
for i,filename in enumerate(filelists):
print(filename)
filename = os.path.join(filedir, filename)

shutil.move(os.path.join(filename,'label.png'),os.path.join(outdir_a,'training'))
name1 = 'train-' + str(i+1).zfill(3) + '.png'
os.rename(os.path.join(outdir_a,'training/label.png'),os.path.join(outdir_a,'training',name1))

shutil.move(os.path.join(filename,'img.png'),os.path.join(outdir_i,'training'))
name2 = 'train-' + str(i+1).zfill(3) + '.png'
os.rename(os.path.join(outdir_i,'training/img.png'),os.path.join(outdir_i,'training',name2))

训练

  1. 将FCN.py中NUM_OF_CLASSESS改为自己训练数据的类别数,注意加上背景。vgg-19预训练模型在程序运行中会进行下载,也可以在训练前下载好http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat放在Model_zoo文件夹中。将data_dir改为自己的数据集所在文件,训练时mode为train。
  2. read_MITSceneParsingData.py中将SceneParsing_folder令为自己的文件夹
    SceneParsing_folder = 'myData'
    ,而不是下载。
  3. 默认batch_size大小是2,迭代次数为10000次
  4. 损失函数可视化:
tensorboard --logdir ./logs/train

测试

在FCN.py中将mode改为visualize,将网络生成的预测图像中灰度值不为0的点,可以在原图上对应位置将其灰度值修改为某固定值如200,就完成了可视化。FCN.py中修改如下。

elif FLAGS.mode == "visualize":
#num: the number of images to be tested which can be a single batch_size or all validation set
valid_images, valid_annotations, num = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
pred = sess.run(pred_annotation, feed_dict={image: valid_images, keep_probability: 1.0})
pred = np.squeeze(pred, axis=3)
for itr in range(num):
src_img = valid_images[itr].astype(np.uint8)
pred_img = pred[itr].astype(np.uint8)
#save images to ./logs/test_visualize
utils.save_image(src_img, FLAGS.logs_dir + 'test_visualize/', name="inp_" + str(itr))
utils.save_image(pred_img, FLAGS.logs_dir + 'test_visualize/', name="pred_" + str(itr))
for i in range(pred_img.shape[0]):
for j in range(pred_img.shape[1]):
if pred_img[i,j] != 0:
#if your source images are RGB format, you need to change three channels
src_img[i,j]=200
utils.save_image(src_img, FLAGS.logs_dir + 'test_visualize/', name="visual_" + str(itr))
print("Saved image: %d" % itr)

相关文件都上传到了我的github中。

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: