您的位置:首页 > 其它

深度学习【22】Mxnet多任务(multi-task)训练

2017-10-27 09:10 267 查看
github上有两个版本的多任务训练分别是:

1、https://github.com/miraclewkf/multi-task-MXNet

2、mxnet自带的例子

第一个由于其数据迭代器是Image,可能会比较慢。

第二个的例子是mnist,需要自己修改数据迭代器。

这里主要记录基于ImageRecordIter迭代器的多任务训练。

1、数据制作

需要自己生成*.lst文件,里面内容如下:

index   task1标签   task2标签    task3标签    图片路径(这行是说明,不需要写入,每一列用\t隔开)
2476    0.000000    0.000000    1.000000    photo_02_8159/00022552.jpg
7623    3.000000    2.000000    2.000000    photo_03_7397/00029434.jpg
14149   0.000000    0.000000    1.000000    photo_05_15560/00060839.jpg
6874    3.000000    1.000000    2.000000    photo_03_7397/00028414.jpg
6048    0.000000    0.000000    1.000000    photo_02_8159/00027259.jpg
14479   3.000000    3.000000    2.000000    photo_05_15560/00065068.jpg
10429   2.000000    0.000000    1.000000    photo_04_15224/00040186.jpg
6949    3.000000    0.000000    1.000000    photo_03_7397/00028521.jpg
81      3.000000    3.000000    2.000000    photo_01_19992/00002536.jpg
11725   2.000000    0.000000    1.000000    photo_05_15560/00051778.jpg
1517    2.000000    3.000000    2.000000    photo_02_8159/00021245.jpg


具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。

生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。

2、修改模型结构

添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):

fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别
fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别
fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别
#分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到
s1 = mx.symbol.SoftmaxOutput(data=fc1, name='softmax1')
s2 = mx.symbol.SoftmaxOutput(data=fc2, name='softmax2')
s3 = mx.symbol.SoftmaxOutput(data=fc3, name='softmax3')
return  mx.symbol.Group([s1,s2,s3])


3、编写ImageRecordIter选项

train = mx.io.ImageRecordIter(
path_imgrec='/path/to/train.rec',
label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],#label名称,于softmax名称一样,后面要加入_label
label_width=3, #重要,需要设置label宽度为3,因为有3个任务
data_shape=[3,224,224],
batch_size=64
)

val = mx.io.ImageRecordIter(
path_imgrec='/path/to/val.rec',
label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],
label_width=3,
batch_size=64,
data_shape=[3,224,224],
)


4、定义多任务训练迭代器

class MultiTask_iter(mx.io.DataIter):
def __init__(self, data_iter):
super(MultiTask_iter,self).__init__('multitask_iter')
self.data_iter = data_iter
self.batch_size = self.data_iter.batch_size

@property
def provide_data(self):
return self.data_iter.provide_data

@property
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
# the name of the label if corresponding to the model you define in get_fine_tune_model() function
return [('softmax1_label', [provide_label[1][0]]),#需要注意的地方
('softmax2_label', [provide_label[1][0]]),
('softmax3_label', [provide_label[1][0]])]

def hard_reset(self):
self.data_iter.hard_reset()

def reset(self):
self.data_iter.reset()

def next(self):
batch = self.data_iter.next()
#需要注意的地方
label = batch.label[0]
ll = label.asnumpy()
label1 = mx.nd.array(ll[:,0]).astype('float32')
label2 = mx.nd.array(ll[:,1]).astype('float32')
label3 = mx.nd.array(ll[:,2]).astype('float32')
# we set task 2 as: if label>0 or not

return mx.io.DataBatch(data=batch.data, label=[label1,label2,label3], \
pad=batch.pad, index=batch.index)


5、定义正确率计算方法

class Multi_Accuracy(mx.metric.EvalMetric):
"""Calculate accuracies of multi label"""

def __init__(self, num=None):
super(Multi_Accuracy, self).__init__('multi-accuracy', num)

def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)

if self.num is not None:
assert len(labels) == self.num

for i in range(len(labels)):
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
label = labels[i].asnumpy().astype('int32')

mx.metric.check_label_shapes(label, pred_label)

if i is None:
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
else:
self.sum_metric[i] += (pred_label.flat == label.flat).sum()
self.num_inst[i] += len(pred_label.flat)


6、训练

train = MultiTask_iter(train)#调用多任务迭代器,其中train参数就是第3步的东西
val = MultiTask_iter(val)

new_sym = get_symbol(10,50,image_shape)

optimizer_params = {
'learning_rate': 0.001,
'momentum' : args.mom,
'wd' : args.wd,
}
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
model = mx.mod.Module(
context       = devs,
symbol        = new_sym,
data_names=['data'],
label_names=['softmax1_label','softmax2_label','softmax3_label']
)
saveroot = args.save_result+'/' + args.save_name
checkpoint = mx.callback.do_checkpoint(saveroot)

model.fit(train,
begin_epoch=0,
num_epoch=100000,
eval_data=val,
eval_metric=Multi_Accuracy(num=3),#需要注意的地方
optimizer='sgd',

optimizer_params=optimizer_params,

initializer=initializer,
allow_missing=True,
batch_end_callback=mx.callback.Speedometer(64, 50),

epoch_end_callback=checkpoint
)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  mxnet 多任务学习
相关文章推荐