bytedance

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

整体框架

 训练流程

1.初始化网络:class NetWork;

2.实例化实际的boosting树:依据参数选择不同模型,如GBDT、DART、GOSS、RF;

3.实例化实际的目标函数:依据参数选择不同类型,如RegressionL2loss、BinaryLogloss、MulticlassSoftmax、LambdarankNDCG等;

4.加载训练数据;

5.初始化实例化后boosting树;

6.初始化实例化后的目标函数;

7.训练模型并保存;

 

预测流程

 

细分模块

通信模块:

树结构

 

目标函数

数据加载

训练数据加载时,单机和并行处理的方式有所不同。在单机下,DatasetLoader::LoadFromFile(..)函数实例化Dataset类,然后可以选择利用DatasetLoader::LoadFromBinFile(...)导入二进制数据或者文本文件数据来对该类进行初始化。利用文本文件进行Dataset初始化时,Parser::CreateParser()通过预读取两行训练数据来确定数据格式(libsvm、tsv、csv)并实例化实际的数据解析器(LibSVMParser、TSVParser、CSVParser),随后当文件不大时利用DatasetLoader::LoadTextDataToMemory直接加载入内存并用SampleTextDataFromMemory采样或者直接调用SampleTextDataFromFile采样样本数据,得到采样数据后使用ConstructBinMappersFromTextData()构造Bin Mapper,随后初始化样本标签元数据,利用ExtractFeaturesFromMemory()或ConstructBinMappersFromTextData()提取特征数据,完成训练数据的加载。

-------

1) class Metadata: 存储非特征性质的元数据,例如样本标签、样本权重、boosting时的初始分;

2) class Parser: 样本数据解析接口基类,会调用Parser::CreateParser()依据数据格式自动实例化不同类型的解析器,例如class LibSVMParser、class TSVParser、class CSVParser,对应的在各派生类中会分别实现ParseOneLine()函数来按行解析不同格式的样本数据;

3) class Dataset: 用于训练或验证的数据集主类,class DatasetLoader是其友元类,主要的特征样本存放在class FeatureGroup类型的容器内;

4) class DatasetLoader: 包含一系列方法用于驱动样本数据从文件/内存加载到class Dataset类对象中,调用接口从DatasetLoader::LoadFromFile()开始,在该函数内部首先读取数据(1.小文件,运用DatasetLoader::LoadTextDataToMemory()间接调用TextReader::ReadAllLines()按行读取数据行到string类型容器内以全部导入内存,再用DatasetLoader::SampleTextDataFromMemory()采样;2.大文件,直接调用DatasetLoader::SampleTextDataFromFile()并间接使用TextReader::SampleTextDataFromFile()从

文件采样得到部分样本保存到string容器内;)。得到按行存放在string类型的容器内的采样样本后,DatasetLoader::ConstructBinMappersFromTextDa()调用解析器对数据按行进行解析,并将特征数据按列临时存储到二维数组中,然后按单机/多机对各个特征列对应的class BinMapper对象进行实例化,并用BinMapper::FindBin()找到bin内各个小区间切分的连续特征的上边界或者离散特征的映射关系,即bin_mapper容器。得到不同特征列具体特征值向bin映射的映射关系后,使用Dataset::Construct()

5) class FeatureGroup: 用于存储特征数据,class Dataset和class DatasetLoader是其友元类,数据存放在class Bin类型的对象内,而class BinMapper类型的容器实现对各子特征实际数值向bin值的映射;

6) class BinMapper:用于将实际的特征样本值转化为bin结构,其主要接口BinMapper::FindBin()实现按特征列查找bin,它首先处理当前列中的无效数据(MissingType::Zero、MissingType::None、MissingType::NaN),然后存储该列特征不同的取值及其出现的次数。当该列特征对应于BinType::NumericalBin,则函数内部再通过BinMapper::FindBinWithZeroAsOneBin()并间接调用BinMapper::GreedyFindBin()实际计算bin内子区间的上边界存到std::vector<double> bin_upper_bound_

上边界的具体计算需详细分析;当该列特征对应于BinType::CategoricalBin,离散特征的数值默认均为整型,则存放到std::unordered_map<int, unsigned int> categorical_2_bin_,此过程会移除数目较少的离散值,然后按照该列特征离散值(distinct)计数建立离散值与bin位置间的相互映射;

 

posted on 2019-02-02 16:21  bytedance  阅读(643)  评论(0编辑  收藏  举报