tensorflow.js示例笔记 - mnist

使用层来进行数字识别,使用tf.layers api训练模型识别MNIST数据库中的手写数字。

index.html

<html>
    <head>
        <title>MNIST</title>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <link rel="stylesheet" href="../shared/tfjs-examples.css"/>

        <style>
            #train {
                margin-top: 10px;
            }

            label {
                display: inline-block;
                width: 250px;
                padding: 6px 0 6px 0;
            }

            .canvases {
                display: inline-block;
            }

            .prediction-canvas {
                width: 100px;
            }

            .pred {
                font-size: 20px;
                line-height: 25px;
                width: 100px;
            }

            .pred-correct {
                background-color: #00cf00;
            }

            .pred-incorrect {
                background-color: red;
            }

            .pred-container {
                display: inline-block;
                width: 100px;
                margin: 10px;
            }

            #train-epochs {
                width: 82px;
                font-size: 14px;
            }
        </style>
    </head>
    <body>
        <div class="tfjs-example-container">
            <section class="title-area">
                <h1>Digit Recognizer with Layers</h1>
                <p class="subtitle">Train a model to recognize handwritten digits from the MNIST database using the tf.layers
                    api.
                </p>
            </section>
            <section>
                <p class="section-head">Description</p>
                <p>
                    This examples lets you train a handwritten digit recognizer using either a Convolutional Neural Network
                    (also known as a ConvNet or CNN) or a Fully Connected Neural Network (also known as a DenseNet).
                </p>
                <p>The MNIST dataset is used as training data.</p>
            </section>
            <section>
                <p class="section-head">Training Parameters</p>
                <div>
                    <label>Model Type:</label>
                    <select id="model-type">
                        <option>ConvNet</option>
                        <option>DenseNet</option>
                    </select>
                </div>

                <div>
                    <label>Number of training epochs:</label>
                    <input id="train-epochs" type="number" value="3" />
                </div>

                <button id="train" disabled>Train Model</button>
            </section>
            <section>
                <p class="section-head">Training Progress</p>
                <p id="status"></p>
                <p id="message"></p>
                <div id="stats">
                    <div class="canvases">
                        <label id="loss-label"></label>
                        <div id="loss-canvas"></div>
                    </div>
                    <div class="canvases">
                        <label id="accuracy-label"></label>
                        <div id="accuracy-canvas"></div>
                    </div>
                    <br />
                    <div class="canvases">
                        <div id="loss-val-canvas"></div>
                    </div>
                    <div class="canvases">
                        <div id="accuracy-val-canvas"></div>
                    </div>
                </div>
            </section>
            <section>
                <p class="section-head">Inference Examples</p>
                <div id="images"></div>
            </section>
        </div>
        <!-- TODO(cais): Decide. DO NOT SUBMIT. -->
        <!-- <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> -->
        <script type="module" src="index.js"></script>
    </body>
</html>

index.js

import * as tf from '@tensorflow/tfjs';
import {IMAGE_H, IMAGE_W, MnistData} from './data';
import * as ui from './ui';

/**
 * Creates a model consisting of only flatten, dense and dropout layers.
 *
 * The model create here has approximately the same number of parameters
 * (~31k) as the convnet created by `createConvModel()`, but is
 * expected to show a significantly worse accuracy after training, due to the
 * fact that it doesn't utilize the spatial information as the convnet does.
 *
 * This is for comparison with the convolutional network above.
 *
 * @returns {tf.Model} An instance of tf.Model.
 */
function createDenseModel() {
    const model = tf.sequential();
    model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]}));
    model.add(tf.layers.dense({units: 42, activation: 'relu'}));
    model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
    return model;
    // 该普通稠密层模型与卷积神经网络模型相比,参数数量都在32000个左右,基本维持了一个公平的形式。但在最终的损失和准
    // 确率上,普通稠密层模型都比不过后者。虽然与卷积神经网络的模型相比,准确率差异只有2%左右,但错误率却是几倍。因此
    // 在处理图片类业务上,卷积神经网络模型较稠密层模型有明显的优势。
}

/**
 * Creates a convolutional neural network (Convnet) for the MNIST data.
 *
 * @returns {tf.Model} An instance of tf.Model.
 */
