使用ML.NET实现NBA得分预测
使用ML.NET实现NBA得分预测
导读:ML.NET系列文章
ML.NET已经发布了v0.2版本,新增了聚类训练器,执行性能进一步增强。本文将介绍一种特殊的回归——泊松回归,并以NBA比赛得分预测的案例来演练。
泊松回归 Poisson regression
前面的文章已提过,回归是用来预测连续值的,泊松回归是其中一种,其特殊在仅用于预测正整数,通常为计数类的数值。泊松分布是离散分布,所以特征值和标签值应为相同(或接近相同)时间间隔下的独立随机事件。
那么什么场景是符合计数,可以适用泊松回归呢?举几个例子,比如共享单车的调度,每一处地域中心,每隔1小时都要统计借车和还车数,根据这个统计我们就可以预测下一个小时此处地域需要调配多少车辆才能满足需要。再比如,公司每个月都有离职员工,那么人力资源部门就可以对月人员流失数进行计数,然后通过泊松回归来预测下个月的流失情况,以便提早采取措施做好招聘计划。
是不是有一点感觉了,本次我们用大家喜欢的NBA比赛得分来进行演练,因为比赛得分正好也是一种计数,也符合连续相同时间间隔(比赛时长的大体相近),比赛结果具有不确定性,所以也是泊松回归大显身手的地方,为了易于理解,我将示范预测的是主场球队的得分。
NBA比赛数据
本案例数据来源Kaggle.com,内容是NBA Team Game Stats from 2014 to 2018,这份数据集收集了最近4年的NBA比赛,格式类似如下:
"","Team","Game","Date","Home","Opponent","WINorLOSS","TeamPoints","OpponentPoints","FieldGoals","FieldGoalsAttempted","FieldGoals.","X3PointShots","X3PointShotsAttempted","X3PointShots.","FreeThrows","FreeThrowsAttempted","FreeThrows.","OffRebounds","TotalRebounds","Assists","Steals","Blocks","Turnovers","TotalFouls","Opp.FieldGoals","Opp.FieldGoalsAttempted","Opp.FieldGoals.","Opp.3PointShots","Opp.3PointShotsAttempted","Opp.3PointShots.","Opp.FreeThrows","Opp.FreeThrowsAttempted","Opp.FreeThrows.","Opp.OffRebounds","Opp.TotalRebounds","Opp.Assists","Opp.Steals","Opp.Blocks","Opp.Turnovers","Opp.TotalFouls"
"1","ATL","1",2014-10-29,"Away","TOR","L","102","109","40","80",".500","13","22",".591","9","17",".529","10","42","26","6","8","17","24","37","90",".411","8","26",".308","27","33",".818","16","48","26","13","9","9","22"
"2","ATL","2",2014-11-01,"Home","IND","W","102","92","35","69",".507","7","20",".350","25","33",".758","3","37","26","10","6","12","20","31","81",".383","12","32",".375","18","21",".857","11","44","25","5","5","18","26"
"3","ATL","3",2014-11-05,"Away","SAS","L","92","94","38","92",".413","8","25",".320","8","11",".727","10","37","26","14","5","13","25","31","69",".449","5","17",".294","27","38",".711","11","50","25","7","9","19","15"
"4","ATL","4",2014-11-07,"Away","CHO","L","119","122","43","93",".462","13","33",".394","20","26",".769","7","38","28","8","3","19","33","48","97",".495","6","21",".286","20","27",".741","11","51","31","6","7","19","30"
"5","ATL","5",2014-11-08,"Home","NYK","W","103","96","33","81",".407","9","22",".409","28","36",".778","12","41","18","10","5","8","17","40","84",".476","8","21",".381","8","11",".727","13","44","26","2","6","15","29"
"6","ATL","6",2014-11-10,"Away","NYK","W","91","85","27","71",".380","10","27",".370","27","28",".964","9","38","20","7","3","15","16","36","83",".434","6","26",".231","7","12",".583","11","40","23","4","2","15","26"
"7","ATL","7",2014-11-12,"Home","UTA","W","100","97","39","76",".513","9","20",".450","13","18",".722","13","46","23","8","4","18","12","43","86",".500","5","23",".217","6","12",".500","8","30","28","12","8","11","17"
"8","ATL","8",2014-11-14,"Home","MIA","W","114","103","42","75",".560","11","28",".393","19","23",".826","3","36","33","10","5","13","20","35","74",".473","10","21",".476","23","25",".920","5","32","27","10","3","14","20"
各字段如下:
比赛基本信息:主场Team,比赛场次序号Game,比赛日期Date,主队Home,客队Opponent,主队胜负Win or Loss。
比赛主客队技术数据:Team Points,Field Goals,Field Goals Attempted,Field Goals Percentage,3 Point Shots,3 Point Shots Attempted,3 Point Shots Percentage,Free Throws,Free Throws Attempted,Free Throws Percentage,Offensive Rebounds,Total Rebounds,Assists,Steals,Blocks,Turnovers,Total Fouls。
这些指标反映了主客队投篮出手次数、命中数、命中率,三分球的出手次数、命中数、命中率,罚球的出手次数、命中数、命中率,助攻,抢断,犯规等,这些都是我们在看NBA时常见的统计。
由于只有这一份数据,为了分别用于训练、评估和预测,我将数据集按7:2:1的比例进行分割。
代码片段分解
定义原始数据结构、预测数据结构,TeamPoints是主队得分,是本次示例要预测的目标,因此定义为标签字段。
public class Match
{
[Column(ordinal: "0")]
public string Id;
[Column(ordinal: "1")]
public string Team;
[Column(ordinal: "2")]
public string Game;
[Column(ordinal: "3")]
public string Date;
[Column(ordinal: "4")]
public string Home;
[Column(ordinal: "5")]
public string Opponent;
[Column(ordinal: "6")]
public string WINorLOSS;
[Column(ordinal: "7", name: "Label")]
public float TeamPoints;
[Column(ordinal: "8")]
public float OpponentPoints;
[Column(ordinal: "9")]
public float FieldGoals;
[Column(ordinal: "10")]
public float FieldGoalsAttempted;
[Column(ordinal: "11")]
public float FieldGoals_;
[Column(ordinal: "12")]
public float X3PointShots;
[Column(ordinal: "13")]
public float X3PointShotsAttempted;
[Column(ordinal: "14")]
public float X3PointShots_;
[Column(ordinal: "15")]
public float FreeThrows;
[Column(ordinal: "16")]
public float FreeThrowsAttempted;
[Column(ordinal: "17")]
public float FreeThrows_;
[Column(ordinal: "18")]
public float OffRebounds;
[Column(ordinal: "19")]
public float TotalRebounds;
[Column(ordinal: "20")]
public float Assists;
[Column(ordinal: "21")]
public float Steals;
[Column(ordinal: "22")]
public float Blocks;
[Column(ordinal: "23")]
public float Turnovers;
[Column(ordinal: "24")]
public float TotalFouls;
[Column(ordinal: "25")]
public float Opp_FieldGoals;
[Column(ordinal: "26")]
public float Opp_FieldGoalsAttempted;
[Column(ordinal: "27")]
public float Opp_FieldGoals_;
[Column(ordinal: "28")]
public float Opp_3PointShots;
[Column(ordinal: "29")]
public float Opp_3PointShotsAttempted;
[Column(ordinal: "30")]
public float Opp_3PointShots_;
[Column(ordinal: "31")]
public float Opp_FreeThrows;
[Column(ordinal: "32")]
public float Opp_FreeThrowsAttempted;
[Column(ordinal: "33")]
public float Opp_FreeThrows_;
[Column(ordinal: "34")]
public float Opp_OffRebounds;
[Column(ordinal: "35")]
public float Opp_TotalRebounds;
[Column(ordinal: "36")]
public float Opp_Assists;
[Column(ordinal: "37")]
public float Opp_Steals;
[Column(ordinal: "38")]
public float Opp_Blocks;
[Column(ordinal: "39")]
public float Opp_Turnovers;
[Column(ordinal: "40")]
public float Opp_TotalFouls;
}
public class MatchPrediction
{
[ColumnName("Score")]
public float TeamPoints;
}
加载数据部分
const string DATA_PATH = "data/nba.games.stats.csv";
static ICollection<Match> LoadData()
{
var matches = new List<Match>();
using (var sr = new StreamReader(File.OpenRead(DATA_PATH)))
{
sr.ReadLine();
while (!sr.EndOfStream)
{
var line = sr.ReadLine();
var values = line.Split(",");
var match = new Match
{
Id = values[0].Trim('"'),
Team = values[1].Trim('"'),
Game = values[2].Trim('"'),
Date = values[3].Trim('"'),
Home = values[4].Trim('"'),
Opponent = values[5].Trim('"'),
WINorLOSS = values[6].Trim('"'),
TeamPoints = Convert.ToSingle(values[7].Trim('"')),
OpponentPoints = Convert.ToSingle(values[8].Trim('"')),
FieldGoals = Convert.ToSingle(values[9].Trim('"')),
FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')),
FieldGoals_ = Convert.ToSingle(values[11].Trim('"')),
X3PointShots = Convert.ToSingle(values[12].Trim('"')),
X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')),
X3PointShots_ = Convert.ToSingle(values[14].Trim('"')),
FreeThrows = Convert.ToSingle(values[15].Trim('"')),
FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')),
FreeThrows_ = Convert.ToSingle(values[17].Trim('"')),
OffRebounds = Convert.ToSingle(values[18].Trim('"')),
TotalRebounds = Convert.ToSingle(values[19].Trim('"')),
Assists = Convert.ToSingle(values[20].Trim('"')),
Steals = Convert.ToSingle(values[21].Trim('"')),
Blocks = Convert.ToSingle(values[22].Trim('"')),
Turnovers = Convert.ToSingle(values[23].Trim('"')),
TotalFouls = Convert.ToSingle(values[24].Trim('"')),
Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')),
Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')),
Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')),
Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')),
Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')),
Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')),
Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')),
Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')),
Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')),
Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')),
Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')),
Opp_Assists = Convert.ToSingle(values[36].Trim('"')),
Opp_Steals = Convert.ToSingle(values[37].Trim('"')),
Opp_Blocks = Convert.ToSingle(values[38].Trim('"')),
Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')),
Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"'))
};
matches.Add(match);
}
}
return matches;
}
训练、评估、预测部分
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData)
{
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(trainData));
pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } });
pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS"));
pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls"));
pipeline.Add(new PoissonRegressor());
var model = pipeline.Train<Match, MatchPrediction>();
return model;
}
static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData)
{
var evaluator = new RegressionEvaluator();
var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData));
Console.WriteLine("LossFn: {0}", metric.LossFn);
Console.WriteLine("RSquared: {0}", metric.RSquared);
Console.WriteLine("Rms: {0}", metric.Rms);
}
static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData)
{
var predicts = model.Predict(predictData);
var results = predictData.Zip(predicts, (d, p) => (d, p));
foreach (var result in results)
{
Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}",
result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints);
}
}
最后是Main调用部分
static void Main(string[] args)
{
var data = LoadData();
var trainCount = Convert.ToInt32(data.Count * 0.7);
var evaluateCount = Convert.ToInt32(data.Count * 0.2);
var trainData = data.Take(trainCount);
var evaluateData = data.Skip(trainCount).Take(evaluateCount);
var predictData = data.Skip(trainCount + evaluateCount);
var model = Train(trainData);
Evaluate(model, evaluateData);
Predict(model, predictData);
}
执行结果
结尾
可以看到,最近的NBA比赛主队预测得分与真实结果对比,正确率已相当可观了,由于特征值都是比赛技术数据,用在以后的比赛时,可根据比赛进行的实时情况不断更新,便可越来越接近结果。
对球迷来说这可是一件神器呀。想想2018世界杯也马上要开始了,保罗、阿喀琉斯什么的都弱爆了,相信小伙伴们也要尝试一下ML.NET的套路了吧,记得拿到历年完整的数据哟!
完整代码如下:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace NBAPrediction
{
class Program
{
const string DATA_PATH = "data/nba.games.stats.csv";
static ICollection<Match> LoadData()
{
var matches = new List<Match>();
using (var sr = new StreamReader(File.OpenRead(DATA_PATH)))
{
sr.ReadLine();
while (!sr.EndOfStream)
{
var line = sr.ReadLine();
var values = line.Split(",");
var match = new Match
{
Id = values[0].Trim('"'),
Team = values[1].Trim('"'),
Game = values[2].Trim('"'),
Date = values[3].Trim('"'),
Home = values[4].Trim('"'),
Opponent = values[5].Trim('"'),
WINorLOSS = values[6].Trim('"'),
TeamPoints = Convert.ToSingle(values[7].Trim('"')),
OpponentPoints = Convert.ToSingle(values[8].Trim('"')),
FieldGoals = Convert.ToSingle(values[9].Trim('"')),
FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')),
FieldGoals_ = Convert.ToSingle(values[11].Trim('"')),
X3PointShots = Convert.ToSingle(values[12].Trim('"')),
X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')),
X3PointShots_ = Convert.ToSingle(values[14].Trim('"')),
FreeThrows = Convert.ToSingle(values[15].Trim('"')),
FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')),
FreeThrows_ = Convert.ToSingle(values[17].Trim('"')),
OffRebounds = Convert.ToSingle(values[18].Trim('"')),
TotalRebounds = Convert.ToSingle(values[19].Trim('"')),
Assists = Convert.ToSingle(values[20].Trim('"')),
Steals = Convert.ToSingle(values[21].Trim('"')),
Blocks = Convert.ToSingle(values[22].Trim('"')),
Turnovers = Convert.ToSingle(values[23].Trim('"')),
TotalFouls = Convert.ToSingle(values[24].Trim('"')),
Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')),
Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')),
Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')),
Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')),
Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')),
Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')),
Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')),
Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')),
Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')),
Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')),
Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')),
Opp_Assists = Convert.ToSingle(values[36].Trim('"')),
Opp_Steals = Convert.ToSingle(values[37].Trim('"')),
Opp_Blocks = Convert.ToSingle(values[38].Trim('"')),
Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')),
Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"'))
};
matches.Add(match);
}
}
return matches;
}
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData)
{
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(trainData));
pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } });
pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS"));
pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls"));
pipeline.Add(new PoissonRegressor());
var model = pipeline.Train<Match, MatchPrediction>();
return model;
}
static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData)
{
var evaluator = new RegressionEvaluator();
var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData));
Console.WriteLine("LossFn: {0}", metric.LossFn);
Console.WriteLine("RSquared: {0}", metric.RSquared);
Console.WriteLine("Rms: {0}", metric.Rms);
}
static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData)
{
var predicts = model.Predict(predictData);
var results = predictData.Zip(predicts, (d, p) => (d, p));
foreach (var result in results)
{
Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}",
result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints);
}
}
static void Main(string[] args)
{
var data = LoadData();
var trainCount = Convert.ToInt32(data.Count * 0.7);
var evaluateCount = Convert.ToInt32(data.Count * 0.2);
var trainData = data.Take(trainCount);
var evaluateData = data.Skip(trainCount).Take(evaluateCount);
var predictData = data.Skip(trainCount + evaluateCount);
var model = Train(trainData);
Evaluate(model, evaluateData);
Predict(model, predictData);
}
}
public class Match
{
[Column(ordinal: "0")]
public string Id;
[Column(ordinal: "1")]
public string Team;
[Column(ordinal: "2")]
public string Game;
[Column(ordinal: "3")]
public string Date;
[Column(ordinal: "4")]
public string Home;
[Column(ordinal: "5")]
public string Opponent;
[Column(ordinal: "6")]
public string WINorLOSS;
[Column(ordinal: "7", name: "Label")]
public float TeamPoints;
[Column(ordinal: "8")]
public float OpponentPoints;
[Column(ordinal: "9")]
public float FieldGoals;
[Column(ordinal: "10")]
public float FieldGoalsAttempted;
[Column(ordinal: "11")]
public float FieldGoals_;
[Column(ordinal: "12")]
public float X3PointShots;
[Column(ordinal: "13")]
public float X3PointShotsAttempted;
[Column(ordinal: "14")]
public float X3PointShots_;
[Column(ordinal: "15")]
public float FreeThrows;
[Column(ordinal: "16")]
public float FreeThrowsAttempted;
[Column(ordinal: "17")]
public float FreeThrows_;
[Column(ordinal: "18")]
public float OffRebounds;
[Column(ordinal: "19")]
public float TotalRebounds;
[Column(ordinal: "20")]
public float Assists;
[Column(ordinal: "21")]
public float Steals;
[Column(ordinal: "22")]
public float Blocks;
[Column(ordinal: "23")]
public float Turnovers;
[Column(ordinal: "24")]
public float TotalFouls;
[Column(ordinal: "25")]
public float Opp_FieldGoals;
[Column(ordinal: "26")]
public float Opp_FieldGoalsAttempted;
[Column(ordinal: "27")]
public float Opp_FieldGoals_;
[Column(ordinal: "28")]
public float Opp_3PointShots;
[Column(ordinal: "29")]
public float Opp_3PointShotsAttempted;
[Column(ordinal: "30")]
public float Opp_3PointShots_;
[Column(ordinal: "31")]
public float Opp_FreeThrows;
[Column(ordinal: "32")]
public float Opp_FreeThrowsAttempted;
[Column(ordinal: "33")]
public float Opp_FreeThrows_;
[Column(ordinal: "34")]
public float Opp_OffRebounds;
[Column(ordinal: "35")]
public float Opp_TotalRebounds;
[Column(ordinal: "36")]
public float Opp_Assists;
[Column(ordinal: "37")]
public float Opp_Steals;
[Column(ordinal: "38")]
public float Opp_Blocks;
[Column(ordinal: "39")]
public float Opp_Turnovers;
[Column(ordinal: "40")]
public float Opp_TotalFouls;
}
public class MatchPrediction
{
[ColumnName("Score")]
public float TeamPoints;
}
}