tensorflow.js示例笔记 - iris

根据鸢尾花的数据对花进行分类,使用神经网络对结构化(表格)数据进行分类。

index.html

<html>
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <link rel="stylesheet" href="../shared/tfjs-examples.css"/>
        <style>
            input {
                width: 75px;
            }

            .input-div {
                padding: 5px;
                font-family: monospace;
                font-size: 16px;
            }

            .input-label {
                display: inline-block;
                width: 160px;
            }

            td {
                padding-left: 5px;
                padding-right: 5px;
                padding-bottom: 5px;
            }

            #predict-header {
                font-weight: bold;
            }

            .output-div {
                padding: 5px;
                padding-top: 20px;
                font-family: monospace;
                font-weight: bold;
            }

            #evaluate-table {
                display: inline-block;
            }

            #evaluate-table td, #evaluate-table th {
                font-family: monospace;
                border: 1px solid #ddd;
                padding: 8px;
            }

            #evaluate-table th {
                padding-top: 12px;
                padding-bottom: 12px;
                text-align: left;
                background-color: #4CAF50;
                color: white;
            }

            .region {
                border-left: 1px dashed #ccc;
                margin-bottom: 5px;
                padding-left: 24px;
                margin-left: -24px;
            }

            .load-save-section {
                padding-top: 3px;
                padding-bottom: 3px;
            }

            .logit-span {
                padding-right: 1em;
            }

            .correct-prediction {
                background-color: greenyellow
            }

            .wrong-prediction {
                background-color: red;
            }
        </style>
    </head>
    <body>
        <div class='tfjs-example-container'>
            <section class='title-area'>
                <h1>鸢尾花</h1>
                <p class='subtitle'>Classify structured (tabular) data with a neural network.</p>
            </section>
            <section>
                <p class='section-head'>Description</p>
                <p>
                    This example uses a neural network to classify tabular data representing different flowers. The data used
                    for
                    each flower are the petal length and width as well as the sepal length and width. The goal
                    is to predict what kind of flower it is based on those features of each data point. The
                    data comes from the famous <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set">Iris flower</a> data
                    set.
                </p>
            </section>
            <section>
                <p class='section-head'>Instructions</p>
                <p>
                    Using the buttons below you can either train a new model from scratch or load a pre-trained model and
                    test its performance.
                </p>
                <p>
                    If you train a model from scratch you can also save it to browser local storage.
                </p>
                <p>
                    If you load a pre-trained model you can edit the properties in first row of "Test Examples" to generate
                    a prediction for those data points.
                </p>
            </section>
            <section>
                <p class='section-head'>Controls</p>
                <div class="region">
                    <h3>Train Model</h3>
                    <div class="create-model">
                        <div class="input-div">
                            <label class="input-label">Train Epochs:</label>
                            <input id="train-epochs" type="number" value="40"></input>
                        </div>
                        <div class="input-div">
                            <span class="input-label">Learning Rate:</span>
                            <input id="learning-rate" type="number" value="0.01"></input>
                        </div>
                        <button id="train-from-scratch">Train model from scratch</button>
                    </div>
                </div>
                <div class="region">
                    <h3>Save/Load Model</h3>
                    <div class="load-save-section">
                        <button id="load-pretrained-remote">Load hosted pretrained model</button>
                    </div>
                    <div class="load-save-section">
                        <button id="load-local" disabled="true">Load locally-saved model</button>
                        <button id="save-local" disabled="true">Save model locally</button>
                        <button id="remove-local" disabled="true">Remove model locally</button>
                        <span id='local-model-status'>Status unavailable.</span>
                    </div>
                </div>
            </section>
            <section>
                <p class='section-head'>Status</p>
                <div>
                    <span id="demo-status">Standing by.</span>
                </div>
            </section>
            <section>
                <p class='section-head'>Training Progress</p>
                <div class='with-cols'>
                    <div>
                        <h4>Loss</h4>
                        <div class="canvases" id="lossCanvas"></div>
                    </div>
                    <div>
                        <h4>Accuracy</h4>
                        <div class="canvases" id="accuracyCanvas"></div>
                    </div>
                    <div>
                        <h4>Confusion Matrix (on validation set)</h4>
                        <div id="confusion-matrix"></div>
                    </div>
                </div>
            </section>
            <section>
                <p class='section-head'>Test Examples</p>
                <div id="evaluate">
                    <table id="evaluate-table">
                        <tr>
                            <th>Petal length</th>
                            <th>Petal width</th>
                            <th>Sepal length</th>
                            <th>Sepal width</th>
                            <th>True class</th>
                            <th>Predicted class</th>
                            <th>Class Probabilities</th>
                        </tr>
                        <tbody id="evaluate-tbody">
                        <tr>
                            <td>
                                <input id="petal-length" value="5.1" />
                                <button id="petal-length-inc">+</button>
                                <button id="petal-length-dec">-</button>
                            </td>
                            <td>
                                <input id="petal-width" value="3.5" />
                                <button id="petal-width-inc">+</button>
                                <button id="petal-width-dec">-</button>
                            </td>
                            <td>
                                <input id="sepal-length" value="1.4" />
                                <button id="sepal-length-inc">+</button>
                                <button id="sepal-length-dec">-</button>
                            </td>
                            <td>
                                <input id="sepal-width" value="0.2" />
                                <button id="sepal-width-inc">+</button>
                                <button id="sepal-width-dec">-</button>
                            </td>
                            <td></td>
                            <td id="winner"></td>
                            <td id="logits"></td>
                        </tr>
                        </tbody>
                    </table>
                </div>
            </section>
            <div>
                <div class="horizontal-section">
                    <div id="horizontal-section"></div>
                </div>
            </div>
        </div>
        <script type="module" src="index.js"></script>
    </body>
</html>

