Datawhale AI 夏令营——电力需求挑战赛——Task2学习笔记

一、实先准备

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import mean_squared_log_error, mean_absolute_error, mean_squared_error
import tqdm
import sys
import os
import gc
import argparse
import warnings
warnings.filterwarnings('ignore')

 

二、测试集与数据集读取

# 使用 read_csv() 函数从文件中读取训练集数据,文件名为 'train.csv'
train = pd.read_csv('./data/data283931/train.csv')
# 使用 read_csv() 函数从文件中读取测试集数据,文件名为 'train.csv'
test = pd.read_csv('./data/data283931/test.csv')

 

三、特征工程

# 合并训练数据和测试数据,并进行排序
data = pd.concat([test, train], axis=0, ignore_index=True)
data = data.sort_values(['id','dt'], ascending=False).reset_index(drop=True)

# 历史平移
for i in range(10,30):
    data[f'last{i}_target'] = data.groupby(['id'])['target'].shift(i)
    
# 窗口统计
data[f'win3_mean_target'] = (data['last10_target'] + data['last11_target'] + data['last12_target']) / 3

# 进行数据切分
train = data[data.target.notnull()].reset_index(drop=True)
test = data[data.target.isnull()].reset_index(drop=True)

# 确定输入特征
train_cols = [f for f in data.columns if f not in ['id','target']]
  1. 合并数据并排序
    1. 将测试数据和训练数据合并,并重置索引。
    2. 按照 iddt(日期)字段对数据进行降序排序。
  2. 历史平移
    1. 对于每个 id,创建历史目标值的滞后特征(last10_targetlast29_target),即每个 id 的目标值向前移动10到29个时间步长。
  3. 窗口统计
    1. 计算窗口统计特征 win3_mean_target,即 last10_targetlast11_targetlast12_target 的平均值。
  4. 数据切分
    1. 根据目标值是否为空,将数据切分为训练集和测试集。
  5. 确定输入特征
    1. train_cols 包含除了 idtarget 之外的所有列。

四、 定义模型

from lightgbm.callback import log_evaluation  # 导入lightgbm库中的log_evaluation回调函数

def time_model(lgb, train_df, test_df, cols):
    # 训练集和验证集切分
    trn_x, trn_y = train_df[train_df.dt>=31][cols], train_df[train_df.dt>=31]['target']  # 训练集特征和标签,使用日期大于等于31的数据
    val_x, val_y = train_df[train_df.dt<=30][cols], train_df[train_df.dt<=30]['target']  # 验证集特征和标签,使用日期小于等于30的数据
    
    # 构建模型输入数据
    train_matrix = lgb.Dataset(trn_x, label=trn_y)  # 训练数据集
    valid_matrix = lgb.Dataset(val_x, label=val_y)  # 验证数据集
    
    # lightgbm参数设定
    lgb_params = {
        'boosting_type': 'gbdt',  # boosting类型为gbdt
        'objective': 'regression',  # 目标函数为回归
        'metric': 'mse',  # 评估指标为均方误差
        'min_child_weight': 5,  # 最小叶子节点样本权重和
        'num_leaves': 2 ** 5,  # 叶子节点数
        'lambda_l2': 10,  # L2正则化权重
        'feature_fraction': 0.8,  # 每次迭代使用的特征比例
        'bagging_fraction': 0.8,  # 每次迭代时用于训练的数据比例
        'bagging_freq': 4,  # bagging的频率
        'learning_rate': 0.05,  # 学习率
        'seed': 2024,  # 随机种子
        'nthread': 16,  # 线程数
        'verbose': -1,  # 不打印训练过程中的信息
    }
    
    # 训练模型
    model = lgb.train(lgb_params, train_matrix, 50000, valid_sets=[train_matrix, valid_matrix],
                      callbacks=[log_evaluation(period=100)])  # 使用50000轮训练,每100轮打印一次评估信息
    
    # 验证集和测试集结果预测
    val_pred = model.predict(val_x, num_iteration=model.best_iteration)  # 预测验证集结果
    test_pred = model.predict(test_df[cols], num_iteration=model.best_iteration)  # 预测测试集结果
    
    # 离线分数评估
    score = mean_squared_error(val_pred, val_y)  # 计算均方误差作为评分
    print(score)  # 打印评分
    
    return val_pred, test_pred  # 返回验证集和测试集预测结果

lgb_oof, lgb_test = time_model(lgb, train, test, train_cols)  # 调用time_model函数进行训练和预测

# 保存结果文件到本地
test['target'] = lgb_test  # 将预测结果添加到测试集中
test[['id','dt','target']].to_csv('submit.csv', index=None)  # 将预测结果保存为submit.csv文件,包括id、日期和目标预测值

这段代码实现了使用LightGBM进行时间序列模型训练和预测的过程。

  1. 首先,根据日期将数据集划分为训练集和验证集。
  2. 使用LightGBM的Dataset类构建模型所需的数据结构。
  3. 设定LightGBM模型的参数,包括boosting类型、目标函数、评估指标等。
  4. 训练模型并通过log_evaluation回调函数定期打印训练过程中的评估信息。
  5. 使用训练好的模型对验证集和测试集进行预测。
  6. 计算验证集预测结果的均方误差作为模型的离线评估指标。
  7. 最后,将测试集预测结果保存为CSV文件,包括每个样本的id、日期和预测目标值。
 

五、 模型评估

1)模型迭代

这段代码会在以下情况下停止:

  1. 训练达到指定的迭代次数: 在 lgb.train 方法中,指定了 50000 作为最大迭代次数 (num_boost_rounds)。模型将会在达到这个迭代次数后停止训练。

  2. 早停机制: LightGBM 提供了早停机制,可以在验证集上的评估指标不再提升时停止训练。在这段代码中,通过 valid_sets=[train_matrix, valid_matrix] 将训练集和验证集传递给 lgb.train 方法,并且使用 callbacks=[log_evaluation(period=100)] 来定期记录评估结果。当验证集上的性能不再改善时,训练将会提前停止,而不会等到达到最大迭代次数。

因此,代码会在达到指定的迭代次数(50000次)或者早停机制触发时停止运行。

2)模型评分

 

posted @ 2024-07-17 22:54  heartBroken  阅读(4)  评论(0编辑  收藏  举报