TensorFlow2.0基本模型和训练框架
2020-02-04 04:25
302 查看
刚刚学习tensorflow2.0版本,总结了一下tensorflow2.0应用的基本框架,希望能帮到以后跟我一样刚刚接触的萌新。
在tensorflow2.0中,首先我们要设计一个自己的模型,所以要创建一个class类,这里我们用一个简单的CNN手写体识别网络来举例。在创建我们自己的model类时,首先要继承tensorflow为我们写好的模块父类,tf.keras.Model,并且至少重写其中的两个函数,init()和call()。
init函数是模块的初始化函数,在模块被创建时运行一次,我们把要实现的各个网络层在这里命名。在例子中我们共初始化了两个卷积层,两个池化层,两个线性层,一个flatten层和一个dropout层。call()函数会在我们调用模块时运行,例如我们实例化了一个cnn = CNN(),此时y = cnn(x)等价于y = cnn.call(x)。在这一函数中我们进行网络的搭建。至此模型的搭建就算结束了。
class CNN(tf.keras.Model): def __init__(self): super().__init__() # 继承父类的init函数,这里需要注意,在python2中需要写成super(CNN, self).__init__() self.conv1 = tf.keras.layers.Conv2D( filters=32, kernel_size = [5,5], activation=tf.nn.relu, padding='same' ) self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2) self.conv2 = tf.keras.layers.Conv2D( filters=64, kernel_size=[5,5], activation=tf.nn.relu, padding='same' ) self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2) self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,)) self.drop1 = tf.keras.layers.Dropout(0.5) self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(units=10, activation=tf.nn.softmax) @tf.function def call(self, inputs): x = self.conv1(inputs) x = self.pool1(x) x = self.conv2(x) x = self.pool2(x) x = self.flatten(x) x = self.drop1(x) x = self.dense1(x) outputs = self.dense2(x) return outputs
在训练模型时,最简单的可以分为5步。
- 运用现在的模型生成预测,即y_pred
- 通过对比y_pred和y_true生成模型的loss
- 将一个batch内的loss压缩成一个实数
- 获得loss关于模型参数的梯度
- 通过获得的梯度优化模型
相应的python代码自然也就是5行。在那之前,我们要实例化一个优化器optimizer,它会帮我们根据梯度自动优化模型参数。相应代码如下:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) # 优化器类型根据需求自行选择 for index in range(epoch_num): with tf.GradientTape() as tape: y_pred = cnn(x) # 对应步骤1 loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred) # 对应步骤2,loss模型根据需求自行选择 loss = tf.reduce_mean(loss) # 对应步骤3 print('epoch: %d, loss: %f' %(index, loss)) # 打印出每一次训练的loss grads = tape.gradient(loss, model.variables) # 对应步骤4 optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables)) # 对应步骤5
至此,一个简单模型的搭建和预测步骤就到此结束了。复杂的结构也可以参考这一简单框架。当我们需要应用模型时,只需要调用y_pred = cnn(x)即可。
- 点赞
- 收藏
- 分享
- 文章举报
相关文章推荐
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow2.0教程2:使用keras训练模型
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TF:基于tensorflow框架利用python脚本下将YoloV3训练好的.ckpt模型文件转换为推理时采用的.pb文件
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载