【749】Kirby - Temporal Fusion Transformer related materials
参考:Darts - Temporal Fusion Transformer(Examples)
参考:什么是协变量以及协变量的定义是什么?(Covariate,研究某种自变量对因变量的影响,则实验过程中除研究的自变量因变量之外,还有其他很多变量对实验造成影响,而这些其他变量中,可以被控制的叫控制变量,不可被控制的叫协变量)
Temporal Fusion Transformer: Time Series Forecasting | Towards Data Science
Inter-Series Attention Model for COVID-19 Forecasting (siam.org)
Interpretable Temporal Attention Network for COVID-19 forecasting - PMC (nih.gov)
A Transformer-based Framework for Multivariate Time Series Representation Learning (acm.org)
Time Series Made Easy in Python — darts documentation (unit8co.github.io)
Temporal Fusion Transformer (TFT) — darts documentation (unit8co.github.io)
TemporalFusionTransformer — pytorch-forecasting documentation
官网🌰解读!
1. 读取数据
# Read data series = AirPassengersDataset().load()
2. 根据每月的天数做平均
# we convert monthly number of passengers to average daily number of passengers per month series = series / TimeSeries.from_series(series.time_index.days_in_month) series = series.astype(np.float32)
3. 获取train和val数据集
# Create training and validation sets: training_cutoff = pd.Timestamp("19571201") train, val = series.split_after(training_cutoff)
4. 数据Normalization
# Normalize the time series, different functions of transformer transformer = Scaler() train_transformed = transformer.fit_transform(train) val_transformed = transformer.transform(val) series_transformed = transformer.transform(series)
5. 构建协变量
- year: 年份的信息
- month: 季节性的月份信息
- integer index: 连续性的时间递进
# create year, month and integer index covariate series covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False) covariates = covariates.stack( datetime_attribute_timeseries(series, attribute="month", one_hot=False) ) covariates = covariates.stack( TimeSeries.from_times_and_values( times=series.time_index, values=np.arange(len(series)), columns=["linear_increase"], ) ) covariates = covariates.astype(np.float32)
6. 数据Normalization
# transform covariates scaler_conv = Scaler() cov_train, cov_val = covariates.split_after(training_cutoff) scaler_covs.fit(cov_train) covariates_transformed = scaler_covs.transform(covariates)
7. 构建模型
# default quantiles for QuantileRegression quantiles = [ 0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99, ] input_chunk_length = 24 forecast_horizon = 12 my_model = TFTModel( input_chunk_length=input_chunk_length, output_chunk_length=forecast_horizon, hidden_size=64, lstm_layers=1, num_attention_heads=4, dropout=0.1, batch_size=16, n_epochs=10, add_relative_index=False, add_encoders=None, likelihood=QuantileRegression( quantiles=quantiles ), # QuantileRegression is set per default # loss_fn=MSELoss(), random_state=42, )
8. 模型训练
my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
9. 结果显示
def eval_model(model, n, actual_series, val_series): pred_series = model.predict(n=n, num_samples=num_samples) # plot actual series plt.figure(figsize=figsize) actual_series[: pred_series.end_time()].plot(label="actual") # plot prediction with quantile ranges pred_series.plot( low_quantile=lowest_q, high_quantile=highest_q, label=label_q_outer ) pred_series.plot(low_quantile=low_q, high_quantile=high_q, label=label_q_inner) plt.title("MAPE: {:.2f}%".format(mape(val_series, pred_series))) plt.legend() eval_model(my_model, 24, series_transformed, val_transformed)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2019-10-12 【442】Remote control GUP Linux
2013-10-12 【128】Word中的VBA