点击查看代码
import torch
from torch import nn
from d2l import torch as d2l
n_train = 50
x_train, _ = torch.sort(torch.rand(n_train) * 5)
def f(x):
return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))
x_test = torch.arange(0, 5, 0.1)
y_truth = f(x_test)
n_test = len(x_test)
print(n_test)
def plot_kernel_reg(y_hat):
d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
xlim=[0, 5], ylim=[-1, 5])
d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
print('x_test.shape', x_test.shape)
print('x_test.repeat_interleave(n_train).shape', x_test.repeat_interleave(n_train).shape)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
print('X_repeat.shape', X_repeat.shape)
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
print('attention_weights.shape', attention_weights.shape)
print('y_train.shape', y_train.shape)
y_hat = torch.matmul(attention_weights, y_train)
print('y_hat.shape', y_hat.shape)
"""
attention_weights.shape torch.Size([50, 50])
y_train.shape torch.Size([50])
y_hat.shape torch.Size([50])
"""
plot_kernel_reg(y_hat)
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
print('attention_weights.shape', attention_weights.shape)
print('attention_weights.unsqueeze(0).shape', attention_weights.unsqueeze(0).shape)
print('attention_weights.unsqueeze(0).unsqueeze(0).shape', attention_weights.unsqueeze(0).unsqueeze(0).shape)
"""
attention_weights.shape torch.Size([50, 50])
attention_weights.unsqueeze(0).shape torch.Size([1, 50, 50])
attention_weights.unsqueeze(0).unsqueeze(0).shape torch.Size([1, 1, 50, 50])
"""
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
print('torch.bmm(X, Y).shape', torch.bmm(X, Y).shape)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)