function createConvModel() {
    // Create a sequential neural network model. tf.sequential provides an API
    // for creating "stacked" models where the output from one layer is used as
    // the input to the next layer.
    const model = tf.sequential();

    // The first layer of the convolutional neural network plays a dual role:
    // it is both the input layer of the neural network and a layer that performs
    // the first convolution operation on the input. It receives the 28x28 pixels
    // black and white images. This input layer uses 16 filters with a kernel size
    // of 5 pixels each. It uses a simple RELU activation function which pretty
    // much just looks like this: __/
    // 第1层。
    model.add(tf.layers.conv2d({
        inputShape: [IMAGE_H, IMAGE_W, 1],
        // 要应用于输入数据的滑动卷积过滤器窗口的尺寸。在这里我们将kernelSize设为3,以指定方形的3*3卷积窗口。
        kernelSize: 3,
        // 要应用于输入数据的尺寸为kernelSize的过滤器窗口数量。在这里我们将对数据应用16个过滤器。
        filters: 16,
        activation: 'relu'
    }));

    // After the first layer we include a MaxPooling layer. This acts as a sort of
    // downsampling using max values in a region instead of averaging.
    // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks
    // 第2层。
    model.add(tf.layers.maxPooling2d({
        poolSize: 2,
        strides: 2
    }));

    // Our third layer is another convolution, this time with 32 filters.
    // 第3层。
    // 第3层和第4层这两个层是前两个层的完全重复(除了conv2d层在其过滤器配置中具有更大的值并且不具有inputShape字段)。
    // 这种由卷积层和极化层组成的几乎重复的"模体"在convnets中是常见的。它在convnet中扮演着关键的角色:特征的分层提取。
    // 为了理解它的含义,可以考虑一个训练过的convnet,它的任务是对图像中的动物进行分类。在convnet的早期阶段,卷积层中
    // 的滤波器(即channel)可以编码诸如直线、曲线和角等低级几何特征。这些低级特征转化为更复杂的特征,如猫的眼睛、鼻子和
    // 耳朵。在convnet的顶层,一个层可能有对整个cat的存在进行编码的过滤器。级别越高,表示越抽象,从像素级值中移除的特征
    // 越多。但是,这些抽象的特征正是convnet任务实现良好精度所需要的,例如,在图像中时检测出猫。此外,这些特征不是手工制
    // 作的,而是通过有监督的学习并以自动方式从数据中提取的。这是一个典型的有代表性的例子,我们也把它描述为层-层转换。
    model.add(tf.layers.conv2d({
        kernelSize: 3,
        filters: 32,
        activation: 'relu'
    }));

    // Max pooling again.
    // 第4层。
    model.add(tf.layers.maxPooling2d({
        poolSize: 2,
        strides: 2
    }));

    // Add another convolution layer.
    // 第5层。
    model.add(tf.layers.conv2d({
        kernelSize: 3,
        filters: 32,
        activation: 'relu'
    }));

    // Now we flatten the output from the 2D filters into a 1D vector to prepare
    // it for input into our last layer. This is common practice when feeding
    // higher dimensional data to a final classification output layer.
    // 第6层。
    // 第6层为扁平层。它将多维张量“压缩”为一维张量,从而保持元素的总数。在我们的例子中,形状为[3,3,32]的3D张量被展平
    // 为1D张量[288](没有批次维度)。挤压操作的一个明显问题是如何对元素排序,因为原始三维空间中没有内在的顺序。答案是:
    // 我们对元素进行排序,这样,如果你沿着展开的一维张量中的元素向下看,看看它们的原始索引(来自三维张量)如何变化,最后
    // 一个索引变化最快,倒数第二个索引变化第二快,以此类推,而第一个索引变化最慢。
    // 扁平层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。扁平层不影响batch的大小。
    // 扁平层中没有权重,它只是将其输入展开为一个长数组。
    model.add(tf.layers.flatten({}));

    // Previous flattened 1D vector will be this layer's input.
    // 第7层。
    // 第7层和第8层我们添加了两个稠密层,为什么要添加两个而不是一个呢?原因是:添加具有非线性激活的层会增加网络的容量。
    // 通常多层神经网络内,必须含有级联的线性函数和非线性函数,有助于表示能力的增强,这意味着模型容量的增大,预测精度的
    // 提高。在boston-housing项目案例中,我们也做出过相同的总结。https://mapleroyals.com/forum/threads/4th-job-skill-changes.135025
    // 实际上,您可以将convnet看作是由两个模型堆叠在一起:
    // 1. 包含conv2d、maxPooling2d和flatten层的模型,用于从输入图像中提取视觉特征。
    // 2. 一种多层感知器(MLP),具有两个密集层,以提取的特征作为输入,并基于它进行数字类预测,这就是这两个密集层的本
    // 质。在深度学习中,许多模型都利用了特征提取层的这种模式,然后最终预测使用MLP。在本书的其余部分中,我们将看到更多
    // 这样的例子,从音频信号分类器到自然语言处理。
    model.add(tf.layers.dense({
        units: 64,
        activation: 'relu'
    }));

    // Our last layer is a dense layer which has 10 output units, one for each
    // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually
    // represent numbers, but it's the same idea if you had classes that
    // represented other entities like dogs and cats (two output classes: 0, 1).
    // We use the softmax function as the activation for the output layer as it
    // creates a probability distribution over our 10 classes so their output
    // values sum to 1.
    // 第8层。
    // 这一层将输出10个值在0-1之间的数字元素,表示对0-9这10个数字的预测概率,它们的和为1,最大的概率对应的数值为预测的倾向值。
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax'
    }));

    return model;
}

