tf拟合

https://files.cnblogs.com/files/chinasoft/tf.js-demo-v2.rar?t=1656483198

 

<script src = "tf.min.js"> </script>
<script>
  /* 根据身高推测体重 */

  //把数据处理成符合模型要求的格式
  function getData() {
      //学习数据
      const heights = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
      const weights = [3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23];

      //验证数据
      const testh = [100, 101, 102, 103, 104, 105, 106];
      const testw = [201, 203, 205, 207, 208, 210, 212];

      //归一化数据
      const inputs = tf.tensor(heights);//.sub(150).div(50);
      const labels = tf.tensor(weights);//.sub(40).div(60);

      const xs = tf.tensor(testh);//.;//sub(150).div(50);
      const ys = tf.tensor(testw);//.sub(40).div(60);

      // //绘制图表
      // tfvis.render.scatterplot(
      //     { name: '身高体重' },
      //     //x轴身高,y轴体重
      //     { values: heights.map((x, i) => ({ x, y: weights[i] })) },
      //     //设置x轴范围,设置y轴范围
      //     { xAxisDomain: [140, 200], yAxisDomain: [40, 110] }
      // );

      return { inputs, labels, xs, ys };
  }



  async function run(){
      const { inputs, labels, xs, ys } = getData();

      //设置连续模型
      const model = tf.sequential();

      //设置全连接层
      model.add(tf.layers.dense({
          units: 1,
          inputShape: [1]
      }));

      // model.add(tf.layers.dense({
      //     units: 1
      // }));

      //设置损失函数,优化函数学习速率为0.1
      model.compile({
          loss: tf.losses.meanSquaredError,
          optimizer: tf.train.adam(0.1)


      });

      await model.fit(inputs, labels, {
          batchSize: 1,
          epochs: 20,
          //设置验证集
          validationData: [xs, ys],
          // callbacks: tfvis.show.fitCallbacks(
          //     { name: '训练过程' },
          //     ['loss', 'val_loss', 'acc', 'val_acc'],
          //     { callbacks: ['onEpochEnd'] }
          // )
          callbacks:function(){
              console.log("1");
          }
      });

      //对身高180的体重进行推测
      // let res = model.predict(tf.tensor([180]).sub(150).div(50));
      // console.log(res.mul(60).add(40).dataSync()[0]);

      let res = model.predict(tf.tensor([180]));
      console.log(res.dataSync()[0]);
      //保存模型
      window.download = async () => {
          await model.save('downloads://my-model');
      }
  }

  run();

</script>

  

posted @ 2022-06-29 14:14  China Soft  阅读(13)  评论(0编辑  收藏  举报