index.js

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import * as data from './data';
import * as loader from './loader';
import * as ui from './ui';

let model;

/**
 * Train a `tf.Model` to recognize Iris flower type.
 *
 * @param xTrain Training feature data, a `tf.Tensor` of shape [numTrainExamples, 4].
 *               The second dimension include the features petal length, petalwidth,
 *               sepal length and sepal width.
 * @param yTrain One-hot training labels, a `tf.Tensor` of shape [numTrainExamples, 3].
 * @param xTest Test feature data, a `tf.Tensor` of shape [numTestExamples, 4].
 * @param yTest One-hot test labels, a `tf.Tensor` of shape [numTestExamples, 3].
 *
 * @returns The trained `tf.Model` instance.
 */
async function trainModel(xTrain, yTrain, xTest, yTest) {
    ui.status('Training model... Please wait.');

    const params = ui.loadTrainParametersFromUI();

    // Define the topology of the model: two dense layers.
    const model = tf.sequential();

    model.add(tf.layers.dense({
        units: 10,
        activation: 'sigmoid',
        inputShape: [xTrain.shape[1]]
    }));

    // 与我们看到的sigmoid激活函数不同,softmax激活函数不是对每个元素进行处理的,因为输入矢量的每个元素都依赖于所有其他元素。
    // 具体来说,输入的每个元素都将转换为其自然指数(即以e = 2.718为基数的exp函数)。然后将指数除以所有元素的指数之和。这样做
    // 有一下几个目的:
    // 1. 确保每个数字都在0到1之间。
    // 2. 确保输出矢量的所有元素总和为1。这是一个理想的属性,因为这两个原因:1.输出可以解释为分配的概率;2.为了与分类交叉熵损失
    // 函数兼容,输出必须满足该属性。
    // 3. 确保输入向量中的较大元素映射到输出向量中的较大元素。
    // 这里给出一个具体的例子,假设在最后致密层经过矩阵乘法和加法产生的向量:[-3 ,0 ,-8]。其长度为3,因为该稠密层配置具有3个
    // 神经元。注意,这些元素是不受任何特定范围限制的浮点数。softmax会将向量转换为:[0.0474107,0.9522698,0.0003195]。
    // softmax函数输出的三个元素有以下特点:
    // 1. 全部在[0,1]中。
    // 2. 总和为1。
    // 3. 以与输入向量中的顺序匹配的方式进行排序。
    // 由于这些特点,可以将输出解释为(由模型)分配给所有可能类别的概率值。在上面的示例中,第二个类别的概率最高,而第一个类别的概率
    // 最低。当使用这种类型的多类分类器的输出时,可以选择最高softmax元素的索引作为最终的对输入所属类别的决策。
    model.add(tf.layers.dense({
        // 最后输出一个3个元素的向量,比如[-3, 0, -8]。
        units: 3,
        // 该激活函数将上述向量转化为[0.0474107,0.9522698,0.0003195]。
        activation: 'softmax'
    }));

    model.summary();

    // 在二分类示例中,我们看到了如何将二进制交叉熵用作损失函数,以及为什么不能将其他更易于理解的指标(如准确性和召回率)用作损失
    // 函数。多类分类的情况非常相似。一个简单的指标,即准确性,该准确性是模型正确分类的示例的一部分。其度量标准对于人类理解模型的
    // 性能非常重要。
    model.compile({
        optimizer: tf.train.adam(params.learningRate),
        // 但是,准确度不是损失函数好的选择,因为它与二分类的准确度一样会遭受零梯度问题。因此,人们为多分类设计了一种特殊的损失函
        // 数:分类交叉熵。它只是将二元交叉熵泛化为两个以上类别的情况。
        // 分类交叉熵损失函数(单个输入)如下:
        // function categoricalCrossentropy(oneHotTruth, probs):
        //     for i in (0 to length of oneHotTruth)
        //         if oneHotTruth(i) is equal to 1
        //             return -log(probs[i]);
        //
        // 在上面的代码中,oneHotTruth是输入示例实际类的oneHot编码。概率是模型输出的softmax概率。上面的代码的关键之处在于,
        // 就分类交叉熵而言,概率中只有一个元素很重要,也就是索引与实际类相对应的元素。概率的其他元素可能会有所不同,但只要它们不
        // 更改实际类的元素,就不会影响分类交叉熵。对于特定的概率元素,它越接近1,则交叉熵的值将越低。
        loss: 'categoricalCrossentropy',
        // 像二分类的交叉熵一样,多分类交叉熵可以直接用tf.metrics命名的函数,使用它来计算简单的分类交叉熵。
        metrics: ['accuracy'],
        // 图表:不同概率输出下分类交叉熵的值。在不失一般性的前提下,所有示例(行)都基于以下情况:存在三个类别,而实际类别是第二
        // 个类别
        // One-hot truth label      概率(softmax 输出)   分类交叉熵      MSE
        // [0,1,0]                 [0.2、0.5、0.3]       0.693         0.127
        // [0,1,0]                 [0.0,0.5,0.5]       0.693         0.167
        // [0,1,0]                 [0.0,0.9,0.1]       0.105         0.006
        // [0,1,0]                 [0.1,0.9,0.0]       0.105         0.006
        // [0,1,0]                 [0.0,0.99,0.01]     0.010         0.00006
        // 通过比较上表中的第1行和第2行或比较第3行和第4行,应该清楚的是,更改与实际类不对应的概率元素不会更改交叉熵,即使它可能会
        // 更改one-hot真值标签和pros,但是会影响MSE。同样,就像在二元交叉熵中一样,当实际类别的概率值接近1时,MSE的回报率会降低,
        // 因此不利于影响正确类别的概率值。这就是为什么分类交叉熵比MSE更适合作为损失函数用于多类分类问题的原因。
    });

    const trainLogs = [];
    const lossContainer = document.getElementById('lossCanvas');
    const accContainer = document.getElementById('accuracyCanvas');
    const beginMs = performance.now();

    // Call 'model.fit' to train the model.
    const history = await model.fit(xTrain, yTrain, {
        epochs: params.epochs,
        validationData: [xTest, yTest],
        callbacks: {
            onEpochEnd: async (epoch, logs) => {
                // Plot the loss and accuracy values at the end of every training epoch.
                const secPerEpoch = (performance.now() - beginMs) / (1000 * (epoch + 1));
                ui.status(`Training model... Approximately ${secPerEpoch.toFixed(4)} seconds per epoch`);
                // 更新损失和准确率图表。
                trainLogs.push(logs);
                tfvis.show.history(lossContainer, trainLogs, ['loss', 'val_loss'])
                tfvis.show.history(accContainer, trainLogs, ['acc', 'val_acc'])
                // 绘制混淆矩阵。
                calculateAndDrawConfusionMatrix(model, xTest, yTest);
            },
        }
    });

    const secPerEpoch = (performance.now() - beginMs) / (1000 * params.epochs);

    ui.status(`Model training complete:  ${secPerEpoch.toFixed(4)} seconds per epoch`);

    return model;
}