/**
 * This callback type is used by the `train` function for insertion into the
 * model.fit callback loop.
 *
 * @callback onIterationCallback
 * @param {string} eventType Selector for which type of event to fire on.
 * @param {number} batchOrEpochNumber The current epoch / batch number
 * @param {tf.Logs} logs Logs to append to
 */

/**
 * Compile and train the given model.
 *
 * @param {tf.Model} model The model to train.
 * @param {onIterationCallback} onIteration A callback to execute every 10 batches & epoch end.
 */
async function train(model, onIteration) {
    ui.logStatus('Training model...');

    // We compile our model by specifying an optimizer, a loss function, and a
    // list of metrics that we will use for model evaluation. Here we're using a
    // categorical crossentropy loss, the standard choice for a multi-class
    // classification problem like MNIST digits.
    // The categorical crossentropy loss is differentiable and hence makes
    // model training possible. But it is not amenable to easy interpretation
    // by a human. This is why we include a "metric", namely accuracy, which is
    // simply a measure of how many of the examples are classified correctly.
    // This metric is not differentiable and hence cannot be used as the loss
    // function of the model.
    model.compile({
        // Now that we've defined our model, we will define our optimizer. The
        // optimizer will be used to optimize our model's weight values during
        // training so that we can decrease our training loss and increase our
        // classification accuracy.

        // We are using rmsprop as our optimizer.
        // An optimizer is an iterative method for minimizing a loss function.
        // It tries to find the minimum of our loss function with respect to the
        // model's weight parameters.
        optimizer: 'rmsprop',
        // 损失函数categoricalCrossentropy分类交叉熵,适用于诸如MNIST之类的多分类问题。在iris分类项目案例中,我
        // 们也使用了相同的损失函数,一般情况下,当模型的输出为概率分布时,就会使用此函数。分类交叉熵会生成一个数字,指
        // 示预测向量与真实标签向量的相似程度。
        loss: 'categoricalCrossentropy',
        // 度量标准函数调用accuracy。假设预测是基于convnet输出的10个元素中的最大元素进行的,则此函数度量的示例中的一
        // 部分已正确分类。回想一下交叉熵损失和精度度量之间的区别:交叉熵是可微的,因此使基于反向传播的训练成为可能,而精
        // 度度量是不可微的,但却更容易解释,因此对于分类问题,这是正确预测在所有预测中所占的百分比。
        metrics: ['accuracy'],
    });

    // Batch size is another important hyperparameter. It defines the number of
    // examples we group together, or batch, between updates to the model's weights
    // during training. A value that is too low will update weights using too few
    // examples and will not generalize well. Larger batch sizes require more memory
    // resources and aren't guaranteed to perform better.
    // 一般而言,使用较大的批次与较小的批次相比好处是,它对模型的权重产生了更一致且变化较小的渐变更新。但是批次大小越大,训练
    // 期间就需要更多的内存。您还应该记住,在给定相同数量的训练数据的情况下,较大的批次大小会导致每个时期的梯度更新数量较少。
    // 因此,如果您使用较大的批量,请确保相应地增加时期数,以免在训练过程中无意中减少了权重更新的次数。
    const batchSize = 320;

    // Leave out the last 15% of the training data for validation, to monitor
    // overfitting during training.
    // 预留15%的训练数据用于训练过程中的验证。
    const validationSplit = 0.15;

    // Get number of training epochs from the UI.
    const trainEpochs = ui.getTrainEpochs();

    // We'll keep a buffer of loss and accuracy values over time.
    let trainBatchCount = 0;

    const trainData = data.getTrainData();
    const testData = data.getTestData();

    const totalNumBatches = Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs;

    // During the long-running fit() call for model training, we include callbacks,
    // so that we can plot the loss and accuracy values in the page as the training
    // progresses.
    let valAcc;
    await model.fit(
        trainData.xs,  // 特征输入。
        trainData.labels,  // 标签输入。
        {
            batchSize,  // 每次梯度更新的样本数。
            validationSplit,  // 末尾15%的数据用于验证。
            epochs: trainEpochs, // 在训练数据上的迭代次数。
            callbacks: {
                onBatchEnd: async (batch, logs) => {
                    trainBatchCount++;
                    ui.logStatus(
                        `Training... (${(trainBatchCount / totalNumBatches * 100).toFixed(1)}% complete). ` +
                        'To stop training, refresh or close page.'
                    );
                    // 绘制损失和准确率图表。
                    ui.plotLoss(trainBatchCount, logs.loss);
                    ui.plotAccuracy(trainBatchCount, logs.acc);

                    // 每10个batch结束时更新测试集的预测结果。
                    if (onIteration && (batch % 10 === 0)) {
                        onIteration('onBatchEnd', batch, logs);
                    }

                    // 主动让出线程,允许UI在训练过程中更新。
                    await tf.nextFrame();
                },
                onEpochEnd: async (epoch, logs) => {
                    valAcc = logs.val_acc;
                    // 绘制损失和准确率图表。
                    ui.plotValLoss(trainBatchCount, logs.val_loss);
                    ui.plotValAccuracy(trainBatchCount, logs.val_acc);

                    // 每个epoch结束时更新测试集的预测结果。
                    if (onIteration) {
                        onIteration('onEpochEnd', epoch, logs);
                    }

                    await tf.nextFrame();
                }
            }
        }
    );

    // 在fit执行完成后(训练结束),对模型进行评估。
    const testResult = model.evaluate(testData.xs, testData.labels);
    const testAccPercent = testResult[1].dataSync()[0] * 100;
    const finalValAccPercent = valAcc * 100;
    // 更新最后的验证准确率和评估(测试)准确率。
    ui.logStatus(
        `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` +
        `Final test accuracy: ${testAccPercent.toFixed(1)}%`
    );
}

