您的位置:首页 > 其它

tensorflow学习笔记六:保存和加载训练模型

2017-03-13 22:02 841 查看
对于机器学习,尤其是深度学习DL的算法,模型训练可能很耗时,几个小时或者几天,所以如果是测试模块出了问题,每次都要重新运行就显得很浪费时间,所以如果训练部分没有问题,那么可以直接将训练的模型保存起来,然后下次运行直接加载模型,然后进行测试很方便。

在tensorflow中保存(save)和加载(restore)模型的类是tf.train.Saver(),其中变量保存的是key-value,不传参数默认是全部变量。 



保存模型使用的是save函数,先创建一个saver对象, 


 

保存模型如下:
import tensorflow as tf
"""
声明variable和op
初始化op声明
"""
#创建saver对象,它添加了一些op用来save和restore模型参数
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init_op)
#训练模型过程
#使用saver提供的简便方法去调用 save op
saver.save(sess, "save_path/file_name.ckpt")


加载模型使用的是restore函数,先创建一个saver对象, 


 

恢复模型如下:
import tensorflow as tf
"""
声明variable和op
初始化op声明
"""
#创建saver 对象
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init_op)#可以执行或不执行,restore的值会override初始值
saver.restore(sess, "save_path/file_name.ckpt")
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