/**
 * Run inference on manually-input Iris flower data.
 *
 * @param model The instance of 'tf.Model' to run the inference with.
 */
async function predictOnManualInput(model) {
    if (model == null) {
        ui.setManualInputWinnerMessage('ERROR: Please load or train model first.');
        return;
    }

    // Use a 'tf.tidy' scope to make sure that WebGL memory allocated for the
    // 'predict' call is released at the end.
    tf.tidy(() => {
        // Prepare input data as a 2D `tf.Tensor`.
        const inputData = ui.getManualInputData();
        const input = tf.tensor2d([inputData], [1, 4]);

        // Call 'model.predict' to get the prediction output as probabilities for
        // the Iris flower categories.
        // predictOut是二维张量[numExamples,3]。
        const predictOut = model.predict(input);
        const logits = Array.from(predictOut.dataSync());
        // 调用argMax方法会使形状减小为[numExample]。参数值-1表示argMax应该在最后一个维度上寻找最大值并返回其索引。
        // 比如假设predictOut具有以下值:[[0,0.6,0.4],[0.8,0,0.2]]。然后argMax(-1)将返回一个张量,该张量指
        // 示在第一个和第二个示例的索引1和0处分别找到沿最后一个(第二个)维的最大值:[1,0]。
        const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
        ui.setManualInputWinnerMessage(winner);
        ui.renderLogitsForManualInput(logits);
    });
}

/**
 * Draw confusion matrix.
 */
async function calculateAndDrawConfusionMatrix(model, xTest, yTest) {
    // 计算和绘制混淆矩阵。
    // 在多类别分类中,混淆矩阵比简单的准确性更能提供更多信息,就像精度和召回率一起构成的分类比二分类的准确性更全面。混淆矩阵可
    // 以提供有助于与模型和训练过程相关的决策的信息。比如将体育网站误认为游戏网站,可能比将体育网站混淆为网络钓鱼诈骗更是一个大
    // 问题。在这种情况下,您可以调整模型的超参数以最大程度地减少代价最高的错误。到目前为止,我们所看到的模型都将数字数组作为输
    // 入。换句话说,每个输入示例都表示为一个简单的数字列表,其长度是固定的,并且元素的顺序无关紧要,只要它们对于馈入模型的所有
    // 示例都是一致的即可。虽然这种类型的模型涵盖了重要的和实用的机器学习问题的很大一部分,但它远非唯一。在接下来的案例项目中,
    // 我们将介绍更复杂的输入数据类型,包括图像和序列。图像是一种无处不在且用途广泛的输入数据,已针对其开发了强大的神经网络结构,
    // 以将机器学习模型的准确性推向超人的水平。
    //
    //           预测值
    //           类别1 类别2 类别3
    // 真 类别1    a     b    c
    // 实 类别2    d     e    f
    // 值 类别3    g     h    i
    //
    // 与二分类的混淆矩阵(TP、FN、FP、TN)一样,多分类混淆矩阵行数据相加是真实值类别数,列数据相加是分类后的类别数,也就有了以
    // 下计算公式(以类别1为例):
    // 精确率 = a / (a + d + g)。
    // 召回率(真肯定率、查全率) = a / (a + b + c)。
    const [preds, labels] = tf.tidy(() => {
        const preds = model.predict(xTest).argMax(-1);
        const labels = yTest.argMax(-1);
        return [preds, labels];
    });

    const confMatrixData = await tfvis.metrics.confusionMatrix(labels, preds);
    const container = document.getElementById('confusion-matrix');
    tfvis.render.confusionMatrix(
        container,
        {
            values: confMatrixData,
            labels: data.IRIS_CLASSES
        },
        {shadeDiagonal: true},
    );

    tf.dispose([preds, labels]);
}

/**
 * Run inference on some test Iris flower data.
 *
 * @param model The instance of `tf.Model` to run the inference with.
 * @param xTest Test data feature, a `tf.Tensor` of shape [numTestExamples, 4].
 * @param yTest Test true labels, one-hot encoded, a `tf.Tensor` of shape [numTestExamples, 3].
 */
async function evaluateModelOnTestData(model, xTest, yTest) {
    ui.clearEvaluateTable();

    tf.tidy(() => {
        const xData = xTest.dataSync();
        const yTrue = yTest.argMax(-1).dataSync();
        const predictOut = model.predict(xTest);
        const yPred = predictOut.argMax(-1);
        ui.renderEvaluateTable(xData, yTrue, yPred.dataSync(), predictOut.dataSync());
        calculateAndDrawConfusionMatrix(model, xTest, yTest);
    });

    predictOnManualInput(model);
}