/**
 * Show predictions on a number of test examples.
 *
 * @param {tf.Model} model The model to be used for making the predictions.
 */
async function showPredictions(model) {
    const testExamples = 100;
    const examples = data.getTestData(testExamples);

    // Code wrapped in a tf.tidy() function callback will have their tensors freed
    // from GPU memory after execution without having to call dispose().
    // The tf.tidy callback runs synchronously.
    tf.tidy(() => {
        // output为形状为[100, 10]的二维张量,第一维对应100个特征输入(手写数字图片),第二维对应
        // 10个可能的数字的概率。
        const output = model.predict(examples.xs);

        // tf.argMax() returns the indices of the maximum values in the tensor along
        // a specific axis. Categorical classification tasks like this one often
        // represent classes as one-hot vectors. One-hot vectors are 1D vectors with
        // one element for each output class. All values in the vector are 0
        // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The
        // output from model.predict() will be a probability distribution, so we use
        // argMax to get the index of the vector element that has the highest
        // probability. This is our prediction. (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3)
        // dataSync() synchronously downloads the tf.tensor values from the GPU so
        // that we can use them in our normal CPU JavaScript code
        // (for a non-blocking version of this function, use data()).
        // output中的axis 0的每一行,代表某个图片可能的10个数字的概率,需要找出其中最大的,将其作为
        // 模型的预测值。
        // argMax()函数返回沿给定轴的最大值的索引。在这种情况下,此轴是第二维,即const axis = 1。
        // argMax()的返回值是形状为[100,1]的张量。通过调用 dataSync(),我们将[100,1]形张量转
        // 换为长度为100的Float32Array。然后Array.from()将Float32Array转换为一个普通的JavaScript
        // 数组,该数组由100个介于0和9之间的整数组成。此预测数组的含义非常简单:这是模型对100个输入图
        // 像进行分类的结果。在MNIST数据集中,目标标签恰好与输出索引完全匹配。因此,我们甚至不需要将数
        // 组转换为字符串标签。预测数组由下一行使用,该行调用一个UI函数,该函数将分类结果与测试图像一起呈现。
        const axis = 1;
        const labels = Array.from(examples.labels.argMax(axis).dataSync());
        const predictions = Array.from(output.argMax(axis).dataSync());

        // 更新测试集的预测结果,对100个图标注对它的预测数字和正确性(绿色)。
        ui.showTestResults(examples, predictions, labels);
    });
}

function createModel() {
    let model;
    const modelType = ui.getModelTypeId();
    if (modelType === 'ConvNet') {
        model = createConvModel();
    } else if (modelType === 'DenseNet') {
        model = createDenseModel();
    } else {
        throw new Error(`Invalid model type: ${modelType}`);
    }
    return model;
}

let data;

