用tensorflow.js把手写字识别项目部署到网上(三)
2020-06-03 05:03
417 查看
script.js处理训练和推理相关过程代码,而且还包括训练和推理的可视化
import {MnistData} from './data.js'; var canvas, ctx, saveButton, clearButton; var pos = {x:0, y:0}; var rawImage; var model; function getModel() { model = tf.sequential(); model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'})); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.conv2d({filters: 16, kernelSize: 3, activation: 'relu'})); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.flatten()); model.add(tf.layers.dense({units: 128, activation: 'relu'})); model.add(tf.layers.dense({units: 10, activation: 'softmax'})); model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']}); return model; } async function train(model, data) { const metrics = ['loss', 'val_loss', 'acc', 'val_acc']; const container = { name: 'Model Training', styles: { height: '640px' } }; const fitCallbacks = tfvis.show.fitCallbacks(container, metrics); const BATCH_SIZE = 512; const TRAIN_DATA_SIZE = 5500; const TEST_DATA_SIZE = 1000; const [trainXs, trainYs] = tf.tidy(() => { const d = data.nextTrainBatch(TRAIN_DATA_SIZE); return [ d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]), d.labels ]; }); const [testXs, testYs] = tf.tidy(() => { const d = data.nextTestBatch(TEST_DATA_SIZE); return [ d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]), d.labels ]; }); return model.fit(trainXs, trainYs, { batchSize: BATCH_SIZE, validationData: [testXs, testYs], epochs: 20, shuffle: true, callbacks: fitCallbacks }); } function setPosition(e){ pos.x = e.clientX-100; pos.y = e.clientY-100; } function draw(e) { if(e.buttons!=1) return; ctx.beginPath(); ctx.lineWidth = 24; ctx.lineCap = 'round'; ctx.strokeStyle = 'white'; ctx.moveTo(pos.x, pos.y); setPosition(e); ctx.lineTo(pos.x, pos.y); ctx.stroke(); rawImage.src = canvas.toDataURL('image/png'); } function erase() { ctx.fillStyle = "black"; ctx.fillRect(0,0,280,280); } function save() { var raw = tf.browser.fromPixels(rawImage,1); var resized = tf.image.resizeBilinear(raw, [28,28]); var tensor = resized.expandDims(0); var prediction = model.predict(tensor); var pIndex = tf.argMax(prediction, 1).dataSync(); alert(pIndex); } function init() { canvas = document.getElementById('canvas'); rawImage = document.getElementById('canvasimg'); ctx = canvas.getContext("2d"); ctx.fillStyle = "black"; ctx.fillRect(0,0,280,280); canvas.addEventListener("mousemove", draw); canvas.addEventListener("mousedown", setPosition); canvas.addEventListener("mouseenter", setPosition); saveButton = document.getElementById('sb'); saveButton.addEventListener("click", save); clearButton = document.getElementById('cb'); clearButton.addEventListener("click", erase); } async function run() { const data = new MnistData(); await data.load(); const model = getModel(); tfvis.show.modelSummary({name: 'Model Architecture'}, model); await train(model, data); init(); alert("Training is done, try classifying your handwriting!"); } document.addEventListener('DOMContentLoaded', run);
相关文章推荐
- 用tensorflow.js把手写字识别项目部署到网上(四)
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
- TensorFlow MNIST 手写数字识别之过拟合
- 网上商城项目,前后端分离,springboot+vue.js,有线上部署教程
- 一个简单的TensorFlow.js的猫狗识别样例
- TensorFlow MNIST手写数字识别学习笔记(二)
- MNIST-手写数字识别-TensorFlow&&Pytorch
- TensorFlow—mnist手写识别
- TensorFlow在MNIST中的应用 识别手写数字(OpenCV+TensorFlow+CNN)
- [置顶] 【tensorflow CNN】构建cnn网络,识别mnist手写数字识别
- Tensorflow MNIST 手写识别
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
- python tensorflow 使用minist数据集实现手写数字识别
- TensorFlow MNIST手写数字识别学习笔记(一)
- TensorFlow CNN 以库函数的方式实现MNIST手写识别
- 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别
- 100天搞定机器学习|day39 Tensorflow Keras手写数字识别
- Android+TensorFlow+CNN+MNIST实现手写数字识别
- tensorflow mnist数据集手写字识别
- Anaconda安装python+tensorflow+opencv+dlib(人脸识别项目)的安装