const HOSTED_MODEL_JSON_URL = 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';

/**
 * The main function of the Iris demo.
 */
async function iris() {
    const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15);

    const localLoadButton = document.getElementById('load-local');
    const localSaveButton = document.getElementById('save-local');
    const localRemoveButton = document.getElementById('remove-local');

    document.getElementById('train-from-scratch').addEventListener('click', async () => {
        model = await trainModel(xTrain, yTrain, xTest, yTest);
        await evaluateModelOnTestData(model, xTest, yTest);
        localSaveButton.disabled = false;
    });

    if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) {
        ui.status('Model available: ' + HOSTED_MODEL_JSON_URL);
        document.getElementById('load-pretrained-remote').addEventListener('click', async () => {
            ui.clearEvaluateTable();
            model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
            await predictOnManualInput(model);
            localSaveButton.disabled = false;
        });
    }

    localLoadButton.addEventListener('click', async () => {
        model = await loader.loadModelLocally();
        await predictOnManualInput(model);
    });

    localSaveButton.addEventListener('click', async () => {
        await loader.saveModelLocally(model);
        await loader.updateLocalModelStatus();
    });

    localRemoveButton.addEventListener('click', async () => {
        await loader.removeModelLocally();
        await loader.updateLocalModelStatus();
    });

    await loader.updateLocalModelStatus();

    ui.status('Standing by.');
    ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
}

iris();

data.js

import * as tf from '@tensorflow/tfjs';

// 定义鸢尾花的3种类型。
export const IRIS_CLASSES = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;