// When page ready, loads the MNIST data, trains the model, and then shows what
// the model predicted on unseen test data.
window.onload = async () => {
    ui.logStatus('Loading MNIST data...');
    data = new MnistData();
    await data.load();

    ui.logStatus('Creating model...');
    const model = createModel();

    // 输出模型规格。
    model.summary();
    // __________________________________________________________________________________________
    // Layer (type)                Input Shape               Output shape              Param #
    // ==========================================================================================
    // conv2d_Conv2D1 (Conv2D)     [[null,28,28,1]]          [null,26,26,16]           160
    // __________________________________________________________________________________________
    // max_pooling2d_MaxPooling2D1 [[null,26,26,16]]         [null,13,13,16]           0
    // __________________________________________________________________________________________
    // conv2d_Conv2D2 (Conv2D)     [[null,13,13,16]]         [null,11,11,32]           4640
    // __________________________________________________________________________________________
    // max_pooling2d_MaxPooling2D2 [[null,11,11,32]]         [null,5,5,32]             0
    // __________________________________________________________________________________________
    // conv2d_Conv2D3 (Conv2D)     [[null,5,5,32]]           [null,3,3,32]             9248
    // __________________________________________________________________________________________
    // flatten_Flatten1 (Flatten)  [[null,3,3,32]]           [null,288]                0
    // __________________________________________________________________________________________
    // dense_Dense1 (Dense)        [[null,288]]              [null,64]                 18496
    // __________________________________________________________________________________________
    // dense_Dense2 (Dense)        [[null,64]]               [null,10]                 650
    // ==========================================================================================
    // Total params: 33194
    // Trainable params: 33194
    // Non-trainable params: 0
    // __________________________________________________________________________________________
    //
    // 从上述图表信息可见该模型总参数为33194个,其中池化层、扁平层无参数。
    // 在深度学习中,通常模型自身的参数和模型的输出会占用显存。
    // 有参数的层主要包括:卷积层、全连接层、BatchNorm层、Embedding层等。
    // 无参数的层主要包括:多数的激活层(Sigmoid/ReLU)、池化层、Dropout层、扁平层等。

    ui.logStatus('Model ready.');

    // Enable the "Train Model" button.
    document.getElementById('train').removeAttribute('disabled');

    ui.setTrainButtonCallback(async () => {
        ui.logStatus('Starting model training...');
        document.getElementById('images').innerHTML = '';
        await train(model, () => showPredictions(model));
    });
};

 

data.js

import * as tf from '@tensorflow/tfjs';

// This is a helper class for loading and managing MNIST data specifically.
// It is a useful example of how you could create your own data manager class
// for arbitrary data though. It's worth a look :).

// MnistData类封装了对MNIST灰度图片训练集和标签集的数据加载和预处理过程,为了使数据对下层算法和框架有友好性,数据的
// 格式和存储都以二进制缓冲区和类型化数组为主。这里的数据处理相对于以JSON为主的数据处理的普通web应用而言是稍显复杂的。

// 安东尼·戈德布卢姆(Kaggle的CEO)曾经这样说过:有人开玩笑说有80%的数据科学家在清理数据,剩下的20%在抱怨清理数据。
// 在实际的数据科学(大数据、机器学习)工作中,清理数据所占比例比外人想象的要多得多。一般而言,训练模型通常只占机器学习
// 或数据科学家工作的一小部分(少于10%)。当下在机器学习的纯应用领域上,以成熟框架和算法作为基础的开发上,更是如此。

const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

export const IMAGE_H = 28;
export const IMAGE_W = 28;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;

const NUM_CLASSES = 10;  // 0-9。

const NUM_DATASET_ELEMENTS = 65000;  // MNIST精灵图有65000个子图,精灵图的高度也为65000像素。
const NUM_TRAIN_ELEMENTS = 55000;    // 我们用于训练的子图为55000个。

// A class that fetches the sprited MNIST dataset and provide data as tf.Tensors.
export class MnistData {
    constructor() {

    }

