您的位置:首页 > Web前端 > JavaScript

用tensorflow.js把手写字识别项目部署到网上(三)

2020-06-03 05:03 417 查看

script.js处理训练和推理相关过程代码,而且还包括训练和推理的可视化

import {MnistData} from './data.js';
var canvas, ctx, saveButton, clearButton;
var pos = {x:0, y:0};
var rawImage;
var model;

function getModel() {
model = tf.sequential();

model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.conv2d({filters: 16, kernelSize: 3, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']});

return model;
}

async function train(model, data) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = { name: 'Model Training', styles: { height: '640px' } };
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 5500;
const TEST_DATA_SIZE = 1000;

const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});

const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(TEST_DATA_SIZE);
return [
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
d.labels
];
});

return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
epochs: 20,
shuffle: true,
callbacks: fitCallbacks
});
}

function setPosition(e){
pos.x = e.clientX-100;
pos.y = e.clientY-100;
}

function draw(e) {
if(e.buttons!=1) return;
ctx.beginPath();
ctx.lineWidth = 24;
ctx.lineCap = 'round';
ctx.strokeStyle = 'white';
ctx.moveTo(pos.x, pos.y);
setPosition(e);
ctx.lineTo(pos.x, pos.y);
ctx.stroke();
rawImage.src = canvas.toDataURL('image/png');
}

function erase() {
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);
}

function save() {
var raw = tf.browser.fromPixels(rawImage,1);
var resized = tf.image.resizeBilinear(raw, [28,28]);
var tensor = resized.expandDims(0);
var prediction = model.predict(tensor);
var pIndex = tf.argMax(prediction, 1).dataSync();

alert(pIndex);
}

function init() {
canvas = document.getElementById('canvas');
rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);
canvas.addEventListener("mousemove", draw);
canvas.addEventListener("mousedown", setPosition);
canvas.addEventListener("mouseenter", setPosition);
saveButton = document.getElementById('sb');
saveButton.addEventListener("click", save);
clearButton = document.getElementById('cb');
clearButton.addEventListener("click", erase);
}

async function run() {
const data = new MnistData();
await data.load();
const model = getModel();
tfvis.show.modelSummary({name: 'Model Architecture'}, model);
await train(model, data);
init();
alert("Training is done, try classifying your handwriting!");
}

document.addEventListener('DOMContentLoaded', run);
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: