tensorflow.js 手写数字识别

digit.js

import * as tf from '@tensorflow/tfjs'
import * as tfvis from '@tensorflow/tfjs-vis'
import {MnistData} from "./data/data";

window.onload = async () => {
const data = new MnistData()
await data.load()
const examples = data.nextTrainBatch(20)
const surface = tfvis.visor().surface({name: '输入示例', tab: '显示图片'})
for (let i = 0; i < 20; i++) {
const imageTensor = tf.tidy(() => {
return examples.xs.slice([i, 0], [1, 784]).reshape([28, 28, 1])
})
const canvas = document.createElement('canvas')
canvas.width = 28
canvas.height = 28
canvas.style = 'margin: 4px'
await tf.browser.toPixels(imageTensor, canvas);
surface.drawArea.appendChild(canvas)
}

const model = tf.sequential()
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}))

model.add(tf.layers.maxPool2d({
poolSize: [2, 2], strides: [2, 2]
}))

model.add(tf.layers.conv2d({
kernelSize: 5, filters: 16, strides: 1, activation: 'relu', kernelInitializer: 'varianceScaling'
}))

model.add(tf.layers.maxPool2d({
poolSize: [2, 2], strides: [2, 2]
}))

model.add(tf.layers.flatten())

model.add(tf.layers.dense({
units: 10, activation: 'softmax', kernelInitializer: 'varianceScaling'
}))

model.compile({
loss: 'categoricalCrossentropy', optimizer: tf.train.adam(), metrics: 'accuracy'
})

const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(6000)
return [d.xs.reshape([6000, 28, 28, 1]), d.labels]
})

const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(600)
return [d.xs.reshape([600, 28, 28, 1]), d.labels]
})

await model.fit(trainXs, trainYs, {
validationData: [testXs, testYs],
epochs: 50,
callbacks: tfvis.show.fitCallbacks({name: '训练效果'},
['loss', 'val_loss', 'acc', 'val_acc'],
{callbacks: ['onEpochEnd']})
})
const canvas = document.querySelector('canvas')

canvas.addEventListener('mousemove', (e) => {
if (e.buttons === 1) {
const ctx = canvas.getContext('2d')
ctx.fillStyle = 'rgb(255,255,255)'
ctx.fillRect(e.offsetX, e.offsetY, 25, 25)
}
})

window.clear = () => {
const ctx = canvas.getContext('2d')
ctx.fillStyle = 'rgb(0,0,0)'
ctx.fillRect(0, 0, 300, 300)
}

clear()

window.predict = () => {
const input = tf.tidy(() => {
return tf.image.resizeBilinear(tf.browser.fromPixels(canvas),
[28, 28], true).slice([0, 0, 0], [28, 28, 1])
.toFloat().div(255).reshape([1, 28, 28, 1])
})
const pred = model.predict(input).argMax(1)
alert(`预测结果为 ${pred.dataSync()[0]}`)
}
}

// 卷积神经网络
// 卷积层 池化层 全连接层
// 卷积层提取特征的 Image Kernels
// 卷积层有权重需要训练,卷积核就是权重
// 层化层用于提取最强特征,扩大感受野,减少计算量, 池化层没有权重训练
// 全连接层,作为输出层,作为分类器,全连接层有权重需要训练
// 设置损失函数与优化器
// 准备训练集和验证集

index.html
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport"
content="width=device-width, user-scalable=no, initial-scale=1.0, maximum-scale=1.0, minimum-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<title>Document</title>
</head>
<body>
<canvas width="300" height="300" style="border: 2px solid #666;"></canvas>
<br>
<button onclick="window.clear();" style="margin: 4px;">清除</button>
<button onclick="window.predict();" style="margin: 4px;">预测</button>
</body>
</html>
posted @ 2024-11-03 00:12  hotMemo  阅读(6)  评论(0编辑  收藏  举报