64注意力汇聚:Nadaraya-Watson 核回归

点击查看代码
import torch
from torch import nn
from d2l import torch as d2l

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本
def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
print(n_test)


def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);


# 平均汇聚
# 𝑓(𝑥)=1 / 𝑛 ∑𝑖=1𝑛𝑦𝑖,
# print(y_train.mean())
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
# print([y_truth, y_hat])
plot_kernel_reg(y_hat)
# d2l.plt.show()

# 非参数注意力汇聚
# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询).
print('x_test.shape', x_test.shape)
print('x_test.repeat_interleave(n_train).shape', x_test.repeat_interleave(n_train).shape)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
print('X_repeat.shape', X_repeat.shape)
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
# 矩阵相乘
print('attention_weights.shape', attention_weights.shape)
print('y_train.shape', y_train.shape)
y_hat = torch.matmul(attention_weights, y_train)
# torch.matmul(input, other) → Tensor
# 若input为一维,other为二维,则先将input的一维向量扩充到二维(维数前面插入长度为1的新维度),
# 然后进行矩阵乘积,得到结果后再将此维度去掉,得到的与input的维度相同。
print('y_hat.shape', y_hat.shape)
"""
attention_weights.shape torch.Size([50, 50])
y_train.shape torch.Size([50])
y_hat.shape torch.Size([50])
"""
plot_kernel_reg(y_hat)
# 注意力权重
# [要显示的行数,要显示的列数,查询的数目,键的数目]
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')
print('attention_weights.shape', attention_weights.shape)
print('attention_weights.unsqueeze(0).shape', attention_weights.unsqueeze(0).shape)
print('attention_weights.unsqueeze(0).unsqueeze(0).shape', attention_weights.unsqueeze(0).unsqueeze(0).shape)
"""
attention_weights.shape torch.Size([50, 50])
attention_weights.unsqueeze(0).shape torch.Size([1, 50, 50])
attention_weights.unsqueeze(0).unsqueeze(0).shape torch.Size([1, 1, 50, 50])
"""

# 带参数注意力汇聚
# 批量矩阵乘法
# 假定两个张量的形状分别是 (𝑛,𝑎,𝑏) 和 (𝑛,𝑏,𝑐) , 它们的批量矩阵乘法输出的形状为 (𝑛,𝑎,𝑐)
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
print('torch.bmm(X, Y).shape', torch.bmm(X, Y).shape)
# torch.bmm(X, Y).shape torch.Size([2, 1, 6])

posted @   荒北  阅读(97)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示