Tensorflow实现部分参数梯度更新操作
2020-02-13 11:29
1331 查看
在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。
本文主要介绍,使用tensorflow部分更新模型参数的方法。
1. 根据Variable scope剔除需要固定参数的变量
def get_variable_via_scope(scope_lst): vars = [] for sc in scope_lst: sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope) vars.extend(sc_variable) return vars trainable_vars = tf.trainable_variables() no_change_scope = ['your_unchange_scope_name'] no_change_vars = get_variable_via_scope(no_change_scope) for v in no_change_vars: trainable_vars.remove(v) grads, _ = tf.gradients(loss, trainable_vars) optimizer = tf.train.AdamOptimizer(lr) train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)
2. 使用tf.stop_gradient()函数
在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数
3. 一个矩阵中部分行或列参数更新
如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。
import tensorflow as tf import numpy as np def entry_stop_gradients(target, mask): mask_h = tf.abs(mask-1) return tf.stop_gradient(mask_h * target) + mask * target mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1]) mask_h = np.abs(mask-1) emb = tf.constant(np.ones([10, 5])) matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1)) parm = np.random.randn(5, 1) t_parm = tf.constant(parm) loss = tf.reduce_sum(tf.matmul(matrix, t_parm)) grad1 = tf.gradients(loss, emb) grad2 = tf.gradients(loss, matrix) print matrix with tf.Session() as sess: print sess.run(loss) print sess.run([grad1, grad2])
以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考
您可能感兴趣的文章:
相关文章推荐
- 通过操作指针,与指针做函数参数'实现字串在主串中出现的次数,然后将出现的部分按照要求进行替换
- tensorflow中optimizer如何实现神经网络的权重,偏移等系数的更新和梯度计算
- Repeater控件实现编辑、更新、删除等操作示例代码
- PHP,操作多个用户,多个线程的session,实现用户登陆状态session值的自动更新
- 数据结构面试之五—二叉树的常见操作(递归实现部分
- Repeater控件实现编辑、更新、删除操作
- 稳扎稳打Silverlight(57) - 4.0通信之WCF RIA Services: 概述, 通过 DomainDataSource 实现数据的添加、查询、更新和删除操作
- 稳扎稳打Silverlight(57) - 4.0通信之WCF RIA Services: 概述, 通过 DomainDataSource 实现数据的添加、查询、更新和删除操作
- 二叉树的部分操作实现
- JS操作DOM节点实现网页更新
- 使用PrepareStatement接口,实现数据表的更新操作
- sql语句,oracal更新操作传入参数为对象,判断对象中的字段是否有值,如果有就更新,如果没有就不更新
- Java实现二叉树的创建和遍历操作(有更新)
- [TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)
- MySQL 如何实现插入时如果不存在则插入,如果存在则更新的操作?
- 实现属于自己的TensorFlow(二) - 梯度计算与反向传播
- Repeater控件实现编辑、更新、删除操作
- MongoDB——更新操作(Update)c#实现
- TensorFlow 用 tf.nn.max_pool 实现最大池化操作
- shell脚本依据参数执行操作的实现