[Scala] NDCG 的 Scala 实现
一、关于 NDCG
[LTR] 信息检索评价指标(RP/MAP/DCG/NDCG/RR/ERR)
二、代码实现
1、训练数据的加载解析
import scala.io.Source /* * 训练行数据 * */ case class TrainDataRow(target: Int, qid: Int, features: Array[Double]) object TrainDataRow { // 加载文件数据 // 格式: // <line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info> // <target> .=. <positive integer> // <qid> .=. <positive integer> // <feature> .=. <positive integer> // <value> .=. <float> // <info> .=. <string> def loadFile(file: String): List[TrainDataRow] = { Source.fromFile(file).getLines.toList.par.map(x => { val strArray = x.split(' ') val label = strArray(0).toInt val qid = strArray(1).split(':')(1).toInt val fValArray = strArray.drop(2).map(x => x.split(':')(1).toDouble) new TrainDataRow(label, qid, fValArray) }).toList } }
2、NDCG 的实现
object NDCG { /* * 计算 NDCG 分值 * */ def score(rows: List[TrainDataRow], k: Int): Double = { val size = k.min(rows.length - 1) // 理想 DCG var idealDcg: Double = 0 val sortedList = rows.sortWith((x, y) => x.target > y.target) for (i <- 0 to size) { // 计算累计效益 val gain = (1 << sortedList(i).target) - 1 // 计算折扣因子 val discount = 1.0 / (Math.log(i + 2) / Math.log(2)) idealDcg += gain * discount } if (idealDcg > 0) { var dcg: Double = 0 for (i <- 0 to size) { // 计算累计效益 val gain = (1 << rows(i).target) - 1 // 计算折扣因子 val discount = 1.0 / (Math.log(i + 2) / Math.log(2)) dcg += gain * discount } dcg / idealDcg } else 0 } }
3、训练数据集的 NDCG 计算
def calcNDCG(trainDataFile: String, k: Int): Double = { println("开始计算...") val start = System.nanoTime() val data = TrainDataRow.loadFile(trainDataFile) // 加载训练数据文件 println("数据量:" + data.length + ",用时:" + (System.nanoTime() - start) / 1000000 + " ms") val grpData: Map[Int, List[TrainDataRow]] = data.groupBy(_.qid) // 根据 qid 分组 val resultNDCG = grpData.map(x => NDCG.score(x._2, k)).sum / grpData.size println(s"NDCG@$k: $resultNDCG") val end = System.nanoTime() println("计算运行时间:" + (end - start) / 1000000 + " ms") resultNDCG }
by. Memento
文章作者:Memento
博客地址:http://www.cnblogs.com/Memento/
版权声明:Memento所有文章遵循创作共用版权协议,要求署名、非商业、保持一致。在满足创作共用版权协议的基础上可以转载,但请以超链接形式注明出处。
博客地址:http://www.cnblogs.com/Memento/
版权声明:Memento所有文章遵循创作共用版权协议,要求署名、非商业、保持一致。在满足创作共用版权协议的基础上可以转载,但请以超链接形式注明出处。