tensorflow.js示例笔记 - mnist
使用层来进行数字识别,使用tf.layers api训练模型识别MNIST数据库中的手写数字。
index.html
<html> <head> <title>MNIST</title> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="stylesheet" href="../shared/tfjs-examples.css"/> <style> #train { margin-top: 10px; } label { display: inline-block; width: 250px; padding: 6px 0 6px 0; } .canvases { display: inline-block; } .prediction-canvas { width: 100px; } .pred { font-size: 20px; line-height: 25px; width: 100px; } .pred-correct { background-color: #00cf00; } .pred-incorrect { background-color: red; } .pred-container { display: inline-block; width: 100px; margin: 10px; } #train-epochs { width: 82px; font-size: 14px; } </style> </head> <body> <div class="tfjs-example-container"> <section class="title-area"> <h1>Digit Recognizer with Layers</h1> <p class="subtitle">Train a model to recognize handwritten digits from the MNIST database using the tf.layers api. </p> </section> <section> <p class="section-head">Description</p> <p> This examples lets you train a handwritten digit recognizer using either a Convolutional Neural Network (also known as a ConvNet or CNN) or a Fully Connected Neural Network (also known as a DenseNet). </p> <p>The MNIST dataset is used as training data.</p> </section> <section> <p class="section-head">Training Parameters</p> <div> <label>Model Type:</label> <select id="model-type"> <option>ConvNet</option> <option>DenseNet</option> </select> </div> <div> <label>Number of training epochs:</label> <input id="train-epochs" type="number" value="3" /> </div> <button id="train" disabled>Train Model</button> </section> <section> <p class="section-head">Training Progress</p> <p id="status"></p> <p id="message"></p> <div id="stats"> <div class="canvases"> <label id="loss-label"></label> <div id="loss-canvas"></div> </div> <div class="canvases"> <label id="accuracy-label"></label> <div id="accuracy-canvas"></div> </div> <br /> <div class="canvases"> <div id="loss-val-canvas"></div> </div> <div class="canvases"> <div id="accuracy-val-canvas"></div> </div> </div> </section> <section> <p class="section-head">Inference Examples</p> <div id="images"></div> </section> </div> <!-- TODO(cais): Decide. DO NOT SUBMIT. --> <!-- <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> --> <script type="module" src="index.js"></script> </body> </html>
index.js
import * as tf from '@tensorflow/tfjs'; import {IMAGE_H, IMAGE_W, MnistData} from './data'; import * as ui from './ui'; /** * Creates a model consisting of only flatten, dense and dropout layers. * * The model create here has approximately the same number of parameters * (~31k) as the convnet created by `createConvModel()`, but is * expected to show a significantly worse accuracy after training, due to the * fact that it doesn't utilize the spatial information as the convnet does. * * This is for comparison with the convolutional network above. * * @returns {tf.Model} An instance of tf.Model. */ function createDenseModel() { const model = tf.sequential(); model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]})); model.add(tf.layers.dense({units: 42, activation: 'relu'})); model.add(tf.layers.dense({units: 10, activation: 'softmax'})); return model; // 该普通稠密层模型与卷积神经网络模型相比,参数数量都在32000个左右,基本维持了一个公平的形式。但在最终的损失和准 // 确率上,普通稠密层模型都比不过后者。虽然与卷积神经网络的模型相比,准确率差异只有2%左右,但错误率却是几倍。因此 // 在处理图片类业务上,卷积神经网络模型较稠密层模型有明显的优势。 } /** * Creates a convolutional neural network (Convnet) for the MNIST data. * * @returns {tf.Model} An instance of tf.Model. */ function createConvModel() { // Create a sequential neural network model. tf.sequential provides an API // for creating "stacked" models where the output from one layer is used as // the input to the next layer. const model = tf.sequential(); // The first layer of the convolutional neural network plays a dual role: // it is both the input layer of the neural network and a layer that performs // the first convolution operation on the input. It receives the 28x28 pixels // black and white images. This input layer uses 16 filters with a kernel size // of 5 pixels each. It uses a simple RELU activation function which pretty // much just looks like this: __/ // 第1层。 model.add(tf.layers.conv2d({ inputShape: [IMAGE_H, IMAGE_W, 1], // 要应用于输入数据的滑动卷积过滤器窗口的尺寸。在这里我们将kernelSize设为3,以指定方形的3*3卷积窗口。 kernelSize: 3, // 要应用于输入数据的尺寸为kernelSize的过滤器窗口数量。在这里我们将对数据应用16个过滤器。 filters: 16, activation: 'relu' })); // After the first layer we include a MaxPooling layer. This acts as a sort of // downsampling using max values in a region instead of averaging. // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks // 第2层。 model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); // Our third layer is another convolution, this time with 32 filters. // 第3层。 // 第3层和第4层这两个层是前两个层的完全重复(除了conv2d层在其过滤器配置中具有更大的值并且不具有inputShape字段)。 // 这种由卷积层和极化层组成的几乎重复的"模体"在convnets中是常见的。它在convnet中扮演着关键的角色:特征的分层提取。 // 为了理解它的含义,可以考虑一个训练过的convnet,它的任务是对图像中的动物进行分类。在convnet的早期阶段,卷积层中 // 的滤波器(即channel)可以编码诸如直线、曲线和角等低级几何特征。这些低级特征转化为更复杂的特征,如猫的眼睛、鼻子和 // 耳朵。在convnet的顶层,一个层可能有对整个cat的存在进行编码的过滤器。级别越高,表示越抽象,从像素级值中移除的特征 // 越多。但是,这些抽象的特征正是convnet任务实现良好精度所需要的,例如,在图像中时检测出猫。此外,这些特征不是手工制 // 作的,而是通过有监督的学习并以自动方式从数据中提取的。这是一个典型的有代表性的例子,我们也把它描述为层-层转换。 model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); // Max pooling again. // 第4层。 model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); // Add another convolution layer. // 第5层。 model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); // Now we flatten the output from the 2D filters into a 1D vector to prepare // it for input into our last layer. This is common practice when feeding // higher dimensional data to a final classification output layer. // 第6层。 // 第6层为扁平层。它将多维张量“压缩”为一维张量,从而保持元素的总数。在我们的例子中,形状为[3,3,32]的3D张量被展平 // 为1D张量[288](没有批次维度)。挤压操作的一个明显问题是如何对元素排序,因为原始三维空间中没有内在的顺序。答案是: // 我们对元素进行排序,这样,如果你沿着展开的一维张量中的元素向下看,看看它们的原始索引(来自三维张量)如何变化,最后 // 一个索引变化最快,倒数第二个索引变化第二快,以此类推,而第一个索引变化最慢。 // 扁平层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。扁平层不影响batch的大小。 // 扁平层中没有权重,它只是将其输入展开为一个长数组。 model.add(tf.layers.flatten({})); // Previous flattened 1D vector will be this layer's input. // 第7层。 // 第7层和第8层我们添加了两个稠密层,为什么要添加两个而不是一个呢?原因是:添加具有非线性激活的层会增加网络的容量。 // 通常多层神经网络内,必须含有级联的线性函数和非线性函数,有助于表示能力的增强,这意味着模型容量的增大,预测精度的 // 提高。在boston-housing项目案例中,我们也做出过相同的总结。https://mapleroyals.com/forum/threads/4th-job-skill-changes.135025 // 实际上,您可以将convnet看作是由两个模型堆叠在一起: // 1. 包含conv2d、maxPooling2d和flatten层的模型,用于从输入图像中提取视觉特征。 // 2. 一种多层感知器(MLP),具有两个密集层,以提取的特征作为输入,并基于它进行数字类预测,这就是这两个密集层的本 // 质。在深度学习中,许多模型都利用了特征提取层的这种模式,然后最终预测使用MLP。在本书的其余部分中,我们将看到更多 // 这样的例子,从音频信号分类器到自然语言处理。 model.add(tf.layers.dense({ units: 64, activation: 'relu' })); // Our last layer is a dense layer which has 10 output units, one for each // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually // represent numbers, but it's the same idea if you had classes that // represented other entities like dogs and cats (two output classes: 0, 1). // We use the softmax function as the activation for the output layer as it // creates a probability distribution over our 10 classes so their output // values sum to 1. // 第8层。 // 这一层将输出10个值在0-1之间的数字元素,表示对0-9这10个数字的预测概率,它们的和为1,最大的概率对应的数值为预测的倾向值。 model.add(tf.layers.dense({ units: 10, activation: 'softmax' })); return model; } /** * This callback type is used by the `train` function for insertion into the * model.fit callback loop. * * @callback onIterationCallback * @param {string} eventType Selector for which type of event to fire on. * @param {number} batchOrEpochNumber The current epoch / batch number * @param {tf.Logs} logs Logs to append to */ /** * Compile and train the given model. * * @param {tf.Model} model The model to train. * @param {onIterationCallback} onIteration A callback to execute every 10 batches & epoch end. */ async function train(model, onIteration) { ui.logStatus('Training model...'); // We compile our model by specifying an optimizer, a loss function, and a // list of metrics that we will use for model evaluation. Here we're using a // categorical crossentropy loss, the standard choice for a multi-class // classification problem like MNIST digits. // The categorical crossentropy loss is differentiable and hence makes // model training possible. But it is not amenable to easy interpretation // by a human. This is why we include a "metric", namely accuracy, which is // simply a measure of how many of the examples are classified correctly. // This metric is not differentiable and hence cannot be used as the loss // function of the model. model.compile({ // Now that we've defined our model, we will define our optimizer. The // optimizer will be used to optimize our model's weight values during // training so that we can decrease our training loss and increase our // classification accuracy. // We are using rmsprop as our optimizer. // An optimizer is an iterative method for minimizing a loss function. // It tries to find the minimum of our loss function with respect to the // model's weight parameters. optimizer: 'rmsprop', // 损失函数categoricalCrossentropy分类交叉熵,适用于诸如MNIST之类的多分类问题。在iris分类项目案例中,我 // 们也使用了相同的损失函数,一般情况下,当模型的输出为概率分布时,就会使用此函数。分类交叉熵会生成一个数字,指 // 示预测向量与真实标签向量的相似程度。 loss: 'categoricalCrossentropy', // 度量标准函数调用accuracy。假设预测是基于convnet输出的10个元素中的最大元素进行的,则此函数度量的示例中的一 // 部分已正确分类。回想一下交叉熵损失和精度度量之间的区别:交叉熵是可微的,因此使基于反向传播的训练成为可能,而精 // 度度量是不可微的,但却更容易解释,因此对于分类问题,这是正确预测在所有预测中所占的百分比。 metrics: ['accuracy'], }); // Batch size is another important hyperparameter. It defines the number of // examples we group together, or batch, between updates to the model's weights // during training. A value that is too low will update weights using too few // examples and will not generalize well. Larger batch sizes require more memory // resources and aren't guaranteed to perform better. // 一般而言,使用较大的批次与较小的批次相比好处是,它对模型的权重产生了更一致且变化较小的渐变更新。但是批次大小越大,训练 // 期间就需要更多的内存。您还应该记住,在给定相同数量的训练数据的情况下,较大的批次大小会导致每个时期的梯度更新数量较少。 // 因此,如果您使用较大的批量,请确保相应地增加时期数,以免在训练过程中无意中减少了权重更新的次数。 const batchSize = 320; // Leave out the last 15% of the training data for validation, to monitor // overfitting during training. // 预留15%的训练数据用于训练过程中的验证。 const validationSplit = 0.15; // Get number of training epochs from the UI. const trainEpochs = ui.getTrainEpochs(); // We'll keep a buffer of loss and accuracy values over time. let trainBatchCount = 0; const trainData = data.getTrainData(); const testData = data.getTestData(); const totalNumBatches = Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs; // During the long-running fit() call for model training, we include callbacks, // so that we can plot the loss and accuracy values in the page as the training // progresses. let valAcc; await model.fit( trainData.xs, // 特征输入。 trainData.labels, // 标签输入。 { batchSize, // 每次梯度更新的样本数。 validationSplit, // 末尾15%的数据用于验证。 epochs: trainEpochs, // 在训练数据上的迭代次数。 callbacks: { onBatchEnd: async (batch, logs) => { trainBatchCount++; ui.logStatus( `Training... (${(trainBatchCount / totalNumBatches * 100).toFixed(1)}% complete). ` + 'To stop training, refresh or close page.' ); // 绘制损失和准确率图表。 ui.plotLoss(trainBatchCount, logs.loss); ui.plotAccuracy(trainBatchCount, logs.acc); // 每10个batch结束时更新测试集的预测结果。 if (onIteration && (batch % 10 === 0)) { onIteration('onBatchEnd', batch, logs); } // 主动让出线程,允许UI在训练过程中更新。 await tf.nextFrame(); }, onEpochEnd: async (epoch, logs) => { valAcc = logs.val_acc; // 绘制损失和准确率图表。 ui.plotValLoss(trainBatchCount, logs.val_loss); ui.plotValAccuracy(trainBatchCount, logs.val_acc); // 每个epoch结束时更新测试集的预测结果。 if (onIteration) { onIteration('onEpochEnd', epoch, logs); } await tf.nextFrame(); } } } ); // 在fit执行完成后(训练结束),对模型进行评估。 const testResult = model.evaluate(testData.xs, testData.labels); const testAccPercent = testResult[1].dataSync()[0] * 100; const finalValAccPercent = valAcc * 100; // 更新最后的验证准确率和评估(测试)准确率。 ui.logStatus( `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` + `Final test accuracy: ${testAccPercent.toFixed(1)}%` ); } /** * Show predictions on a number of test examples. * * @param {tf.Model} model The model to be used for making the predictions. */ async function showPredictions(model) { const testExamples = 100; const examples = data.getTestData(testExamples); // Code wrapped in a tf.tidy() function callback will have their tensors freed // from GPU memory after execution without having to call dispose(). // The tf.tidy callback runs synchronously. tf.tidy(() => { // output为形状为[100, 10]的二维张量,第一维对应100个特征输入(手写数字图片),第二维对应 // 10个可能的数字的概率。 const output = model.predict(examples.xs); // tf.argMax() returns the indices of the maximum values in the tensor along // a specific axis. Categorical classification tasks like this one often // represent classes as one-hot vectors. One-hot vectors are 1D vectors with // one element for each output class. All values in the vector are 0 // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The // output from model.predict() will be a probability distribution, so we use // argMax to get the index of the vector element that has the highest // probability. This is our prediction. (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3) // dataSync() synchronously downloads the tf.tensor values from the GPU so // that we can use them in our normal CPU JavaScript code // (for a non-blocking version of this function, use data()). // output中的axis 0的每一行,代表某个图片可能的10个数字的概率,需要找出其中最大的,将其作为 // 模型的预测值。 // argMax()函数返回沿给定轴的最大值的索引。在这种情况下,此轴是第二维,即const axis = 1。 // argMax()的返回值是形状为[100,1]的张量。通过调用 dataSync(),我们将[100,1]形张量转 // 换为长度为100的Float32Array。然后Array.from()将Float32Array转换为一个普通的JavaScript // 数组,该数组由100个介于0和9之间的整数组成。此预测数组的含义非常简单:这是模型对100个输入图 // 像进行分类的结果。在MNIST数据集中,目标标签恰好与输出索引完全匹配。因此,我们甚至不需要将数 // 组转换为字符串标签。预测数组由下一行使用,该行调用一个UI函数,该函数将分类结果与测试图像一起呈现。 const axis = 1; const labels = Array.from(examples.labels.argMax(axis).dataSync()); const predictions = Array.from(output.argMax(axis).dataSync()); // 更新测试集的预测结果,对100个图标注对它的预测数字和正确性(绿色)。 ui.showTestResults(examples, predictions, labels); }); } function createModel() { let model; const modelType = ui.getModelTypeId(); if (modelType === 'ConvNet') { model = createConvModel(); } else if (modelType === 'DenseNet') { model = createDenseModel(); } else { throw new Error(`Invalid model type: ${modelType}`); } return model; } let data; // When page ready, loads the MNIST data, trains the model, and then shows what // the model predicted on unseen test data. window.onload = async () => { ui.logStatus('Loading MNIST data...'); data = new MnistData(); await data.load(); ui.logStatus('Creating model...'); const model = createModel(); // 输出模型规格。 model.summary(); // __________________________________________________________________________________________ // Layer (type) Input Shape Output shape Param # // ========================================================================================== // conv2d_Conv2D1 (Conv2D) [[null,28,28,1]] [null,26,26,16] 160 // __________________________________________________________________________________________ // max_pooling2d_MaxPooling2D1 [[null,26,26,16]] [null,13,13,16] 0 // __________________________________________________________________________________________ // conv2d_Conv2D2 (Conv2D) [[null,13,13,16]] [null,11,11,32] 4640 // __________________________________________________________________________________________ // max_pooling2d_MaxPooling2D2 [[null,11,11,32]] [null,5,5,32] 0 // __________________________________________________________________________________________ // conv2d_Conv2D3 (Conv2D) [[null,5,5,32]] [null,3,3,32] 9248 // __________________________________________________________________________________________ // flatten_Flatten1 (Flatten) [[null,3,3,32]] [null,288] 0 // __________________________________________________________________________________________ // dense_Dense1 (Dense) [[null,288]] [null,64] 18496 // __________________________________________________________________________________________ // dense_Dense2 (Dense) [[null,64]] [null,10] 650 // ========================================================================================== // Total params: 33194 // Trainable params: 33194 // Non-trainable params: 0 // __________________________________________________________________________________________ // // 从上述图表信息可见该模型总参数为33194个,其中池化层、扁平层无参数。 // 在深度学习中,通常模型自身的参数和模型的输出会占用显存。 // 有参数的层主要包括:卷积层、全连接层、BatchNorm层、Embedding层等。 // 无参数的层主要包括:多数的激活层(Sigmoid/ReLU)、池化层、Dropout层、扁平层等。 ui.logStatus('Model ready.'); // Enable the "Train Model" button. document.getElementById('train').removeAttribute('disabled'); ui.setTrainButtonCallback(async () => { ui.logStatus('Starting model training...'); document.getElementById('images').innerHTML = ''; await train(model, () => showPredictions(model)); }); };
data.js
import * as tf from '@tensorflow/tfjs'; // This is a helper class for loading and managing MNIST data specifically. // It is a useful example of how you could create your own data manager class // for arbitrary data though. It's worth a look :). // MnistData类封装了对MNIST灰度图片训练集和标签集的数据加载和预处理过程,为了使数据对下层算法和框架有友好性,数据的 // 格式和存储都以二进制缓冲区和类型化数组为主。这里的数据处理相对于以JSON为主的数据处理的普通web应用而言是稍显复杂的。 // 安东尼·戈德布卢姆(Kaggle的CEO)曾经这样说过:有人开玩笑说有80%的数据科学家在清理数据,剩下的20%在抱怨清理数据。 // 在实际的数据科学(大数据、机器学习)工作中,清理数据所占比例比外人想象的要多得多。一般而言,训练模型通常只占机器学习 // 或数据科学家工作的一小部分(少于10%)。当下在机器学习的纯应用领域上,以成熟框架和算法作为基础的开发上,更是如此。 const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; export const IMAGE_H = 28; export const IMAGE_W = 28; const IMAGE_SIZE = IMAGE_H * IMAGE_W; const NUM_CLASSES = 10; // 0-9。 const NUM_DATASET_ELEMENTS = 65000; // MNIST精灵图有65000个子图,精灵图的高度也为65000像素。 const NUM_TRAIN_ELEMENTS = 55000; // 我们用于训练的子图为55000个。 // A class that fetches the sprited MNIST dataset and provide data as tf.Tensors. export class MnistData { constructor() { } async load() { // Make a request for the MNIST sprited image. let img = new Image(); // 这里我们用到了canvas是因为DOM的img对象并没有提供给我们获取其像素的API,DOM这个级别的API是无法操作像素的。 // 只能通过canvas来做一个中间层,将img内容分批次绘制到canvas上,再从canvas提取出像素数据,看下面for循环部分。 const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d', { willReadFrequently: true, }); // 请求手写数字灰度图片,并将图片转化为数据格式。 const imgRequest = new Promise((resolve) => { // CORS配置。 img.crossOrigin = ''; // naturalWidth和naturalHeight指图片的原始大小,在计算时可以强制校正图片尺寸。 img.onload = () => { img.width = img.naturalWidth; // 图片本身宽度784。 img.height = img.naturalHeight; // 图片本身长度65000。 // 初始化一个新的二进制缓冲区,包含整个MMINST精灵图的每个像素(每一张子图的每个像素)。它将图像总数和每张图像的尺寸和通道数量相乘。 // 图片为PNG格式,所以有rgba4个通道,最后乘以4。 const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); // 65000 * (28 * 28) * 4。 const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; // 分13次处理完,每次处理MNIST精灵图中的5000个子图。 // NUM_DATASET_ELEMENTS和img.height一致,都为65000。每次取5000像素的高度,分13次就能取完。为什么它们两个的一致?我们可以这 // 么看:因为精灵图宽度为784像素,784整好是每个子图的总像素,因为单个子图的宽度为28*28,整好就是784,所以MNIST一横排的像素总量正 // 好就是一个子图的像素总量,从计数上就可以看成一个子图占了MNIST的一排,MNIST有多少排就有多少张子图,MNIST有65000像素高,也就是 // 有65000排,也就是我们可以算出子图有65000张,这个数量和实际的子图数量是一致的。从获取到的MNIST精灵图来看,子图并不是以28*28的 // 长宽方位比依次存放在MNIST上的,而是以1*784的长宽方位比来存放的,正好就是和我们想的一致:一排就是一张子图。将每张子图按照28来切割 // 并换行,就可以组成一个28*28的方形原始图,这下我们就能看明白手写的数字了,虽然这种排列方式利于眼睛的分辨,但显然不适合像素的提取, // 一旦这样排列,我们在提取子图像素的时候,就变成了去处理矩阵数据(单个子图的正方形就可以看做是一个28*28矩阵)。 // 下面内容涉及到的两个for循环,其中外层的循环遍历chunkSize,内层的循环遍历某个chunkSize下的像素。 const chunkNum = NUM_DATASET_ELEMENTS / chunkSize; // => 65000 / 5000 = 13。 for (let i = 0; i < chunkNum; i++) { // drawImage允许我们从一个图片源上的裁剪某个特定矩形区域并绘制到canvas的特定位置、大小的区域上。 // 这里我们从image上裁剪同image宽度,高度为chunkSize的区域绘制到canvas上。下一次循环就裁剪下 // 一个chunkSize高度的区域。 ctx.drawImage( img, // 绘制到上下文的元素,允许任何画布图像源,例如:HTMLImageElement、SVGImageElement等。 // 以下参数都基于需要绘制到目标上下文的image的矩形(裁剪)选择框。 0, // 左上角X轴坐标。 i * chunkSize, // 左上角Y轴坐标。 img.width, // 宽度。 chunkSize, // 高度。 // 以下参数都基于image的左上角在目标画布上的绘制。 0, // 左上角X轴坐标。 0, // 左上角Y轴坐标。 img.width, // 宽度。 chunkSize // 高度。 ); // getImageData返回一个ImageData对象,包含canvas给定的矩形区域的像素数据,这里的话每个循环里的canvas上线文 // 就是15680000个像素(784 * 5000 * 4 = 15680000,4个通道总的需要乘以4)。ImageData结构示例: // { // width: 100, // height: 100, // colorSpace: 'srgb', // data: Uint8ClampedArray[40000] // 像素数据。 // } const imageData = ctx.getImageData( // 以下参数都基于将要被提取的图像数据矩形区域。 0, // 左上角x坐标。 0, // 左上角y坐标。 canvas.width, // 宽度。 canvas.height // 高度。 ); // datasetBytesView长度为3920000,即精灵图中chunkSize个子图的像素总数量((28 * 28) * 5000)。 // 采用Float32类型数组当做数据视图来操作它,这个视图即表示了datasetBytesBuffer的一段空间。 const datasetBytesView = new Float32Array( // length => 3920000 // 数据。 datasetBytesBuffer, // 起始偏移量。第一次循环的时候i为0,也就是从0开始,即无偏移。第二次循环代的时候i为1,即偏移从5000个子图的四通道数 // 量以后开始。剩下的以此类推,每次循环都处理完当前chunkSize的数据,循环完13个chunkSize后,即整个MNIST精灵图的 // 数据就在datasetBytesBuffer中了。也就是说每次对chunkSize的迭代处理中,其实都是针对datasetBytesBuffer进行 // 处理,往它里面存放数据,也就是下面的内层for循环中,针对当前chunkSize下的每个像素,对datasetBytesBuffer从 // 当前的偏移值i * IMAGE_SIZE * chunkSize * 4开始存储,共存储IMAGE_SIZE * chunkSize个数据。 i * IMAGE_SIZE * chunkSize * 4, // => i * (28 * 28) * 5000 * 4。 // 元素数量。 IMAGE_SIZE * chunkSize // => (28 * 28) * 5000 = 392000 ); // 将该区域内总的像素数量除以4得到单通道像素数量。因为原数据包含了4个通道的数据,除以4才表示单个通道的数据长度,也就是该矩 // 形区域内所有像素的按单通道计算的数量。等同于MAGE_SIZE * chunkSize,也就是本次for循环的datasetBytesView的元素长度。 const len = imageData.data.length / 4; // => 15680000 / 4 = 3920000。 // 遍历当前chunkSize高度区域内的每一个像素(按单通道计算)。通过datasetBytesView视图将数据放入datasetBytesBuffer // 中的对应位置,来改变datasetBytesBuffer的值。 for (let j = 0; j < len; j++) { // All channels hold an equal value since the image is grayscale, so just read the red channel. // 因为图片为灰度化的,所有rgba通道都有相同的值,因此只需要读取r(红色)通道的值即可(间隔4个取1次,就能保证每次取的 // 都是r通道的值因此乘以了4)。最后将取出的r通道值除以255,得到0-1之间的数。赋值给datasetBytesView视图对应的下标, // datasetBytesView和datasetBytesBuffer是连通的,视图被改,datasetBytesBuffer就对应改变了,整个外层for循 // 环结束后,datasetBytesBuffer(this.datasetImages)就是整个MINST精灵图的所有像素的总数据(灰度数据)。 // 这里除以255是为了避免训练数据与预测数据之间出现不匹配的情况,因为我们的MNIST卷积网络应使用归一化到0-1之间的图像张 // 量数据进行训练。 datasetBytesView[j] = imageData.data[j * 4] / 255; } // 当上面这个for循环完成后,我们就得到了该chunkSize高度区域的所有的r通道的像素转成灰度的值。然后再接着遍历下一个chunkSize // 高度区域,并获得它的所有r通道的像素数据。 } // chunkSize高度区域都全部遍历完以后,datasetBytesBuffer里就已经有了整个MNIST精灵图的r通道的灰度像素信息。最终datasetImages // 的长度为50960000,即整个MNIST精灵图的像素总数量(65000 * (28 * 28))。 // 采用Float32类型数组当做数据视图来操作它。 this.datasetImages = new Float32Array(datasetBytesBuffer); // length => 50960000 // 图片DOM用完就回收。 img = null; resolve(); }; // 赋值src触发图片onload加载回调。 img.src = MNIST_IMAGES_SPRITE_PATH; }); // 请求标签数据。 const labelsRequest = fetch(MNIST_LABELS_PATH); // 等待图片、标签数据都加载并处理完成。 const [, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); // 将标签数据(uint8)存在一个二进制缓冲区中(字节数组),并采用Uint8类型数组当做数据视图来操作它。 // 标签数据为针对每个数字的独热编码集合。以数字6举例: // 样本 0 1 2 3 4 5 6 7 8 9 // 标签 0 0 0 0 0 0 1 0 0 0 // 只有一个是1,即数字6独热。 // 预测 0.10 0.01 0.01 0.01 0.09 0.01 0.71 0.01 0.03 0.02 // datasetLabels为一个1维数组,数据长度为65000 * 10,因为每组独热编码长度为10,分别对应10个数字的编码,只有一个为1,其他为0。 // datasetLabels每10个下标的元素即为一组针对一个数字的真实值的独热编码。所以下面在划分的时候乘以了NUM_CLASSES(10)。 this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Slice the images and labels into train and test sets. // 将图片数据和标签数据划分为训练集和测试集。 this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); // 0到(28 * 28) * 55000。 this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); // 上面剩下的。 this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 0到10 * 55000。 this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 上面剩下的。 } /** * Get all training data as a data tensor and a label tensor. * * @returns * xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]` * labels: The one-hot encoded labels tensor, of shape '[numTrainExamples, 10]'. */ getTrainData() { // 这包括表示为NHWC形状[N,28、28、1]的4维张量(批次示例的第一维)的输入MNIST图像,其中N是图像总数。 const xs = tf.tensor4d( this.trainImages, // NHWC:[55000, 28, 28, 1]。 [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1] ); // 这包括输入标签,表示为形状为[N,10]的独热编码2维张量。 const labels = tf.tensor2d( this.trainLabels, // NHWC:[55000, 10]。 [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES] ); // 返回的训练数据就是标准的特征集和标签集配对。 return {xs, labels}; } /** * Get all test data as a data tensor and a labels tensor. * * @param {number} numExamples Optional number of examples to get. If not provided, * all test examples will be returned. * @returns xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`. * labels: The one-hot encoded labels tensor, of shape `[numTestExamples, 10]`. */ getTestData(numExamples) { let xs = tf.tensor4d( this.testImages, [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1] ); let labels = tf.tensor2d( this.testLabels, [this.testLabels.length / NUM_CLASSES, NUM_CLASSES] ); if (numExamples != null) { // 修正形状为[100, 28, 28, 1],即从测试集中取出100张图片的特征和标签。 xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]); labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]); } // 返回的测试数据就是标准的特征集和标签集配对。这与训练集是相似的,只是它不包含在训练集以内,模型未曾接触到。 return {xs, labels}; } }
ui.js
import * as tfvis from '@tensorflow/tfjs-vis'; // This is a helper class for drawing loss graphs and MNIST images to the // window. For the purposes of understanding the machine learning bits, you can // largely ignore it const statusElement = document.getElementById('status'); const messageElement = document.getElementById('message'); const imagesElement = document.getElementById('images'); const trainButton = document.getElementById('train'); const modelType = document.getElementById('model-type'); const epochInput = document.getElementById('train-epochs'); const lossLabelElement = document.getElementById('loss-label'); const lossContainer = document.getElementById('loss-canvas'); const lossValContainer = document.getElementById('loss-val-canvas'); const accuracyLabelElement = document.getElementById('accuracy-label'); const accuracyContainer = document.getElementById('accuracy-canvas'); const accuracyValContainer = document.getElementById('accuracy-val-canvas'); export function logStatus(message) { statusElement.innerText = message; } export function trainingLog(message) { messageElement.innerText = `${message}\n`; } export function showTestResults(batch, predictions, labels) { const testExamples = batch.xs.shape[0]; imagesElement.innerHTML = ''; for (let i = 0; i < testExamples; i++) { const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]); const div = document.createElement('div'); div.className = 'pred-container'; // 预测的值。 const pred = document.createElement('div'); // 判断预测是否正确。 const prediction = predictions[i]; const label = labels[i]; const correct = prediction === label; // 预测正确显示绿色,错误显示红色。 pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`; // 预测的值。 pred.innerText = `pred: ${prediction}`; // 创建一个canvas来展示手写数字灰度图片。 const canvas = document.createElement('canvas'); canvas.className = 'prediction-canvas'; draw(image.flatten(), canvas); div.appendChild(pred); div.appendChild(canvas); imagesElement.appendChild(div); } } const lossValues = []; const lossValValues = []; const accuracyValues = []; const accuracyValValues = []; export function plotLoss(batch, loss) { lossValues.push({ x: batch, y: loss }); tfvis.render.linechart( lossContainer, { values: lossValues, series: ['train'] }, { xLabel: 'Batch Number', yLabel: 'Loss', width: 400, height: 300, } ); lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`; } export function plotValLoss(batch, loss) { lossValValues.push({ x: batch, y: loss }); tfvis.render.linechart( lossValContainer, { values: lossValValues, series: ['validation'] }, { xLabel: 'Batch Number', yLabel: 'Loss', width: 400, height: 300, seriesColors: ['#f16528'] } ); } export function plotAccuracy(batch, accuracy) { accuracyValues.push({ x: batch, y: accuracy }); tfvis.render.linechart( accuracyContainer, { values: accuracyValues, series: ['train'] }, { xLabel: 'Batch Number', yLabel: 'Accuracy', width: 400, height: 300, } ); accuracyLabelElement.innerText = `last accuracy: ${(accuracy * 100).toFixed(1)}%`; } export function plotValAccuracy(batch, accuracy) { accuracyValValues.push({ x: batch, y: accuracy }); tfvis.render.linechart( accuracyValContainer, { values: accuracyValValues, series: ['validation'] }, { xLabel: 'Batch Number', yLabel: 'Accuracy', width: 400, height: 300, seriesColors: ['#f16528'] } ); } export function draw(image, canvas) { const [width, height] = [28, 28]; canvas.width = width; canvas.height = height; const ctx = canvas.getContext('2d'); const imageData = new ImageData(width, height); const data = image.dataSync(); for (let i = 0; i < height * width; ++i) { const j = i * 4; imageData.data[j + 0] = data[i] * 255; imageData.data[j + 1] = data[i] * 255; imageData.data[j + 2] = data[i] * 255; imageData.data[j + 3] = 255; } ctx.putImageData(imageData, 0, 0); } export function getModelTypeId() { return document.getElementById('model-type').value; } export function getTrainEpochs() { return Number.parseInt(document.getElementById('train-epochs').value); } export function setTrainButtonCallback(train) { trainButton.addEventListener('click', async () => { // Disable button during the training. trainButton.setAttribute('disabled', true); modelType.setAttribute('disabled', true); epochInput.setAttribute('disabled', true); // Start training. await train(); // Release button and reset chart data for next training without refreshing page. trainButton.removeAttribute('disabled'); modelType.removeAttribute('disabled'); epochInput.removeAttribute('disabled'); // Rest data array. lossValues.splice(0, lossValues.length); lossValValues.splice(0, lossValValues.length); accuracyValues.splice(0, accuracyValues.length); accuracyValValues.splice(0, accuracyValValues.length); }); }