您的位置:首页 > 其它

懒阳的深度学习日记tf2.0 mnist分类

2020-04-02 18:30 1131 查看

mnist分类

实现mnist训练 实现简单的分类任务

标题1、引用模块并下载数据

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
#print(x_train.shape)
#print(y_train.shape)
#print(x_test.shape)
#print(y_test.shape)
#print(y_train[2]) #第三个图片属于第四类
#plt.imshow(x_train[2])

训练模型和评估

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))    #将输入变为一维数据
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(units = 10,activation = 'softmax')) #设置全连接层
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

优化器介绍

模型预测

loss, accuracy = model.evaluate(x_test, y_test)
print('accuracy is {}, loss is {}'.format(accuracy, loss))
  • 点赞
  • 收藏
  • 分享
  • 文章举报
逐原之鹿_懒阳 发布了2 篇原创文章 · 获赞 0 · 访问量 23 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: