ML.NET教程之出租车车费预测(回归问题)

理解问题

出租车的车费不仅与距离有关,还涉及乘客数量,是否使用信用卡等因素(这是的出租车是指纽约市的)。所以并不是一个简单的一元方程问题。

准备数据

建立一控制台应用程序工程,新建Data文件夹,在其目录下添加taxi-fare-train.csvtaxi-fare-test.csv文件,不要忘了把它们的Copy to Output Directory属性改为Copy if newer。之后,添加Microsoft.ML类库包。

加载数据

新建MLContext对象,及创建TextLoader对象。TextLoader对象可用于从文件中读取数据。

MLContext mlContext = new MLContext(seed: 0);

_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
    Separator = ",",
    HasHeader = true,
    Column = new[]
    {
        new TextLoader.Column("VendorId", DataKind.Text, 0),
        new TextLoader.Column("RateCode", DataKind.Text, 1),
        new TextLoader.Column("PassengerCount", DataKind.R4, 2),
        new TextLoader.Column("TripTime", DataKind.R4, 3),
        new TextLoader.Column("TripDistance", DataKind.R4, 4),
        new TextLoader.Column("PaymentType", DataKind.Text, 5),
        new TextLoader.Column("FareAmount", DataKind.R4, 6)
    }
});

提取特征

数据集文件里共有七列,前六列做为特征数据,最后一列是标记数据。

public class TaxiTrip
{
    [Column("0")]
    public string VendorId;

    [Column("1")]
    public string RateCode;

    [Column("2")]
    public float PassengerCount;

    [Column("3")]
    public float TripTime;

    [Column("4")]
    public float TripDistance;

    [Column("5")]
    public string PaymentType;

    [Column("6")]
    public float FareAmount;
}

public class TaxiTripFarePrediction
{
    [ColumnName("Score")]
    public float FareAmount;
}

训练模型

首先读取训练数据集,其次建立管道。管道中第一步是把FareAmount列复制到Label列,做为标记数据。第二步,通过OneHotEncoding方式将VendorIdRateCodePaymentType三个字符串类型列转换成数值类型列。第三步,合并六个数据列为一个特征数据列。最后一步,选择FastTreeRegressionTrainer算法做为训练方法。
完成管道后,开始训练模型。

IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
    .Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);

评估模型

这里要使用测试数据集,并用回归问题的Evaluate方法进行评估。

IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"*       Model quality metrics evaluation         ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");

保存模型

完成训练的模型可以被保存为zip文件以备之后使用。

using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
    mlContext.Model.Save(model, fileStream);

使用模型

首先加载已经保存的模型。接着建立预测函数对象,TaxiTrip为函数的输入类型,TaxiTripFarePrediction为输出类型。之后执行预测方法,传入待测数据。

ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
    loadedModel = mlContext.Model.Load(stream);
}

var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

var taxiTripSample = new TaxiTrip()
{
    VendorId = "VTS",
    RateCode = "1",
    PassengerCount = 1,
    TripTime = 1140,
    TripDistance = 3.75f,
    PaymentType = "CRD",
    FareAmount = 0 // To predict. Actual/Observed = 15.5
};

var prediction = predictionFunction.Predict(taxiTripSample);

Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");

完整示例代码

using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System;
using System.IO;

namespace TexiFarePredictor
{
    class Program
    {
        static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
        static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
        static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
        static TextLoader _textLoader;

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 0);

            _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
            {
                Separator = ",",
                HasHeader = true,
                Column = new[]
                {
                    new TextLoader.Column("VendorId", DataKind.Text, 0),
                    new TextLoader.Column("RateCode", DataKind.Text, 1),
                    new TextLoader.Column("PassengerCount", DataKind.R4, 2),
                    new TextLoader.Column("TripTime", DataKind.R4, 3),
                    new TextLoader.Column("TripDistance", DataKind.R4, 4),
                    new TextLoader.Column("PaymentType", DataKind.Text, 5),
                    new TextLoader.Column("FareAmount", DataKind.R4, 6)
                }
            });

            var model = Train(mlContext, _trainDataPath);

            Evaluate(mlContext, model);

            TestSinglePrediction(mlContext);

            Console.Read();
        }

        public static ITransformer Train(MLContext mlContext, string dataPath)
        {
            IDataView dataView = _textLoader.Read(dataPath);
            var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
                .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
                .Append(mlContext.Regression.Trainers.FastTree());
            var model = pipeline.Fit(dataView);
            SaveModelAsFile(mlContext, model);
            return model;
        }

        private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
        {
            using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
                mlContext.Model.Save(model, fileStream);
        }

        private static void Evaluate(MLContext mlContext, ITransformer model)
        {
            IDataView dataView = _textLoader.Read(_testDataPath);
            var predictions = model.Transform(dataView);
            var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
            Console.WriteLine();
            Console.WriteLine($"*************************************************");
            Console.WriteLine($"*       Model quality metrics evaluation         ");
            Console.WriteLine($"*------------------------------------------------");
            Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
            Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");
        }

        private static void TestSinglePrediction(MLContext mlContext)
        {
            ITransformer loadedModel;
            using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
            {
                loadedModel = mlContext.Model.Load(stream);
            }

            var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

            var taxiTripSample = new TaxiTrip()
            {
                VendorId = "VTS",
                RateCode = "1",
                PassengerCount = 1,
                TripTime = 1140,
                TripDistance = 3.75f,
                PaymentType = "CRD",
                FareAmount = 0 // To predict. Actual/Observed = 15.5
            };

            var prediction = predictionFunction.Predict(taxiTripSample);

            Console.WriteLine($"**********************************************************************");
            Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
            Console.WriteLine($"**********************************************************************");
        }
    }
}

程序运行后显示的结果:

*************************************************
*       Model quality metrics evaluation
*------------------------------------------------
*       R2 Score:      0.92
*       RMS loss:      2.81
**********************************************************************
Predicted fare: 15.7855, actual fare: 15.5
**********************************************************************

最后的预测结果还是比较符合实际数值的。

posted @ 2018-12-24 22:56  Ken.W  阅读(1014)  评论(2编辑  收藏  举报