    async load() {
        // Make a request for the MNIST sprited image.
        let img = new Image();

        // 这里我们用到了canvas是因为DOM的img对象并没有提供给我们获取其像素的API,DOM这个级别的API是无法操作像素的。
        // 只能通过canvas来做一个中间层,将img内容分批次绘制到canvas上,再从canvas提取出像素数据,看下面for循环部分。
        const canvas = document.createElement('canvas');
        const ctx = canvas.getContext('2d', {
            willReadFrequently: true,
        });

        // 请求手写数字灰度图片,并将图片转化为数据格式。
        const imgRequest = new Promise((resolve) => {
            // CORS配置。
            img.crossOrigin = '';

            // naturalWidth和naturalHeight指图片的原始大小,在计算时可以强制校正图片尺寸。
            img.onload = () => {
                img.width = img.naturalWidth;     // 图片本身宽度784。
                img.height = img.naturalHeight;   // 图片本身长度65000。

                // 初始化一个新的二进制缓冲区,包含整个MMINST精灵图的每个像素(每一张子图的每个像素)。它将图像总数和每张图像的尺寸和通道数量相乘。
                // 图片为PNG格式,所以有rgba4个通道,最后乘以4。
                const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); // 65000 * (28 * 28) * 4。

                const chunkSize = 5000;
                canvas.width = img.width;
                canvas.height = chunkSize;

                // 分13次处理完,每次处理MNIST精灵图中的5000个子图。
                // NUM_DATASET_ELEMENTS和img.height一致,都为65000。每次取5000像素的高度,分13次就能取完。为什么它们两个的一致?我们可以这
                // 么看:因为精灵图宽度为784像素,784整好是每个子图的总像素,因为单个子图的宽度为28*28,整好就是784,所以MNIST一横排的像素总量正
                // 好就是一个子图的像素总量,从计数上就可以看成一个子图占了MNIST的一排,MNIST有多少排就有多少张子图,MNIST有65000像素高,也就是
                // 有65000排,也就是我们可以算出子图有65000张,这个数量和实际的子图数量是一致的。从获取到的MNIST精灵图来看,子图并不是以28*28的
                // 长宽方位比依次存放在MNIST上的,而是以1*784的长宽方位比来存放的,正好就是和我们想的一致:一排就是一张子图。将每张子图按照28来切割
                // 并换行,就可以组成一个28*28的方形原始图,这下我们就能看明白手写的数字了,虽然这种排列方式利于眼睛的分辨,但显然不适合像素的提取,
                // 一旦这样排列,我们在提取子图像素的时候,就变成了去处理矩阵数据(单个子图的正方形就可以看做是一个28*28矩阵)。
                // 下面内容涉及到的两个for循环,其中外层的循环遍历chunkSize,内层的循环遍历某个chunkSize下的像素。
                const chunkNum = NUM_DATASET_ELEMENTS / chunkSize; // => 65000 / 5000 = 13。

                for (let i = 0; i < chunkNum; i++) {
                    // drawImage允许我们从一个图片源上的裁剪某个特定矩形区域并绘制到canvas的特定位置、大小的区域上。
                    // 这里我们从image上裁剪同image宽度,高度为chunkSize的区域绘制到canvas上。下一次循环就裁剪下
                    // 一个chunkSize高度的区域。
                    ctx.drawImage(
                        img, // 绘制到上下文的元素,允许任何画布图像源,例如:HTMLImageElement、SVGImageElement等。
                        // 以下参数都基于需要绘制到目标上下文的image的矩形(裁剪)选择框。
                        0,               // 左上角X轴坐标。
                        i * chunkSize,   // 左上角Y轴坐标。
                        img.width,       // 宽度。
                        chunkSize,       // 高度。
                        // 以下参数都基于image的左上角在目标画布上的绘制。
                        0,               // 左上角X轴坐标。
                        0,               // 左上角Y轴坐标。
                        img.width,       // 宽度。
                        chunkSize        // 高度。
                    );

                    // getImageData返回一个ImageData对象,包含canvas给定的矩形区域的像素数据,这里的话每个循环里的canvas上线文
                    // 就是15680000个像素(784 * 5000 * 4 = 15680000,4个通道总的需要乘以4)。ImageData结构示例:
                    // {
                    //        width: 100,
                    //        height: 100,
                    //     colorSpace: 'srgb',
                    //        data: Uint8ClampedArray[40000]   // 像素数据。
                    // }
                    const imageData = ctx.getImageData(
                        // 以下参数都基于将要被提取的图像数据矩形区域。
                        0,               // 左上角x坐标。
                        0,               // 左上角y坐标。
                        canvas.width,    // 宽度。
                        canvas.height    // 高度。
                    );

                    // datasetBytesView长度为3920000,即精灵图中chunkSize个子图的像素总数量((28 * 28) * 5000)。
                    // 采用Float32类型数组当做数据视图来操作它,这个视图即表示了datasetBytesBuffer的一段空间。
                    const datasetBytesView = new Float32Array(  // length => 3920000
                        // 数据。
                        datasetBytesBuffer,
                        // 起始偏移量。第一次循环的时候i为0,也就是从0开始,即无偏移。第二次循环代的时候i为1,即偏移从5000个子图的四通道数
                        // 量以后开始。剩下的以此类推,每次循环都处理完当前chunkSize的数据,循环完13个chunkSize后,即整个MNIST精灵图的
                        // 数据就在datasetBytesBuffer中了。也就是说每次对chunkSize的迭代处理中,其实都是针对datasetBytesBuffer进行
                        // 处理,往它里面存放数据,也就是下面的内层for循环中,针对当前chunkSize下的每个像素,对datasetBytesBuffer从
                        // 当前的偏移值i * IMAGE_SIZE * chunkSize * 4开始存储,共存储IMAGE_SIZE * chunkSize个数据。
                        i * IMAGE_SIZE * chunkSize * 4,  // => i * (28 * 28) * 5000 * 4。
                        // 元素数量。
                        IMAGE_SIZE * chunkSize           // => (28 * 28) * 5000 = 392000
                    );

                    // 将该区域内总的像素数量除以4得到单通道像素数量。因为原数据包含了4个通道的数据,除以4才表示单个通道的数据长度,也就是该矩
                    // 形区域内所有像素的按单通道计算的数量。等同于MAGE_SIZE * chunkSize,也就是本次for循环的datasetBytesView的元素长度。
                    const len = imageData.data.length / 4;  //  => 15680000 / 4 = 3920000。

                    // 遍历当前chunkSize高度区域内的每一个像素(按单通道计算)。通过datasetBytesView视图将数据放入datasetBytesBuffer
                    // 中的对应位置,来改变datasetBytesBuffer的值。
                    for (let j = 0; j < len; j++) {
                        // All channels hold an equal value since the image is grayscale, so just read the red channel.
                        // 因为图片为灰度化的,所有rgba通道都有相同的值,因此只需要读取r(红色)通道的值即可(间隔4个取1次,就能保证每次取的
                        // 都是r通道的值因此乘以了4)。最后将取出的r通道值除以255,得到0-1之间的数。赋值给datasetBytesView视图对应的下标,
                        // datasetBytesView和datasetBytesBuffer是连通的,视图被改,datasetBytesBuffer就对应改变了,整个外层for循
                        // 环结束后,datasetBytesBuffer(this.datasetImages)就是整个MINST精灵图的所有像素的总数据(灰度数据)。
                        // 这里除以255是为了避免训练数据与预测数据之间出现不匹配的情况,因为我们的MNIST卷积网络应使用归一化到0-1之间的图像张
                        // 量数据进行训练。
                        datasetBytesView[j] = imageData.data[j * 4] / 255;
                    }
                    // 当上面这个for循环完成后,我们就得到了该chunkSize高度区域的所有的r通道的像素转成灰度的值。然后再接着遍历下一个chunkSize
                    // 高度区域,并获得它的所有r通道的像素数据。
                }

                // chunkSize高度区域都全部遍历完以后,datasetBytesBuffer里就已经有了整个MNIST精灵图的r通道的灰度像素信息。最终datasetImages
                // 的长度为50960000,即整个MNIST精灵图的像素总数量(65000 * (28 * 28))。
                // 采用Float32类型数组当做数据视图来操作它。
                this.datasetImages = new Float32Array(datasetBytesBuffer);  // length => 50960000

                // 图片DOM用完就回收。
                img = null;

                resolve();
            };

            // 赋值src触发图片onload加载回调。
            img.src = MNIST_IMAGES_SPRITE_PATH;
        });