// Iris flowers data. Source: https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
// 鸢尾花分类的原始数据,格式为二元数组,每个子数组为一个鸢尾花样本。子数组中的前4个元素为
// 每个鸢尾花样本的特征:花瓣(Petal)长度、花瓣宽度、萼片(Sepal)长度、萼片宽度,第5个元素为鸢尾花的类型,
// 值为0、1、2,即iris-setosa、iris-versicolor和iris-virginica这3种类型。
const IRIS_DATA = [
    [5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
    [4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
    [4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
    [4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
    [4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
    [5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
    [5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
    [5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
    [4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
    [5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
    [4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
    [5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
    [5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
    [5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
    [4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
    [4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
    [5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
    [6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
    [6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
    [4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
    [5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
    [6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
    [5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
    [5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
    [6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
    [6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
    [6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
    [5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
    [5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
    [6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
    [5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
    [5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
    [5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
    [5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
    [7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
    [7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
    [6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
    [6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
    [5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
    [7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
    [6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
    [6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
    [6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
    [7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
    [6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
    [7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
    [6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
    [6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
    [6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
    [6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
];

/**
 * Convert Iris data arrays to `tf.Tensor`s.
 *
 * @param data The Iris input feature data, an `Array` of `Array`s, each element
 *             of which is assumed to be a length-4 `Array` (for petal length, petal
 *             width, sepal length, sepal width).
 * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
 *                representing the true category of the Iris flower. Assumed to have the same
 *                 array length as `data`.
 * @param testSplit Fraction of the data at the end to split as test data: anumber between 0 and 1.
 * @return A length-4 `Array`, with
 *           - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
 *           - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
 *           - test data as `tf.Tensor` of shape [numTestExamples, 4].
 *           - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
 */
function convertToTensors(data, targets, testSplit) {
    const numExamples = data.length;
    if (numExamples !== targets.length) {
        throw new Error('data and split have different numbers of examples');
    }

    // Randomly shuffle 'data' and 'targets'.
    // 先将预处理后的dataByClass和targetByClass洗牌。
    // 这里的做法是先创建一个相同长度的数组用于存放数组下标,再将改数组洗牌,然后遍历该随机化后的下标
    // 数组,按随机的下标顺序从原始的dataByClass和targetByClass中取出对应的值,就得到了它们随机
    // 化后的值。
    const indices = [];
    for (let i = 0; i < numExamples; ++i) {
        indices.push(i);
    }
    tf.util.shuffle(indices);

    const shuffledData = [];
    const shuffledTargets = [];
    for (let i = 0; i < numExamples; ++i) {
        shuffledData.push(data[indices[i]]);
        shuffledTargets.push(targets[indices[i]]);
    }

    // Split the data into a training set and a tet set, based on 'testSplit'.
    // 计算用于在训练阶段的训练和验证样本数。
    // testSplit在本案例中定义是0.15,就是说训练输入数据集合汇中最后15%的训练样本用做验证,
    const numTestExamples = Math.round(numExamples * testSplit);
    const numTrainExamples = numExamples - numTestExamples;

    const xDims = shuffledData[0].length;

    // Create a 2D 'tf.Tensor' to hold the feature data.
    // 将随机化以后的输入样本数据集合转成二维张量。
    // [numExamples, xDims] => [样本数量,每个样本的特征数量] => [axis 0,axis 1]d
    // 样本数量为第1个秩(即轴axis 0)的长度,每个样本的特征数量为第2个秩(即axis 1)的长度。
    const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);

    // Create a 1D 'tf.Tensor' to hold the labels, and convert the number label
    // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
    // ***在研究解决分类问题的模型之前,我们需要重点介绍在此多类分类任务中分类目标(物种)的表示方式。将在这里
    // 重点介绍独热编码。
    // 到目前为止,我们在该例子项目集合中看到的所有机器学习示例都涉及简单的目标表示,例如下载时间预测问题和
    // Boston Housing问题中的单个数字,以及0-1的二进制表示钓鱼检测问题中的目标。但是,在本节问题中,以一种
    // 不太熟悉的方式(称为one-hot编码,即独热编码、一位有效编码)来表示这三种鸢尾花。shuffledTargets为随
    // 机化后的鸢尾花类型数组,以[1, 0, 2]为例,先调用把其转为元素为int32的一维张量,实际print出来还是[1, 0, 2],
    // 再对其做独热编码,独热深度为IRIS_NUM_CLASSES即3,表示最后一个轴的长度为3(放到这里就表示每个子数组的
    // 元素个数为3个),编码的结果是返回一个二维张量:
    // [
    //     [0, 1, 0],
    //     [1, 0, 0],
    //     [0, 0, 1],
    // ]
    // 如果上所示,每一行只有一个数字是1(热、有效),其他都是0(非热,非有效),这就是独热的含义。
    const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
    // 尽管上面有内容介绍了独热编码,你也可能对上述tfjs中的独热编码的结果感到不可理解,它其实是这么计算出来的:
    // 首先,tf.tensor1d(shuffledTargets).toInt()的返回张量,在这里被称为indices,即索引集合张量,严
    // 格上,它的元素都是从0开始的整数数字。第二个参数为depth,就是编码后输出张量的最后一个轴(秩)的长度,这
    // 在上面的内容也提到过。就当前案例来讲,在indices为[1, 0, 2],depth为3的情况下,index 0应该对应100
    // 即[1, 0, 0]编码,其原理就是基于独热编码是使用N位状态寄存器来对N个状态进行编码的:对于indices来说,它
    // 内部最小数为0,最大数为2,即表示它内部有3个状态(即便没有1,比如indices为[0, 0, 2],也是按最大和最小
    // 数的差异来计算的,也是表示有3个状态的。当然和每个状态的排列先后顺序也无关系),就是说1、0、2这三个数字表
    // 示的索引,对于它们每个来说,我们可以在进行编码前暂时用状态自己来代替,表示每个数字有3位,以1为例,合起来
    // 就是3个1,即111,状态的数量决定了每个状态的位数,也就是整个indices在独热编码输出后,其张量形状就变成了
    // 下面这样:
    // [
    //     [1, 1, 1],    // => 1
    //     [0, 0, 0],    // => 0
    //     [2, 2, 2],    // => 2
    // ]
    // 如上所示,axis 0长度为状态数,axis 1长度为depth,编码后的张量形状相当于状态数*depth,或features*depth。
    // 还有一点需要注意的就是,这里我们的depth也和状态数量一致,恰好也为3。如果depth为4,那在编码后indices的
    // 形状就应该是这样的:
    // [
    //     [1, 1, 1, 0],    // => 1
    //     [0, 0, 0, 0],    // => 0
    //     [2, 2, 2, 0],    // => 2
    // ]
    // 如上所示,因为depth为4,但状态数只有3,所以在最后一个轴上缺失了1位状态,默认就用0来填充。
    // 上面提到的只是depth比状态数量多的情况,那么当depth比状态数量少又怎么办呢?比如depth为2,形状就应该是这样的:
    // [
    //     [1, 1],    // => 1
    //     [0, 0],    // => 0
    //     [2, 2],    // => 2
    // ]
    // 没错,每个状态的位数从末尾开始直接被切掉了1位,只保留了和depth一样的位数。也就是说在后面做编码的时候,被切掉
    // 的这个位就可以不考虑了。
    //
    // 接着我们再深入地对编码过程进行介绍,还是以正常的3个状态和depth为3进行编码,即以第一个矩阵的结构为编码后的输出
    // 的张量形状。首先,我们先把所有索引按大小顺序列出来:0、1、2,这将作为一个对照表和每个状态(3位)的每位进行比较,
    // 然后依次进行编码:
    //
    // 对0进行编码:
    //    首先预设0为000
    //
    //    比较第一位:
    //    000
    //    012
    //    上下按位比较比较:000的第一位为0,和012中的第一位0一致,那这一位就是有效位,填充为1,此时000变为100。
    //
    //    比较第二位:
    //    100
    //    012
    //    上下按位比较比较:100的第二位为0,和012中的第二位1不一致,那这一位就是非有效位,填充为0,此时100还是100。
    //
    //    比较第三位:
    //    100
    //    012
    //    上下按位比较比较:100的第三位为0,和012中的第三位2不一致,那这一位就是非有效位,填充为0,此时100还是100。
    //
    // 对0以012进行三位状态的独热编码到此结束,最后得出编码后的0为100,表示为一维张量形式即[1, 0, 0]。
    //
    // 然后对于1和2的编码我们不再赘述,以上述步骤类推,可知编码后的1为010,编码后的2为001。整体即:
    // 0 => 100 => [1, 0, 0]
    // 1 => 010 => [0, 1, 0]
    // 2 => 001 => [0, 0, 1]
    //
    // 然后我们以此编码映射来替换我们的张量,替换后的张量即为:
    // [
    //     [0, 1, 0],    // => 1
    //     [1, 0, 0],    // => 0
    //     [0, 0, 1],    // => 2
    // ]
    //
    // 至此,indices张量在经过depth为3的oneHot独热编码以后就是上面这个二维的张量,和本段内容靠前区域的介绍结果一致。
    // 总结状态数和depth相同或不同的情况下的编码结果如下:
    // indices为[1, 0, 2],depth为3,独热编码输出:
    // [
    //     [0, 1, 0],    // => 1
    //     [1, 0, 0],    // => 0
    //     [0, 0, 1],    // => 2
    // ]
    // indices为[1, 0, 2],depth为4,独热编码输出:
    // [
    //     [0, 1, 0, 0],    // => 1
    //     [1, 0, 0, 0],    // => 0
    //     [0, 0, 1, 0],    // => 2
    // ]
    // indices为[1, 0, 2],depth为2,独热编码输出:
    // [
    //     [0, 1],    // => 1
    //     [1, 0],    // => 0
    //     [0, 0],    // => 2
    // ]
    //
    // 至此还缺少了对一个特殊情况的解释,就是万一有一个状态不是0或者正整数怎么,比如为负整数,因为虽然tfjs的oneHot函数
    // 只接收0和正整数作为indices的元素,但其实它也能处理负数。比如我们这里的indices在编码前是[1, 0, -1, 2],那么
    // 它最终被编码以后应该是什么样呢?很简单,就是将-1完全处理为所有位都是非有效位即可,这时indices在编码以后就应该是:
    // [
    //     [0, 1, 0],    // => 1
    //     [1, 0, 0],    // => 0
    //     [0, 0, 0],    // => -1
    //     [0, 0, 1],    // => 2
    // ]
    // 其实它除了能兼容负整数,还能兼容字符串,只不过对字符串的处理是把他们当做0来处理的,比如当indices为['x', 'g', 'a']
    // 的时候,编码后应该是:
    // [
    //     [1, 0, 0],    // => 'x'
    //     [1, 0, 0],    // => 'g'
    //     [1, 0, 0],    // => 'a'
    // ]
    //
    // 以上内容只是针对tfjs的oneHot独热编码来看介绍的,有它自己的特点,而且我们也只介绍了该编码函数的第一个和第二个参数,
    // 就是输入的indices和depth。除此之外,它还支持onValue、offValue和dtype这三个可选参数,它们分别表示有效位、非有
    // 效位和输出张量的数据类型。onValue、offValue不传,分别默认为1和0;dtype不传,默认为int32。在Python版的tf中,
    // 该函数的兄弟one_hot,还支持轴axis的传入,它控制了输出张量的形状:状态数*depth、depth*状态数或depth*batch*状态数。
    //
    // 上面的内容中,我们只介绍了独热编码的实现,那么为什么要用它呢,它有哪些优势是我们期望得到的。对于我们这里的案例[1, 0, 2],
    // 为什么不就让1表示1,让0表示0,让2表示2呢?非要编个码,改为010之类的?因为独热编码有这些优势:
    // 1. 我们可以创造出公平性。每次出现都携带者团队成员的数量,避免了招摇撞骗,夸大自己的占比,比如2可以说我有N个0和2个1,但2只
    // 表示索引位置,它要这么说的话,显然是无道理的。因此每个成员只能是1,只是用来标记是不是你,无法夸大你的比重。大家的值都是1,
    // 避免了你是1我是2,出现谁大谁小从而干扰计算。这一点的科学性解释如下:
    // 大部分算法是基于向量空间中的度量来进行计算的,为了使非偏序关系的变量取值不具备偏序性,而且到圆点是等距的。使用one-hot编码,
    // 将离散特征的取值扩展到了欧式空间,离散特征的某个取值就对应欧式空间的某个点。将离散型特征使用one-hot编码,会让特征之间的距
    // 离计算更加合理。以本项目案例来讲,我们不能说0(iris-setosa)比2(iris-virginica)更接近1(iris-versicolor),这明
    // 显是不正确的。神经网络以实数为基础,并基于数学运算,例如乘法和加法。因此,它们对数字的大小及其顺序非常敏感。如果将类别编码为
    // 单个数字,则它将成为神经网络必须学习的额外非线性关系。相比之下独热编码类别不涉及任何隐含的排序,因此不会以这种方式来增加神经
    // 网络的学习能力,因为这部分非线性关系是没有意义的,这就能防止神经网络走歪了。
    // 2. 与整数相比,神经网络输出连续的浮点型值要容易得多。对于神经网络的最后一层,更为优雅自然的方法便是输出一些单独的浮点型数字,
    // 通过类似于 sigmoid激活函数,每个浮点数在[0,1]区间内用于二分类。在这种方法中,每个数字都是模型对输入示例属于相应类别概率
    // 的估计,越接近1属于相应类别的概率越高,越接近0责反之。这正是独热编码的目的:这是概率分数的正确答案,模型应针对该分数通过训练
    // 过程进行调整。
    //
    // 但是独热编码也是有局限性的:
    // 1. 首先它不适合大量的数据。如果总共有5条数据,那其中一条是这么表示[1, 0, 0, 0, 0],如果是10条,这么表示
    // [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]。如果是5000条,那就是1个1,4999个0。这种情况,术语上叫过于稀疏,反而更不利于计算。
    // 2. 另外,也是因为独热很公正公平,所以导致成员间没有个人关系。有时候,尤其当自然语言处理时,我们却希望能表示出每个词语间的相
    // 关性。比如我们在表示心情好坏程度的时候,有如下几种心情:悲伤、郁闷、无聊、微笑、大笑、爆笑。那么,我们假设以0为中心点,负面的
    // 情绪定义为负数,正面的情绪定义为正数。这几种心情可以这么表示:
    // 悲伤    -3
    // 郁闷    -2
    // 无聊    -1
    // 微笑    1
    // 大笑    2
    // 爆笑    3
    // 这样,我们就可以了解他们之间的关系了:爆笑(3)程度要大于微笑(1)。大笑(2)和郁闷(-2)是完全相反的状态。这种带成员关系
    // 的数据,就不合适用独热编码的方式来表示了。
    //
    // ***独热编码适用于一般的分类问题,比如手写数字识别,OCR识别,花朵种类识别(就像本项目案例一样),因为一般的训练集存储的都不
    // 是独热编码。我们这里存储的鸢尾花类型是数字0、1、2,我就就可以先对其进行独热编码后,再用于训练中的计算。

    // Split the data into training and test sets, using `slice`.
    const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
    const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);

    const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
    const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);

    return [xTrain, yTrain, xTest, yTest];
}

/**
 * Obtains Iris data, split into training and test sets.
 *
 * @param testSplit Fraction of the data at the end to split as test data: a
 *                  number between 0 and 1.
 *
 * @param return A length-4 'Array', with
 *               - training data as an 'Array' of length-4 'Array' of numbers.
 *               - training labels as an 'Array' of numbers, with the same length as the return training
 *                 data above. Each element of the 'Array' is from the set {0, 1, 2}.
 *               - test data as an 'Array' of length-4 'Array' of numbers.
 *               - test labels as an 'Array' of numbers, with the same length as the return test data above.
 *                 Each element of the 'Array' is from the set {0, 1, 2}.
 */
export function getIrisData(testSplit) {
    return tf.tidy(() => {
        // 对原始数据集进行预处理,整理出用于训练的带特征的输入数据集合和对应的标签集合,
        const dataByClass = [];
        const targetsByClass = [];
        for (let i = 0; i < IRIS_CLASSES.length; ++i) {
            dataByClass.push([]);      // 带特征的输入数据集合。
            targetsByClass.push([]);   // 标签集合。
        }
        for (const example of IRIS_DATA) {
            const target = example[example.length - 1];
            const data = example.slice(0, example.length - 1);
            dataByClass[target].push(data);
            targetsByClass[target].push(target);
        }
        // 预处理后的结构:
        // dataByClass => [
        //     [                               // trainData 1
        //         [5.1, 3.5, 1.4, 0.2, 0],
        //         [4.9, 3.0, 1.4, 0.2, 0],
        //         [4.7, 3.2, 1.3, 0.2, 0],
        //         ...
        //     ],
        //     [                               // trainData 2
        //         [7.0, 3.2, 4.7, 1.4, 1],
        //         [6.4, 3.2, 4.5, 1.5, 1],
        //         [6.9, 3.1, 4.9, 1.5, 1],
        //         ...
        //     ],
        //     [                               // trainData 3
        //         [6.3, 3.3, 6.0, 2.5, 2],
        //         [5.8, 2.7, 5.1, 1.9, 2],
        //         [7.1, 3.0, 5.9, 2.1, 2],
        //         ...
        //     ]
        // ]
        //
        // targetsByClass => [
        //     [0, 0, 0, ...],     // testData 1
        //     [1, 1, 1, ...],     // testData 2
        //     [2, 2, 2, ...],     // testData 3
        // ]
        //
        // 这里为什么是3套数据呢,其实就是我们这里要分别针对鸢尾花的3个类型进行训练,在下面的for循环中,
        // 我们就要把这3套数据的张量给构建出来,如上所示,trainData 1将和testData 1搭配,其他以此类推。

        const xTrains = [];
        const yTrains = [];
        const xTests = [];
        const yTests = [];
        for (let i = 0; i < IRIS_CLASSES.length; ++i) {
            const [xTrain, yTrain, xTest, yTest] = convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
            xTrains.push(xTrain);
            yTrains.push(yTrain);
            xTests.push(xTest);
            yTests.push(yTest);
        }

        const concatAxis = 0;

        return [
            tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
            tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
        ];
    });
}

loader.js

import * as tf from '@tensorflow/tfjs';
import * as ui from './ui';

// Test whether a given URL is retrievable.
export async function urlExists(url) {
    ui.status('Testing url ' + url);
    try {
        const response = await fetch(url, {method: 'HEAD'});
        return response.ok;
    } catch (err) {
        return false;
    }
}

/**
 * Load pretrained model stored at a remote URL.
 *
 * @return An instance of `tf.Model` with model topology and weights loaded.
 */
export async function loadHostedPretrainedModel(url) {
    ui.status('Loading pretrained model from ' + url);
    try {
        const model = await tf.loadLayersModel(url);
        ui.status('Done loading pretrained model.');
        return model;
    } catch (err) {
        console.error(err);
        ui.status('Loading pretrained model failed.');
    }
}

// The URL-like path that identifies the client-side location where downloaded
// or locally trained models can be stored.
const LOCAL_MODEL_URL = 'indexeddb://tfjs-iris-demo-model/v1';

export async function saveModelLocally(model) {
    const saveResult = await model.save(LOCAL_MODEL_URL);
}

export async function loadModelLocally() {
    return await tf.loadLayersModel(LOCAL_MODEL_URL);
}

export async function removeModelLocally() {
    return await tf.io.removeModel(LOCAL_MODEL_URL);
}

// Check the presence and status of locally saved models (e.g., in IndexedDB).
// Update the UI control states accordingly.
export async function updateLocalModelStatus() {
    const localModelStatus = document.getElementById('local-model-status');
    const localLoadButton = document.getElementById('load-local');
    const localRemoveButton = document.getElementById('remove-local');

    const modelsInfo = await tf.io.listModels();
    if (LOCAL_MODEL_URL in modelsInfo) {
        localModelStatus.textContent = 'Found locally-stored model saved at ' + modelsInfo[LOCAL_MODEL_URL].dateSaved.toDateString();
        localLoadButton.disabled = false;
        localRemoveButton.disabled = false;
    } else {
        localModelStatus.textContent = 'No locally-stored model is found.';
        localLoadButton.disabled = true;
        localRemoveButton.disabled = true;
    }
}

ui.js

import {IRIS_CLASSES, IRIS_NUM_CLASSES} from './data';

// Clear the evaluation table.
export function clearEvaluateTable() {
    const tableBody = document.getElementById('evaluate-tbody');
    while (tableBody.children.length > 1) {
        tableBody.removeChild(tableBody.children[1]);
    }
}

// Get manually input Iris data from the input boxes
export function getManualInputData() {
    return [
        Number(document.getElementById('petal-length').value),
        Number(document.getElementById('petal-width').value),
        Number(document.getElementById('sepal-length').value),
        Number(document.getElementById('sepal-width').value),
    ];
}

const confusionMatrixCanvas = document.getElementById('confusion-matrix');

/**
 * Render a confusion matrix.
 *
 * @param {tf.Tensor} confusionMat Confusion matrix as a 2D tf.Tensor object.
 *                    The value at row `r` and column `c` is the number of times examples of
 *                    actual class `r` were predicted as class `c`.
 */
export function drawConfusionMatrix(confusionMat) {
    const w = confusionMatrixCanvas.width;
    const h = confusionMatrixCanvas.height;
    const ctx = confusionMatrixCanvas.getContext('2d');
    ctx.clearRect(0, 0, w, h);
    const n = confusionMat.shape[0];
    const rawConfusion = confusionMat.dataSync();
    const normalizedConfusion = confusionMat.div(confusionMat.sum(-1).expandDims(0)).dataSync();
    for (let i = 0; i < n; ++i) {
        for (let j = 0; j < n; ++j) {
            const rgbValue = Math.round(255 * (1 - normalizedConfusion[i * n + j]));
            ctx.fillStyle = `rgb(${rgbValue}, ${rgbValue}, ${rgbValue})`;
            ctx.fillRect(w / n * j, h / n * i, w / n, h / n);
            ctx.stroke();
            ctx.strokeStyle = '#808080';
            ctx.rect(w / n * j, h / n * i, w / n, h / n);
            ctx.stroke();
            ctx.font = '18px Arial';
            ctx.fillStyle = '#ff00ff';
            ctx.fillText(`${rawConfusion[i * n + j]}`, w / n * (j + 0.45), h / n * (i + 0.66));
            ctx.stroke();
        }
    }
}

export function setManualInputWinnerMessage(message) {
    const winnerElement = document.getElementById('winner');
    winnerElement.textContent = message;
}

function logitsToSpans(logits) {
    let idxMax = -1;
    let maxLogit = Number.NEGATIVE_INFINITY;
    for (let i = 0; i < logits.length; ++i) {
        if (logits[i] > maxLogit) {
            maxLogit = logits[i];
            idxMax = i;
        }
    }
    const spans = [];
    for (let i = 0; i < logits.length; ++i) {
        const logitSpan = document.createElement('span');
        logitSpan.textContent = logits[i].toFixed(3);
        if (i === idxMax) {
            logitSpan.style['font-weight'] = 'bold';
        }
        logitSpan.classList = ['logit-span'];
        spans.push(logitSpan);
    }
    return spans;
}

function renderLogits(logits, parentElement) {
    while (parentElement.firstChild) {
        parentElement.removeChild(parentElement.firstChild);
    }
    logitsToSpans(logits).map(logitSpan => {
        parentElement.appendChild(logitSpan);
    });
}

export function renderLogitsForManualInput(logits) {
    const logitsElement = document.getElementById('logits');
    renderLogits(logits, logitsElement);
}

export function renderEvaluateTable(xData, yTrue, yPred, logits) {
    const tableBody = document.getElementById('evaluate-tbody');

    for (let i = 0; i < yTrue.length; ++i) {
        const row = document.createElement('tr');
        for (let j = 0; j < 4; ++j) {
            const cell = document.createElement('td');
            cell.textContent = xData[4 * i + j].toFixed(1);
            row.appendChild(cell);
        }
        const truthCell = document.createElement('td');
        truthCell.textContent = IRIS_CLASSES[yTrue[i]];
        row.appendChild(truthCell);
        const predCell = document.createElement('td');
        predCell.textContent = IRIS_CLASSES[yPred[i]];
        predCell.classList = yPred[i] === yTrue[i] ? ['correct-prediction'] : ['wrong-prediction'];
        row.appendChild(predCell);
        const logitsCell = document.createElement('td');
        const exampleLogits = logits.slice(i * IRIS_NUM_CLASSES, (i + 1) * IRIS_NUM_CLASSES);
        logitsToSpans(exampleLogits).map(logitSpan => {
            logitsCell.appendChild(logitSpan);
        });
        row.appendChild(logitsCell);
        tableBody.appendChild(row);
    }
}

export function wireUpEvaluateTableCallbacks(predictOnManualInputCallback) {
    const petalLength = document.getElementById('petal-length');
    const petalWidth = document.getElementById('petal-width');
    const sepalLength = document.getElementById('sepal-length');
    const sepalWidth = document.getElementById('sepal-width');

    const increment = 0.1;
    document.getElementById('petal-length-inc').addEventListener('click', () => {
        petalLength.value = (Number(petalLength.value) + increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('petal-length-dec').addEventListener('click', () => {
        petalLength.value = (Number(petalLength.value) - increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('petal-width-inc').addEventListener('click', () => {
        petalWidth.value = (Number(petalWidth.value) + increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('petal-width-dec').addEventListener('click', () => {
        petalWidth.value = (Number(petalWidth.value) - increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-length-inc').addEventListener('click', () => {
        sepalLength.value = (Number(sepalLength.value) + increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-length-dec').addEventListener('click', () => {
        sepalLength.value = (Number(sepalLength.value) - increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-width-inc').addEventListener('click', () => {
        sepalWidth.value = (Number(sepalWidth.value) + increment).toFixed(1);
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-width-dec').addEventListener('click', () => {
        sepalWidth.value = (Number(sepalWidth.value) - increment).toFixed(1);
        predictOnManualInputCallback();
    });

    document.getElementById('petal-length').addEventListener('change', () => {
        predictOnManualInputCallback();
    });
    document.getElementById('petal-width').addEventListener('change', () => {
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-length').addEventListener('change', () => {
        predictOnManualInputCallback();
    });
    document.getElementById('sepal-width').addEventListener('change', () => {
        predictOnManualInputCallback();
    });
}

export function loadTrainParametersFromUI() {
    return {
        epochs: Number(document.getElementById('train-epochs').value),
        learningRate: Number(document.getElementById('learning-rate').value)
    };
}

export function status(statusText) {
    console.log(statusText);
    document.getElementById('demo-status').textContent = statusText;
}

 

posted @ 2024-05-21 13:14  james·von  阅读(26)  评论(0编辑  收藏  举报