alex_bn_lee

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

【749】Kirby - Temporal Fusion Transformer related materials

参考:Darts - Temporal Fusion Transformer(Examples)

参考:什么是协变量以及协变量的定义是什么?(Covariate,研究某种自变量对因变量的影响,则实验过程中除研究的自变量因变量之外,还有其他很多变量对实验造成影响,而这些其他变量中,可以被控制的叫控制变量,不可被控制的叫协变量)


Temporal Fusion Transformer Unleashed: Deep Forecasting of Multivariate Time Series in Python | by Heiko Onnen | Medium

Temporal Fusion Transformer: Time Series Forecasting | Towards Data Science

Inter-Series Attention Model for COVID-19 Forecasting (siam.org)

Interpretable Temporal Attention Network for COVID-19 forecasting - PMC (nih.gov)

A Transformer-based Framework for Multivariate Time Series Representation Learning (acm.org) 

Time Series Made Easy in Python — darts documentation (unit8co.github.io)

Temporal Fusion Transformer (TFT) — darts documentation (unit8co.github.io)

TemporalFusionTransformer — pytorch-forecasting documentation


官网🌰解读!

1. 读取数据

# Read data
series = AirPassengersDataset().load()

2. 根据每月的天数做平均

# we convert monthly number of passengers to average daily number of passengers per month
series = series / TimeSeries.from_series(series.time_index.days_in_month)
series = series.astype(np.float32)

3. 获取train和val数据集

# Create training and validation sets:
training_cutoff = pd.Timestamp("19571201")
train, val = series.split_after(training_cutoff)

4. 数据Normalization

# Normalize the time series, different functions of transformer
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

5. 构建协变量

  • year: 年份的信息
  • month: 季节性的月份信息
  • integer index: 连续性的时间递进
# create year, month and integer index covariate series
covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
covariates = covariates.stack(
datetime_attribute_timeseries(series, attribute="month", one_hot=False)
)
covariates = covariates.stack(
TimeSeries.from_times_and_values(
times=series.time_index,
values=np.arange(len(series)),
columns=["linear_increase"],
)
)
covariates = covariates.astype(np.float32)

6. 数据Normalization

# transform covariates
scaler_conv = Scaler()
cov_train, cov_val = covariates.split_after(training_cutoff)
scaler_covs.fit(cov_train)
covariates_transformed = scaler_covs.transform(covariates)

7. 构建模型

# default quantiles for QuantileRegression
quantiles = [
0.01,
0.05,
0.1,
0.15,
0.2,
0.25,
0.3,
0.4,
0.5,
0.6,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
0.99,
]
input_chunk_length = 24
forecast_horizon = 12
my_model = TFTModel(
input_chunk_length=input_chunk_length,
output_chunk_length=forecast_horizon,
hidden_size=64,
lstm_layers=1,
num_attention_heads=4,
dropout=0.1,
batch_size=16,
n_epochs=10,
add_relative_index=False,
add_encoders=None,
likelihood=QuantileRegression(
quantiles=quantiles
), # QuantileRegression is set per default
# loss_fn=MSELoss(),
random_state=42,
)

8. 模型训练

my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)

9. 结果显示

def eval_model(model, n, actual_series, val_series):
pred_series = model.predict(n=n, num_samples=num_samples)
# plot actual series
plt.figure(figsize=figsize)
actual_series[: pred_series.end_time()].plot(label="actual")
# plot prediction with quantile ranges
pred_series.plot(
low_quantile=lowest_q, high_quantile=highest_q, label=label_q_outer
)
pred_series.plot(low_quantile=low_q, high_quantile=high_q, label=label_q_inner)
plt.title("MAPE: {:.2f}%".format(mape(val_series, pred_series)))
plt.legend()
eval_model(my_model, 24, series_transformed, val_transformed)

posted on   McDelfino  阅读(109)  评论(0编辑  收藏  举报

相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
历史上的今天:
2019-10-12 【442】Remote control GUP Linux
2013-10-12 【128】Word中的VBA
点击右上角即可分享
微信分享提示