        // 请求标签数据。
        const labelsRequest = fetch(MNIST_LABELS_PATH);

        // 等待图片、标签数据都加载并处理完成。
        const [, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);

        // 将标签数据(uint8)存在一个二进制缓冲区中(字节数组),并采用Uint8类型数组当做数据视图来操作它。
        // 标签数据为针对每个数字的独热编码集合。以数字6举例:
        // 样本  0     1     2     3     4     5     6     7     8     9
        // 标签  0     0     0     0     0     0     1     0     0     0       // 只有一个是1,即数字6独热。
        // 预测  0.10  0.01  0.01  0.01  0.09  0.01  0.71  0.01  0.03  0.02
        // datasetLabels为一个1维数组,数据长度为65000 * 10,因为每组独热编码长度为10,分别对应10个数字的编码,只有一个为1,其他为0。
        // datasetLabels每10个下标的元素即为一组针对一个数字的真实值的独热编码。所以下面在划分的时候乘以了NUM_CLASSES(10)。
        this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

        // Slice the images and labels into train and test sets.
        // 将图片数据和标签数据划分为训练集和测试集。
        this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);  // 0到(28 * 28) * 55000。
        this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);      // 上面剩下的。

        this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 0到10 * 55000。
        this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);     // 上面剩下的。
    }

    /**
     * Get all training data as a data tensor and a label tensor.
     *
     * @returns
     *   xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`
     *   labels: The one-hot encoded labels tensor, of shape '[numTrainExamples, 10]'.
     */
    getTrainData() {
        // 这包括表示为NHWC形状[N,28、28、1]的4维张量(批次示例的第一维)的输入MNIST图像,其中N是图像总数。
        const xs = tf.tensor4d(
            this.trainImages,
            // NHWC:[55000, 28, 28, 1]。
            [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]
        );

        // 这包括输入标签,表示为形状为[N,10]的独热编码2维张量。
        const labels = tf.tensor2d(
            this.trainLabels,
            // NHWC:[55000, 10]。
            [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES]
        );

        // 返回的训练数据就是标准的特征集和标签集配对。
        return {xs, labels};
    }

    /**
     * Get all test data as a data tensor and a labels tensor.
     *
     * @param {number} numExamples Optional number of examples to get. If not provided,
     *                             all test examples will be returned.
     * @returns xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
     *              labels: The one-hot encoded labels tensor, of shape `[numTestExamples, 10]`.
     */
    getTestData(numExamples) {
        let xs = tf.tensor4d(
            this.testImages,
            [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]
        );

        let labels = tf.tensor2d(
            this.testLabels,
            [this.testLabels.length / NUM_CLASSES, NUM_CLASSES]
        );

        if (numExamples != null) {
            // 修正形状为[100, 28, 28, 1],即从测试集中取出100张图片的特征和标签。
            xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
            labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
        }

        // 返回的测试数据就是标准的特征集和标签集配对。这与训练集是相似的,只是它不包含在训练集以内,模型未曾接触到。
        return {xs, labels};
    }
}

ui.js

import * as tfvis from '@tensorflow/tfjs-vis';

// This is a helper class for drawing loss graphs and MNIST images to the
// window. For the purposes of understanding the machine learning bits, you can
// largely ignore it
const statusElement = document.getElementById('status');
const messageElement = document.getElementById('message');
const imagesElement = document.getElementById('images');

const trainButton = document.getElementById('train');
const modelType = document.getElementById('model-type');
const epochInput = document.getElementById('train-epochs');

const lossLabelElement = document.getElementById('loss-label');
const lossContainer = document.getElementById('loss-canvas');
const lossValContainer = document.getElementById('loss-val-canvas');
const accuracyLabelElement = document.getElementById('accuracy-label');
const accuracyContainer = document.getElementById('accuracy-canvas');
const accuracyValContainer = document.getElementById('accuracy-val-canvas');

export function logStatus(message) {
    statusElement.innerText = message;
}

export function trainingLog(message) {
    messageElement.innerText = `${message}\n`;
}

export function showTestResults(batch, predictions, labels) {
    const testExamples = batch.xs.shape[0];
    imagesElement.innerHTML = '';

    for (let i = 0; i < testExamples; i++) {
        const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);

        const div = document.createElement('div');
        div.className = 'pred-container';

        // 预测的值。
        const pred = document.createElement('div');

        // 判断预测是否正确。
        const prediction = predictions[i];
        const label = labels[i];
        const correct = prediction === label;

        // 预测正确显示绿色,错误显示红色。
        pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
        // 预测的值。
        pred.innerText = `pred: ${prediction}`;

        // 创建一个canvas来展示手写数字灰度图片。
        const canvas = document.createElement('canvas');
        canvas.className = 'prediction-canvas';
        draw(image.flatten(), canvas);

        div.appendChild(pred);
        div.appendChild(canvas);

        imagesElement.appendChild(div);
    }
}

const lossValues = [];
const lossValValues = [];
const accuracyValues = [];
const accuracyValValues = [];

export function plotLoss(batch, loss) {
    lossValues.push({
        x: batch,
        y: loss
    });

    tfvis.render.linechart(
        lossContainer,
        {
            values: lossValues,
            series: ['train']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Loss',
            width: 400,
            height: 300,
        }
    );

    lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`;
}

export function plotValLoss(batch, loss) {
    lossValValues.push({
        x: batch,
        y: loss
    });

    tfvis.render.linechart(
        lossValContainer,
        {
            values: lossValValues,
            series: ['validation']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Loss',
            width: 400,
            height: 300,
            seriesColors: ['#f16528']
        }
    );
}

export function plotAccuracy(batch, accuracy) {
    accuracyValues.push({
        x: batch,
        y: accuracy
    });

    tfvis.render.linechart(
        accuracyContainer,
        {
            values: accuracyValues,
            series: ['train']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Accuracy',
            width: 400,
            height: 300,
        }
    );

    accuracyLabelElement.innerText = `last accuracy: ${(accuracy * 100).toFixed(1)}%`;
}

export function plotValAccuracy(batch, accuracy) {
    accuracyValValues.push({
        x: batch,
        y: accuracy
    });

    tfvis.render.linechart(
        accuracyValContainer,
        {
            values: accuracyValValues,
            series: ['validation']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Accuracy',
            width: 400,
            height: 300,
            seriesColors: ['#f16528']
        }
    );
}

export function draw(image, canvas) {
    const [width, height] = [28, 28];
    canvas.width = width;
    canvas.height = height;

    const ctx = canvas.getContext('2d');
    const imageData = new ImageData(width, height);
    const data = image.dataSync();

    for (let i = 0; i < height * width; ++i) {
        const j = i * 4;
        imageData.data[j + 0] = data[i] * 255;
        imageData.data[j + 1] = data[i] * 255;
        imageData.data[j + 2] = data[i] * 255;
        imageData.data[j + 3] = 255;
    }

    ctx.putImageData(imageData, 0, 0);
}

export function getModelTypeId() {
    return document.getElementById('model-type').value;
}

export function getTrainEpochs() {
    return Number.parseInt(document.getElementById('train-epochs').value);
}

export function setTrainButtonCallback(train) {
    trainButton.addEventListener('click', async () => {
        // Disable button during the training.
        trainButton.setAttribute('disabled', true);
        modelType.setAttribute('disabled', true);
        epochInput.setAttribute('disabled', true);

        // Start training.
        await train();

        // Release button and reset chart data for next training without refreshing page.
        trainButton.removeAttribute('disabled');
        modelType.removeAttribute('disabled');
        epochInput.removeAttribute('disabled');
        // Rest data array.
        lossValues.splice(0, lossValues.length);
        lossValValues.splice(0, lossValValues.length);
        accuracyValues.splice(0, accuracyValues.length);
        accuracyValValues.splice(0, accuracyValValues.length);
    });
}

 

posted @ 2024-05-21 13:23  james·von  阅读(48)  评论(0编辑  收藏  举报