TowardsDataScience-2023-博客中文翻译-三十七-
TowardsDataScience 2023 博客中文翻译(三十七)
PyTorch 介绍 — 构建你的第一个线性模型
原文:
towardsdatascience.com/pytorch-introduction-building-your-first-linear-model-d868a8681a41
学习如何通过使用“神奇”的 Linear 层来构建你的第一个 PyTorch 模型。
·发表于 Towards Data Science ·阅读时长 8 分钟·2023 年 12 月 12 日
--
回归模型 — AI 生成的图像
在我上一篇博客中,我们学习了如何使用 PyTorch 张量,这是 PyTorch 库中最重要的对象。张量是深度学习模型的骨架,因此我们可以利用它们来将更简单的机器学习模型拟合到我们的数据集上。
尽管 PyTorch 以其深度学习能力而闻名,但我们也可以使用该框架来拟合简单的线性模型——这实际上是熟悉torch
API 的最佳方式之一!
在这篇博客中,我们将继续 PyTorch 介绍系列,查看如何使用 torch
库开发一个简单的线性回归模型。在这个过程中,我们将了解 torch
的优化器、权重和其他学习模型参数,这对于更复杂的架构将非常有用。
让我们开始吧!
加载和处理数据
在这篇博客中,我们将使用歌曲流行度数据集,我们希望根据一些歌曲特征来预测某首歌曲的流行度。让我们先看一下数据集的前几行:
songPopularity = pd.read_csv(‘./data/song_data.csv’)
歌曲流行度特征列 — 作者提供的图像
这个数据集的一些特征包括关于每首歌曲的有趣指标,例如:
-
歌曲的“能量”级别。
-
对歌曲的关键(例如 A、B、C、D 等)进行标签编码
-
歌曲响度
-
歌曲节奏。
我们的目标是利用这些特征来预测歌曲流行度,这是一个从 0 到 100 的指数。在我们上面展示的示例中,我们旨在预测以下歌曲流行度:
歌曲受欢迎程度目标列— 图片由作者提供
我们将使用 PyTorch 模块来预测这个连续变量,而不是使用 sklearn
。学习如何在 pytorch
中拟合线性回归的好处是什么?我们将获得的知识可以应用于其他复杂模型,如深层神经网络!
让我们从准备数据集开始,首先对特征和目标进行子集划分:
features = ['song_duration_ms',
'acousticness', 'danceability',
'energy', 'instrumentalness',
'key', 'liveness', 'loudness',
'audio_mode', 'speechiness',
'tempo', 'time_signature', 'audio_valence']
target = 'song_popularity'
songPopularityFeatures = songPopularity[features]
songPopularityTarget = songPopularity[target]
我们将使用 train_test_split
将数据划分为训练集和测试集。我们将在将数据转换为 tensors
之前执行此转换,因为 sklearn
的方法会自动将数据转换为 pandas
或 numpy
格式:
X_train, X_test, y_train, y_test = train_test_split(songPopularityFeatures, songPopularityTarget, test_size = 0.2)
创建了 X_train
、 X_test
、 y_train
和 y_test
后,我们现在可以将数据转换为 torch.tensor
— 通过将数据传递给 torch.Tensor
函数来完成这个过程很简单:
import torch
def dataframe_to_tensor(df):
return torch.tensor(df.values, dtype=torch.float32)
# Transform DataFrames into PyTorch tensors using the function
X_train = dataframe_to_tensor(X_train)
X_test = dataframe_to_tensor(X_test)
y_train = dataframe_to_tensor(y_train)
y_test = dataframe_to_tensor(y_test)
我们的对象现在是 torch.Tensor
格式,这是 nn.Module
期望的格式。下面是 X_train
:
X_train 张量 — 图片由作者提供
很棒——我们已经将训练和测试数据转换为 tensor
格式。我们准备好创建我们的第一个 torch
模型了,接下来我们将做这件事!
构建我们的线性模型
我们将使用一个继承自 nn.Module
父类的 LinearRegressionModel class
来训练我们的模型。nn.Module
类是 Pytorch 所有神经网络的基础类。
from torch import nn
class LinearRegressionModel(nn.Module):
'''
Torch Module class.
Initializes weight randomly and gets trained via train method.
'''
def __init__(self, optimizer):
super().__init__()
self.optimizer = optimizer
# Initialize Weights and Bias
self.weights = nn.Parameter(
torch.randn(1, 5, dtype=torch.float),
requires_grad=True)
self.bias = nn.Parameter(
torch.randn(1, 5, dtype=torch.float),
requires_grad=True
)
在这个类中,我们创建对象时只需要一个参数——optimizer
。我们这样做是因为我们希望在训练过程中测试不同的优化器。在上面的代码中,让我们聚焦于 # Initialize Weights and Bias
之后的权重初始化:
self.weights = nn.Parameter(
torch.randn(1, 13, dtype=torch.float),
requires_grad=True)
self.bias = nn.Parameter(
torch.randn(1, dtype=torch.float),
requires_grad=True
)
线性回归是一个非常简单的函数,公式为 y = b0 + b1x1 + ... bnxn
其中:
-
y 等于我们想要预测的目标
-
b0 等于偏差项。
-
b1, …, bn 等于模型的权重(每个变量在最终决策中的权重以及它是负面还是正面贡献)。
-
x1, …, xn 是特征的值。
nn.Parameter
的理念是初始化 b0(在 self.bias
中初始化的偏差)和 b1, … , bn(在 self.weights
中初始化的权重)。我们正在初始化 13 个权重,因为我们的训练数据集中有 13 个特征。
由于我们处理的是线性回归,因此只有一个偏差值,所以我们只需初始化一个随机标量(如果这个名字对你来说很陌生,可以查看我的第一篇文章!)。此外,请注意,我们正在使用 torch.randn
随机初始化这些参数。
现在,我们的目标是通过反向传播优化这些权重——为此,我们需要设置我们的线性层,包括回归公式:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (self.weights * x + self.bias).sum(axis=1)
trainModel
方法将帮助我们执行反向传播和权重调整:
def trainModel(
self,
epochs: int,
X_train: torch.Tensor,
X_test: torch.Tensor,
y_train: torch.Tensor,
y_test: torch.Tensor,
lr: float
):
'''
Trains linear model using pytorch.
Evaluates the model against test set for every epoch.
'''
torch.manual_seed(42)
# Create empty loss lists to track values
self.train_loss_values = []
self.test_loss_values = []
loss_fn = nn.L1Loss()
if self.optimizer == 'SGD':
optimizer = torch.optim.SGD(
params=self.parameters(),
lr=lr
)
elif self.optimizer == 'Adam':
optimizer = torch.optim.Adam(
params=self.parameters(),
lr=lr
)
for epoch in range(epochs):
self.train()
y_pred = self(X_train)
loss = loss_fn(y_pred, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Set the model in evaluation mode
self.eval()
with torch.inference_mode():
self.evaluate(X_test, y_test, epoch, loss_fn, loss)
在这个方法中,我们可以选择使用随机梯度下降(SGD)或自适应矩估计(Adam)优化器。更重要的是,让我们深入了解每个训练周期(对整个数据集的一次遍历)之间发生了什么:
self.train()
y_pred = self(X_train)
loss = loss_fn(y_pred, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
这段代码在神经网络的背景下极为重要。它包括了典型的torch
模型的训练过程:
-
我们使用
self.train()
将模型设置为训练模式 -
接下来,我们使用
self(X_train)
将数据传递给模型 — 这将把数据传递通过前向层。 -
loss_fn
计算训练数据上的损失。我们的损失函数是torch.L1Loss
,由平均绝对误差组成。 -
optimizer.zero_grad()
将梯度设置为零(它们会在每个周期累积,因此我们希望在每次遍历时从零开始)。 -
loss.backward()
计算每个权重相对于损失函数的梯度。这是优化权重的步骤。 -
最后,我们使用
optimizer.step()
更新模型的参数
最后的步骤是揭示如何使用evaluate
方法评估我们的模型:
def evaluate(self, X_test, y_test, epoch_nb, loss_fn, train_loss):
'''
Evaluates current epoch performance on the test set.
'''
test_pred = self(X_test)
test_loss = loss_fn(test_pred, y_test.type(torch.float))
if epoch_nb % 10 == 0:
self.train_loss_values.append(train_loss.detach().numpy())
self.test_loss_values.append(test_loss.detach().numpy())
print(f"Epoch: {epoch_nb} - MAE Train Loss: {train_loss} - MAE Test Loss: {test_loss} ")
这段代码计算测试集中的损失。此外,我们将使用这种方法在每 10 个周期中打印训练和测试集的损失。
模型准备好后,让我们在数据上训练它,并可视化训练和测试的学习曲线!
拟合模型
让我们使用构建的代码来训练模型并观察训练过程 — 首先,我们将使用Adam
优化器和0.001
学习率训练模型 500 个周期:
adam_model = LinearRegressionModel('Adam')
adam_model.trainModel(200, X_train, X_test, y_train, y_test, 0.001)
这里我使用adam
优化器训练模型 200 个周期。训练和测试损失的概述如下:
前 150 个周期的训练和测试演变 — 作者提供的图片
我们还可以绘制整个周期的训练和测试损失:
训练和测试损失 — 作者提供的图片
我们的损失仍然有点高(在最后一个周期,MAE 约为 21),因为线性回归可能无法解决这个问题。
最后,让我们只用SGD
拟合模型:
sgd_model = LinearRegressionModel(‘SGD’)
sgd_model.trainModel(500, X_train, X_test, y_train, y_test, 0.001)
SGD 模型的训练和测试损失 — 作者提供的图片
有趣的是 — 训练和测试损失没有改善!这发生是因为 SGD 对特征缩放非常敏感,可能在处理不同尺度的特征时难以计算梯度。作为挑战,尝试对特征进行缩放,并用SGD
检查结果。缩放后,你还会注意到Adam
优化器模型的行为更稳定!
结论
感谢您抽出时间阅读这篇文章!在这篇博客文章中,我们检查了如何使用torch
训练一个简单的线性回归模型。虽然 PyTorch 因其深度学习(更多层和复杂功能)而闻名,但学习简单模型是玩转这个框架的好方法。此外,这也是熟悉“损失”函数和梯度概念的绝佳用例。
我们还看到了 SGD
和 Adam
优化器的工作原理,特别是它们对未缩放特征的敏感程度。
最后,我希望你保留这个过程,它可以扩展到其他类型的模型、函数和过程:
-
train()
将模型设置为训练模式。 -
使用
torch.model
将数据传递给模型。 -
使用
nn.L1Loss()
进行回归问题。其他损失函数可以在这里找到。 -
optimizer.zero_grad()
将梯度设置为零。 -
loss.backward()
计算每个权重相对于损失函数的梯度。 -
使用
optimizer.step()
更新模型的权重。
下次 PyTorch 的帖子见!我还推荐你访问 PyTorch 从零到精通课程,这是一个精彩的免费资源,激发了这篇文章的方法论。
欢迎加入我新创建的 YouTube 频道——数据之旅。
本博客文章中使用的数据集可以在 Kaggle 平台上获得,并通过 Spotify 官方 APP 提取(www.kaggle.com/datasets/yasserh/song-popularity-dataset/data
)。 数据集的许可证为 CC0:公有领域*
PyTorch 简介——张量与张量计算
原文:
towardsdatascience.com/pytorch-introduction-tensors-and-tensor-calculations-412ff818bd5b
了解张量以及如何在最著名的机器学习库之一 pytorch
中使用它们
·发表于 Towards Data Science ·8 分钟阅读·2023 年 11 月 30 日
--
数学魔法——由 AI 生成的图像
在深度学习领域(包括 ChatGPT 的构建基础)中最重要的库之一是 pytorch
。与 Tensorflow 框架一起,pytorch
是可供软件开发人员和数据科学家使用的最著名的神经网络训练框架之一。除了其可用性和简单的 API 外,它在灵活性和内存使用方面表现出色,使其在多维微积分中极其快速(这是反向传播的一个主要组成部分,反向传播是优化神经网络权重的重要技术)——这些细节使其成为公司在构建深度学习模型时最受欢迎的库之一。
在这篇博客文章中,我们将检查一些使用pytorch
的基本操作,并了解如何处理tensor
对象!张量是数据的数学表示,通常有不同的名称:
-
1 维张量:通常称为标量,由一个数学值组成。
-
1 维张量:由 n 个示例组成,通常称为 1-D 向量,能够在单一维度中存储不同的数学元素。
-
2 维张量:通常称为矩阵,能够在两个维度中存储数据。可以想象成一个普通的 SQL 表或一个 Excel 电子表格。
-
3 维张量及以上:具有这种维度的数据通常更难以可视化,通常称为 n-维 张量*。
在对数学概念进行简单介绍后,让我们探索如何在 Python 中使用pytorch
!
张量对象
如我们所述,张量对象是 n 维 对象的数学泛化,可以扩展到几乎任何维度。尽管在深度学习的上下文中,tensors
通常是多维的,但我们也可以使用 torch
创建单元素张量(通常称为标量)(尽管名为 pytorch
,我们使用 torch
来操作 Python 中的库)。
如果张量是 torch
(或 pytorch
)中的核心对象,我们如何在库中创建它们?
超简单!让我们创建第一个单元素张量:
import torch
scalar = torch.tensor(5)
我们的 scalar
对象包含一个数字 — 5。让我们通过在 Python 控制台中调用它来可视化我们的张量:
标量对象 — 作者图像
事实 1:
torch.tensor
用于创建张量对象
当然,我们不仅限于单元素张量 — 我们还可以创建具有多个元素的 1 维对象。让我们将一个列表传递到 torch.tensor
中,看看结果如何:
vector = torch.tensor([7, 7])
vector
vector 对象 — 作者图像
我们的对象 vector
现在包含一个维度上的两个元素。可以将这些数据视为 1 行或 1 列的数据。
拥有“维度”允许我们访问张量中的有趣属性 — 例如 ndim
:
vector.ndim
vector 对象的 ndim — 作者图像
事实 2:
*tensor.ndim*
用于获取张量对象的维度数量
在我们的例子中,vector
对象只有一个维度。我们如何知道张量对象包含多少元素?通过使用另一个属性 - shape
!
vector.shape
vector 对象的 shape — 作者图像
事实 3:
*tensor.shape*
用于获取张量对象的形状
我们的张量对象在一个维度中包含两个元素。我们将看看这个输出与多维对象的比较。
torch
张量还包含一个附加的数据类型。要了解是哪个,我们可以使用:
vector.dtype
vector 对象的 dtype — 作者图像
事实 4:
*tensor.dtype*
输出张量对象的类型。
我们的张量包含 int64
格式的数据。
现在让我们将对象扩展为一个 2-D 张量:
matrix = torch.tensor([[10.0, 20.0],
[30.0, 40.0]])
matrix
matrix 对象 — 作者图像
让我们看看关于我们的 matrix
对象的一些属性:
print(matrix.ndim)
print(matrix.shape)
print(matrix.dtype)
ndim、shape 和 dtype 的矩阵对象 — 作者图像
我们的 matrix
对象包含两个维度中每个有 2 个元素的数据,dtype 为 float32
。
为了完成我们对创建张量的探索,让我们看看如何使用 torch.rand
生成随机张量:
torch.rand(size=(4, 4))
随机张量 — 作者图像
例如,在上面的张量中,我们使用tensor.rand
生成一个 4x4 的矩阵。这是在深度学习中非常常见的操作(例如,生成随机的神经网络层权重以便后续优化)。
张量操作
现在我们来看看如何对张量进行操作。如果你已经熟悉numpy
,这应该很简单!从一个简单的加法操作开始:
tensor = torch.tensor([1, 2, 3])
tensor + 20
张量 + 10 计算 — 作者提供的图片
向张量添加标量很简单 — 只需使用普通的数学操作即可!你能猜到如何将张量与标量相乘吗?
简单!
tensor * 10
张量 * 10 计算 — 作者提供的图片
你也可以使用抽象的torch.multiply
:
torch.multiply(tensor, 10)
张量 * 10 计算 — 作者提供的图片
张量的两个最常见的操作是Hadamard和点积,后者是注意力机制中广泛使用的著名计算之一。
让我们创建两个 2 维张量来检查这些操作:
tensor_1 = torch.tensor([[1,2,3],[2,3,4]])
tensor_2 = torch.tensor([[1,2],[2,3],[3,4]])
张量 _1,一个 2x3 张量 — 作者提供的图片
张量 _2,一个 3x2 张量 — 作者提供的图片
要执行 Hadamard 积,张量的形状必须匹配。让我们计算tensor_1
与其自身的 Hadamard 积:
# Hadamard product
tensor_1 * tensor_1
张量 _1 乘以张量 _1 — 作者提供的图片
对于点积操作,张量的内维度必须匹配。让我们将张量 _1(一个 2x3 张量)与张量 _2(一个 3x2 张量)相乘:
torch.matmul(tensor_1, tensor_2)
张量 _1 与张量 _2 的点积 — 作者提供的图片
我们也可以使用优雅的@操作,它执行相同的操作:
tensor_1 @ tensor_2
张量 _1 与张量 _2 的点积 — 作者提供的图片
张量索引
在最后的示例中,让我们看看如何从张量中提取某些元素。对于这些示例,我们将使用:
indexing_example = torch.tensor([[10,20,30],[40,50,60],[70,80,90]])
indexing_example
2-D 张量示例 — 作者提供的图片
pytorch
中的索引与其他 Python 对象类似 — 让我们尝试索引第一列:
indexing_example[0,:]
第 1 行示例 — 作者提供的图片
使用[]
中的 0 索引将使我们能够提取对象的第一行。:
符号使我们能够从某一维度提取所有元素。在我们的例子中,我们想要从列(第 2 维)中提取所有元素。
你能猜到如何提取第一列吗?只需交换索引的位置!
indexing_example[:,0]
第 1 列示例 — 作者提供的图片
对于更复杂的对象,我们也可以使用相同的逻辑。让我们尝试从一个 3D tensor
中索引一个元素:
indexing_example_3d = torch.tensor([[[10,20,30],[40,50,60],[70,80,90]], [[100,200,300],[400,500,600],[700,800,900]]])
indexing_example_3d
3D 张量 — 图片由作者提供
我们如何从这个张量中提取元素“100”?让我们来看看,我们需要:
-
第一行
-
第一列
-
第二个矩阵
使用索引逻辑,我们可以轻松实现这一点:
indexing_example_3d[1,0,0]
从 indexing_example_3d 提取的 100 元素 — 图片由作者提供
在torch
中,3D 对象的索引顺序如下:矩阵、行、列。
你能尝试索引一个 4D 对象吗?
额外内容 — 张量存储在哪里?
使用torch
相对于其他数组库(如numpy
)的一个优点是能够将张量保存在gpu
中——这在我们需要加速神经网络计算时特别有用。
默认情况下,你的张量存储在cpu
上(而大多数计算机仅有一个 cpu 可用),但你可以通过以下操作将张量发送到gpu
:
device = "cuda" if torch.cuda.is_available() else "cpu"
如果torch.cuda.is_available()
在你的计算机上找到特定的 NVIDIA gpu,它会让你将张量发送到该 gpu。
想象一下你有一个存储在tensor
命名对象中的张量,你可以使用.to
方法将其发送到设备:
tensor_on_gpu = tensor.to(device)
结论
感谢你抽时间阅读这篇文章!处理张量非常有趣,确实能为你提供一个坚实的基础,来使用高级神经网络。
torch
API 非常优雅且易于可视化。之后,你可以使用这些张量来训练神经网络(这是我将在本系列的下一篇博客中展示的内容)。此外,在学习过程中稍微掌握一些线性代数将对学习其他数据科学和机器学习算法极有帮助。
这篇文章的灵感来源于www.learnpytorch.io/
——这是一个关于 Pytorch 的优秀免费课程,我强烈推荐。在DareData我们参与了许多深度学习项目,我不能强调这门课程对我们培训人员学习这种机器学习范式及其相关框架的重要性。
在下一篇文章中,我们将探讨如何使用torch
训练线性回归——敬请期待!
如果你想参加我的 Python 课程,随时加入 我的 16 小时 Python 课程 (初学者完整 Python 训练营)。我的 Python 课程适合初学者/中级开发者,我非常希望你能加入我的课堂!
Python 初学者训练营 — 图片由作者提供
PyTorch 模型性能分析与优化
如何使用 PyTorch Profiler 和 TensorBoard 加速训练并降低成本
·
关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 6 月 12 日
--
图片由 Torsten Dederichs 提供,来自 Unsplash
训练深度学习模型,尤其是大型模型,可能是一项昂贵的开支。我们可以用来管理这些成本的主要方法之一是性能优化。性能优化是一个迭代的过程,在这个过程中我们不断寻找提高应用程序性能的机会,并利用这些机会。在之前的文章中(例如,这里),我们强调了进行这种分析时拥有适当工具的重要性。所选择的工具可能会依赖于多种因素,包括训练加速器的类型(例如,GPU、HPU 或其他)以及训练框架。
性能优化流程(作者提供)
本文将重点讨论在 GPU 上使用 PyTorch 进行训练。更具体地说,我们将重点关注 PyTorch 内置的性能分析器,PyTorch Profiler,以及查看其结果的方式之一,PyTorch Profiler TensorBoard 插件。
本文并不打算取代官方的 PyTorch 文档,无论是关于PyTorch Profiler还是使用 TensorBoard 插件分析性能分析结果。我们的意图是展示这些工具如何在日常开发过程中使用。实际上,如果你还没有阅读官方文档,我们建议你在阅读本文之前先查看一下官方文档。
一段时间以来,我特别对TensorBoard 插件教程中的一个部分感到好奇。该教程介绍了一个基于 Resnet 架构的分类模型,并在流行的 Cifar10 数据集上进行训练。它接着展示了如何使用 PyTorch Profiler 和 TensorBoard 插件来识别和修复数据加载器中的瓶颈。输入数据管道中的性能瓶颈并不少见,我们在之前的一些文章中对此进行了详细讨论(例如,这里)。令我惊讶的是教程中呈现的最终(优化后)结果(截至本文撰写时),我们已将其粘贴如下:
优化后的性能(来自 PyTorch 网站)
如果你仔细观察,你会发现优化后的 GPU 使用率为 40.46%。现在无法掩饰这一点:这些结果绝对糟糕,应该让你夜不能寐。正如我们以前扩展过的(例如,这里),GPU 是我们训练机器中最昂贵的资源,我们的目标应是最大化其使用率。40.46% 的利用率结果通常代表了显著的训练加速和成本节省机会。我们肯定能做得更好!在这篇博客文章中,我们将尝试做得更好。我们将从尝试重现官方教程中展示的结果开始,并看看是否可以使用相同的工具进一步提高训练性能。
玩具示例
以下代码块包含了由 TensorBoard-plugin tutorial 定义的训练循环,做了两个小的修改:
-
我们使用了一个与教程中使用的 CIFAR10 数据集具有相同属性和行为的虚拟数据集。进行此更改的动机可以在 这里 找到。
-
我们用 warmup 标志设置为 3 和 repeat 标志设置为 1 初始化了 torch.profiler.schedule。我们发现,这稍微增加的预热步骤数量改善了分析结果的稳定性。
import numpy as np
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.datasets
import torchvision.models
import torchvision.transforms as T
from torchvision.datasets.vision import VisionDataset
from PIL import Image
class FakeCIFAR(VisionDataset):
def __init__(self, transform):
super().__init__(root=None, transform=transform)
self.data = np.random.randint(low=0,high=256,size=(10000,32,32,3),dtype=np.uint8)
self.targets = np.random.randint(low=0,high=10,size=(10000),dtype=np.uint8).tolist()
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self) -> int:
return len(self.data)
transform = T.Compose(
[T.Resize(224),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = FakeCIFAR(transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,
shuffle=True)
device = torch.device("cuda:0")
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').cuda(device)
criterion = torch.nn.CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
# train step
def train(data):
inputs, labels = data[0].to(device=device), data[1].to(device=device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# training loop wrapped with profiler object
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, batch_data in enumerate(train_loader):
if step >= (1 + 4 + 3) * 1:
break
train(batch_data)
prof.step() # Need to call this at the end of each step
在教程中使用的 GPU 是 Tesla V100-DGXS-32GB。在这篇文章中,我们尝试重现并改进教程中的性能结果,使用的是包含 Tesla V100-SXM2–16GB GPU 的 Amazon EC2 p3.2xlarge 实例。尽管它们共享相同的架构,但这两种 GPU 之间存在一些差异,你可以在 这里 了解更多信息。我们使用 AWS PyTorch 2.0 Docker 镜像 运行训练脚本。TensorBoard 观众概览页面中显示的训练脚本性能结果如下面的图像所示:
TensorBoard Profiler 概览标签中的基线性能结果(作者捕获)
我们首先注意到,与教程相反,我们实验中的概览页面(torch-tb-profiler 版本 0.4.1)将三个分析步骤合并为一个。因此,平均总体步骤时间是 80 毫秒,而不是报告中的 240 毫秒。这在 Trace 标签中可以清楚地看到(根据我们的经验,Trace 标签几乎总是提供更准确的报告),每个步骤大约需要 80 毫秒。
TensorBoard Profiler 跟踪视图标签中的基准性能结果(作者捕获)
请注意,我们的起始点 31.65% GPU 利用率和 80 毫秒步骤时间与教程中的起始点 23.54%和 132 毫秒有所不同。这可能是由于训练环境的差异,包括 GPU 类型和 PyTorch 版本。我们还注意到,虽然教程基准结果明确诊断出 DataLoader 是瓶颈,我们的结果则没有。我们发现数据加载瓶颈往往会伪装成“CPU Exec”或“Other”在概览标签中显示为高百分比。
优化 #1:多进程数据加载
让我们开始应用多进程数据加载如教程中所述。鉴于Amazon EC2 p3.2xlarge实例具有 8 个 vCPUs,我们将DataLoader工作进程数设置为 8,以获得最佳性能:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,
shuffle=True, num_workers=8)
该优化的结果如下所示:
TensorBoard Profiler 概览标签中的多进程数据加载结果(作者捕获)
单行代码的更改使 GPU 利用率提高了 200%以上(从 31.65%提升至 72.81%),并且将我们的训练步骤时间减少了一半以上(从 80 毫秒降至 37 毫秒)。
这就是教程中的优化过程结束的地方。尽管我们的 GPU 利用率(72.81%)明显高于教程中的结果(40.46%),但我毫不怀疑,你会发现这些结果仍然相当不令人满意。
个人评论,你可以选择跳过:想象一下如果 PyTorch 在 GPU 训练时默认应用多进程数据加载将节省多少全球资金!的确,使用多进程可能会有一些不良副作用。然而,必须有某种形式的自动检测算法可以运行,以排除潜在问题的存在并相应地应用此优化。
优化 #2:内存锁定
如果我们分析最后一次实验的 Trace 视图,我们可以看到仍然有相当多的时间(37 毫秒中的 10 毫秒)花费在将训练数据加载到 GPU 上。
Trace 视图标签中的多进程数据加载结果(由作者捕获)
为了解决这个问题,我们将应用另一种 PyTorch 推荐的优化来简化数据输入流程,内存固定。使用固定内存可以提高主机到 GPU 的数据复制速度,更重要的是,可以使它们异步进行。这意味着我们可以在对当前批次进行训练步骤的同时在 GPU 中准备下一个训练批次。需要注意的是,尽管异步执行通常会提高性能,但也可能降低时间测量的准确性。为了我们博客文章的目的,我们将继续使用 PyTorch Profiler 报告的测量结果。有关如何获得准确测量的说明,请参见这里。有关内存固定及其副作用的更多详细信息,请参阅PyTorch 文档。
这种内存固定优化需要对两行代码进行更改。首先,我们将DataLoader的pin_memory标志设置为 True。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,
shuffle=True, num_workers=8, pin_memory=True)
然后我们将主机到设备的内存传输(在train函数中)修改为非阻塞模式:
inputs, labels = data[0].to(device=device, non_blocking=True), \
data[1].to(device=device, non_blocking=True)
内存固定优化的结果如下所示:
TensorBoard Profiler 概览标签中的内存固定结果(由作者捕获)
我们的 GPU 利用率现在达到了 92.37%,并且步骤时间进一步减少。但我们仍然可以做得更好。请注意,尽管进行了优化,性能报告仍继续表明我们花费了大量时间将数据复制到 GPU 上。我们将在下面的第 4 步中回到这个问题。
优化 #3:增加批量大小
对于我们的下一个优化,我们将注意力转向最后一次实验的内存视图:
TensorBoard Profiler 中的内存视图(由作者捕获)
图表显示,在 16 GB 的 GPU 内存中,我们的利用率峰值低于 1 GB。这是资源利用不足的极端例子,通常(但并非总是)表明存在提升性能的机会。控制内存利用率的一种方法是增加批量大小。在下面的图像中,我们展示了将批量大小增加到 512(内存利用率增加到 11.3 GB)时的性能结果。
在 TensorBoard Profiler 概述选项卡中增加批量大小的结果(由作者捕捉)
尽管 GPU 利用率的度量变化不大,但我们的训练速度已经显著提高,从每秒 1200 个样本(批量大小 32 时 46 毫秒)增加到每秒 1584 个样本(批量大小 512 时 324 毫秒)。
警告:与我们之前的优化相反,增加批量大小可能会对您的训练应用程序的行为产生影响。不同的模型对批量大小的变化表现出不同程度的敏感性。有些模型可能只需要对优化器设置进行一些调整。对于其他模型,调整到较大的批量大小可能会更困难甚至不可能。有关训练大批量数据的一些挑战,请参见这篇之前的文章。
优化 #4:减少主机到设备的拷贝
您可能注意到了之前结果中饼图上代表主机到设备数据拷贝的大红色眼睛。解决这种瓶颈的最直接方法是查看是否可以减少每批次中的数据量。注意,在我们的图像输入的情况下,我们将数据类型从 8 位无符号整数转换为 32 位浮点数,并在执行数据拷贝之前应用归一化。在下面的代码块中,我们建议对输入数据流进行更改,将数据类型转换和归一化的操作延迟到数据在 GPU 上之后:
# maintain the image input as an 8-bit uint8 tensor
transform = T.Compose(
[T.Resize(224),
T.PILToTensor()
])
train_set = FakeCIFAR(transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024,
shuffle=True, num_workers=8,
pin_memory=True)
device = torch.device("cuda:0")
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').cuda(device)
criterion = torch.nn.CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
# train step
def train(data):
inputs, labels = data[0].to(device=device, non_blocking=True), \
data[1].to(device=device, non_blocking=True)
# convert to float32 and normalize
inputs = (inputs.to(torch.float32) / 255\. - 0.5) / 0.5
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
由于此更改,从 CPU 到 GPU 复制的数据量减少了 4 倍,红色的眼睛几乎消失:
在 TensorBoard Profiler 概述选项卡中减少 CPU 到 GPU 拷贝的结果(由作者捕捉)
我们现在达到了 97.51%(!!) 的 GPU 利用率和每秒 1670 个样本的训练速度!让我们看看还能做些什么。
优化 #5:将梯度设置为 None
在这一阶段,我们似乎已经完全利用了 GPU,但这并不意味着我们不能更有效地利用它。一种流行的优化方法是将模型参数梯度设置为 None 而不是 zero,以减少 GPU 中的内存操作。有关此优化的更多细节,请参阅 PyTorch 文档。实施此优化所需的只是将 optimizer.zero_grad 调用的 set_to_none 设置为 True:
optimizer.zero_grad(set_to_none=True)
在我们的案例中,这种优化并未显著提升我们的性能。
优化 #6: 自动混合精度
GPU 内核视图显示了 GPU 内核活动的时间量,可以作为提高 GPU 利用率的有用资源:
TensorBoard Profiler 中的内核视图(作者捕获)
本报告中最明显的细节之一是 GPU Tensor Cores 的使用不足。Tensor Cores 是相对较新的 GPU 架构上提供的矩阵乘法专用处理单元,可以显著提升 AI 应用的性能。它们的未被使用可能代表了一个重要的优化机会。
由于 Tensor Cores 是专门为混合精度计算设计的,一种直接增加其利用率的方法是修改我们的模型以使用 自动混合精度(AMP)。在 AMP 模式下,模型的部分会自动转换为较低精度的 16 位浮点数,并在 GPU TensorCores 上运行。
重要的是,完整实现 AMP 可能需要 梯度缩放,我们在演示中没有包括。请在适应之前务必查看有关混合精度训练的 文档。
启用 AMP 所需对训练步骤的修改在下面的代码块中演示。
def train(data):
inputs, labels = data[0].to(device=device, non_blocking=True), \
data[1].to(device=device, non_blocking=True)
inputs = (inputs.to(torch.float32) / 255\. - 0.5) / 0.5
with torch.autocast(device_type='cuda', dtype=torch.float16):
outputs = model(inputs)
loss = criterion(outputs, labels)
# Note - torch.cuda.amp.GradScaler() may be required
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Tensor Core 利用率的影响显示在下图中。尽管它继续显示出进一步改进的机会,但仅通过一行代码,利用率从 0% 跃升至 26.3%。
Tensor Core 利用率与 AMP 优化在 TensorBoard Profiler 中的内核视图(作者捕获)
除了增加 Tensor Core 利用率外,使用 AMP 还降低了 GPU 内存利用率,从而释放出更多空间以增加批量大小。下图捕获了在 AMP 优化后和批量大小设置为 1024 时的训练性能结果:
TensorBoard Profiler 概览标签中的 AMP 优化结果(作者捕获)
尽管 GPU 利用率略有下降,但我们的主要吞吐量指标仍然增加了近 50%,从每秒 1670 个样本提高到 2477 个样本。我们正处于一个好状态!
警告:降低模型部分的精度可能会对其收敛性产生重要影响。与增加批量大小的情况一样(见上文),使用混合精度的影响会因模型而异。在某些情况下,AMP 将几乎无需努力即可工作。其他情况下,可能需要更多的工作来调整自动缩放器。还有些情况下,可能需要明确设置模型不同部分的精度类型(即,手动混合精度)。
欲了解有关使用混合精度作为内存优化方法的更多细节,请参阅我们的上一篇博客文章。
优化 #7:以图模式训练
我们将应用的最终优化是模型编译。与 PyTorch 默认的“急切执行”模式相反,compile API 将模型转换为中间计算图,然后将其编译为低级计算内核,以优化底层训练加速器的方式进行处理。有关 PyTorch 2 中模型编译的更多信息,请查看我们的上一篇文章。
以下代码块演示了应用模型编译所需的更改:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').cuda(device)
model = torch.compile(model)
模型编译优化的结果如下所示:
TensorBoard Profiler 概览标签中的图编译结果(作者捕获)
与之前实验中的 2477 个样本相比,模型编译进一步将我们的吞吐量提高到每秒 3268 个样本,性能提升了额外的 32%(!!)。
图编译改变训练步骤的方式在 TensorBoard 插件的不同视图中非常明显。例如,Kernel View 显示了新(融合)GPU 内核的使用,而 Trace View(如下所示)显示出与之前完全不同的模式。
TensorBoard Profiler 追踪视图标签中的图编译结果(作者捕获)
临时结果
在下表中,我们总结了我们应用的连续优化的结果。
性能结果总结(作者)
通过使用 PyTorch Profiler 和 TensorBoard 插件进行分析和优化的迭代方法,我们将性能提高了 817%!!
我们的工作完成了吗?绝对没有! 我们实施的每项优化都会揭示新的性能改进机会。这些机会以资源被释放的形式呈现(例如,转向混合精度使我们能够增加批量大小),或者以新发现的性能瓶颈的形式呈现(例如,我们最终的优化揭示了主机到设备数据传输中的瓶颈)。此外,还有许多我们在这篇文章中没有尝试的知名优化形式(例如,参见 这里 和 这里)。最后,新库优化(例如,我们在第 7 步演示的模型编译功能)不断发布,进一步推动我们的性能提升目标。正如我们在介绍中强调的,要充分利用这些机会,性能优化必须成为你开发工作流程中的一个迭代且持续的部分。
总结
在这篇文章中,我们展示了对一个玩具分类模型进行性能优化的重大潜力。虽然你可以使用其他性能分析工具,每种工具都有其优缺点,但我们选择了 PyTorch Profiler 和 TensorBoard 插件,因为它们的集成方便。
我们应该强调的是,成功优化的路径会因训练项目的细节而大相径庭,包括模型架构和训练环境。在实践中,实现你的目标可能比我们在这里展示的例子要困难得多。我们描述的一些技术可能对你的性能影响不大,甚至可能使其更糟。我们还指出,我们选择的具体优化以及选择应用的顺序在某种程度上是随意的。强烈建议你根据项目的具体细节开发自己的工具和技术,以实现优化目标。
机器学习工作负载的性能优化有时被视为次要的、非关键的和令人厌烦的。我希望我们已经成功地说服你,开发时间和成本的节省潜力值得对性能分析和优化进行有意义的投资。而且,你可能会发现这其实也很有趣 :)。
下一步?
这仅仅是冰山一角。性能优化远比我们在这里所涵盖的内容要复杂得多。在这篇文章的续集中,我们将深入探讨一个在 PyTorch 模型中相当常见的性能问题,即计算的部分在 CPU 上运行而不是 GPU 上,这种情况通常对开发者来说是未知的。我们还鼓励你查看我们在Medium 上的其他文章,其中许多文章涉及机器学习工作负载性能优化的不同方面。
PyTorch 模型性能分析与优化——第二部分
如何使用 PyTorch Profiler 和 TensorBoard 识别和减少训练步骤中的 CPU 计算
·
关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 6 月 19 日
--
图片由 Denise Chan 提供,来自 Unsplash
这是关于分析和优化在 GPU 上运行的 PyTorch 模型系列帖子的第二部分。在我们的第一篇文章中,我们展示了使用PyTorch Profiler和TensorBoard迭代分析和优化 PyTorch 模型的过程——以及显著的潜力。在这篇文章中,我们将重点关注由于使用急切执行而在 PyTorch 中尤其普遍的一种特定性能问题:模型执行的某些部分依赖于 CPU。识别这些问题的存在和来源可能非常困难,并且通常需要使用专用的性能分析工具。在这篇文章中,我们将分享一些在使用PyTorch Profiler和PyTorch Profiler TensorBoard 插件时识别此类性能问题的技巧。
急切执行的利与弊
PyTorch 的主要吸引力之一是其急切执行模式。在急切模式下,每个 PyTorch 操作在到达时都会立即独立执行。这与图形模式相对,在图形模式中,整个模型会被预先编译成一个单一的图形,以适合在 GPU 上运行,并作为一个整体执行。通常,这种预编译会带来更好的性能(例如,见这里)。在急切模式下,编程上下文在每个操作后返回到应用程序,从而允许我们访问和评估任意张量。这使得构建、分析和调试机器学习模型变得更加容易。另一方面,这也使得我们的模型更容易受到(有时是偶然的)低效代码块的插入。正如我们将展示的那样,知道如何识别和修复这些代码块可以对模型的速度产生显著影响。
玩具示例
在接下来的代码块中,我们介绍了用于演示的玩具示例。代码是非常宽松地基于我们之前的文章和这个 PyTorch 教程中定义的损失函数。
我们从定义一个简单的图像块分类模型开始。其架构对于本篇文章并不重要。
from math import log2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.transforms as T
from torchvision.datasets.vision import VisionDataset
import numpy as np
from PIL import Image
img_size = 32
num_classes = 8192
hidden_size = 50
# simple patch classification model
class Net(nn.Module):
def __init__(self, img_size=img_size,
hidden_size=hidden_size,
num_classes=num_classes):
super().__init__()
self.conv_in = nn.Conv2d(3, hidden_size, 3, padding='same')
num_hidden = int(log2(img_size))
hidden = []
for i in range(num_hidden):
hidden.append(nn.Conv2d(hidden_size, hidden_size, 3, padding='same'))
hidden.append(nn.ReLU())
hidden.append(nn.MaxPool2d(2))
self.hidden = nn.Sequential(*hidden)
self.conv_out = nn.Conv2d(hidden_size, num_classes, 3, padding='same')
def forward(self, x):
x = F.relu(self.conv_in(x))
x = self.hidden(x)
x = self.conv_out(x)
x = torch.flatten(x, 1)
return x
接下来,我们定义一个相当标准的交叉熵损失函数。这个损失函数将是我们讨论的主要焦点。
def log_softmax(x):
return x - x.exp().sum(-1).log().unsqueeze(-1)
def weighted_nll(pred, target, weight):
assert target.max() < num_classes
nll = -pred[range(target.shape[0]), target]
nll = nll * weight[target]
nll = nll / weight[target].sum()
sum_nll = nll.sum()
return sum_nll
# custom loss definition
class CrossEntropyLoss(nn.Module):
def forward(self, input, target):
pred = log_softmax(input)
loss = weighted_nll(pred, target, torch.Tensor([0.1] * num_classes).cuda())
return loss
最后,我们定义数据集和训练循环:
# dataset with random patches of size 32x32
class FakePatches(VisionDataset):
def __init__(self, transform):
super().__init__(root=None, transform=transform)
self.data = np.random.randint(low=0,high=256,size=(10000,32,32,3),dtype=np.uint8)
self.targets = np.random.randint(low=0,high=num_classes,size=(10000),dtype=np.uint16).tolist()
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self) -> int:
return len(self.data)
transform = T.Compose([T.PILToTensor()])
train_set = FakePatches(transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,
shuffle=True, num_workers=8, pin_memory=True)
device = torch.device("cuda:0")
model = Net().cuda(device)
criterion = CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
# training loop wrapped with profiler object
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(’./log/example’),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, data in enumerate(train_loader):
inputs = data[0].to(device=device, non_blocking=True)
labels = data[1].to(device=device, non_blocking=True)
inputs = (inputs.to(torch.float32) / 255\. - 0.5) / 0.5
if step >= (1 + 4 + 3) * 1:
break
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
prof.step()
一个经验丰富的 PyTorch 开发者可能已经注意到,我们的示例中在损失函数中包含了一些低效的代码行。同时,代码中没有明显的问题,这种类型的低效并不罕见。如果你想测试你的 PyTorch 熟练度,请尝试在继续阅读之前找出我们交叉熵损失实现中的三个问题。在接下来的部分中,我们将假设我们未能自行找到这些问题,并展示如何使用 PyTorch Profiler 及其相关的 TensorBoard 插件来识别它们。
正如我们在上一篇文章中所述,我们将迭代地运行实验,识别性能问题,并尝试修复它们。我们将在Amazon EC2 g5.2xlarge 实例(包含 NVIDIA A10G GPU 和 8 个虚拟 CPU)上运行实验,并使用官方的AWS PyTorch 2.0 Docker 镜像。我们选择的训练环境有些随意,不应被视为对其任何组件的认可。
在这篇文章中,我们将重点关注 PyTorch Profiler 报告中的跟踪视图部分。有关如何使用报告的其他视图的示例,请参见我们的上一篇文章。正如我们在上一篇文章中强调的那样,虽然 PyTorch 默认的异步执行可以提高性能,但它也可能降低时间测量的准确性。为了我们博客的目的,我们将依赖于 PyTorch Profiler 报告的步骤时间。对于一些实验,我们将引入人工同步事件以强制主机和设备对齐。
初步性能结果
在下图中,我们展示了上述脚本的跟踪视图结果:
基线模型的跟踪视图(由作者捕获)
如上所述,训练步骤大约需要 9.35 毫秒,其中大约三分之一的时间花在交叉熵损失的前向传递上!这清楚地表明存在问题。我们的损失函数相对于模型包含的计算较少,绝不应占据如此大比例的步骤时间。仔细查看调用栈,我们可以看到一些函数调用加强了我们的怀疑,包括“to”、“copy_”和“cudaStreamSynchronize”。这种组合通常表明数据正在从 CPU 复制到 GPU——这不是我们希望在损失计算中间发生的事情。
我们现在知道损失函数中存在性能问题,而且很可能与将张量从主机复制到 GPU 有关。然而,这可能不足以识别导致问题的确切代码行。为了便于搜索,我们将使用带标签的 torch.profiler.record_function 上下文管理器包装每一行代码,并重新运行性能分析。
# custom loss definition
class CrossEntropyLoss(nn.Module):
def forward(self, input, target):
with torch.profiler.record_function('log_softmax'):
pred = log_softmax(input)
with torch.profiler.record_function('define_weights'):
weights = torch.Tensor([0.1]*num_classes).cuda()
with torch.profiler.record_function('weighted_nll'):
loss = weighted_nll(pred, target, weights)
return loss
标签的添加帮助我们识别出weight定义,或更准确地说,是将权重复制到 GPU 的代码行作为问题代码。
权重定义的性能问题,如追踪视图中所示(由作者捕获)
在当前定义的损失函数中,我们在每个训练步骤中都将权重向量复制到 GPU 上。由于我们的权重向量是常量,这似乎非常浪费。
优化 #1:从训练步骤中移除多余的主机到 GPU 的拷贝
一旦我们识别出第一个问题,修复它相当简单。在下面的代码块中,我们在损失init函数中仅将权重向量复制到 GPU 一次:
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*num_classes).cuda()
def forward(self, input, target):
with torch.profiler.record_function('log_softmax'):
pred = log_softmax(input)
with torch.profiler.record_function('weighted_nll'):
loss = weighted_nll(pred, target, self.weight)
return loss
下图显示了修复后的性能分析结果:
优化 #1 后的追踪视图(由作者捕获)
尽管我们的步骤时间已降至 8.9 毫秒,但我们可以清楚地看到,我们有一个新的性能问题需要解决,这次出现在weighted_nll函数中。和之前一样,我们可以使用 torch.profiler.record_function 来识别出问题代码行,即assert调用。
def weighted_nll(pred, target, weight):
with torch.profiler.record_function('assert'):
assert target.max() < num_classes
with torch.profiler.record_function('range'):
r = range(target.shape[0])
with torch.profiler.record_function('index'):
nll = -pred[r, target]
with torch.profiler.record_function('nll_calc'):
nll = nll * weight[target]
nll = nll/ weight[target].sum()
sum_nll = nll.sum()
return sum_nll
请注意,这个问题在基础实验中也存在,但被我们之前的性能问题掩盖了。在性能优化过程中,之前被其他问题掩盖的严重问题突然显现是很常见的。
对调用栈的更详细分析显示了对“item”、“_local_scalar_dense”和“cudaMemcpyAsync”的调用。这通常表明数据正在从 GPU 复制到主机。确实,我们的assert调用是在 CPU 上执行的,需要访问驻留在 GPU 上的target张量,从而引发了高效性低的数据拷贝。
优化 #2:从训练步骤中移除多余的 GPU 到主机的拷贝
虽然验证输入标签的合法性可能是必要的,但应该以不会对训练性能产生如此负面影响的方式进行。在我们的案例中,解决这个问题只是简单地将assert移动到数据输入管道中,在标签被复制到 GPU 之前。
重要提示:虽然我们的目标通常是尝试减少主机和 GPU 之间的拷贝,但有时这要么不可行(例如,如果我们需要 GPU 不支持的内核),要么不理想(例如,如果在 CPU 上运行特定内核会提高性能)。
在移除 assert 后,我们的步骤时间下降到 8.71 毫秒,新的性能问题浮现:
优化 #2 之后的跟踪视图(由作者捕获)
在 weighted_nll 中,pred 张量的索引由 r 和 target 张量定义。虽然 target 张量已经在 GPU 上,但在前一行中定义的 r 张量却没有。这再次触发了低效的主机到 GPU 数据拷贝。
优化 #3:用 torch.arange 替换 range
Python 的 range 函数在 CPU 上输出一个列表。在训练步骤中出现任何列表都应该引起警惕。在下面的代码块中,我们将 range 替换为 torch.arange 并将其配置为直接在 GPU 上创建输出张量:
def weighted_nll(pred, target, weight):
with torch.profiler.record_function('range'):
r = torch.arange(target.shape[0], device="cuda:0")
with torch.profiler.record_function('index'):
nll = -pred[r, target]
with torch.profiler.record_function('nll_calc'):
nll = nll * weight[target]
nll = nll/ weight[target].sum()
sum_nll = nll.sum()
return sum_nll
该优化的结果如下:
优化 #3 之后的跟踪视图(由作者捕获)
现在我们谈到的确实令人兴奋!!我们的步骤时间已经下降到 5.63 毫秒。但如果这对你来说好得不真实,那你是对的。最后的 profiler 步骤后出现的长同步事件表明 GPU 落后于 CPU,整体步骤时间稍高。为了获得更准确的读数,让我们在实验的其余部分中将一个人工同步事件附加到每个训练步骤的末尾:
for step, data in enumerate(train_loader):
inputs = data[0].to(device=device, non_blocking=True)
labels = data[1].to(device=device, non_blocking=True)
inputs = (inputs.to(torch.float32) / 255\. - 0.5) / 0.5
if step >= (1 + 4 + 3) * 1:
break
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
prof.step()
添加人工同步事件后,我们新的 基准 步骤时间是十毫秒。
为了识别进一步的优化机会,让我们仔细查看 weighted_nll 函数的跟踪视图,该函数占据了大部分损失计算。
weighted_nll 函数 的跟踪视图(由作者捕获)
从跟踪中我们可以看到,这个函数由多个小块组成,每个小块最终被映射到一个单独的 CUDA 核心,通过CudaLaunchKernel调用加载到 GPU 上。理想情况下,我们希望减少 GPU 核心的总数。这不仅会减少 CPU 和 GPU 之间的交互,还可能减少 GPU 操作的总数。实现这一目标的一种方法是,尽可能优先使用更高层次的 PyTorch 操作符,如 torch.nn.NLLLoss。这些函数被假定为“融合”底层操作,从而减少整体核心的数量。
优化 #4:用 torch.nn.NLLLoss 替换自定义的 NLL
下面的代码块包含我们更新后的损失定义,现在使用 torch.nn.NLLLoss。
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*num_classes).cuda()
def forward(self, input, target):
pred = log_softmax(input)
nll = torch.nn.NLLLoss(self.weight)
loss = nll(pred, target)
return loss
在这里,我们擅自引入了另一个常见的错误,并将继续进行演示。
使用更高层次的函数进一步将我们的步骤时间缩短到 9 毫秒。
优化后的跟踪视图 #4(作者拍摄)
然而,放大损失函数后,我们可以看到大量时间现在花在初始化 torch.nn.NLLLoss 对象上!
优化后的损失函数跟踪视图 #4(作者拍摄)
回顾我们的损失函数,我们可以看到在每次训练步骤中我们都会初始化一个新的 NLLLoss 对象!自然地,对象初始化发生在 CPU 上,尽管(在我们的情况下)它相对较快,但这是我们希望在训练步骤中避免的。
优化 #5:避免在训练步骤中初始化对象
在下面的代码块中,我们修改了我们的损失实现,使得在init函数中只创建一个 torch.nn.NLLLoss 实例。
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*num_classes).cuda()
self.nll = torch.nn.NLLLoss(self.weight)
def forward(self, input, target):
pred = log_softmax(input)
loss = self.nll(pred, target)
return loss
结果显示出额外的适度改进。
优化后的跟踪视图 #5(作者拍摄)
优化 #6:使用 torch.nn.CrossEntropyLoss 替代自定义损失
PyTorch 包含一个内置的 torch.nn.CrossEntropyLoss,我们现在评估并与我们自定义的损失实现进行比较。
criterion = torch.nn.CrossEntropyLoss().cuda(device)
结果步骤时间降至新的低点 8.5 毫秒,如下所示:
优化后的 Trace View #6(作者捕获)
优化 #7:编译损失函数
对于我们的最终优化尝试,我们将配置损失函数在图形模式下运行,使用 torch.compile API。正如我们在 这篇之前的帖子 中详细讨论并在 这篇文章的前篇 中演示的那样,torch.compile 将使用 内核融合 和 乱序执行 等技术,将损失函数映射到最适合底层训练加速器的低级计算内核中。
criterion = torch.compile(torch.nn.CrossEntropyLoss().cuda(device))
下图显示了该实验的 Trace View 结果,随后是损失函数的放大图。
模型编译后的 Trace View(作者捕获)
模型编译后的损失函数 Trace View(作者捕获)
我们首先看到的是包含“OptimizedModule”和“dynamo”的术语,这些术语表明使用了 torch.compile。我们还可以看到,实际上,模型编译并未减少损失函数加载的内核数量,这意味着它没有识别出额外内核融合的机会。事实上,在我们的案例中,损失编译实际上导致了损失函数前向传递时间的轻微增加。看来,CrossEntropyLoss 不够有“实质”以从这种优化中受益。
你可能会想,为什么我们不能直接将 torch 编译应用于我们的初始损失函数,并依赖它以最佳方式编译我们的代码。这可以省去我们上述描述的逐步优化的麻烦。这个方法的问题在于,尽管 PyTorch 2.0 编译(截至本文撰写时)确实优化了某些类型的 GPU 到 CPU 的交叉,但某些类型会导致图形编译崩溃,其他类型会导致创建多个小图而不是一个大图。最后一种情况会导致 图形断裂,这本质上限制了 torch.compile 功能提升性能的能力。(一种解决方法是调用 torch.compile 时将 fullgraph 标志设置为 True。)有关使用此选项的更多详细信息,请参见 我们之前的帖子。
结论
在本文中,我们展示了发现和解决一些性能问题的过程。尽管应用人工同步事件后,我们的玩具模型的训练速度增加了相对较少的百分比(约 15%),但我们展示的优化潜在影响可能非常显著。
请记住,我们最初使用的损失函数看起来相当简单。如果没有对我们应用程序行为进行深入分析,我们可能永远不会意识到其中存在问题。
让我们总结一些我们所学到的东西。我们将总结分为两部分。在第一部分中,我们描述了一些可能影响训练性能的编码习惯。在第二部分中,我们推荐了一些性能分析的技巧。请注意,这些结论基于我们在本文中分享的示例,可能不适用于您自己的用例。机器学习模型在属性和行为上有很大的差异。因此,强烈建议您根据您自己项目的细节来评估这些结论。
编码提示
您实现模型前向传播的方式可能会显著影响其性能。这里我们仅列出了一些基于我们在本文中涵盖的示例的建议。
-
避免在前向传播中初始化常量张量,应在构造函数中进行初始化。
-
避免在前向传播中对驻留在 GPU 上的张量使用asserts。可以将它们移动到数据输入管道,并/或者检查 PyTorch 是否有任何用于执行所需数据验证的内置方法。
-
避免使用列表。检查是否可以使用 torch.arange 直接在设备上创建张量作为更好的替代方案。
-
使用 PyTorch 操作符,如torch.nn.NLLLoss和torch.nn.CrossEntropyLoss,而不是创建自己的损失实现。
-
避免在前向传播中初始化对象,应在构造函数中进行初始化。
-
在相关情况下考虑使用 torch.compile。
性能分析提示
正如我们展示的那样,Tensorboard PyTorch Profiler 插件的 Trace View 对于识别我们模型中的性能问题至关重要。以下是我们示例中的一些主要收获总结:
-
高 GPU 利用率并不一定意味着您的代码运行良好。
-
注意代码中可能超出预期时间的部分。
-
使用torch.profiler.record_function来准确定位性能问题。
-
GPU 利用率的下降并不一定与性能问题的源对齐。
-
注意防止从主机到 GPU 的意外数据复制。通常通过“to”、“copy_”和“cudaStreamSynchronize”调用来识别这些操作,您可以在 Trace View 中搜索这些调用。
-
注意避免从 GPU 到主机的无意数据拷贝。这些通常通过“item”和“cudaStreamSynchronize”调用来识别,你可以在 Trace View 中搜索这些调用。
尽管我们的演示接近尾声,我们的优化工作还没有完成。请参见我们的上一篇文章以获取下一步的建议。
摘要
在这篇文章中,我们重点讨论了训练应用程序中由于 CPU 和 GPU 在前向传播阶段之间的冗余交互而导致的性能问题。我们展示了如何使用性能分析工具,如 PyTorch Profiler 及其关联的 TensorBoard 插件,来识别这些问题并实现显著的性能提升。
与我们之前的文章一样,我们强调成功优化的路径会根据训练项目的细节,包括模型架构和训练环境,差异很大。在实践中,达到你的目标可能比我们这里展示的例子更为困难。我们描述的一些技术可能对你的性能影响甚微,甚至可能使其变差。我们还指出,我们选择的具体优化和应用它们的顺序是有些随意的。强烈建议你根据项目的具体细节开发自己的工具和技术以实现优化目标。
接下来做什么?
在我们关于 PyTorch 模型性能优化的第三篇文章中,我们讨论了如何识别和减少 cudaMemcpyAsync 事件以及为什么最好避免布尔掩码操作。我们也鼓励你查看我们在Medium 上的其他文章,其中许多文章涵盖了机器学习工作负载性能优化的不同方面。
PyTorch 模型性能分析与优化 — 第三部分
如何减少“Cuda Memcpy Async”事件以及为什么你需要警惕布尔掩码操作
·
关注 发表在 Towards Data Science ·11 分钟阅读·2023 年 8 月 10 日
--
图片由 Braden Jarvis 提供,来源于 Unsplash
这是关于使用 PyTorch Profiler 和 TensorBoard 分析和优化 PyTorch 模型的系列文章的第三部分。我们的意图是突出性能分析和优化在 GPU 训练负载中的好处及其对训练速度和成本的潜在影响。特别是,我们希望展示诸如 PyTorch Profiler 和 TensorBoard 等分析工具对所有 ML 开发者的可及性。你不需要成为 CUDA 专家即可从我们讨论的技术中获得有意义的性能提升。
在我们的 第一篇文章 中,我们演示了 PyTorch Profiler TensorBoard 插件 的不同视图如何用于识别性能问题,并回顾了几种加速训练的流行技术。在 第二篇文章 中,我们展示了 TensorBoard 插件 Trace View 如何用于识别张量从 CPU 到 GPU 的拷贝以及反向拷贝的情况。这种数据移动——可能会导致同步点并显著降低训练速度——通常是无意的,有时可以很容易地避免。本篇文章的主题是我们遇到的 GPU 和 CPU 之间的同步点,这些同步点不与张量拷贝有关。与张量拷贝的情况一样,这些同步点可能会导致训练步骤的停滞,显著减慢整体训练时间。我们将展示此类情况的存在、如何使用 PyTorch Profiler 和 PyTorch Profiler TensorBoard 插件 Trace View 进行识别,以及以最小化此类同步事件的方式构建模型的潜在性能收益。
就像我们之前的帖子一样,我们将定义一个玩具 PyTorch 模型,然后 迭代地 分析其性能,识别瓶颈,并尝试修复它们。我们将在一个 Amazon EC2 g5.2xlarge 实例上运行实验(该实例包含一个 NVIDIA A10G GPU 和 8 个 vCPU),并使用官方的 AWS PyTorch 2.0 Docker 镜像。请记住,我们描述的某些行为可能因 PyTorch 版本而异。
示例
在接下来的块中,我们引入一个玩具 PyTorch 模型,它对 256x256 的输入图像进行语义分割,即,它接收一个 256x256 的 RGB 图像,并输出一个来自十个语义类别的“每像素”标签的 256x256 图。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.profiler
import torch.utils.data
from torch import Tensor
class Net(nn.Module):
def __init__(self, num_hidden=10, num_classes=10):
super().__init__()
self.conv_in = nn.Conv2d(3, 10, 3, padding='same')
hidden = []
for i in range(num_hidden):
hidden.append(nn.Conv2d(10, 10, 3, padding='same'))
hidden.append(nn.ReLU())
self.hidden = nn.Sequential(*hidden)
self.conv_out = nn.Conv2d(10, num_classes, 3, padding='same')
def forward(self, x):
x = F.relu(self.conv_in(x))
x = self.hidden(x)
x = self.conv_out(x)
return x
为了训练我们的模型,我们将使用标准的 交叉熵损失 并做一些修改:
-
我们假设目标标签包含一个 ignore 值,用于表示我们想从损失计算中排除的像素。
-
我们假设其中一个语义标签将某些像素标识为图像的“背景”。我们将定义我们的损失函数将这些视为 ignore 标签。
-
我们只会在遇到包含至少两个唯一值的目标张量的批次时更新模型权重。
尽管我们为演示目的选择了这些修改,但这类操作并不罕见,可以在许多“标准” PyTorch 模型中找到。由于我们已经是性能分析的“专家”,我们已经提前将损失函数中的每个操作都用 torch.profiler.record_function 上下文管理器进行了封装,(如我们 第二篇帖子 中所述)。
class MaskedLoss(nn.Module):
def __init__(self, ignore_val=-1, num_classes=10):
super().__init__()
self.ignore_val = ignore_val
self.num_classes = num_classes
self.loss = torch.nn.CrossEntropyLoss()
def cross_entropy(self, pred: Tensor, target: Tensor) -> Tensor:
# create a boolean mask of valid labels
with torch.profiler.record_function('create mask'):
mask = target != self.ignore_val
# permute the logits in preparation for masking
with torch.profiler.record_function('permute'):
permuted_pred = torch.permute(pred, [0, 2, 3, 1])
# apply the boolean mask to the targets and logits
with torch.profiler.record_function('mask'):
masked_target = target[mask]
masked_pred = permuted_pred[mask.unsqueeze(-1).expand(-1, -1, -1,
self.num_classes)]
masked_pred = masked_pred.reshape(-1, self.num_classes)
# calculate the cross-entropy loss
with torch.profiler.record_function('calc loss'):
loss = self.loss(masked_pred, masked_target)
return loss
def ignore_background(self, target: Tensor) -> Tensor:
# discover all indices where target label is "background"
with torch.profiler.record_function('non_zero'):
inds = torch.nonzero(target == self.num_classes - 1, as_tuple=True)
# reset all "background" labels to the ignore index
with torch.profiler.record_function('index assignment'):
target[inds] = self.ignore_val
return target
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
# ignore background labels
target = self.ignore_background(target)
# retrieve a list of unique elements in target
with torch.profiler.record_function('unique'):
unique = torch.unique(target)
# check if the number of unique items pass the threshold
with torch.profiler.record_function('numel'):
ignore_loss = torch.numel(unique) < 2
# calculate the cross-entropy loss
loss = self.cross_entropy(pred, target)
# zero the loss in the case that the number of unique elements
# is below the threshold
if ignore_loss:
loss = 0\. * loss
return loss
我们的损失函数看起来很无害,对吧?错!正如我们下面将看到的,损失函数包含了许多触发主机-设备同步事件的操作,这些操作会显著降低训练速度——这些操作不涉及将张量复制进出 GPU。就像我们之前的帖子一样,我们挑战你在继续阅读之前尝试识别出三个性能优化的机会。
出于演示目的,我们使用随机生成的图像和每像素标签图,如下所定义。
from torch.utils.data import Dataset
# A dataset with random images and label maps
class FakeDataset(Dataset):
def __init__(self, num_classes=10):
super().__init__()
self.num_classes = num_classes
self.img_size = [256, 256]
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3]+self.img_size, dtype=torch.float32)
rand_label = torch.randint(low=-1, high=self.num_classes,
size=self.img_size)
return rand_image, rand_label
train_set = FakeDataset()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256,
shuffle=True, num_workers=8, pin_memory=True)
最后,我们使用 PyTorch Profiler 配置到我们的需求来定义训练步骤:
device = torch.device("cuda:0")
model = Net().cuda(device)
criterion = MaskedLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
# training loop wrapped with profiler object
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/prof'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, data in enumerate(train_loader):
inputs = data[0].to(device=device, non_blocking=True)
labels = data[1].to(device=device, non_blocking=True)
if step >= (1 + 4 + 3) * 1:
break
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
prof.step()
如果你直接运行这个训练脚本,你可能会看到高达 ~90% 的 GPU 使用率,却不知道其中存在任何问题。只有通过性能分析,我们才能识别出潜在的性能瓶颈和训练加速的机会。那么,话不多说,让我们看看模型的表现吧。
初步性能结果
在这篇文章中,我们将重点关注 PyTorch Profiler TensorBoard 插件 的 跟踪视图。请参见我们的 先前文章 获取有关如何使用插件支持的其他 视图 的提示。
在下图中,我们展示了玩具模型单次训练步骤的 跟踪视图。
基线模型的跟踪视图(作者捕获)
我们可以清楚地看到,我们 1.3 秒的训练步骤被 torch.nonzero 操作在损失函数的第一行中 完全 主导了。所有其他操作似乎都集中在巨大 cudaMemcpyAsync 事件的两侧。发生了什么??!! 为什么这么一个看似无害的操作会造成如此大的困扰?
我们或许不应感到如此惊讶,因为 torch.nonzero 的文档确实包含了以下说明:“当 input
在 CUDA 上时,[torch.nonzero()](https://pytorch.org/docs/stable/generated/torch.nonzero.html#torch.nonzero)
会导致主机与设备的同步。” 需要同步的原因在于,与其他常见的 PyTorch 操作不同,torch.nonzero 返回的张量大小是 未 预先确定的。CPU 事先不知道输入张量中有多少个非零元素。它需要等待来自 GPU 的同步事件,以便执行适当的 GPU 内存分配并正确准备后续的 PyTorch 操作。
注意,cudaMempyAsync 的长度并不能体现 torch.nonzero 操作的复杂性,而是反映了 CPU 需要等待 GPU 完成所有之前由 CPU 启动的内核的时间。例如,如果我们在第一次调用后立即进行额外的 torch.nonzero 调用,我们第二次的 cudaMempyAsync 事件会显著比第一次短,因为 CPU 和 GPU 已经或多或少“同步”了。(请记住,这个解释来自非 CUDA 专家,所以请自行斟酌……)
优化 #1: 减少对 torch.nonzero 操作的使用
现在我们理解了瓶颈的来源,挑战在于找到一个执行相同逻辑但不触发主机-设备同步事件的替代操作序列。在我们的损失函数的情况下,我们可以使用torch.where操作符,如下代码块所示:
def ignore_background(self, target: Tensor) -> Tensor:
with torch.profiler.record_function('update background'):
target = torch.where(target==self.num_classes-1,
-1*torch.ones_like(target),target)
return target
在下图中,我们展示了Trace View的变化。
优化 #1 后的 Trace View(作者拍摄)
尽管我们成功去除了来自torch.nonzero操作符的cudaMempyAsync,但它立即被来自torch.unique操作符的cudaMempyAsync所替代,我们的步骤时间没有变化。这里 PyTorch 文档并不友好,但根据我们之前的经验,我们可以假设,我们再次因为使用未确定大小的张量而遭遇了主机-设备同步事件。
优化 #2:减少使用torch.unique操作符
替换torch.unique操作符为等效的替代方案并非总是可能的。然而,在我们的情况中,我们实际上不需要知道唯一标签的值,只需要知道唯一标签的数量。这可以通过在展平的target张量上应用torch.sort操作并计算结果步函数中的步骤数量来完成。
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
# ignore background labels
target = self.ignore_background(target)
# sort the list of labels
with torch.profiler.record_function('sort'):
sorted,_ = torch.sort(target.flatten())
# indentify the steps of the resultant step function
with torch.profiler.record_function('deriv'):
deriv = sorted[1:]-sorted[:-1]
# count the number of steps
with torch.profiler.record_function('count_nonzero'):
num_unique = torch.count_nonzero(deriv)+1
# calculate the cross-entropy loss
loss = self.cross_entropy(pred, target)
# zero the loss in the case that the number of unique elements
# is below the threshold
with torch.profiler.record_function('where'):
loss = torch.where(num_unique<2, 0.*loss, loss)
return loss
在下图中,我们捕捉了第二次优化后的Trace View:
优化 #2 后的 Trace View(作者拍摄)
我们再次解决了一个瓶颈,却面临新的问题,这次来自布尔掩码例程。
布尔掩码是我们常用的例程之一,用于减少所需的总体机器操作数。在我们的情况下,我们的目的是通过去除“忽略”像素来减少计算量,将交叉熵计算限制在感兴趣的像素上。显然,这适得其反。与之前一样,应用布尔掩码会导致一个未确定大小的张量,而触发的cudaMempyAsync大大超过了排除“忽略”像素所节省的任何开销。
优化 #3:注意布尔掩码操作
在我们的案例中,解决这个问题相当简单,因为PyTorch CrossEntropyLoss内置了设置ignore_index的选项。
class MaskedLoss(nn.Module):
def __init__(self, ignore_val=-1, num_classes=10):
super().__init__()
self.ignore_val = ignore_val
self.num_classes = num_classes
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
with torch.profiler.record_function('calc loss'):
loss = self.loss(pred, target)
return loss
在下图中,我们展示了结果的Trace View:
最终 Trace View(作者拍摄)
天哪!!我们的步骤时间已经降到了 5.4 毫秒。发生了什么?!!通过简单地调整几个函数调用,而无需修改损失函数逻辑,我们能够从训练步骤中移除同步点。重要的是,当计算几百个步骤的平均时间时,实际上为~330 毫秒,大约是我们开始时的四倍快。这比上面报告的 5.4 毫秒要高得多。这个差异源于 PyTorch Profiler 测量的是每个训练步骤的 CPU 活动时间(例如,内核加载),这不一定与 GPU 活动对齐。虽然上述同步事件引入了不必要的开销,但它们有一个积极的副作用,即提高了 CPU 和 GPU 活动之间的对齐,并提高了时间测量的准确性。在它们不存在的情况下,通过分析器测量的步骤时间出现大幅波动并不罕见。在这种情况下,建议对大量步骤的步骤时间进行平均。有关异步执行对时间测量准确性影响的更多信息,请参见这里。
重要说明:在我们选择的示例中,我们采取的步骤减少了cudaMempyAsync事件的数量,对训练步骤时间产生了明显的影响。然而,也可能存在一些情况,相同类型的改变可能会损害性能而不是提升它。例如,在布尔掩码的情况下,如果我们的掩码极其稀疏且原始张量非常大,应用掩码所节省的计算可能会超过主机与设备同步的成本。重要的是,每种优化的影响应根据具体情况进行评估。
总结
在这篇文章中,我们重点讨论了由于主机与设备同步事件而引发的训练应用中的性能问题。我们看到了一些触发这些事件的 PyTorch 操作符的例子——它们的共同特点是它们输出的张量的大小取决于输入。你也可能会遇到其他操作符引发的同步事件,这些操作符在这篇文章中没有涉及。我们展示了如何使用性能分析器,如PyTorch Profiler及其关联的TensorBoard 插件来识别这些类型的事件。
在我们的示例中,我们找到了使用固定大小张量的替代操作符,避免了同步事件的需求,从而显著改善了训练时间。然而,在实际应用中,你可能会发现解决这些瓶颈要困难得多——甚至不可能。有时,克服这些问题可能需要重新设计模型的部分内容。
下一步是什么?
在我们关于 PyTorch 模型优化的系列文章的下一部分中,我们将分析和解决 DL 训练工作负载的数据预处理管道中的性能瓶颈。一定要查看。
PyTorch 模型性能分析与优化 — 第六部分
如何使用 PyTorch Profiler、PyTorch Hooks 和 TensorBoard 识别和分析反向传播中的性能问题
·
关注 发表于 Towards Data Science · 8 分钟阅读 · 2023 年 9 月 20 日
--
图片来源:David Clode 于 Unsplash
这是我们关于使用PyTorch Profiler和TensorBoard分析和优化 PyTorch 模型的系列文章的第六部分。在这篇文章中,我们将探讨分析中一种更复杂的性能问题——训练步骤的反向传播过程中的瓶颈。我们将解释这种类型的瓶颈为什么特别具有挑战性,并提出一种使用 PyTorch 内置的为训练步骤的不同部分附加钩子来分析它的方法。非常感谢Yitzhak Levi对本文的贡献。
玩具模型
为了方便讨论,我们使用流行的timm Python 模块(版本 0.9.7)定义了一个简单的基于 Vision Transformer(ViT)的分类模型。我们将模型定义为patch_drop_rate标志设置为0.5,这会使模型在每个训练步骤中随机丢弃一半的补丁。训练脚本被编程为最小化非确定性,使用torch.use_deterministic_algorithms函数和cuBLAS环境变量CUBLAS_WORKSPACE_CONFIG。请参见下面的代码块以获取完整的模型定义:
import torch, time, os
import torch.optim
import torch.profiler
import torch.utils.data
from timm.models.vision_transformer import VisionTransformer
from torch.utils.data import Dataset
# use the GPU
device = torch.device("cuda:0")
# configure PyTorch to use reproducible algorithms
torch.manual_seed(0)
os.environ[
"CUBLAS_WORKSPACE_CONFIG"
] = ":4096:8"
torch.use_deterministic_algorithms(True)
# define the ViT-backed classification model
model = VisionTransformer(patch_drop_rate=0.5).cuda(device)
# define the loss function
loss_fn = torch.nn.CrossEntropyLoss()
# define the training optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
train_set = FakeDataset()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,
num_workers=8, pin_memory=True)
t0 = time.perf_counter()
summ = 0
count = 0
model.train()
# training loop wrapped with profiler object
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/perf')
) as prof:
for step, data in enumerate(train_loader):
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].squeeze(-1).to(device=device, non_blocking=True)
with torch.profiler.record_function('forward'):
outputs = model(inputs)
loss = loss_fn(outputs, label)
optimizer.zero_grad(set_to_none=True)
with torch.profiler.record_function('backward'):
loss.backward()
with torch.profiler.record_function('optimizer_step'):
optimizer.step()
prof.step()
batch_time = time.perf_counter() - t0
if step > 1: # skip first step
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 500:
break
print(f'average step time: {summ/count}')
我们将在一个Amazon EC2 g5.2xlarge实例(包含 NVIDIA A10G GPU 和 8 个 vCPUs)上进行实验,并使用官方AWS PyTorch 2.0 Docker 镜像。
初始性能结果
在下面的图像中,我们捕捉了在TensorBoard 插件Trace View中显示的性能结果:
反向传播中的瓶颈(作者)
在训练步骤的前向传递中,操作被集中在顶部线程中,但在底部线程的反向传递中似乎出现了性能问题。在这里我们看到一个操作,GatherBackward,占用了大量的追踪部分。仔细观察,我们可以看到底层操作包括“to”、“copy_”和“cudaStreamSynchronize”。正如我们在第二部分中看到的,这些操作通常表示数据从主机复制到设备——这是我们在训练步骤中希望避免的。
到这个时候,你自然会问:为什么会发生这种情况?我们模型定义的哪个部分导致了这个问题?GatherBackward追踪提示可能涉及一个torch.gather操作,但它来自哪里以及为什么会引发同步事件呢?
在我们之前的帖子中(例如,这里),我们提倡使用标记的torch.profiler.record_function上下文管理器来确定性能问题的来源。这里的问题是性能问题发生在反向传递中,而我们无法控制!特别是,我们无法将反向传递中的单个操作用上下文管理器包裹。 理论上,可以通过深入分析追踪视图并将反向传递中的每个片段与前向传递中的对应操作匹配来识别问题模型操作。然而,这不仅可能非常繁琐,而且还需要对模型训练步骤的所有底层操作有深入的了解。使用torch.profiler.record_function标签的优点在于它允许我们轻松定位模型中的问题部分。理想情况下,我们希望即使在反向传递中出现性能问题时,也能保留相同的能力。在下一部分,我们将描述如何使用 PyTorch hooks 实现这一点。
使用 PyTorch 反向钩子的性能分析
尽管 PyTorch 不允许你封装单独的反向传播操作,但它允许你使用其钩子支持来前置和/或后置自定义功能。PyTorch 支持为torch.Tensors和torch.nn.Modules注册钩子。尽管我们在本文中提出的技术将依赖于向模块注册反向钩子,但张量钩子的注册也可以类似地用于替换或增强基于模块的方法。
在下面的代码块中,我们定义了一个包装函数,该函数接受一个模块并注册一个full_backward_hook和一个full_backward_pre_hook(尽管在实际应用中一个就足够了)。每个钩子被编程为使用torch.profiler.record_function函数向捕获的性能分析跟踪中添加一条消息。backward_pre_hook被编程为打印一个“before”消息,而backward_hook则打印一个“after”消息。一个可选的details字符串被附加以区分相同模块类型的多个实例。
def backward_hook_wrapper(module, details=None):
# define register_full_backward_pre_hook function
def bwd_pre_hook_print(self, output):
message = f'before backward of {module.__class__.__qualname__}'
if details:
message = f'{message}: {details}'
with torch.profiler.record_function(message):
return output
# define register_full_backward_hook function
def bwd_hook_print(self, input, output):
message = f'after backward of {module.__class__.__qualname__}'
if details:
message = f'{message}: {details}'
with torch.profiler.record_function(message):
return input
# register hooks
module.register_full_backward_pre_hook(bwd_pre_hook_print)
module.register_full_backward_hook(bwd_hook_print)
return module
使用backward_hook_wrapper函数,我们可以开始定位性能问题的来源。我们从封装模型和损失函数开始,如下代码块所示:
model = backward_hook_wrapper(model)
loss_fn = backward_hook_wrapper(loss_fn)
使用TensorBoard 插件 Trace View 的搜索框,我们可以识别“before”和“after”消息的位置,并推断出模型和损失的反向传播开始和结束的位置。这使我们能够得出性能问题发生在模型的反向传播过程中。下一步是用我们的backward_hook_wrapper函数封装 Vision Transformer 的内部模块:
model.patch_embed = backward_hook_wrapper(model.patch_embed)
model.pos_drop = backward_hook_wrapper(model.pos_drop)
model.patch_drop = backward_hook_wrapper(model.patch_drop)
model.norm_pre = backward_hook_wrapper(model.norm_pre)
model.blocks = backward_hook_wrapper(model.blocks)
model.norm = backward_hook_wrapper(model.norm)
model.fc_norm = backward_hook_wrapper(model.fc_norm)
model.head_drop = backward_hook_wrapper(model.head_drop)
在上面的代码块中,我们指定了每个内部模块。另一种将模型的一级模块全部封装起来的方法是迭代其命名子模块:
for submodule in model.named_children():
submodule = backward_hook_wrapper(submodule)
下图显示了“before backward of PatchDropout”消息出现在有问题的GatherBackward操作之前:
在 Trace View 中识别有问题的反向操作的来源(作者)
我们的性能分析表明,性能问题的来源是PathDropout模块。检查该模块的forward函数,我们确实看到调用了torch.gather。
对于我们的示例模型,我们只需进行两次分析迭代即可找出性能问题的来源。在实际应用中,可能需要额外的迭代。
请注意,PyTorch 包括torch.nn.modules.module.register_module_full_backward_hook函数,它会在一次调用中将一个 hook 附加到训练步骤中的所有模块上。虽然这在简单的情况下(如我们的示例)可能足够,但它不能区分相同模块类型的不同实例。
现在我们已经知道了性能问题的来源,我们可以开始着手修复它。
优化建议:尽可能使用索引代替 gather
既然我们知道问题的来源在于torch.gather操作的 DropPatches 模块,我们可以研究一下导致长时间主机-设备同步事件的触发因素。我们的调查使我们回到了torch.use_deterministic_algorithms函数的文档,该函数告知我们,当作用于需要梯度的 CUDA 张量时,torch.gather表现出非确定性行为,除非torch.use_deterministic_algorithms函数的mode设置为True。换句话说,通过将脚本配置为使用确定性算法,我们修改了torch.gather反向传递的默认行为。事实证明,正是这种变化导致了同步事件的需求。确实,如果我们去除这个配置,性能问题就会消失!问题是,我们能否在不付出性能代价的情况下保持算法的确定性。
在下面的代码块中,我们提出了一种 PathDropout 模块的forward函数的替代实现,它使用torch.Tensor 索引而不是torch.gather来生成相同的输出。修改的代码行已被突出显示。
from timm.layers import PatchDropout
class MyPatchDropout(PatchDropout):
def forward(self, x):
prefix_tokens = x[:, :self.num_prefix_tokens]
x = x[:, self.num_prefix_tokens:]
B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1\. - self.prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device),
dim=-1)[:, :num_keep]
# The following three lines were modified from the original
# to use PyTorch indexing rather than torch.gather
stride = L * torch.unsqueeze(torch.arange(B, device=x.device), 1)
keep_indices = (stride + keep_indices).flatten()
x = x.reshape(B * L, -1)[keep_indices].view(B, num_keep, -1)
x = torch.cat((prefix_tokens, x), dim=1)
return x
model.patch_drop = MyPatchDropout(
prob = model.patch_drop.prob,
num_prefix_tokens = model.patch_drop.num_prefix_tokens
)
在下图中,我们捕捉了上述更改后的跟踪视图:
优化后的跟踪视图(作者提供)
我们可以清楚地看到,冗长的同步事件不再出现。
在我们的玩具模型中,我们很幸运地使用了torch.gather操作的方式使其可以被 PyTorch 索引替代。自然地,这并非总是如此;其他torch.gather的用法可能没有基于索引的等效实现。
结果
在下表中,我们比较了在不同场景下训练我们的玩具模型的性能结果:
优化结果(作者提供)
在我们的玩具示例中,优化带来了适度但可测量的影响——约 2%的性能提升。有趣的是,在可重复模式下的 torch 索引表现优于默认的(非确定性)torch.gather。基于这些发现,评估使用索引而非torch.gather的选项可能是一个好主意。
总结
尽管 PyTorch(正当的)以易于调试和追踪而闻名,torch.autograd仍然有些神秘,分析训练步骤的反向传递可能相当困难。为了应对这一挑战,PyTorch 支持在反向传播的不同阶段插入钩子。在这篇文章中,我们展示了如何使用 PyTorch 反向钩子,以及torch.profiler.record_function,通过迭代过程识别反向传递中的性能问题来源。我们将这种技术应用于一个简单的 ViT 模型,并了解了torch.gather操作的一些细微差别。
在这篇文章中,我们探讨了一种非常特定的性能瓶颈。请务必查看我们的其他中等博客文章,它们涵盖了关于性能分析和机器学习工作负载优化的各种话题。
QA-LoRA:在你的 GPU 上微调量化的大型语言模型
量化感知微调
·
关注 发表在 Towards Data Science · 10 分钟阅读 · 2023 年 10 月 14 日
--
作者插图 — 使用了来自 Pixabay 的图像 (1,2)
最先进的大型语言模型(LLMs)经过数十亿参数的预训练。虽然预训练的 LLMs 可以执行许多任务,但经过微调后,它们的性能会大幅提升。
多亏了 LoRA,微调成本可以显著降低。LoRA 在冻结的原始参数上添加了低秩张量,即少量参数(百万级)。在微调过程中,只有添加的张量中的参数会被训练。
LoRA 仍然需要将模型加载到内存中。为了降低内存成本并加快微调速度,一种新的方法提出了量化感知 LoRA (QA-LoRA) 微调。
在这篇文章中,我将解释 QA-LoRA,并对其与之前的工作(特别是 QLoRA)的性能进行评估。我还展示了如何使用 QA-LoRA 来微调你自己量化感知的 LoRA 以适应 Llama 2。
QLoRA 存在什么问题?
在量化的 LLM 上微调 LoRA 是 QLoRA 已经可以做到的事情。在我之前的文章中,我多次使用它来微调 LLM,例如 Llama 2 和 GPT-NeoX,在我的台式电脑上或使用免费的 Google Colab 实例。
[## 在你的计算机上使用 QLoRa 和 TRL 微调 Llama 2
在 Guanaco 上以及使用正确的填充
kaitchup.substack.com](https://kaitchup.substack.com/p/fine-tune-llama-2-on-your-computer?source=post_page-----c7291866706c--------------------------------)
在深入探讨 QA-LoRA 之前,了解 QLoRA 的当前局限性是很有趣的。
NormalFloat4 (NF4) 量化
LLM 量化算法通常将参数量化为 4 位精度,使用 INT4 数据类型。使用这种数据类型的计算在最近的 GPU 上得到了越来越多的优化。
QLoRA 默认不使用 INT4,而是另一种称为 NormalFloat4 (NF4) 的数据类型。你可以将其视为压缩的浮点数。根据 QLoRA 作者的说法,NF4 优于 INT4。使用 NF4 量化的 LLM 实现了更低的困惑度。
然而,NF4 计算对于快速推理并不理想。这是为什么使用 GPTQ 量化的模型比使用 bitsandbytes NF4 量化的模型更快的原因之一。在之前的文章中,我确认了 GPTQ 模型确实更快。
[## GPTQ 或 bitsandbytes:哪个量化方法适用于 LLM - Llama 2 的示例
大型语言模型量化,以便在你的计算机上进行负担得起的微调和推理
kaitchup.substack.com](https://kaitchup.substack.com/p/gptq-or-bitsandbytes-which-quantization?source=post_page-----c7291866706c--------------------------------)
NF4 也是 QA-LoRA 作者指出的一个弱点。
NF4 基础模型,但 FP16 LoRA
虽然基础模型使用 NF4 进行量化,但训练后的 LoRA 参数仍保持较高的精度,通常是 FP16,如下图所示。
图由作者提供
这对于 QLoRA 性能至关重要,因为天真地训练量化参数会导致性能较差。
因此,在推理时,我们有两种不同的方法来使用用 QLoRA 训练的 LoRA 适配器:
-
将它们加载到基础 LLM 之上,就像我们在 QLoRA 微调期间所做的那样
-
将它们与基础 LLM 合并
加载它们是保持性能的最佳方式。我们将 LoRA 的参数保持在 16 位精度,但由于它们仅有几百万个参数,相对于量化的基础 LLM,它们并不会消耗太多 VRAM。
另一种选择是将 LoRA 参数与基础模型合并。我在上一篇文章中探索了几种合并方案。
理想情况下,我们需要将基础模型解量化到 LoRA 参数所使用的相同精度,然后将 LoRA 参数与解量化的基础模型合并。
但结果是,合并模型不再被量化(FP16)。它是一个大规模解量化模型。我们可以对整个合并模型进行量化,但量化总是会丢失信息。我们将得到一个性能低于原始 QLoRA 微调结束时的性能的模型。
这是我获得的不同配置的结果:
表格和作者提供的结果
注意:较低的困惑度更好。
我们可以看到,量化合并模型会导致显著更高的困惑度。我们不能在不显著降低性能的情况下合并 QLoRA 适配器,同时保持量化。QLoRA 适配器不是“量化感知”的。
使用 QA-LoRA 的量化感知微调
QA-LoRA 在这篇 arXiv 论文中提出:
QA-LoRA: 量化感知的大型语言模型低秩适应(Xu 等人,2023)
这篇论文不容易阅读。QA-LoRA 动机充分,大部分结果/实验令人信服。然而,理解其工作原理需要对量化背后的机制有所了解。
我不会深入探讨数学理论和证明。我认为理解 QA-LoRA 最简单的方法是将其视为联合量化和微调 LoRA 参数的过程。适配器的参数和量化参数都是在微调过程中学习和应用的。
为了突出与 QLoRA 的区别,我们可以参考论文中的这一段:
我们引入了分组操作,将量化的参数数量从 Dout 增加到 L×Dout,同时将适应的参数数量从 Din×Dint+Dint×Dout 减少到 L × Dint + Dint × Dout。正如我们在实验中将看到的,适度的 L 可以在保持计算效率的同时实现令人满意的语言理解准确性。
此外,QA-LoRA 使用标准的 INT4 数据类型,而 QLoRA 使用 NF4。
QA-LoRA 性能
让我们看看 QA-LoRA 的作者报告的性能。他们报告了许多实验,但我认为下面的表格最能概述 QA-LoRA 的性能,相比 QLoRA 以及在各种量化精度下:
表格来自 Xu et al. (2023)
在此表中,我们比较了原始的 LLaMA 7B(16-bit)与:
-
标准 QLoRA 与 NF4 量化基础 LLM 和 FP16 LoRA(标记为 QLoRA”)
-
LLaMA 7B 量化为 GPTQ 到 INT4(标记为 “LLaMA-7B w/ GPTQ”)
-
合并的 QLoRA 适配器量化为 GTPQ(标记为 “QLoRA w/ GPTQ”)
-
QA-LoRA
标准 QLoRA 的表现最佳。这是预期的,因为它使用了非常好的量化数据类型(NF4),而 LoRA 的参数仍然是 FP16。
我们可以看到,当我们想要合并 QLoRA 适配器并对合并后的模型进行量化(QLoRA w/ GPTQ)时,性能显著下降。正如我们在本文前面的部分讨论的那样,这是预期中的情况。
另一方面,QA-LoRA 的表现几乎与标准 QLoRA 相当,同时 LLM 完全量化为 INT4。换句话说,QA-LoRA 能够正常工作。
QA-LoRA 也比 QLoRA 更灵活,允许用量化到较低精度的 LLM 进行微调。QA-LoRA 的 3-bit 精度优于 QLoRA 合并并量化到 4-bit(QA-LoRA 3-bit 的准确率为 60.1%,而 QLoRA w/ GPTQ 4-bit 为 59.8%)。
总体而言,QA-LoRA 的结果看起来非常令人印象深刻。
QA-LoRA 实现概述
QA-LoRA 的作者在 GitHub 上发布了他们的实现(MIT 许可证)。注意:原始实现已不再提供。我做了一个分支,你可以在这里找到。
QA-LoRA 实现严重依赖于 AutoGPTQ(MIT 许可证)。它利用 AutoGPTQ 的一个特定分支并替换了几个函数。
你希望微调的 LLM 必须已经用这个特定的 AutoGPTQ 分支进行量化。你可以尝试使用来自 Hugging Face Hub 的量化 LLM,但由于 AutoGPTQ 经常变化,这些量化的 LLM 可能不完全兼容 QA-LoRA。例如,我为上一篇文章量化的 Llama 2 7B 的量化配置不受这个 AutoGPTQ 分支支持。
如果你想了解 QA-LoRA 在 AutoGPTQ 中的更改,可以更详细地查看文件 “peft_utils.py” 中的 “GPTQLoraLinear” 类。QA-LoRA 的主要创新在于两行代码:
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
LoraLayer.__init__(self, linear_module.in_features//group_size, linear_module.out_features)
Llama 2 支持的代码适配
我不得不修补 QA-LoRA 的实现,使其在 Llama 2 上运行。
如果 QA-LoRA 在你阅读本文时仍然无法运行 Llama 2,可以用这个文件替换 “qalora.py”:
about.benjaminmarie.com/data/py/qalora/qalora.py
我只做了两个修改:
-
将“model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))”替换为“model.config.torch_dtype=torch.float16”(当前版本的第 300 行)
-
将“module = module.to(torch.float32)”替换为“module = module.to(torch.float16)”(当前版本的第 340 行)
当前实现仅适用于使用 pad token 的模型。Llama 2 不使用 pad token。我不得不手动修改量化的 Llama 2 的 config.json,添加这一行:
"pad_token_id": 0,
它仅指定了用于填充的“unk_token”,其 ID 为 0。
使用 QA-LoRA 微调 Llama 2 的要求
我在一个笔记本中实现了以下所有部分,你可以在这里找到:
QA-LoRA 依赖项
我建议在设置 QA-LoRA 之前创建一个沙箱,例如使用 conda。当前实现使用了几个包的过时版本。
由于我们必须替换 AutoGPTQ 中的源文件以添加 QA-LoRA 支持,AutoGPTQ 必须从源代码编译。
由于这一替换,我们必须首先克隆 AutoGPTQ 和 QA-LoRA 两个存储库,然后替换 AutoGPTQ 中的文件:
git clone -b v0.3.0 https://github.com/PanQiWei/AutoGPTQ.git
git clone https://github.com/yuhuixu1993/qa-lora.git
cp qa-lora/peft_utils.py ./AutoGPTQ/auto_gptq/utils/
应用我的补丁(如有必要):
wget https://about.benjaminmarie.com/data/py/qalora/qalora.py
cp qalora.py qa-lora/
AutoGPTQ 现在可以安装了:
cd AutoGPTQ
pip install .[triton]
cd ..
这可能需要多达 10 分钟。
安装 QA-LoRA 依赖项:
cd qa-lora
pip install -r requirements.txt
cd ..
我还根据 QA-LoRA 文档的推荐从源代码安装了 bitsandbytes,但我认为这不是必须的。你可以尝试使用“pip install bitsandbytes”代替(速度更快)。
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=118 make cuda11x
python setup.py install
pip install -r requirements.txt
pip install protobuf==3.20.*
cd ..
预量化 LLM
目前,QA-LoRA 仅微调已经通过 AutoGPTQ 量化的 LLM(支持 INT4、INT3 和 INT2)。如果你想微调一个 fp16/32 LLM,你需要先使用我们安装的 AutoGPTQ 版本进行量化。
你也可以查看笔记本末尾,我在其中写了量化代码。对于本文,我量化了 Llama 2 7B 并将其上传到 Hugging Face Hub。你可以使用它来运行本教程:
注意:safetensors 格式尚不支持。QA-LoRA 期望在存储库中找到文件 pytorch_model.bin。
量化的组大小设置为 32(32g)。我选择这个组大小是因为在 QA-LoRA 中硬编码为 32。
硬件要求
以下部分可以在免费的 Google Colab 实例或至少具有 10 GB VRAM 的 GPU 上运行。
QA-LoRA 微调
我使用默认数据集 Alpaca 进行了 100 步的微调。批次大小为 1,梯度累积步骤为 16。它消耗了大约 7 GB 的 VRAM。
如果你的微调出现不稳定的情况,调整学习率和/或 LoRA alpha/rank 也可能提高稳定性。
cd /content/qa-lora/
python qalora.py --model_path kaitchup/Llama-2-7b-4bit-32g-autogptq \
--save_steps 10 \
--output_dir output \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--max_steps 100 \
--lora_r 16
使用 Google Colab 的 T4 GPU,这次微调用了 28 分钟。
我在 Hugging Face Hub 上上传了最终的检查点:
合并 QA-LoRA 适配器
与 QLoRA 相比,QA-LoRA 适配器可以在不损失性能的情况下与量化基础 LLM 合并。
这里是合并代码:
import torch
#Path the quantized base model
model_path = 'Llama-2-7b-4bit-32g-autogptq/gptq_model-4bit-32g.bin'
#Path to the adapter fine-tuned with QA-LoRA
lora_path = 'output/adapter_model.bin'
#Where the merged model will be saved
merged_path = 'output_model'
#The scale is the LoRA alpha divided by the LoRA rank. I trained with LoRA_alpha = LoRA_rank = 16
scale = 16 / 16
#The group size of the quantized base LLN
group_size = 32
#We merge using the CPU
model = torch.load(model_path, map_location='cpu')
lora = torch.load(lora_path, map_location='cpu')
tmp_keys = [key[17:-14] for key in lora.keys() if 'lora_A' in key]
for tmp_key in tmp_keys:
model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / group_size /model[tmp_key+'.scales']
torch.save(model, merged_path)
该代码在 CPU 上运行,因此你只需要足够的 CPU RAM 来加载要合并的模型。注意基础 LLM 和我们微调的 QA-LoRA 适配器必须在本地可访问。
一旦合并,模型就准备好进行推理。它是一个标准的 GPTQ 模型。
结论
QA-LoRA 有效。我们对 Llama 2 进行了量化感知的 LoRA 微调。
我们发现量化感知微调相较于 QLoRA 有两个显著的优势:
-
这更快
-
它微调了一个可以与基础 LLM 完美合并的适配器。
当前的实现还不够灵活。这仍然是一个非常年轻的项目。主要问题在于它依赖于旧版本的 AutoGPTQ。QA-LoRA 的作者计划稍后支持最新版本的 AutoGPTQ。
注意在我的实验中使用了 4-bit 量化。QA-LoRA 已支持 2-bit 和 3-bit 量化。你可以尝试这些更低的精度以进一步减少内存消耗。
为了支持我的工作,请考虑订阅我的通讯:
[## Kaitchup - 预算内的 AI | 本杰明·玛丽,博士 | Substack]
每周的新闻、提示和有关在你的计算机上微调、运行和服务大型语言模型的教程。每个...
QLoRa:在你的 GPU 上微调大型语言模型
在消费级硬件上微调拥有数十亿参数的模型现在已经成为可能
·
关注 发表在 Towards Data Science ·6 min read·2023 年 5 月 30 日
--
插图由作者提供。
大多数大型语言模型(LLM)过于庞大,无法在消费级硬件上进行微调。例如,微调一个拥有 650 亿参数的模型需要超过 780GB 的 GPU 内存。这相当于十个 A100 80GB 的 GPU。换句话说,你需要云计算才能微调你的模型。
现在,使用 QLoRa(Dettmers 等人,2023),你只需一个 A100 就可以做到这一点。
在这篇博客文章中,我将介绍 QLoRa。我将简要描述它的工作原理,并展示如何在你的 GPU 上使用它来微调一个拥有 200 亿参数的 GPT 模型。
注意:我使用了自己的 nVidia RTX 3060 12 Gb 运行了本文中的所有命令。你也可以使用 Google Colab 的免费实例获得相同的结果。如果你想使用内存较小的 GPU,你需要使用更小的大语言模型。
我在这篇文章中提供了运行 QLoRa 进行微调所需的所有代码。如果你不想自己编写代码,我还在 The Kaitchup(我的 Substack 通讯)上创建了一个 Google Colab 笔记本。这是笔记本 #2。
QLoRa:低秩适配器的量化大语言模型
在 2021 年 6 月,Hu 等人(2021)介绍了针对大语言模型的低秩适配器(LoRa)。
LoRa 为 LLM 的每一层添加了一小部分可训练的参数,即适配器,并冻结了所有原始参数。对于微调,我们只需更新适配器权重,这大大减少了内存占用。
QLoRa 通过引入:4-bit 量化、双重量化以及利用 nVidia 统一内存进行分页,进一步向前迈进了三步。
简而言之,每个步骤的工作原理如下:
-
4-bit NormalFloat 量化:这是一种改进量化分位数的量化方法。它确保每个量化箱中有相等数量的值。这可以避免计算问题和对异常值的错误。
-
双重量化:QLoRa 的作者将其定义为:“对量化常数进行量化以节省额外内存的过程。”
-
统一内存分页:它依赖于 NVIDIA 统一内存功能,自动处理 CPU 和 GPU 之间的页到页传输。它确保了无错误的 GPU 处理,尤其是在 GPU 可能内存不足的情况下。
所有这些步骤都大幅减少了微调的内存要求,同时性能几乎与标准微调相当。
使用 QLoRa 对 GPT 模型进行微调
QLoRa 的硬件要求:
-
GPU:以下演示在具有 12 Gb VRAM 的 GPU 上运行,适用于参数少于 20 亿的模型,例如 GPT-J。例如,我在我的 RTX 3060 12 Gb 上运行了它。如果你有一张具有 24 Gb VRAM 的更大显卡,你可以使用 20 亿参数的模型,例如 GPT-NeoX-20b。
-
RAM:我建议至少 6 Gb。大多数现代计算机都有足够的 RAM。
-
硬盘:GPT-J 和 GPT-NeoX-20b 都是非常大的模型。我建议至少有 80 Gb 的可用空间。
如果你的机器不满足这些要求,Google Colab 的免费实例也足够了。
QLoRa 的软硬件要求:
我们需要 CUDA。确保它已安装在你的机器上。
我们还需要安装所有依赖项:
-
bitsandbytes:一个包含量化 LLM 所需所有工具的库。
-
Hugging Face Transformers 和 Accelerate:这些是用于高效训练来自 Hugging Face Hub 模型的标准库。
-
PEFT:一个提供各种方法实现的库,用于只微调少量(额外)模型参数。我们需要它用于 LoRa。
-
数据集:这不是必需的。我们仅用它来获取一个用于微调的数据集。当然,你可以提供你自己的数据集。
我们可以通过 PIP 获取它们全部:
pip install -q -U bitsandbytes
pip install -q -U git+https://github.com/huggingface/transformers.git
pip install -q -U git+https://github.com/huggingface/peft.git
#pip install -q -U git+https://github.com/huggingface/accelerate.git
#current version of Accelerate on GitHub breaks QLoRa
#Using standard pip instead
pip install -q -U accelerate
pip install -q -U datasets
接下来,我们可以开始编写 Python 脚本。
GPT 模型的加载和量化
我们需要以下导入来加载和量化 LLM。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
对于这个演示,我们将微调 GPT NeoX 模型,这个模型由 EleutherAI 预训练。这个模型拥有 200 亿个参数。 注意:GPT NeoX 具有允许商业使用的宽松许可(Apache 2.0)。
我们可以从 Hugging Face Hub 获取这个模型和相关的分词器:
model_name = "EleutherAI/gpt-neox-20b"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
然后,我们需要详细说明量化器的配置,如下:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
-
load_in_4bit: 模型将以 4-bit 精度加载到内存中。
-
bnb_4bit_use_double_quant: 我们将进行 QLoRa 提议的双量化。
-
bnb_4bit_quant_type: 这是量化类型。“nf4”表示 4-bit NormalFloat。
-
bnb_4bit_compute_dtype: 当我们以 4-bit 加载和存储模型时,我们会在需要时部分去量化,并以 16-bit 精度(bfloat16)进行所有计算。
所以现在我们可以以 4-bit 加载模型:
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quant_config, device_map={"":0})
然后,我们启用梯度检查点:
model.gradient_checkpointing_enable()
为 LoRa 预处理 GPT 模型
这里我们使用 PEFT。我们为 LoRa 准备模型,为每一层添加可训练的适配器。
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
在 LoraConfig 中,你可以调整 r、alpha 和 dropout 以获得更好的任务结果。你可以在 PEFT 仓库 找到更多选项和细节。
使用 LoRa,我们仅增加了 800 万参数。我们只训练这些参数,冻结其他所有参数。微调应该会很快。
准备好你的数据集
对于这个演示,我使用“english_quotes”数据集。 这是一个由著名名言组成的数据集 以 CC BY 4.0 许可 进行分发。
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
使用 QLoRa 微调 GPT-NeoX-20B
最终,使用 Hugging Face Transformers 进行微调是非常标准的。
import transformers
tokenizer.pad_token = tokenizer.eos_token
trainer = transformers.Trainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
warmup_steps=2,
max_steps=20,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit"
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
别忘了 optim=”paged_adamw_8bit”。它激活了分页以更好地管理内存。没有它,我们会遇到内存不足的错误。
在 Google Colab 上运行这个微调只需 5 分钟。
VRAM 消耗应达到 15Gb 峰值。
就这样,我们免费微调了一个 LLM!
它有效吗?让我们尝试推理。
GPT 推理与 QLoRa
我们微调的 QLoRa 模型可以直接用于标准的 Hugging Face Transformers 推理,具体如下:
text = "Ask not what your country"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
你应该得到以下输出:
Ask not what your country can do for you, ask what you can do for your country.”
– John F.
我们得到了预期的报价。对于 5 分钟的微调来说还不错!
结论
大型语言模型变得更大,但同时,我们终于拥有了在消费者硬件上进行微调和推理的工具。
多亏了 LoRa,现在 QLoRa 的出现使我们能够在不依赖云计算的情况下微调具有亿级参数的模型,并且根据QLoRa 论文性能下降并不显著。
如果你在运行代码时遇到任何问题,请留言,我会尽量帮助你。你还可以在官方 GitHub 仓库中找到关于 QLoRa 实现的更多信息。
如果你想部署 LLM,可以查看我使用 nVidia Triton 推理服务器的教程:
如何在本地服务器上运行大型语言模型
towardsdatascience.com
如果你喜欢这篇文章并且对阅读下一篇感兴趣,支持我工作的最佳方式是订阅 Kaitchup:
[## Kaitchup - AI 预算低 | Benjamin Marie | Substack
每周提供有关微调、运行和服务大型语言模型的教程、技巧和新闻。…
newsletter.kaitchup.com](https://newsletter.kaitchup.com/?source=post_page-----27bed5a03e2b--------------------------------)
QQ 图绘制你的数据启蒙:分布的银河系旅行指南
你的数据正常吗?终极指南:使用 R 绘制 QQ 图
·
关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 4 月 26 日
--
图片来源于 Pete Linforth Pixabay
介绍
统计学是一个神秘的领域。在其广泛的理论、公式和框架中,蕴含着对各个学科领域都有深刻应用的知识。然而,尽管统计学取得了巨大的成功,但它也可能显得不那么友好,因为它的很多能力来自它所做的假设。一些最常见的假设包括独立同分布假设(i.i.d.)和正态性假设(数据可以近似或收敛于正态分布)。一旦这些假设开始破裂,理论就会变得越来越复杂,即使它们工作,它们也往往会失去很多力量。幸运的是,有支持这些假设的定理和法则,比如中心极限定理,它表明对于大样本量的各种分布的估计量,正态性假设是成立的。然而,为了充分实现统计学的潜力并利用其力量,必须在使用统计框架或测试得出结论之前,验证数据是否符合这些假设至关重要。通过仔细评估数据是否满足必要的假设,并相应调整我们的方法,我们可以确保统计分析提供准确可靠的见解,帮助我们在广泛的学科领域做出明智的决策。
在本文中,我们将探讨一种最强大的方法,即 QQ 图,用于测试你的数据是否符合正态分布。顾名思义,这是一种需要通过视觉检查来得出结论的图形方法,以判断数据是否正态分布。在统计学中,并没有太多的视觉方法,因为大多数假设检验框架依赖于估计量、临界区域和 p 值来获得数值估计。测试正态性的框架(如 Kolmogorov Smirnov 检验)也存在,但与使用 QQ 图的简单方法相比,这些框架被发现要弱得多。此外,需要承认的是,QQ 图可以用来测试数据是否遵循任何分布,而不仅仅是正态分布。例如,如果你认为收入是指数分布的,你仍然可以使用 QQ 图来检查。只是该测试在正态分布的背景下使用最广泛,因为正态分布在统计学中有着巨大的实用价值。
为什么不使用直方图?
这是一个有效的问题:如果构建 QQ 图也是一种图形方法,为什么不简单地绘制你的数据的直方图,并将其与正态分布的直方图进行比较呢?然而,这里有一些需要考虑的问题。
-
你会选择什么样的分 bin 尺寸呢?虽然很大的 bin 尺寸可能会导致你错过分布中的一些重要特征,但很小的 bin 宽度可能会显示出很多噪音,使解释变得困难。选择适当的 bin 宽度可以显著改变你对底层分布的理解和认知。
-
此外,直观检查直方图与正态分布曲线之间的拟合情况可能会显得繁琐且主观。你可以覆盖一个正态分布曲线(如图 1 所示),查看它与条形的高度匹配程度,但这个过程容易出错。
图 1:使用直方图的难度 [图像由作者提供]
使用 QQ 图要简单得多,因为绘制图形不需要调整任何参数。你只需输入数据并生成图形。此外,视觉检查要容易得多:你只需检查两条线的匹配程度,而不必比较沿着钟形曲线不同位置的条形,这在数据偏斜或多峰时尤其具有挑战性。
QQ 图是什么?
QQ 图(分位数-分位数图)是将给定数据的样本(或观察到的)分位数与理论(或预期)分位数进行绘图。由于我们将使用这些图来检查数据是否呈正态分布,因此理论分位数必须与正态分布的分位数相对应。但是,这些分位数是什么?
分位数是将分布划分为相等部分的值。例如,中位数是第 50 百分位数(或分位数),将分布分为两个相等的部分。
样本分位数是基于样本数据估计的分位数。给定一个包含 n 个观察值的样本,第 k 个样本分位数是将样本分为两部分的值:第一部分为 k/n,剩余部分为(n-k)/n。例如,假设我们有以下数据:
如何计算第 50 个分位数?第一步是将数据按升序排列,得到:
注意到值 37 将数据分成了两个相等的部分。5,…,21 构成了第一部分,占剩余数据的 5/10。43,…,88 构成了第二部分,占剩余数据的 5/10。一般来说,第 k 小的样本被称为数据的第 k 阶统计量,表示为 X₍ₖ₎。实际上,第 k 阶统计量是数据的 100*k/(n + 1)分位数。
例如,在给定数据中,10 是第 3 小的值。因此,X₍₃₎ = 10。这个值代表数据的 100*3/(11 + 1) = 第 25 个分位数。上述一般化的证明稍微复杂,因为它利用了贝塔分布的性质和阶统计量的一般化 CDF。为了完整性,下面给出了证明。然而,如果你不想了解这些内容,可以选择跳过并继续讨论样本分位数。
理论 [供感兴趣者参考]
定理: 对于大小为 n 的独立同分布数据 X₁, X₂, · · · , Xₙ,第 k 阶统计量 X₍ₖ₎是 100*p 分位数的无偏估计量,其中 p 由下式给出:
换句话说,如果 F 是 Xᵢ的 CDF,
(回顾一下,F 表示累积分布。如果某个值的累积分布函数期望为 k/(n +1),则该值是 k/(n +1)分位数的无偏估计量)
证明: 首先,我们找到均匀分布第 k 阶统计量的分布:回顾一下,标准均匀分布的概率密度函数和累积分布函数由下式给出:
我们现在可以找到均匀分布第 k 阶统计量的 CDF。设 W 为小于 x 的 Xᵢ的数量。显然,W 服从参数为 n = n,p = Fᵤ(x)的二项分布。因此,第 k 阶统计量的 CDF 为:
注:I 是指示函数,当大括号内的条件满足时取值为 1,否则为 0。现在,我们可以对上述函数关于 x 进行微分,以得到均匀分布第 k 阶统计量的概率密度函数:
上述分布类似于贝塔分布的结构。回顾一下,Y ∼ Beta(a, b) 的密度函数及相应的期望值由下式给出:
通过比较均匀分布第 k 阶统计量的概率分布和贝塔分布的 PDF,我们得出结论:
接下来,我们依赖于统计学中最重要的公式之一的证明:我们展示了如果 X 是具有 CDF Fₓ(x)的随机变量,则 Fₓ(X)的 PDF 是标准均匀分布。设 Y = Fₓ(X)。
最后,我们计算 Fₓ(X₍ₖ₎)的分布。我们利用随机变量的 CDF 总是一个递增函数这一事实。因此,
现在,让我们讨论理论分位数。理论分位数是指理论分布(如正态分布)的分位数。与基于给定数据估计的样本分位数不同,理论分位数由分布的参数(如均值和标准差)决定。其值不依赖于提供的样本,而仅依赖于用于检验数据的分布。例如,标准正态分布的 5%理论分位数大约为-1.645,这意味着标准正态曲线下的 5%面积位于-1.645 的左侧。一般来说,如果我们假设给定数据具有分布 N(µ, σ²),则 100*p(对于 p 在 0 和 1 之间)的理论分位数(πₚ)由下式给出:
这里Φ和Φ⁻¹分别指正态分布的累积分布函数和反函数。从上述讨论中可以明显看出,给定一个包含 n 个值的数据集,我们可以做两件事:
-
我们可以确定数据集中每个值的相应的样本分位数。我们通过升序排列值来做到这一点,并得出第 k 个最小样本是第 p = k/(n + 1)个样本分位数。
-
对于每个给定的样本分位数,我们可以计算相应的理论分位数。特别是,如果我们想要测试我们的数据是否符合分布 N(µ, σ²),我们可以得出第 p 个理论分位数为 πₚ = µ+σΦ⁻¹(p)。
这正是构建 QQ 图的方案。让我们通过一个示例来理解这个过程。
构建 QQ 图:一个示例
考虑以下包含 9 个观察值的数据集:
现在我们将为给定的数据集构建 QQ 图,以测试其是否符合 N(1, 2)的分布。为简单起见,我们只取了 10 个观察值。然而,使用 R 或任何其他编程语言,可以将其扩展到任意数量的点,我们将在下一节中看到。为解决这个问题,我们考虑上述两个步骤:
A) 我们可以确定数据集中每个值的相应的样本分位数。我们通过升序排列值来做到这一点,并得出第 k 个最小样本是第 p = k/(n + 1)个样本分位数: 首先,我们将数据集按升序排列,并分配相应的顺序统计值:
表 1 [作者提供的图像]
注意我们如何使用表格格式来表示数据。这在手工构建 QQ 图时可能非常有用,因为它允许您清晰地封装必要的信息。接下来,我们记录每个这些顺序统计量所估计的分位数:
表 2 [作者提供的图像]
B) 对于每个给定的样本分位数,我们可以计算相应的理论分位数。特别是,如果我们想要测试我们的数据是否符合分布 N(µ, σ²),我们可以得出第 p 个理论分位数为 πₚ = µ + σΦ⁻¹(p):利用导出的公式,我们可以估算如下正态分布的相应理论分位数:
表 3 [作者提供的图像]
Φ⁻¹的值可以使用 z 表或统计软件(例如 R 中的 qnorm 函数)找到。因此,我们得到了样本分位数值(ˆπp = X(k))和理论分位数值πₚ = µ + σΦ⁻¹(p)。最后一步是将它们一起绘制在散点图中。这给出了以下图表:
[作者提供的图像]
这就是给定数据的 QQ 图!是的,这只是理论分位数与样本分位数的图示。注意,我还添加了一条虚线 y = x,以便将给定的 QQ 图与理论分位数和样本分位数相同的情况进行比较。如果数据是正态分布的,那么 QQ 图上的点将沿着直线分布。点距离直线越近,数据越正态。相反,如果数据不是正态分布的,那么 QQ 图上的点将偏离直线。例如,在上述图中,样本分位数值总是低于相应的理论分位数值,这支持了数据不是正态分布的假设。实际上,上述数据集实际上是从均匀分布中获得的。因此,我们能够使用 QQ 图的概念来解释给定数据的正态性。
现在,让我们来看一个数据实际呈正态分布的示例。这次考虑一个大小为 14 的数据集:
第一步是绘制一个表格,将值按升序排列,并计算相应的顺序统计量及其代表的分位数。随后,我们填入理论分位数值。完成的表格如下所示:
表 4 [作者提供的图像]
如前所述,将上述点绘制在图表上得到如下 QQ 图:
[作者提供的图像]
如你所见,这些点趋向于沿直线分布,这表明数据的分布与均值为 1 和方差为 2 的正态分布非常相似。请注意,如果数据集的样本点很少,QQ 图可能不会很有帮助。然而,如果样本数量足够多,QQ 图将成为测试数据集分布的强大工具,它将是你通往概率分布银河系的终极指南!
大样本:R 代码
我们将使用基础 R 中的 ToothGrowth 数据集,该数据集包含有关维生素 C 对豚鼠牙齿生长影响的信息。特别是,我们将查看 60 只豚鼠的牙本质细胞(负责牙齿生长的细胞)的长度,并测试牙齿长度的分布是否正态。我们可以如下加载数据集:
# Load Packages
library(tidyverse)
# Load Dataset
data("ToothGrowth")
x = ToothGrowth$len
n = length(x)
# Length:
print(n)
> [1] 60
# First few observations
print(x[1:10])
> [1] 4.2 11.5 7.3 5.8 6.4 10.0 11.2 11.2 5.2 7.0
我们感兴趣的是测试数据集是否符合以下参数 µ 和 σ 的正态分布:
# Mean of target distribution
mu <- mean(x)
print(mu)
> [1] 18.81333
# Variance of target distribution
sigma <- sd(x)
print(sigma)
> [1] 7.649315
步骤 1:计算样本分位数 为此,我们只需将数据按升序排列。相应的分位数值将从 1/(n + 1) 到 n/(n + 1),如下所示:
# Sample Quantiles: Sort Sample Provided
sq = sort(x)
print(sq[1:10])
> [1] 4.2 5.2 5.8 6.4 7.0 7.3 8.2 9.4 9.7 9.7
# Corresponding Quantile Values: k/(n + 1), k = 1, 2, ..., n
p = seq(1/(n + 1), n/(n + 1), by = 1/(n + 1))
print(p[1:5])
> [1] 0.01639344 0.03278689 0.04918033 0.06557377 0.08196721
步骤 2:计算理论分位数 使用公式 πₚ = µ + σΦ⁻¹(p),我们计算映射到每个样本分位数的理论分位数。注意:R 中的 qnorm() 函数允许我们计算正态分布的逆 CDF:
# Theoretical Quantiles: Use Derived Formula
tq = mu + sigma*qnorm(p)
print(tq[1:5])
> [1] 2.484468 4.728449 6.170135 7.265988 8.165790
步骤 3:将理论分位数与样本分位数绘图并与 y = x 比较
data.frame(tq, sq) %>%
ggplot(aes(x = tq, y = sq)) +
# Plot original points
geom_point(size = 2.5, color = "blue") +
# To plot the y = x line
geom_abline(slope = 1, intercept = 0, linewidth = 0.8, linetype = 2) +
# Plot styling [Optional]
theme_minimal() +
labs(x = "Theoretical Quantiles", y = "Sample Quantiles", title = "QQ Plot") +
theme(plot.title = element_text(hjust = 0.5))
这产生了以下 QQ 图:
[图像由作者提供]
如你所见,点趋向于沿线分布,表明数据的分布与均值为 18.81333 和方差为 7.649315² = 58.51202 的正态分布非常相似。最后,为了展示 QQ 图的巨大实用性,我们来看看另一个例子,我们尝试测试数据是否呈指数分布。特别是,我们将从均匀分布中抽取样本,并使用 QQ 图检查它是否呈指数分布。我们可以按如下方式加载样本:
# Load Sample
x = runif(10000)
n = length(x)
# First few observations
print(x[1:5])
> [1] 0.07251043 0.20345894 0.20417683 0.48878998 0.50945799
我们感兴趣的是测试数据集是否呈指数分布,其速率参数等于均值的倒数(速率参数的最大似然估计量)。
步骤 1:计算样本分位数 为此,我们只需将数据按升序排列。相应的分位数值将从 1/(n + 1) 到 n/(n + 1),如下面计算所示:
# Sample Quantiles: Sort Sample Provided
sq = sort(x)
print(sq[1:5])
> [1] 6.094808e-05 9.398628e-05 3.439812e-04 3.590921e-04 4.317588e-04
# Corresponding Quantile Values: k/(n + 1), k = 1, 2, ..., n
p = seq(1/(n + 1), n/(n + 1), by = 1/(n + 1))
print(p[1:5])
> [1] 0.00009999 0.00019998 0.00029997 0.00039996 0.00049995
步骤 2:计算理论分位数 这次,我们使用公式 πₚ = Φ⁻¹(p, λ = 1/mean(X)),来计算映射到每个样本分位数的理论分位数,其中 Φ⁻¹(p, λ) 是指数分布的 CDF 的逆函数,速率为 λ。注意:R 中的 qexp() 函数允许我们计算指数分布的逆 CDF:
# Theoretical Quantiles: Use Derived Formula
tq = qexp(p, rate = 1/mean(x))
print(tq[1:5])
> [1] 5.010184e-05 1.002087e-04 1.503205e-04 2.004374e-04 2.505593e-04
步骤 3:将理论分位数与样本分位数绘图并与 y = x 比较
data.frame(tq, sq) %>%
ggplot(aes(x = tq, y = sq)) +
# Plot original points
geom_point(size = 2.5, color = "blue") +
# To plot the y = x line
geom_abline(slope = 1, intercept = 0, linewidth = 0.8, linetype = 2) +
# Plot styling [Optional]
theme_minimal() +
labs(x = "Theoretical Quantiles", y = "Sample Quantiles", title = "QQ Plot") +
theme(plot.title = element_text(hjust = 0.5))
这产生了以下 QQ 图:
[图像由作者提供]
如你所见,点远离直线,表明数据的分布与指数分布不相似。事实上,我们可以绘制样本与均匀分布的 QQ 图,以展示样本确实适合标准均匀分布:
# Theoretical Quantiles
tq = qunif(p)
print(tq[1:5])
data.frame(tq, sq) %>%
ggplot(aes(x = tq, y = sq)) +
# Plot original points
geom_point(size = 2.5, color = "blue") +
# To plot the y = x line
geom_abline(slope = 1, intercept = 0, linewidth = 0.8, linetype = 2) +
# Plot styling [Optional]
theme_minimal() +
labs(x = "Theoretical Quantiles", y = "Sample Quantiles", title = "QQ Plot") +
theme(plot.title = element_text(hjust = 0.5))
这产生了以下 QQ 图,确实是一个完美的匹配!
[图像由作者提供]
结论
总结来说,我们对 QQ 图的研究揭示了其在统计分析中的关键作用。通过将我们的数据分布与理论分布进行比较,QQ 图使我们能够更深入地理解数据的行为,并识别出任何偏离预期分布的情况。此外,QQ 图的多功能性使我们可以在各种场景中应用它们,例如识别异常值、比较数据集以及模型诊断。这些应用展示了 QQ 图的广泛用途,并强调了它们在统计分析领域的重要性。QQ 图提供了一种强大且易于理解的数据分布探索方法,其在统计文献中的重要性证明了其价值。通过在我们的分析中利用 QQ 图,我们可以获得对数据的新见解,并基于对数据行为的更深入理解做出明智的决策。
如果你有任何疑问或建议,请在评论框中回复。请随时通过邮件与我联系。
如果你喜欢我的文章并想阅读更多,请访问这个链接。
注意:所有包含表格、图表和方程式的图片均由作者制作。
量化 GPT-4 的隐藏回归
研究生成型 AI 使用和测试的第三部分
·
关注 发布于Towards Data Science ·5 分钟阅读·2023 年 9 月 22 日
--
由Randy Fath拍摄,来源于Unsplash
GPT-4 比 GPT-3 更大更好。GPT-4 可以撰写优美的演讲,通过标准化考试-,Exam%20results,-(ordered%20by%20GPT),甚至解读图像。自 2023 年 3 月 14 日发布以来,OpenAI 不断迭代和更新 GPT-4,以提升其对每天接收的数百万个查询的处理能力。然而,OpenAI 的 API 中最新版本的 GPT-4,即“gpt-4”,真的比三月份的初始版本 “gpt-4–0314” 更好吗?
从Kolena机器学习工程师的角度来看,这篇文章是系列讨论的继续,强调了一种针对 LLM 的测试范式,比较了不同场景下 GPT 模型的表现。
尽管通过各种测试基准和指标,“gpt-4”的整体表现可能比“gpt-4–0314”更好,但“更好”一词是相对的。用户在线分享了他们最近在不同环境中 经历了 GPT-4 模型性能回退。一个广为流传的 GPT-4 性能回退实例是,它无法像之前一样识别 17077 是一个质数。
自然地,使用最先进的模型时,它在主观和客观表现上不断下降,这是一个问题。可能还存在什么其他隐藏的回归?
我们可以通过使用 CoQA(对话问答)**数据集来测试 GPT-4 的隐藏回归。CoQA 数据集包含多篇文章,每篇文章有一系列相关的问题,其中理解问题n对于回答问题n+1是必要的。以一篇关于体育历史的文章为例,这里有一些潜在的问题:
1. 谁是获得奖牌最多的奥运选手?
2. 他们来自哪个国家?
3. 他们获得了多少枚金牌?
无法单独回答这些问题,因为在回答第一个问题之前,我们不知道感兴趣的对象。
发现
从高层次来看,GPT-4 的表现比 GPT-3 显著更好,但它仍然不完美:
+------------+-------------+-------------+-----------+
| model | avg_BERT_F1 | avg_ROUGE_1 | n_correct |
+------------+-------------+-------------+-----------+
| gpt-4 | 0.92729 | 0.77735 | 4708 |
| gpt-4–0314 | 0.92497 | 0.77284 | 4718 |
| gpt-3 | 0.90845 | 0.71676 | 4348 |
+------------+-------------+-------------+-----------+
注意:“gpt-3”是 GPT-3.5 系列的最新 Turbo 模型,而 n_correct 是平均值 BERT_F1 和 ROUGE_1 大于 0.75 的问题数量
从上述内容来看,为什么“gpt-4–0314”在指标(BERT_F1 和 ROUGE_1)上表现更差却有更多正确回答问题的次数?也许这两个模型在回答相同问题时都存在错误,但不能保证“gpt-4”和“gpt-4–0314”的失败集是同质的。在假设新模型应该具有更高性能的情况下,当我们观察指标时,无法解释这种差异或退步的原因。我们可以通过逻辑上将数据分解为更小的组来深入理解潜在的失败根源。
当我们根据每篇文章的数据源对 CoQA 数据集进行分层时,我们会发现涉及维基百科文章的问题回答数据在最新的 GPT-4 模型中表现更好,但整体上以及在所有其他数据源中的表现较差。
通过 BERT_F1、ROUGE_1 和正确答案数量对“gpt-4”和“gpt-4–0314”进行比较,数据来自Kolena
上图展示了“gpt-4–0314”作为基准与“gpt-4”的比较,突出了在不同数据源中正确答案数量的差异,这些差异表明了改进或退步。在正确答案的数量方面,GPT-4 唯一的改进来自于维基百科的数据点,而在其他所有地方表现都在下降。
分析
这是否揭示了“gpt-4”是“gpt-4–0314”在维基百科文章上经过微调的版本? 不幸的是,我们并不知道。
那么我们是否可以说 GPT-4 变得更糟了?按照这个标准,不一定。尽管学术界认为维基百科是不可靠的信息来源,但许多人仍然经常使用它以获取快速且可访问的信息。如果 OpenAI 希望 GPT 能回答任何领域的任何问题,那么对维基百科的全面理解比理解新闻文章更有价值,因为用户每天进行数百万次随机查询。新闻文章往往有共同的主题,而普通人可能不会向 GPT 提问涉及新闻文章的维基百科中不存在的话题。
在对数据集进行不同数据源的分层之前,没有具体的解释说明为什么“gpt-4–0314”相比于“gpt-4”获得了更多正确结果。通过一次分层,我们获得了一个合理的解释,说明模型之间的差异及其原因。
结论
随着时间的推移,GPT-4 在处理多个数据源的对话式问题回答上有所退步,但在涉及维基百科文章的查询上表现有所改善。
能够识别隐藏的退步应成为所有工程师在将模型部署到生产环境之前的优先事项。寻找 LLM 的隐藏退步并非易事,但采用正确的方法可以简化这一过程。最佳模型并不一定是整体性能最好的模型,而是最能在最重要的场景中取得最佳结果的模型。
我们将在未来的博客文章中深入探讨 CoQA 的更多层次,以进一步了解 GPT-4 随时间的变化。敬请关注!
CoQA 数据集包含来自七个不同数据集的数据,这些数据集拥有不同的许可证。在本文中,我们不会透露数据集中的任何数据,仅使用这些商业数据源中的数据进行测试和分析,包括 Gutenberg、CNN、MCTest 和 Wikipedia,这些数据源具有 CC BY-SA 4.0、MSR-LA 或 Apache 许可证。
使用 GTFS 数据量化交通模式
在这篇文章中,我探讨了依托于通用交通信息规范(General Transit Feed Specification)和各种空间数据科学工具的四个精选城市的公共交通系统。
·
关注 发表在 Towards Data Science ·12 分钟阅读·2023 年 12 月 4 日
--
我在这个笔记本中选择了四个城市:布达佩斯、柏林、斯德哥尔摩和多伦多,使用公开可用的 GTFS(通用交通信息规范)数据概述它们的公共交通系统。这个笔记本旨在作为一个入门教程,介绍如何使用 Pandas、GeoPandas 和其他标准数据科学工具访问、操作、聚合和可视化公共交通数据,以获取有关公共交通的见解。后来,这种理解在各种应用场景中都可能有帮助,如交通、城市规划和位置智能。
此外,尽管 GTFS 格式应具有通用性和普遍性,我也会指出在接下来的分析步骤中仍需逐一进行城市级别的见解和人工验证的情况。
在这篇文章中,所有图像均由作者创建。
1. 收集和解析 GTFS 数据
对于这篇文章,我从 Transitfeeds.com 下载了公共交通数据,这是一个公共交通数据的在线汇总网站。特别是,我下载了以下城市的最新更新时间的数据:
在接下来的代码块中,我将多次探索这些城市,创建比较图,并强调 GTFS 格式的通用性。此外,为了确保我的分析可以轻松更新到最新的数据转储,我将每个城市的 GTFS 数据存储在一个对应于更新日期的文件夹中:
import os
root = 'data'
cities = ['Budapest', 'Toronto', 'Berlin', 'Stockholm']
updated = {city : [f for f in os.listdir(root + '/' + city) if '20' in f][0] for city in cities}
updated
此单元格的输出:
我的 GTFS 数据转储的确切记录时间。
现在,让我们更详细地查看这些文件夹中存储的不同文件:
for city in cities:
print(len(os.listdir(root + '/' + city + '/' + updated[city])), sorted(os.listdir(root + '/' + city + '/' + updated[city])), '\n')
此单元格的输出:
不同城市 GTFS 数据转储中的文件列表。
看起来所有这些文件中都有八个文件,类似于 GTFS 结构的核心部分。
虽然你可以在这里阅读更多关于 GTFS 结构的内容,但现在我将继续讨论可能最基本的地理空间特征——公共交通站点的位置。
2. 公共交通站点位置
首先,我创建了一个示例可视化,解析了布达佩斯的 stops.txt 文件。然后,我使用Shapely和 GeoPandas,通过使用每个站点位置的经纬度坐标创建 Point 几何体,创建一个几何数据结构,即 GeoDataFrame。接着,使用GeoPandas内置的基于 Matplotlib 的绘图工具,我创建了一个简单的地图。此外,我还使用OSMNx将城市边界添加为一个封闭的多边形。
import geopandas as gpd
from shapely.geometry import Point
import matplotlib.pyplot as plt
import pandas as pd
import osmnx as ox
city = 'Budapest'
df_stops = pd.read_csv(root + '/' + city + '/' + updated[city] + '/stops.txt')
geometry = [Point(xy) for xy in zip(df_stops['stop_lon'], df_stops['stop_lat'])]
gdf_stops = gpd.GeoDataFrame(df_stops, geometry=geometry)
df_stops.crs = 'EPSG:4326'
admin = ox.geocode_to_gdf(city)
f, ax = plt.subplots(1,1,figsize=(10,10))
admin.plot(ax=ax, color = 'none', edgecolor = 'k')
gdf_stops.plot(ax=ax, alpha = 0.2)
结果图:
布达佩斯公共交通停靠点的位置。
起初看起来不错!绝大多数停靠点位置都在城市内;只有少数通勤线路位于城市之外。此外,在图的中心,我们可以清楚地看到多瑙河流经的地方没有停靠点。现在,让我们看看所有城市的这张图!
stops = {}
admin = {}
f, ax = plt.subplots(2,2,figsize=(15,15))
indicies = [(i,j) for i in range(2) for j in range(2)]
for idx, city in enumerate(cities):
bx = ax[indicies[idx]]
df_stops = pd.read_csv(root + '/' + city + '/' + updated[city] + '/stops.txt')
geometry = [Point(xy) for xy in zip(df_stops['stop_lon'], df_stops['stop_lat'])]
gdf_stops = gpd.GeoDataFrame(df_stops, geometry=geometry)
gdf_stops.crs = 'EPSG:4326'
admin_ = ox.geocode_to_gdf(city)
gdf_stops.plot(ax=bx, markersize = 3, alpha = 0.5)
admin_.plot(ax=bx, color = 'none', edgecolor = 'k', linewidth = 4)
bx.set_title(city, fontsize = 15)
bx.axis('off')
stops[city] = gpd.overlay(gdf_stops, admin_)
admin[city] = admin_
plt.tight_layout()
四个目标城市的公共交通停靠点位置——未过滤。
既然我们已经研究了五个城市,我们可以对 GTFS 和公共交通数据的一般性得出更多的结论。一方面,从技术角度来看,这些数据在城市和地方交通机构之间确实非常相似。另一方面,当将数据添加到管道中时,必须单独检查每个城市,因为如斯德哥尔摩和多伦多的案例所示,与城市相关的数据实际上属于更大的,例如市政级别的区域。好消息是,例如,添加来自 OpenStreetMap 的行政边界,使得将数据过滤到实际城市变得容易!
实际上,我在前一个代码块的倒数第二行进行了最后的过滤步骤,并将城市级别的停靠点位置存储在stops字典中。
f, ax = plt.subplots(2,2,figsize=(15,15))
indicies = [(i,j) for i in range(2) for j in range(2)]
for idx, (city, admin_) in enumerate(admin.items()):
bx = ax[indicies[idx]]
gdf_stops = stops[city]
gdf_stops.plot(ax=bx, markersize = 3, alpha = 0.5)
admin_.plot(ax=bx, color = 'none', edgecolor = 'k', linewidth = 4)
bx.axis('off')
plt.tight_layout()
四个目标城市的公共交通停靠点位置——过滤后。
3. 出发时间
首先,我们探讨了停靠点的位置;现在,让我们看看何时应该接近这些点。另外,假设公共交通服务对需求和居民的时间表及习惯有显著见解,我们也可以从这些时间表的时间模式中获取一些见解。
from datetime import datetime
def parse_time_string(time_string):
hour_value = int(time_string.split(':')[0])
if hour_value > 23: hour_value = hour_value-24
time_string = str(hour_value) + ':' + time_string.split(':', 1)[1]
parsed_time = datetime.strptime(time_string, "%H:%M:%S")
return parsed_time
stop_times = {}
for city in cities:
print(city)
df_stop_times = pd.read_csv(root + '/' + city + '/' + updated[city] + '/stop_times.txt')
df_stop_times['departure_time'] = df_stop_times['departure_time'].apply(parse_time_string)
stop_times[city] = df_stop_times
for city, df_stop_times in stop_times.items():
print(city, len(df_stop_times))
df_stop_times.head(3)
这个代码块的输出:
每个城市的停靠点数量及其时间表的样本。
在读取、解析和转换出发时间信息后,让我们创建一系列关于每个城市城市级别出发时间分布的视觉图:
f, ax = plt.subplots(4,1,figsize=(15,20))
for idx, (city, df_stop_times) in enumerate(stop_times.items()):
bx = ax[idx]
df_stop_times = df_stop_times
df_stop_times['hour_minute'] = df_stop_times['departure_time'].dt.strftime('%H:%M')
departure_counts = df_stop_times['hour_minute'].value_counts().sort_index()
departure_counts.plot(kind='bar', color='steelblue', alpha = 0.8, width=0.8, ax = bx)
bx.set_xlabel('Hour-Minute of Departure', fontsize = 18)
bx.set_ylabel('Number of Departures', fontsize = 18)
bx.set_title('Histogram of Departure Times by Hour-Minute in ' + city, fontsize = 26)
bx.set_xticks([ijk for ijk, i in enumerate(departure_counts.index) if ':00' in i])
bx.set_xticklabels([i for i in departure_counts.index if ':00' in i])
plt.tight_layout()
这个代码块的输出图像:
每个目标城市的每日出发频率。
虽然对这些图表进行适当、详细的解释可能需要一些严肃的地方城市规划知识,但也可以根据个人经验和常识进行一些推测。
首先,考虑到这些城市气候、文化背景和昼夜节律的差异,我对它们在早上 8 点的对齐程度感到惊讶。布达佩斯、柏林和斯德哥尔摩都有几乎相同的起床模式,而有趣的是,多伦多的起床时间稍长一些。
其次,下午结束的时间似乎有一个明确的顺序——柏林最早结束,其次是布达佩斯、多伦多和斯德哥尔摩。一个有趣的问题是这些时间是否标志着工作班次的结束或人们回家的时间。
第三,布达佩斯和斯德哥尔摩的早晚高峰明显可区分,而柏林和多伦多则不那么明显,这可能与这些城市的中心化程度有关。
4. 空间分布
对于表格数据、浮点数/整数变量和空间信息,计算一般趋势是可能的。这个过程被称为空间索引。空间索引在实践中意味着我们将一个区域,例如城市的行政边界,划分成规则的网格。我的个人最爱是 Uber 的H3 六边形网格——你可以在这里了解更多。使用这个网格,我们可以通过计算每个网格单元格中的停靠点数量来高效地进行空间聚合!让我们尝试不同的六边形分辨率。
import geopandas as gpd # version: 0.13.1
import h3 # version: 3.7.3
from shapely.geometry import Polygon # version: 1.8.5.post1
import numpy as np # version: 1.22.4
# use this function to split an admin polygon into a hexagon grid
def split_admin_boundary_to_hexagons(admin_gdf, resolution):
coords = list(admin_gdf.geometry.to_list()[0].exterior.coords)
admin_geojson = {"type": "Polygon", "coordinates": [coords]}
hexagons = h3.polyfill(admin_geojson, resolution, geo_json_conformant=True)
hexagon_geometries = {hex_id : Polygon(h3.h3_to_geo_boundary(hex_id, geo_json=True)) for hex_id in hexagons}
return gpd.GeoDataFrame(hexagon_geometries.items(), columns = ['hex_id', 'geometry'])
# let's test two resolutions for Budapest
hexagons_gdf_h8 = split_admin_boundary_to_hexagons(admin['Budapest'], 8)
hexagons_gdf_h9 = split_admin_boundary_to_hexagons(admin['Budapest'], 9)
hexagons_gdf_h8.plot()
hexagons_gdf_h9.plot()
布达佩斯——六边形网格。
技术说明: 当我在处理下一个单元时,我遇到了一个错误——即,柏林的行政边界是一个多边形,而所有其他城市的行政区域都是简单的多边形,这也是我的 split_admin_boundary_to_hexagons 函数期望的输入。因此,我检查了一下,结果发现由于某种原因,存在一个额外的小多边形,面积几乎为零,所以我不得不通过运行以下命令来清理它:
admin['Berlin'] = gpd.GeoDataFrame(admin['Berlin'].geometry.explode(), columns = ['geometry']).head(1)
现在我们已经有了 H3 网格构建的工作原型,让我们计算四个示例城市中每个六边形的停靠点数量和停靠次数。为了计算停靠点位置的数量,我们只需进行空间连接;但要计算停靠次数,我们还需要结合出发时间。这样,我们也使空间和时间维度得以结合!
from matplotlib.colors import LogNorm
resolution = 7
for city in cities:
# create the hexagon grid
hexagons_gdf = split_admin_boundary_to_hexagons(admin[city], resolution)
# merge stops and stopppings
gdf_stop_times = stops[city].merge(stop_times[city], left_on = 'stop_id',right_on = 'stop_id')
# compute the number of unique stops and stoppings in each hexagon
gdf_stops = gpd.sjoin(gdf_stop_times, hexagons_gdf)
nunique = gdf_stops.groupby(by = 'hex_id').nunique().to_dict()['stop_id']
total = gdf_stops.groupby(by = 'hex_id').count().to_dict()['stop_id']
hexagons_gdf['nunique'] = hexagons_gdf.hex_id.map(nunique).fillna(0)
hexagons_gdf['total'] = hexagons_gdf.hex_id.map(total).fillna(0)
# visualize the number of stops and stoppings
f, ax = plt.subplots(1,2,figsize=(15,5))
plt.suptitle(city + ', resolution = ' + str(resolution), fontsize=25)
hexagons_gdf.plot(column = 'nunique', cmap = 'RdYlGn', legend = True, ax = ax[0])
hexagons_gdf.plot(column = 'total', cmap = 'RdYlGn', legend = True, ax = ax[1])
# for log-saled coloring:
# hexagons_gdf.plot(column = 'total', cmap = 'RdYlGn', legend = True, ax = ax[1], norm=LogNorm(vmin=1, vmax = hexagons_gdf.total.max()))
ax[0].set_title('Number of unique stops', fontsize = 17)
ax[1].set_title('Number stoppings', fontsize = 17)
for aax in ax: aax.axis('off')
让我们看看这是什么样的:
5. 不同的交通方式
现在,我们已经看到时间和空间中的聚合趋势。接下来,让我们放大并提取与数据记录相对应的不同交通方式。这些信息通常存储在 routes.txt 中,在 route_type 列下。这个、这个 和 这个 映射可以将交通方式代码编码为其英文名称。
基于这些,我创建了官方地图和一个简化版本,稍后我会使用。在简化版本中,我将例如‘电车、街车、轻轨’(代码 0)和‘电车服务’(代码 900)两个类别都重命名为‘电车’。
map_complete = { 0 : 'Tram, Streetcar, Light rail',
1 : 'Subway, Metro',
2 : 'Rail',
3 : 'Bus',
4 : 'Ferry',
11 : 'Trolleybus',
100 : 'Railway Service',
109 : 'Suburban Railway',
400 : 'Urban Railway Service',
401 : 'Metro Service',
700 : 'Bus Service',
717 : 'Regional Bus',
900 : 'Tram Service',
1000: 'Water Transport Service'}
map_simple = { 0 : 'Tram',
1 : 'Subway',
2 : 'Railway',
3 : 'Bus',
4 : 'Ferry',
11 : 'Trolleybus',
100 : 'Railway',
109 : 'Railway',
400 : 'Railway',
401 : 'Subway',
700 : 'Bus',
717 : 'Bus',
900 : 'Tram',
1000 : 'Ferry', }
现在,直观地展示每种交通方式的频率,测量方式是通过可能的路线数量:
from collections import Counter
import matplotlib.pyplot as plt
# this function does some nice formatting on the axis and labels
def format_axis(ax):
for pos in ['right', 'top']: ax.spines[pos].set_edgecolor('w')
for pos in ['bottom', 'left']: ax.spines[pos].set_edgecolor('k')
ax.tick_params(axis='x', length=6, width=2, colors='k')
ax.tick_params(axis='y', length=6, width=2, colors='k')
for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(12)
for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(12)
f, ax = plt.subplots(1,4,figsize = (15,4))
routes = {}
for idx, city in enumerate(cities):
# get the data, map the english names
df_route = pd.read_csv(root + '/' + city + '/' + updated[city] + '/routes.txt')
df_route['route_type'] = df_route['route_type'].astype(int)
df_route['route_type_en'] = df_route['route_type'].map(map_simple)
D = dict(Counter(df_route.route_type_en))
routes[city] = df_route
# define my color palette
transport_colors = {'Railway': '#A65C47',
'Bus': '#0BB3D9',
'Tram': '#F2B705',
'Ferry': '#997CA6' ,
'Trolleybus' : '#D91818',
'Subway' : '#0869A6'}
# create the bar charts
labels = D.keys()
values = D.values()
colors = [transport_colors[l] for l in labels]
ax[idx].bar(labels, values, color = colors)
ax[idx].set_xticklabels(labels, fontsize = 10, rotation = 60, ha = 'right')
format_axis(ax[idx])
ax[idx].set_title(city, fontsize = 18)
ax[idx].set_ylabel('Number of routes', fontsize = 15)
ax[idx].set_yscale('log')
plt.tight_layout()
这个代码块输出了以下图形:
研究城市间交通模式的分布。
6. 路线形状
在进行时空聚合和交通工具缩放后,最后但同样重要的是,我想强调我最喜欢的部分——记录每条路线形状的 GPS 点序列。
以这种方式访问:
city = 'Budapest'
df_routes = pd.read_csv(root + '/' + city + '/' + updated[city] + '/shapes.txt')
df_routes
本单元的输出:
布达佩斯的路线形状表。
现在,将布达佩斯的路线表转换为由 LineStrings 组成的 GeoDataFrame,添加上一节的路线类型映射,并使用连接路线(和交通工具)到形状的 tripst.txt GTFS 文件。
结果将是一个公共交通道路网络的地理地图,我们可以根据前一部分的颜色编码来上色,例如,按照交通工具的不同颜色进行标记。
import contextily as ctx
for city in cities:
df_shapes = pd.read_csv(root + '/' + city + '/' + updated[city] + '/shapes.txt')
# transform the shape a GeoDataFrame
df_shapes['shape_pt_lat'] = df_shapes['shape_pt_lat'].astype(float)
df_shapes['shape_pt_lon'] = df_shapes['shape_pt_lon'].astype(float)
df_shapes['geometry'] = df_shapes.apply(lambda row: Point(row['shape_pt_lon'], row['shape_pt_lat']), axis=1)
df_shapes = gpd.GeoDataFrame(df_shapes[['shape_id', 'geometry']])
df_shapes.crs = 4326
gdf_shapes = gpd.GeoDataFrame(df_shapes[['shape_id', 'geometry']].groupby(by = 'shape_id').agg(list))
gdf_shapes = gdf_shapes[[len(g)>1 for g in gdf_shapes['geometry'].to_list()]]
gdf_shapes['geometry'] = gdf_shapes['geometry'].apply(lambda x: LineString(x))
gdf_shapes = gpd.GeoDataFrame(gdf_shapes)
# let's only keep those routes which have at least one segment that falls into the city
gdf_shapes['shape_id'] = gdf_shapes.index
shapes_in_city = set(gpd.overlay(gdf_shapes, admin[city]).shape_id.to_list())
gdf_shapes = gdf_shapes[gdf_shapes.shape_id.isin(shapes_in_city)]
# map the means of transport
df_route = routes[city][['route_id', 'route_type_en']].drop_duplicates()
df_trips = pd.read_csv(root + '/' + city + '/' + updated[city] + '/trips.txt')[['route_id', 'shape_id']].drop_duplicates()
df_trips = df_trips.merge(df_route, left_on = 'route_id', right_on = 'route_id')
# merge these
gdf_shapes = gdf_shapes.merge(df_trips, left_index = True, right_on = 'shape_id')
# create the visuals in matplotlib
f, ax = plt.subplots(1,1,figsize=(15,15))
cax = admin[city].plot(ax=ax, edgecolor = 'k', color = 'none')
cax = admin[city].plot(ax=ax, edgecolor = 'k', alpha = 0.52)
for transport, color in transport_colors.items():
gdf_shapes[gdf_shapes.route_type_en==transport].plot(ax=ax, color = color, alpha = 0.9, linewidth = 1)
ax.axis('off')
ctx.add_basemap(ax, alpha = 0.8, crs = 4326, url = ctx.providers.CartoDB.DarkMatterNoLabels)
plt.tight_layout()
plt.savefig('figure7_'+city+ '.png', dpi = 150, bbox_inches = 'tight')
plt.close()
每个城市的结果图:
布达佩斯——仅有少数线路延伸至城市聚集区之外。
多伦多——拥有高度集中的公共交通网络。
柏林——展示了其铁路线路如何从城市开始并交织于乡村。
斯德哥尔摩——其公共交通网络从城市出发,但扩展到乡村。
结论
在这篇文章中,我简要概述了探索 GTFS 数据所需的技术细节,并以流线型的方式探索不同城市的公共交通系统。虽然所有这些城市使用的通用格式使得从一个城市到另一个城市的分析任务变得非常方便和易于采用,但我的发现和结果也突出了由于不同城市定义和术语的细微差别,逐个城市级别的验证和合理性检查的重要性。
最后,如果你想回到过去,查看类似的交通相关分析回溯到古罗马,请查看我之前的 TDS 文章,“所有道路都通往罗马吗?”。
分位数损失与分位数回归
学习如何调整回归算法以预测数据的任意分位数
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 1 月 28 日
--
介绍
回归是一个机器学习任务,其目标是基于一组特征向量预测一个实际值。存在各种回归算法:线性回归、逻辑回归、梯度提升或神经网络。在训练过程中,这些算法中的每一个都根据用于优化的损失函数调整模型的权重。
损失函数的选择依赖于特定任务和所需实现的度量值。许多损失函数(如 MSE、MAE、RMSLE 等)专注于根据特征向量预测变量的期望值。
在本文中,我们将查看一种特殊的损失函数,称为 分位损失,用于预测特定变量分位数。在深入分位损失的细节之前,让我们简要回顾一下分位数的术语。
分位数
分位数 qₐ 是一个值,它将给定的数字集合划分为使得 α ** 100%* 的数字小于该值,而 (1 — α) * 100% 的数字大于该值。
分位数 qₐ 对于 α = 0.25、α = 0.5 和 α = 0.75 经常在统计中使用,被称为 四分位数。这些四分位数分别表示为 Q₁、Q₂ 和 Q₃。三个四分位数将数据分成 4 个相等的部分。
类似地,有 百分位数 p,将给定的数字集合分成 100 个相等的部分。一个百分位数表示为 pₐ,其中 α 是小于相应值的数字百分比。
四分位数 Q₁、Q₂ 和 Q₃ 分别对应于百分位数 p₂₅、p₅₀ 和 p₇₅。
在下面的例子中,找到了给定数字集合的所有三个四分位数。
一个示例显示了给定数字集合的所有三个四分位数。第一个四分位数 Q₁ 等于 10,因为 25%的值小于 10,而 75%的值大于 10。类似地,这种类推适用于其他四分位数。
分位损失
旨在预测特定变量分位的机器学习算法使用分位损失作为损失函数。在进行表述之前,让我们考虑一个简单的例子。
假设有一个问题,其目标是预测一个变量的第 75 百分位。实际上,这句话等同于预测误差在 75%的情况下必须为负,而在其余 25%的情况下必须为正。这就是在分位损失背后的实际直觉。
表述
分位损失公式如下所示。α 参数指的是需要预测的分位数。
分位损失公式
分位损失的值取决于预测值是否小于或大于真实值。为了更好地理解其背后的逻辑,假设我们的目标是预测第 80 分位数,因此将 α = 0.8 插入公式。结果,公式看起来是这样的:
基本上,在这种情况下,分位损失对低估的预测处罚是对高估的预测的 4 倍。这样,模型对低估的误差会更加严格,并且更频繁地预测较高的值。结果是,拟合的模型在平均情况下会高估结果,大约在 80%的情况下,并且在 20%的情况下会产生低估。
现在假设为同一目标获得了两个预测。目标值为 40,而预测值为 30 和 50。让我们计算两种情况的分位损失。尽管在两种情况下绝对误差都是 10,但损失值却不同:
-
对于 30,损失值为l = 0.8 * 10 = 8
-
对于 50,损失值为l = 0.2 * 10 = 2。
下面的图表显示了在真实值为 40 时,不同α参数的损失值。
相反,如果α值为 0.2,则高估的预测将比低估的预测受到 4 倍的惩罚。
预测某个变量分位数的问题称为分位数回归。
示例
我们将创建一个包含 10,000 个样本的合成数据集,其中玩家在视频游戏中的评分将基于游戏时长来估计。
数据集生成
预测变量(小时)与目标(评分)之间的散点图
我们将数据按 80:20 的比例拆分为训练集和测试集:
按 80:20 比例拆分数据集
为了比较,我们将建立三个具有不同α值的回归模型:0.2、0.5 和 0.8。每个回归模型将由 LightGBM 创建——这是一个高效实现梯度提升的库。
根据官方文档的信息,LightGBM 允许通过将objective参数指定为‘quantile’并传递相应的alpha值来解决分位数回归问题。
训练了 3 个模型后,可以用它们来获得预测(第 6 行)。
训练 LGBM 模型,目标 = ‘quantile’
我们通过下面的代码片段来可视化预测:
预测变量(小时)与真实/预测目标值之间的散点图
从上面的散点图可以看出,随着α值的增大,模型生成的结果往往会更加高估。此外,我们还将比较每个模型的预测与所有目标值。
不同模型的预测比较
这将产生以下输出:
从输出中可以清楚地看到:对于任何α,预测值在大约α * 100%的情况下都高于真实值。因此,我们可以实验性地得出结论,我们的预测模型工作正常。
分位数回归模型的预测误差在大约α ** 100%的情况下为负值,而在(1 — α) * 100%的情况下为正值。
结论
我们发现了分位数损失——一种灵活的损失函数,可以纳入任何回归模型中以预测某个变量的分位数。以 LightGBM 为例,我们看到如何调整模型以解决分位数回归问题。事实上,许多其他流行的机器学习库也允许将分位数损失设置为损失函数。
本文中使用的代码可在 GitHub 上获取。
[## ML-medium/quantile_regression.ipynb at master · slavastar/ML-medium
本库包含了我在 Medium 博客上的数据科学文章中的可重现代码……
资源
除非另有说明,否则所有图片均由作者提供。
量化及其他:将 LLMs 的推理时间减少 80%
原文:
towardsdatascience.com/quantisation-and-co-reducing-inference-times-on-llms-by-80-671db9349bdb
·发布于Towards Data Science ·12 分钟阅读·2023 年 10 月 27 日
--
来源:www.pexels.com/photo/cropland-in-autumn-18684338/
量化是一种用于多种算法的技术,但随着最近大型语言模型(LLMs)的涌现,这一技术变得越来越普及。在本文中,我旨在提供有关 LLMs 量化的信息,以及这一技术对在本地运行这些模型的影响。我将探讨量化之外的另一种策略,这可以进一步减少运行这些模型的计算需求。我将解释这些技术为何可能对你有兴趣,并展示一些带有代码示例的基准测试,以说明这些技术的有效性。我还简要介绍了硬件需求/建议以及实现 LLM 目标的现代工具。在后续文章中,我计划提供逐步说明和代码,用于微调你自己的 LLM,请留意。
TL;DR — 通过量化我们的 LLM 并更改张量dtype,我们能够在具有 2 倍参数的 LLM 上进行推理,同时将Wall time减少 80%。
一如既往,如果你希望讨论我在这里提到的任何内容,请联系我。
本文中的所有观点均为我个人意见。本文未获得赞助。
量化(LLMs 的量化)是什么?
量化使我们能够通过将网络的权重和偏差从原始的浮点格式(例如 32 位)转换为更低精度的格式(例如 8 位),来减少神经网络的大小。原始的浮点格式可以根据模型的架构和训练过程有所不同。量化的最终目的是减少模型的大小,从而减少内存和计算需求,以运行推理和训练模型。如果你尝试自己量化模型,量化过程可能会变得非常繁琐。这主要是因为某些供应商的硬件不支持。幸运的是,可以通过使用特定的第三方服务和软件来绕过这些限制。
就我个人而言,我在 Mac 上量化 Meta 的 Llama-2 等 LLMs 时遇到了不少麻烦。这主要是由于对标准库(或任何带有自定义 CUDA 内核的内容)的支持不足。不过,optimum 和 onnx 等第三方工具确实存在,可以使我们的生活稍微轻松一些。
快速且简单的选项是下载HuggingFace(HF)上提供的任何预量化模型。我特别想提到TheBloke,他们提供了许多流行 LLMs 的量化版本,包括本文中将演示的 LLama-2 模型。有关如何对每个模型进行推理的说明可以在各自的模型卡上找到。
如果你想自己运行量化模型但没有自己的 GPU,我建议在以下网站租用 NVIDIA 硬件:
· Vast.ai— 免责声明 — 使用时请自行判断。这里实际上是租用随机人的 GPU。建议在使用此服务时不要共享任何敏感信息。不过,它确实非常便宜。
如果你想购买 NVIDIA 硬件并希望获得最佳性价比,我建议购买 2 个二手 RTX3090。虽然更新的 RTX4090 具有更好的基准性能,但大型语言模型(LLMs)更需要高内存读写速度,而不是更高的处理器速度。3090 和 4090 型号在内存读写速度上差异不大,因此我认为较旧的型号提供了更好的价值。
如果你有预算,选择几乎没有限制。
作为免费的选项,我建议:
· Google colab— 在运行时提供免费的 GPU,但有一定的限制(免费层的 RAM 也有限制,不过你可以付费获得更多)。
· Kaggle 也在他们的笔记本中提供 GPU。
如果你坚持使用 Mac 硬件,我建议使用 M2 Ultra,并配备尽可能多的 RAM(理想情况下为 64GB 以上)。这仍会比上述 NVIDIA 选项慢,但如果你只是希望在 LLM 上运行推理而不是训练自己的模型,这绝对是可行的。如果你在 Mac 硬件上量化自己的模型时遇到问题,我只能推荐Georgi Greganov 的 llama.cpp。使用这个 repo,你可以下载并编译 Meta 的 llama 2 模型为 C++并将其量化为 4 位精度。然后我们可以在这些模型上运行推理。该 repo 的 README 提供了清晰的操作说明。
那么,为什么我们要在本地运行/托管自己的 LLM 呢?
简短的答案是,一如既往,这取决于情况。撰写本文时,OpenAI 的 GPT4(通过 ChatGPT 提供)被广泛认为是表现最好的 LLM。我认为定价也非常合理,而且模型本身无疑比我上述提到的策略更容易互动。你唯一需要安装的依赖是你的账户信息和信用卡号码;)。
我确实认为运行本地 LLM 有很强的理由:
询问关于专有文档/数据的问题。
你可以使用自己的上下文和数据对 LLM 进行微调。通过自己进行这些操作,你不会将任何信息分享给第三方,这是一大优势。
询问关于 2021 年 9 月知识截止后的主题 (GPT4)。
我看到了一些 GPT4 在此时间段之后提供详细信息的情况,但模型经常指出知识截止存在。
对模型进行微调以解决特定于你场景的问题。
再次回到第一点,你可以调整自己的 LLM 模型以适应你的需求。
你可以看到这些 LLM 的底层工作原理。你可以检查模型架构,并进一步发展你对技术的理解。
这很免费(前提是你已经有自己的硬件,并且不计算运行所需的电费)
量化最终将帮助你在本地运行 LLM,使用比在未量化模型上运行推理时更少的计算资源。
基准比较 Llama-2
我现在将演示量化对 Meta 的 Llama-2 7B 和 13B 模型的影响。我在租用的 GPU 上进行了这些实验,如上所述,但也在 Google Colab 笔记本中进行了测试,以确认结果是可重复的。我们唯一需要做的编辑是运行 7B 参数模型的 8 位量化版本作为我们在 Colab 笔记本中的基准,否则在运行推断时会超出内存限制(在我看来,这已经充分证明了在运行 LLMs 时使用量化的理由!)。不过可以跟着代码示例操作 — 这些代码示例直接来自于我的免费 Colab 笔记本版本。
如果你使用的是 colab 笔记本 — 当安装诸如accelerate和bitsandbytes等依赖项时,请在笔记本中使用常规的 pip 安装。安装完成后,重启运行时。否则,软件包将无法被识别。另外,不要忘记将运行时更改为 GPU,选择运行时 > 更改运行时类型 > T4 GPU。
我应该补充说明,这一切的前提是你已经获得了 Meta 和 HF 的模型访问权限。为此,你必须首先通过此链接向 Meta 提交申请表:
ai.meta.com/resources/models-and-libraries/llama-downloads/
接受确认的时间可能从 2 分钟到 2 天不等。请注意,你用于 Meta 表单和 HF 账户的电子邮件地址必须匹配,以便通过 HF API 使用模型。
一旦确认已收到,你可以登录到 Hugging Face 并开始使用模型。
让我们开始吧。
在 8 位量化下对 Llama2–7B 基础模型进行推断。
首先处理我们的导入 — 在此阶段,如果你遇到任何错误消息,请根据需要运行 pip 安装 — 安装完成后不要忘记重启你的运行时(如上所述)。
from transformers import AutoModelForCausalLM,AutoTokenizer
import torch
from accelerate import Accelerator
接下来我们从 Hugging Face 复制模型名称,以便 API 可以下载相关模型。我们还需要输入我们的 HF 访问令牌。可以通过在 HF 网站的右上角选择你的个人资料 > 设置 > 访问令牌 > 生成令牌或复制现有令牌来找到它。
model_name = "meta-llama/Llama-2-7b-hf"
hf_key = "insertyourkeyhere"
现在让我们下载模型:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, load_in_8bit=True, token=hf_key)
在这里我们使用device_map参数来选择我们希望使用的硬件。在这种情况下,它选择了第一个可用的 GPU。这里可以加载自定义的device_maps,但这超出了本文的范围。还要注意load_in_8bit参数。这是我们为了减少推断运行时的内存需求而进行的量化步骤。如果你希望使用 LLMs 构建更大的项目/产品,这个简单的技术可以用于在资源有限的设备(边缘设备或手机等)上部署模型。
接下来我们设置我们的分词器:
tokeniser = AutoTokenizer.from_pretrained(model_name, token=hf_key)
prompt = "A great hobby to have is "
toks = tokeniser(prompt, return_tensors="pt")
输入你希望的提示。我们使用的基础模型经过了文本补全的训练。
现在让我们对标记化的提示进行推理。如果有任何语法对你来说是新的,请随意查看 HF 文档。本质上,我们正在解包toks对象的内容并将其传递给我们的 GPU。输出限制为最多 15 个标记(如果愿意,可以编辑此参数)。model.generate()方法用于使用我们的 Llama2 模型生成输出。完成后,我们再次将输出传输到 CPU 内存中,以便查看我们的输出。
%%time
res = model.generate(**toks.to("cuda"), max_new_tokens=15).to('cpu')
res
# OUPUT
# CPU times: user 7.47 s, sys: 1.17 s, total: 8.64 s
# Wall time: 16.4 s
让我们分解这些时间指标,以更好地理解我们所看到的内容。CPU 时间分解为三个主要组成部分:
1. user — 这表示在用户模式代码中花费的时间。换句话说,就是 CPU 执行 Python 代码所需的时间。在这种情况下,花费了 7.47 秒。这个指标通常也被称为用户时间。
2. sys — 这表示在系统调用或内核模式代码中花费的 CPU 时间。它是 CPU 执行操作系统相关任务的时间。在我们的案例中是 1.17 秒。
3. total — 是用户时间和系统时间的总和。
接下来是墙时间。这指的是运行我们代码块所需的‘现实世界’时间。
CPU 时间和墙时间(7.76 秒)之间的差异是由于运行模型推理时涉及的其他内存密集型操作。这些操作包括但不限于 GPU 内存传输、调度、I/O 操作等。
让我们解码结果以查看模型的输出:
tokeniser.batch_decode(res)
# OUTPUT
# ['<s> A great hobby to have is 3D printing.\nYou can make whatever you want in 3D']
太棒了。我们已经成功地在一个基础量化 LLM 上进行了推理。
我们可以进一步使用的一种技术是将不同的数据类型分配给我们在计算过程中使用的 Llama2 模型中的张量。在之前,我们通过使用load_in_8bit=True参数来量化模型的参数,而现在我们将使用torch_dtype=torch.bfloat16参数,以减少模型在推理过程中的内存使用。这第二种方法不被认为是量化技术,因为它只改变了模型张量使用的数据类型,而第一种方法涉及通过在加载期间将模型参数的精度降低到 8 位来进行量化。
这两种方法被认为是减少运行我们的 LLM 的计算需求的有效技术。让我们看看第二种技术的有效性如何。
让我们用新的参数更新我们的模型:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, torch_dtype=torch.bfloat16)
此阶段 colab 可能会提示内存不足。只需通过选择 Runtime > Restart runtime 来重启运行时,并重新运行笔记本中的所有相关单元格。
现在我们用更新后的张量数据类型在模型上运行推理:
%%time
res = model.generate(**toks.to("cuda"), max_new_tokens=15).to('cpu')
res
# OUTPUT
# CPU times: user 1.65 s, sys: 440 ms, total: 2.09 s
# Wall time: 4.7 s
哇。通过调整张量的数据类型,我们将总 CPU 时间减少了 6.66 秒。我们的墙时间减少了约 71%。让我们解码输出,看看是否注意到任何数据类型的变化影响:
tokeniser.batch_decode(res)
# OUTPUT
# ['<s> A great hobby to have is 3D printing. It’s a fun way to create new things,']
我们可以使用各种指标和测试来评估和比较模型的输出。在这篇文章中,我将简单地采用人工评估。两个输出都合格、连贯且相关。考虑到在我们的第二个示例中墙上时间减少了 71%,我认为我们目前的技术是成功的。
让我们看看我们能多快在预量化的 Llama2–7B 模型上运行推理。
在具有更新张量数据类型的预量化 Llama2–7B 上进行推理。
托TheBloke的福,我们能够访问 Meta 的 Llama-2 模型的预量化版本。有关量化过程的详细信息可以在模型卡中找到。
我们将使用相同的张量数据类型技术,这种技术使我们在墙上时间上取得了显著的减少。这次使用预量化模型。
让我们更新模型:
model_name = 'TheBloke/Llama-2-7b-Chat-GPTQ'
名称末尾的 Q 表示模型已经完成了量化处理。
现在我们下载具有更新张量数据类型的模型:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, torch_dtype=torch.float16)
更新分词器:
tokeniser = AutoTokenizer.from_pretrained(model_name, token=hf_key)
prompt = "A great hobby to have is "
toks = tokeniser(prompt, return_tensors="pt")
运行推理:
%%time
res = model.generate(**toks.to("cuda"), max_new_tokens=15).to('cpu')
res
# OUTPUT
# CPU times: user 1.44 s, sys: 351 ms, total: 1.79 s
# Wall time: 4.33 s
我们进一步进行了改进。正如你所见,总 CPU 时间减少了约 14%。墙上时间减少了约 8%。
让我们检查输出:
tokeniser.batch_decode(res)
# OUTPUT
# ['<s> A great hobby to have is 3D printing.\n 3D printing is a fascinating hob']
现在很明显,由于我们的令牌限制设置为 15,最终的词被裁剪了。我确认我增加了令牌限制,最终的词被评估为 hobby。在人工验证方面,我仍然认为这是合格的。
现在让我们结合我们学到的所有内容,并对更大的 Llama-2–13B 模型进行推理。该模型的参数数量几乎是我们之前测试的模型的 2 倍。我们将与我们训练的第一个模型(基础的 Llama-2–7B,使用 8 位量化)进行基准测试,看看两者的比较情况。
在具有更新张量数据类型的预量化 Llama2–13B 上进行推理。
我们将使用相同的语法,但当然会更新模型名称。
model_name = 'TheBloke/Llama-2-13B-GPTQ'
下载具有更新张量数据类型的模型:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, torch_dtype=torch.float16)
更新分词器:
tokeniser = AutoTokenizer.from_pretrained(model_name, token=hf_key)
prompt = "A great hobby to have is "
toks = tokeniser(prompt, return_tensors="pt")
运行推理:
%%time
res = model.generate(**toks.to("cuda"), max_new_tokens=15).to('cpu')
res
# OUTPUT
# CPU times: user 1.45 s, sys: 167 ms, total: 1.61 s
# Wall time: 3.22 s
让我们将其放入上下文中:
推理时间 Meta-Llama-2–7B(8 位量化)与预量化的 LLama-2–13B(使用 float16 张量)
我们几乎将参数数量翻倍(从 7B 增加到 13B)。我们将总 CPU 时间减少了 81%,墙上时间减少了 80%。我不会撒谎,我对这个结果非常满意。
让我们获取输出:
tokeniser.batch_decode(res)
# OUTPUT
# ['<s> A great hobby to have is 3D printing. It is a great way to create things that you want']
我们不仅通过减少计算需求大幅缩短了推理时间,而且我认为 13B 模型的输出也比我们运行推理的第一个 7B 模型更为连贯。
我希望这篇文章能向你展示这些技术在大幅度减少这些大语言模型(LLMs)推理时间方面的有效性。在我们的第一个例子中,甚至在没有应用我们自己的量化方法之前,模型是无法加载到笔记本中的。通过使用这些技术,我们能够部署一个更大的大语言模型(参数数量),将推理时间减少大约 80%,并且改善输出。如果这还不是一个积极的结果,我不知道什么是了!
我很高兴讨论和交流这里涉及的任何话题。
所有图片属于作者,除非另有说明。
咖啡数据中的数量与质量
咖啡数据科学
我的实验数据收集
·
关注 发表在 Towards Data Science ·5 min read·Mar 24, 2023
--
在咖啡中,味道是至高无上的,但使用总溶解固体 (TDS) 来测量提取效率的量化数据已经成为评估硬件和技术的有用工具。TDS 是通过折射仪来测量的,通常偏好使用数字折射仪。
在过去一年里,数字折射仪的成本显著下降。DiFluid 推出了两款价格远低于标准的 VST 或 Atago 的折射仪。目前的数据表明,DiFluid R2 与 VST 或 Atago 一样有效。我认为这款折射仪提出了一个有趣的问题,即咖啡数据的可获取性越来越高:在数据收集过程中,质量还是数量更为重要?
一些数据的示例,所有图片均由作者提供。
总结一下,折射仪可以测量总溶解固体(TDS),这是理解咖啡浓度和计算萃取效率的绝佳指标。它在我的探索中已成为一个重要工具。
需要明确的是,我只制作高浓度的浓缩咖啡(12% 到 20% TDS,16% 到 24% EY),而折射仪可能在处理如过滤咖啡等低浓度饮品时面临其他挑战。然而,我不会讨论这些话题。
尽管我拥有三台数字折射仪:Atago、DiFluid 和 DiFluid R2,但我尚未发布如何使用折射仪进行 TDS 测量的常规操作。我一直在通过多次探索,以数据证明我的某些操作是否相对于收集样品所花费的时间是有价值的,例如:
-
冷却样品至指定温度(通常与校准温度相同)。
-
过滤样品时使用注射器过滤器
-
每个样品使用新的移液管
-
用酒精清洁玻璃样品盘
-
每个样品都需校准设备
包含所有步骤的常规操作可能需要一段时间,而且这样会减少在相同时间内收集的数据量。
协议遵守
我做数据相关工作已经很长时间了(超过十年)。在用户研究中通常会出现的问题之一是数据质量与数量的关系。为了获得更好的数据质量,一个人必须更好地遵守协议,但遵守协议需要时间。然而,一旦应用机器学习算法,数据中仍会添加一定量的噪音。结果显示,即使质量较低的更多数据,对于某些实验来说也可能更可取,因为我们没有整天的时间去做所有事情。
所以即使信号中存在噪音,快速收集更多样品可以使噪音被平均。我也喜欢将这一点应用到咖啡上,因为我还有其他生活中的事情要做。
并不是每个人都有实验室、资金和时间来控制所有变量。所以尽可能控制。即使有噪音,只要噪音是稳定的,那么它比随机噪音更可控。最糟糕的是噪音中的系统偏差。
咖啡样品
另一个需要考虑的因素是,咖啡的折射测量尚不完全明确。我们知道折射率与 TDS 之间有联系,但仍有一些灰色地带。糖水有非常清晰的折射率,但通过光学折射仪观察,咖啡并没有那么明显的分界线。
从咖啡开始到结束的溶解物是否造成相同的折射?
任何给定的咖啡饮品有多均匀?
基本上,假设折射仪是完美的,那么固有噪声是多少?如果这个噪声比其他协议步骤大得多,那么这些步骤应重新考虑。
其他需要考虑的问题:
-
数字折射仪是否会出现校准漂移?
-
这些设备是否也能优雅地老化?
-
如果两个样品在比校准更高的温度下采集,这重要吗?温度是否以相同的方式影响读数,还是一个受控变量?
DiFluid 设备很有趣,因为它们还输出折射率。这有助于显示读数是否由于温度变化或其他因素造成。
当前常规
我将分享我目前的常规,但这可能会有所变化。这一常规是数据驱动的,以下是简短版本,后跟长版本及其理由:
-
设备:DiFluid R2
-
校准:我很少校准我的设备。
-
样品采集:我搅拌样品并使用移液管进行采集。之后我会冲洗并重复使用移液管。
-
样品过滤:我不过滤我的样品。
-
样品温度:我没有对样品温度进行校正。
-
样品数量:1
-
清洁镜头:我使用微纤维毛巾。
长格式:
-
R2 至少与 Atago 一样准确,数据表明它更准确,可能比 VST 更准确。它还比 Atago 提供的读数更快。
-
校准:我很少校准我的设备。我没有测试校准漂移,但如果有漂移,它应该会对我的所有样品产生相同的影响并进行平均。如果其他数据已经产生,我愿意调整我的常规。
-
样品采集:我搅拌样品并使用移液管进行采集。我不喜欢浪费移液管,所以之后我会冲洗它们并继续使用,直到我决定更换。对于折射仪上的糖测试,我使用新的移液管。目前还不确定这对样品的影响有多大。
-
样品过滤:我不过滤样品。证据表明过滤样品不会提高准确性,只能提高精确度。我通常会收集比需要更多的样品来弥补精确度。
-
样品温度:我没有对样品温度进行校正。我查看了样品温度,发现冷却样品与使用热样品时有一个小但统计上显著的变化。然而,只要我对所有样品都做相同的处理,这个变量就不会影响结论,因为性能是相对的。奇怪的是,我最近一直在做提取冷却,所以我的样品比以前冷得多。
-
样本数量:一个。我对收集更多样本不感兴趣,但我在 过去 曾展示过,如果你把样本留在设备上几分钟,它会蒸发,读取结果会发生变化。我也不确定多次取样是否会提高质量。
-
清洁镜头:我使用微纤维毛巾。我不使用酒精或酒精擦拭布。如果你留意,玻璃会清洁得相当干净。
我希望例行程序不会阻碍你收集数据。最终,数据分析会告诉你是否需要改进数据收集,因为关键因素是你收集和分析数据的经验。
如果你愿意,可以关注我的 Twitter、YouTube 和 Instagram,我会在这些平台上发布不同机器上的浓缩咖啡镜头和相关内容。你也可以在 LinkedIn 找到我。你还可以在 Medium 关注我并 订阅。
我的进一步阅读
使用 GGUF 和 llama.cpp 对 Llama 模型进行量化
GGML 与 GPTQ 与 NF4
·
关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 9 月 4 日
--
作者提供的图片
由于大型语言模型(LLMs)的庞大规模,量化已成为高效运行这些模型的关键技术。通过降低权重的精度,可以节省内存并加速推理,同时保持模型的绝大部分性能。近期,8 位和 4 位量化技术使在消费级硬件上运行 LLMs成为可能。加上 Llama 模型的发布及其参数高效微调技术(如 LoRA、QLoRA),这形成了一个丰富的本地 LLM 生态系统,与 OpenAI 的 GPT-3.5 和 GPT-4 竞争。
除了在这篇文章中讨论的天真方法外,还有三种主要的量化技术:NF4、GPTQ 和 GGML。NF4是一种静态方法,QLoRA 使用它以 4 位精度加载模型进行微调。在上一篇文章中,我们探讨了 GPTQ 方法,并将自己的模型量化以在消费级 GPU 上运行。在这篇文章中,我们将介绍 GGML 技术,看看如何量化 Llama 模型,并提供实现最佳结果的技巧和窍门。
你可以在Google Colab和GitHub上找到代码。
什么是 GGML?
GGML 是一个专注于机器学习的 C 语言库。它由 Georgi Gerganov 创建,"GG"就是这个名字的缩写。这个库不仅提供了机器学习的基础元素,如张量,还提供了独特的二进制格式以分发 LLM。
这个格式最近更改为GGUF。这个新格式旨在可扩展,以便新功能不会破坏与现有模型的兼容性。它还将所有元数据集中在一个文件中,如特殊令牌、RoPE 缩放参数等。简而言之,它解决了一些历史痛点,并且应该具有未来保障。有关更多信息,你可以阅读这个地址上的规范。在本文的其余部分,我们将称所有使用 GGUF 或以前格式的模型为“GGML 模型”。
GGML 旨在与llama.cpp库一起使用,该库也由 Georgi Gerganov 创建。这个库是用 C/C++编写的,以便高效推断 Llama 模型。它可以加载 GGML 模型并在 CPU 上运行。最初,这与 GPTQ 模型的主要区别在于 GPTQ 模型是在 GPU 上加载和运行的。然而,现在你可以通过 llama.cpp 将一些 LLM 的层卸载到 GPU 上。举个例子,7b 参数模型有 35 层。这大大加快了推断速度,并允许你运行那些无法完全装入 VRAM 的 LLM。
图片由作者提供
如果命令行工具是你的强项,llama.cpp 和 GGUF 支持已集成到许多 GUI 中,如 oobabooga 的 text-generation-web-ui、koboldcpp、LM Studio 或 ctransformers。你可以使用这些工具加载你的 GGML 模型,并以类似 ChatGPT 的方式与其互动。幸运的是,许多量化模型直接可以在 Hugging Face Hub 上获得。你会很快注意到,大多数模型都是由 LLM 社区中的知名人物 TheBloke 进行量化的。
在下一部分,我们将看到如何量化我们自己的模型并在消费级 GPU 上运行它们。
如何用 GGML 对 LLM 进行量化?
我们来看看 TheBloke/Llama-2–13B-chat-GGML 仓库中的文件。我们可以看到14 种不同的 GGML 模型,对应不同的量化类型。它们遵循特定的命名规则:“q” + 存储权重的位数(精度)+ 特定变体。以下是所有可能的量化方法及其对应的使用场景,基于 TheBloke 制作的模型卡:
-
q2_k
:对 attention.vw 和 feed_forward.w2 张量使用 Q4_K,对其他张量使用 Q2_K。 -
q3_k_l
:对 attention.wv、attention.wo 和 feed_forward.w2 张量使用 Q5_K,其余张量使用 Q3_K -
q3_k_m
:对 attention.wv、attention.wo 和 feed_forward.w2 张量使用 Q4_K,其余张量使用 Q3_K -
q3_k_s
:对所有张量使用 Q3_K -
q4_0
:原始量化方法,4 位。 -
q4_1
:比 q4_0 精度更高,但不如 q5_0 高。不过推理速度比 q5 模型快。 -
q4_k_m
:对一半的 attention.wv 和 feed_forward.w2 张量使用 Q6_K,其余张量使用 Q4_K -
q4_k_s
:对所有张量使用 Q4_K -
q5_0
:更高的精度,更高的资源使用和较慢的推理。 -
q5_1
:更高的精度,更高的资源使用和更慢的推理。 -
q5_k_m
:对一半的 attention.wv 和 feed_forward.w2 张量使用 Q6_K,其余张量使用 Q5_K -
q5_k_s
:对所有张量使用 Q5_K -
q6_k
:对所有张量使用 Q8_K -
q8_0
:几乎无法与 float16 区分。高资源使用和较慢。不推荐给大多数用户。
作为经验法则,我推荐使用 Q5_K_M,因为它保留了大部分模型的性能。或者,如果你想节省一些内存,可以使用 Q4_K_M。一般来说,K_M 版本比 K_S 版本更好。我不推荐 Q2 或 Q3 版本,因为它们会大幅降低模型性能。
既然我们对可用的量化类型有了更多了解,接下来我们来看看如何在真实模型中使用它们。你可以在Google Colab上的免费 T4 GPU上执行以下代码。第一步是编译 llama.cpp 并在我们的 Python 环境中安装所需的库。
# Install llama.cpp
!git clone https://github.com/ggerganov/llama.cpp
!cd llama.cpp && git pull && make clean && LLAMA_CUBLAS=1 make
!pip install -r llama.cpp/requirements.txt
现在我们可以下载我们的模型了。我们将使用我们在上一篇文章中微调过的模型,[mlabonne/EvolCodeLlama-7b](https://huggingface.co/mlabonne/EvolCodeLlama-7b)
。
MODEL_ID = "mlabonne/EvolCodeLlama-7b"
# Download model
!git lfs install
!git clone https://huggingface.co/{MODEL_ID}
这一步可能需要一些时间。一旦完成,我们需要将权重转换为 GGML FP16 格式。
MODEL_NAME = MODEL_ID.split('/')[-1]
# Convert to fp16
fp16 = f"{MODEL_NAME}/{MODEL_NAME.lower()}.fp16.bin"
!python llama.cpp/convert.py {MODEL_NAME} --outtype f16 --outfile {fp16}
最后,我们可以使用一种或几种方法来量化模型。在这种情况下,我们将使用我之前推荐的 Q4_K_M 和 Q5_K_M 方法。这是唯一实际上需要 GPU 的步骤。
QUANTIZATION_METHODS = ["q4_k_m", "q5_k_m"]
for method in QUANTIZATION_METHODS:
qtype = f"{MODEL_NAME}/{MODEL_NAME.lower()}.{method.upper()}.gguf"
!./llama.cpp/quantize {fp16} {qtype} {method}
我们的两个量化模型现在已准备好进行推断。我们可以检查 bin 文件的大小,以查看我们压缩了多少。FP16 模型占用 13.5 GB,而 Q4_K_M 模型占用 4.08 GB(小了 3.3 倍),Q5_K_M 模型占用 4.78 GB(小了 2.8 倍)。
让我们使用 llama.cpp 来高效运行它们。由于我们使用的是具有 16 GB VRAM 的 GPU,我们可以将每一层都卸载到 GPU 上。在这种情况下,它代表 35 层(7b 参数模型),因此我们将使用 -ngl 35
参数。在下面的代码块中,我们还将输入一个提示和我们想使用的量化方法。
import os
model_list = [file for file in os.listdir(MODEL_NAME) if "gguf" in file]
prompt = input("Enter your prompt: ")
chosen_method = input("Name of the model (options: " + ", ".join(model_list) + "): ")
# Verify the chosen method is in the list
if chosen_method not in model_list:
print("Invalid name")
else:
qtype = f"{MODEL_NAME}/{MODEL_NAME.lower()}.{method.upper()}.gguf"
!./llama.cpp/main -m {qtype} -n 128 --color -ngl 35 -p "{prompt}"
让我们使用 Q5_K_M 方法向模型提出“编写一个 Python 函数来打印第 n 个斐波那契数”的请求。如果我们查看日志,我们可以确认我们成功将层卸载了,感谢“llm_load_tensors: offloaded 35/35 layers to GPU”这一行。这里是模型生成的代码:
def fib(n):
if n == 0 or n == 1:
return n
return fib(n - 2) + fib(n - 1)
for i in range(1, 10):
print(fib(i))
这不是一个非常复杂的提示,但它很快生成了一个有效的代码片段。使用 llama.cpp,你可以在终端中使用交互模式(-i
标志)将本地 LLM 作为助手。注意,这也适用于配备苹果 Metal 性能着色器(MPS)的 Macbook,这是一种运行 LLM 的优秀选择。
最后,我们可以将我们的量化模型推送到 Hugging Face Hub 上的新仓库,并使用“ -GGUF”后缀。首先,让我们登录并修改以下代码块以匹配你的用户名。
!pip install -q huggingface_hub
username = "mlabonne"
from huggingface_hub import notebook_login, create_repo, HfApi
notebook_login()
最后,我们可以将量化后的模型推送到 Hugging Face Hub 上的新仓库,并使用 "-GGUF" 后缀。首先,让我们登录并修改以下代码块以匹配你的用户名。你可以在 Google Colab 的“Secrets”选项卡中输入你的 Hugging Face 令牌(huggingface.co/settings/tokens
)。我们使用 allow_patterns
参数仅上传 GGUF 模型,而不是整个目录。
!pip install -q huggingface_hub
from huggingface_hub import create_repo, HfApi
from google.colab import userdata
# Defined in the secrets tab in Google Colab
hf_token = userdata.get('huggingface')
api = HfApi()
username = "mlabonne"
# Create empty repo
create_repo(
repo_id = f"{username}/{MODEL_NAME}-GGUF",
repo_type="model",
exist_ok=True,
token=hf_token
)
# Upload gguf files
api.upload_folder(
folder_path=MODEL_NAME,
repo_id=f"{username}/{MODEL_NAME}-GGUF",
allow_patterns=f"*.gguf",
token=hf_token
)
我们已经成功量化、运行并推送了 GGML 模型到 Hugging Face Hub!在下一部分,我们将深入探讨 GGML 如何实际量化这些模型。
使用 GGML 进行量化
GGML 量化权重的方式不如 GPTQ 那么复杂。基本上,它将值块分组并将它们舍入到较低的精度。一些技术,如 Q4_K_M 和 Q5_K_M,实现了对关键层的更高精度。在这种情况下,每个权重都以 4 位精度存储,除了注意力.wv 和 feed_forward.w2 张量的一半。实验结果表明,这种混合精度在准确性和资源使用之间提供了一个良好的折衷。
如果我们查看ggml.c 文件,可以看到块是如何定义的。例如,block_q4_0
结构定义如下:
#define QK4_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
在 GGML 中,权重按块处理,每块包含 32 个值。对于每块,从最大权重值派生出一个缩放因子(delta)。然后对块中的所有权重进行缩放、量化,并高效打包以进行存储(nibbles)。这种方法显著减少了存储需求,同时允许在原始权重和量化权重之间进行相对简单且确定的转换。
现在我们对量化过程有了更多了解,我们可以将结果与 NF4 和 GPTQ 进行比较。
NF4 vs. GGML vs. GPTQ
哪种技术更适合 4 位量化?为了回答这个问题,我们需要介绍运行这些量化 LLM 的不同后端。对于 GGML 模型,使用 Q4_K_M 模型的 llama.cpp 是最佳选择。对于 GPTQ 模型,我们有两个选项:AutoGPTQ 或 ExLlama。最后,NF4 模型可以直接在 transformers 中通过 --load-in-4bit
标志运行。
Oobabooga 在一篇出色的博客文章中进行了多个实验,比较了不同模型在困惑度方面的表现(困惑度越低越好):
基于这些结果,我们可以说 GGML 模型在困惑度方面略有优势。差异并不特别显著,这也是为什么在生成速度(以 token/秒计)方面更值得关注。最佳技术取决于你的 GPU:如果你有足够的 VRAM 可以容纳整个量化模型,GPTQ 和 ExLlama 将是最快的。如果不是这样,你可以卸载一些层并使用GGML 模型和 llama.cpp 来运行你的 LLM。
结论
在这篇文章中,我们介绍了 GGML 库和新的 GGUF 格式,以高效地存储这些量化模型。我们使用它来量化我们自己的 Llama 模型为不同格式(Q4_K_M 和 Q5_K_M)。随后,我们运行了 GGML 模型,并将我们的二进制文件推送到了 Hugging Face Hub。最后,我们深入研究了 GGML 的代码,以理解它是如何实际量化权重的,并将其与 NF4 和 GPTQ 进行了比较。
量化是使大型语言模型(LLMs)变得更加普及的强大手段,通过降低运行成本来实现。在未来,混合精度和其他技术将不断提升我们使用量化权重时可以达到的性能。在此之前,希望你喜欢阅读这篇文章,并从中学到了新知识。
如果你对 LLMs 的更多技术内容感兴趣,请在 Medium 上关注我。
关于量化的文章
使用 8 位量化减少大型语言模型的体积
towardsdatascience.com ## 第二部分:使用 GPTQ 的 4 位量化
使用 AutoGPTQ 量化你自己的 LLMs
towardsdatascience.com
通过一次点击了解更多关于机器学习的信息并支持我的工作 — 立即成为 Medium 会员:
[## 通过我的推荐链接加入 Medium — Maxime Labonne
作为 Medium 会员,你的一部分会员费用将用于支持你阅读的作者,同时你可以完全访问每一个故事…
medium.com](https://medium.com/@mlabonne/membership?source=post_page-----3612dfbcc172--------------------------------)
量子计算机在人工智能和机器学习革命中的作用
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 3 月 18 日
--
一个易于理解的量子计算机工作原理介绍,说明它们为何在演进 AI 和 ML 系统中至关重要。获得对这些机器背后量子原理的简单理解。
由作者使用 Microsoft Icons 创建的图像。
量子计算是一个迅速发展的领域,有能力彻底改变人工智能(AI)和机器学习(ML)。随着对更大、更好和更准确的 AI 和 ML 的需求加速,标准计算机将被推到其能力的极限。量子计算机基于并行化,并能够处理更复杂的算法,将是解锁下一代 AI 和 ML 模型的关键。本文旨在通过拆解一些实现量子计算的关键原理来揭示量子计算机的工作原理。
量子计算机是一种可以并行执行多个任务的机器,使其能够非常快速地解决极其复杂的问题。尽管传统计算机将继续满足普通人的日常需求,但量子计算机的快速处理能力有潜力彻底改变许多行业,远远超出传统计算工具的可能性。通过同时运行数百万次模拟,量子计算可以应用于,
-
化学和生物工程:复杂的模拟能力可能使科学家能够发现和测试新药物和资源,而无需耗费实验室实验的时间、风险和费用。
-
金融投资:市场波动极其难以预测,因为它们受大量复合因素的影响。这些几乎无限的可能性可以通过量子计算机进行建模,从而比标准机器提供更高的复杂性和更好的准确性。
-
操作与制造:一个给定的过程可能有成千上万的相互依赖的步骤,这使得制造中的优化问题变得复杂。由于有如此多的可能排列,模拟制造过程需要大量计算,通常需要假设以缩小可能性的范围以适应计算限制。量子计算机的固有并行性将使不受限制的模拟成为可能,从而在制造中实现前所未有的优化水平。
叠加态——固有的并行性
量子计算机依赖于叠加态的概念。在量子力学中,叠加态是指同时存在于多个状态的概念。叠加态的一个条件是它不能被直接观察,因为观察本身会迫使系统进入一个单一状态。在叠加态中,观察任何给定状态的概率是一定的。
对叠加态的直观理解
在 1935 年,物理学家厄尔温·薛定谔在给阿尔伯特·爱因斯坦的一封信中分享了一个思想实验,这个实验概括了叠加态的概念。在这个思想实验中,薛定谔描述了一只被封闭在一个容器中的猫,容器里有一个放射性原子,这个原子有 50%的几率衰变并释放出致命的辐射。薛定谔解释说,在观察者打开箱子并查看里面之前,猫是活着还是死去的概率是相等的。在箱子未被打开之前,可以认为猫同时存在于活着和死去的状态。打开箱子并观察猫的行为使其被迫进入一个单一的死或活的状态。
对叠加态的实验理解
1801 年,托马斯·扬进行的一个更具象的实验展示了叠加态,尽管叠加态的含义直到许多年后才被理解。在这个实验中,一束光线被照射到一个有两个狭缝的屏幕上。期望是每个狭缝后面会在板上出现一束光。然而,扬观察到了几个光强度增强的峰值和光强度减少的谷值,而不仅仅是两个光点。这一模式使扬得出结论,光子在通过屏幕上的狭缝时必须表现为波动现象。他得出这个结论是因为他知道,当两条波相互干涉时,如果它们都是峰值,它们会叠加在一起, resulting in a 增强的统一波(产生光点)。相反,当两条波处于相反位置时,它们会相互抵消(产生黑暗的谷值)。
双缝实验。左:如果光子仅作为粒子存在时的预期结果。右:实际结果表明光子可以作为波动存在。图像由作者创建。
虽然波粒二象性的结论一直存在,但随着技术的发展,这一实验的意义也发生了变化。科学家们发现,即使一次只发射一个光子,波动模式也会出现在后面的屏幕上。这意味着单个粒子通过了两个缝隙,并作为两波交汇。然而,当光子击中屏幕并被测量时,它显示为单独的光子。测量光子位置的行为迫使它重新组合为单一状态,而不是在通过屏幕时的多重状态。这一实验展示了叠加态。
双缝实验显示了叠加态,因为光子在测量发生之前存在于多种状态中。左:引入测量设备时的结果。右:没有测量时的结果。图像由作者创建。
叠加态在量子计算机中的应用
标准计算机通过操作二进制数字(比特)来工作,比特存储在两种状态中的一种,0 和 1。相比之下,量子计算机使用量子比特(qubits)。量子比特可以处于叠加态,因此它们不仅限于 0 或 1,而是同时存在于 0 和 1 以及若干种介于 0 和 1 之间的状态。这种状态的叠加使得量子计算机能够并行处理数百万个算法。
量子比特通常由如光子和电子等亚原子粒子构成,双缝实验确认这些粒子可以处于叠加态。科学家们通过激光或微波束将这些亚原子粒子强制进入叠加态。
约翰·戴维森用一个简单的例子解释了使用量子比特而非比特的优势。因为标准计算机中的一切都是由 0 和 1 构成的,当在标准机器上运行模拟时,机器会遍历不同的 0 和 1 序列(即比较 00000001 与 10000001)。由于量子比特同时存在于 0 和 1 状态,不需要尝试不同的组合。相反,单次模拟将同时包含所有可能的 0 和 1 组合。这种固有的并行性使得量子计算机能够同时处理数百万次计算。
纠缠 — 指数级的处理能力
在量子力学中,纠缠的概念描述了量子粒子相互作用并纠缠在一起的倾向,以至于它们无法孤立描述,因为一个粒子的状态受到另一个粒子状态的影响。当两个粒子纠缠在一起时,它们的状态是相互依赖的,无论彼此的距离如何。如果一个量子位的状态发生变化,配对的量子位状态也会瞬间变化。爱因斯坦惊叹地将这种距离无关的伙伴关系描述为“遥远的诡异行动”。
因为观察一个量子粒子会迫使其进入孤立状态,科学家们发现如果一个纠缠对中的粒子具有向上的自旋,则配对粒子将具有相反的、向下的自旋。尽管我们仍未完全理解这种现象的发生原因,但其对量子计算的影响已十分强大。
左:两个处于叠加态的粒子变得纠缠在一起。 右:一次观察迫使一个粒子进入向上的自旋。作为回应,配对粒子进入向下的自旋。即使这些粒子被距离分隔,它们仍然保持纠缠状态,并且它们的状态彼此依赖。图片由作者创作。
在量子计算中,科学家利用了这一现象。空间设计的算法在纠缠的量子位上工作,从而大幅加速计算。在标准计算机中,添加一个比特会线性增加处理能力。因此,如果比特数翻倍,处理能力也会翻倍。在量子计算机中,添加量子位会指数级增加处理能力。因此,添加一个量子位会极大地提高计算能力。
去相干 —— 量子错误
尽管纠缠为量子计算提供了巨大的优势,但实际应用却面临严峻的挑战。如前所述,观察一个量子粒子会迫使其进入一个特定的状态,而不是继续处于叠加态。在量子系统中,任何外部干扰(温度变化、振动、光线等)都可以被视为一种“观察”,迫使量子粒子假设一个特定的状态。由于粒子变得越来越纠缠和状态依赖,它们尤其容易受到外部干扰的影响。这是因为干扰只需影响一个量子位,就会对许多纠缠的量子位产生连锁反应。当一个量子位被强迫进入 0 或 1 状态时,它会丧失叠加态中包含的信息,导致算法完成前出现错误。这一挑战被称为去相干,已阻碍了量子计算机的实际使用。去相干的测量方式是错误率。
某些物理误差减少技术已被用来最小化外界干扰,包括将量子计算机保持在冰冻温度和真空环境中,但到目前为止,这些措施在量子误差率方面并未产生足够显著的改善。科学家们还在探索错误更正代码,以修复错误而不影响信息。虽然谷歌最近部署了一个错误更正代码,这导致了历史上最低的错误率,但信息丢失仍然过高,无法在实际中使用量子计算机。误差减少目前是物理学家的主要关注点,因为这是实际量子计算的最大障碍。
尽管还需要更多工作来实现量子计算机的实际应用,但显然有巨大的机会利用量子计算来部署高度复杂的人工智能和机器学习模型,从而提升各种行业。
快乐学习!
来源
叠加态: scienceexchange.caltech.edu/topics/quantum-science-explained/quantum-superposition
纠缠: quantum-computing.ibm.com/composer/docs/iqx/guide/entanglement
量子计算机: builtin.com/hardware/quantum-computing
量子计算完全初学者指南
量子计算基础指南,无需先前知识
·
关注 发表在 Towards Data Science · 25 分钟阅读 · 2023 年 10 月 20 日
--
一台 IBM 量子冷却器用于在纽约约克镇的 IBM 量子实验室中保持 IBM 的 50-qubit 量子计算机的低温。来源:www.flickr.com/photos/ibm_research_zurich/40786969122
一些人将人类对地球资源的统治称为人类世,这个词源自希腊语中的“anthropo”(人类)和“cene”(最近)。特别是过去一个世纪,由于 20 世纪中期计算机的出现带来了技术创新的速度,已被称为第四次工业革命。
在过去七十年里,计算已经改变了社会的各个方面,能够以加速的速度进行高效生产,将人类劳动从主要的生产领域转移到服务领域,并通过电信技术指数级提升了信息存储、生成和传输。
我们是如何到达这里的?从根本上说,技术进步依赖于现有的科学。如果没有对电磁学的性质和原子结构的理解,我们就不会有电力和驱动计算机的集成电路。因此,我们迟早会想到利用量子力学提供的物理现实的最准确、最基本的描述来进行计算。
我对量子计算产生了兴趣,这源于对物理学和计算本质的强烈兴趣。如果成功,量子计算可能会通过指数级提升当前计算机的效率,开启信息时代前所未有的一章。作为一个对数据、计算和信息科学感兴趣的人,理解量子信息的基本知识不仅会使你对量子物理有一个非常基础的理解,还会为你准备迎接信息时代下一个主要技术前沿做好准备。
量子现象与量子信息
为了理解计算的基本概念,有必要对量子计算所利用的物理现象有一个基本的了解。
这些现象包括电子自旋和光偏振,后者是光子自旋的另一种说法。请记住,电子是带负电的亚原子粒子,围绕带正电的原子核旋转,而光子则是电磁或光的粒子等价物。电子和光子的自旋是相关的,因为它们都指的是在经典力学中没有类比的量子属性,经典力学描述的是日常物体的尺度。
尽管如此,引入自旋最简单的方法是与经典属性角动量进行比较。角动量指的是经典系统中线动量的旋转等效量,其中动量被计算为质量和速度的乘积。因此,动量是一个向量量,因为它具有大小和方向。角动量表示为粒子的位矢和动量矢量的叉积。由于角动量是经典属性,它接受连续值,并且可以表示为体积分(从二维中的曲线下的面积推广而来)。
自旋通常定义为内在角动量。回忆一下,在经典力学中,力被定义为动量的变化。此外,系统的能量是以运动或运动变化率来定义的,这假设了质量。与经典力学不同,爱因斯坦的相对论通过等式 E = mc²将内在能量归于静止质量。同样,内在角动量与亚原子粒子的内在能量状态密切相关。事实上,这是一个基本粒子所具备的属性,无论它们是否实际旋转,即不考虑诸如动量和位置等外在因素,因此称为内在。像经典角动量一样,量子自旋在旋转下会发生变化。然而,与经典角动量不同,自旋是量子化的,这意味着它只接受一组离散的值。
基本粒子的最大自旋由 n(任何半整数 n/2 值)与约化普朗克常数ℏ(h/2𝜋)的乘积给出。所有普通粒子,称为费米子,都具有半整数(1/2)自旋,而力载体粒子,如光子,具有整数(1)自旋。电子和光子都有两种可能的自旋状态:自旋“向上”或“向下”。用数学术语来说,电子的最大自旋为 1/2ℏ或-1/2ℏ,即正向或负向的“旋转”。光子的最大自旋为 1ℏ和-1ℏ,因为它取整数自旋值。尽管我们使用了“旋转”一词,但最好不要将其视为空间变换。
现在让我们来看一下将被利用于量子计算的奇异量子性质。我们之前提到电子可以有两种可能的自旋状态,但它在任何给定时刻处于哪个状态呢?在这里区分系统的状态和测量是很有用的。在经典力学中,状态和测量完全一致:系统的状态就是你所测量的。在量子力学中却不是如此。没有测量的情况下,系统的状态由波函数𝛹i 的相干叠加给出。测量之后,如果我们测量一个单粒子,系统的状态将由𝛹↓或𝛹↑给出。这种状态和测量之间的分离使得量子计算机能够进行在二维复数向量空间中可以取无限值的操作。
最后,测量遵循某些规则,这些规则与测量的方式有关。具体来说,测量的方向对结果很重要。假设我们有两个方向:垂直和水平。如果我们在垂直方向上测量电子的自旋,我们将得到自旋向上或向下的状态。如果我们进行完全相同的测量,即再次在垂直方向上测量自旋,我们将得到相同的结果。这表明存在一种实验设置可以产生可预测的结果。然而,如果我们首先在垂直方向上测量电子的自旋,然后在水平方向上测量,并不断重复测量,结果将是自旋向上或向下的随机序列,在足够的试验中均匀分布。这意味着,如果我们不小心,量子测量可能会产生随机结果。量子算法的目的是控制操作,以便我们得到所期望的结果。
尽管量子比特(qubits),即量子计算的信息单位,可以通过电子或光子的自旋来表示,但我们将使用前者作为量子计算的物理类比。
从量子现象到量子计算
经典计算机的基本信息单位称为比特(bit),它具有两个离散状态,通常表示为 0 或 1。由于计算机是一个物理机器,这种数学抽象必须映射到某种物理现象上。经典计算机将这些离散状态映射到流动的电流或电压上。当电压较低或几乎为零时,我们用它来表示状态 0,当电压较高时,我们用它来表示状态 1。换句话说,电压幅度的调制使我们能够机械地实现一个二进制表示系统。这些低电压和高电压状态的序列随后被排列成模拟逻辑操作的电路,如与门(AND)、异或门(XOR)等,称为逻辑门。通过电路的逻辑操作组合随后被搭建以执行任何可计算的算法。
现在我们看到经典计算机利用电力来实现计算,这有助于我们理解量子计算机的工作方式。与经典计算机不同,量子计算机利用量子或亚原子尺度现象来进行计算。虽然我们在日常宏观尺度上测量电压时它是一个连续变量,但量子力学告诉我们,在亚原子尺度下情况并非如此。相反,根据所有实验数据,亚原子粒子似乎仅占据离散的能量状态。这意味着电子和光子可以占据某些能量状态,而不能占据其他状态。这与我们对物理对象能够占据任何连续能量状态的直觉相矛盾。例如,尽管我们通常认为时间是一个完全连续的变量,且可以无限分割,但对于亚原子粒子的能量状态,这种情况完全不同。
这有一个独特的结果,那就是亚原子粒子不能被描述为具有固定的位置和动量。虽然我们可以尝试同时描述这些变量,但存在一个物理尺度,在这个尺度下,精度会崩溃,使得知道动量就意味着失去位置的信息,反之亦然。这个物理尺度被称为普朗克尺度,用h表示:6.626070·10⁻³⁴ m²kg/s,代表了经典与量子尺度现象之间的物理阈值。在这个尺度下,根据所有实验证据,亚原子粒子同时占据所有可能的状态。由于这一特性,我们只能将亚原子粒子描述为所有可能状态的概率分布,如薛定谔方程所述。然而,正如我们之前指出的,还有一种称为测量的第二种描述。在测量之前,粒子存在于薛定谔波函数描述的叠加态中。测量之后,粒子会坍缩到一个确定的位置状态。量子计算利用了量子力学这一奇特特性来进行计算,即利用叠加和测量状态。(如果你想弄清楚叠加态客观性的实验基础,可以阅读这篇文章)。
因此,如果经典计算机通过两个可能的离散状态来构建计算,我们可以认为量子计算机则通过离散状态以及叠加态来构建计算。当我们测量量子比特时,它可以处于状态 0 或 1。然而,在测量之前,量子比特处于 0 和 1 的叠加态。在叠加态期间,量子比特可以占据无限多个状态。通过利用量子力学的法则,量子计算超越了经典计算的计算能力,经典计算的状态空间局限于 2^n。值得注意的是,测量将量子状态简化为经典状态,即相同的状态空间 2^n。那么,量子计算以什么方式提供了经典计算无法比拟的优势呢?
正如我们下面将看到的,量子算法使得在叠加态下进行受控操作成为可能,这使我们在测量后能够得到有用的答案。计算机科学家根据解决算法所需的时间步骤来定义算法的复杂性。如果n表示算法的输入长度,T(n)表示解决它所需的时间,那么复杂性指的是描述T(n)增长的函数。如果T(n)是多项式,则该算法被称为属于多项式时间类问题。如果T(n)是指数函数,则它属于指数时间类问题。那些属于指数时间的,例如大数的质因数分解,对于经典计算机来说是难以处理的,因为解决问题所需的时间呈指数级增长,并且很容易超出人类可接受的时间限制。
量子计算的前景部分在于能够足够快地解决指数时间问题。
表示量子比特:线性代数
为了理解量子计算,我们必须了解表示量子比特(qubits)的一些数学基础。这些数学工具需要与我们将要映射计算的基本现象相对应,主要是线性代数。
我们用二维单位向量来表示量子比特(Qubits)。
什么是向量?
向量是由至少两个值表示的量:大小和方向。向量的大小由欧几里得距离给出,而方向由起点给出。例如(1,-3)表示一个二维向量,其长度为 3.162,方向由 x 值给出。
什么是单位向量?
单位向量是长度或大小等于 1 的向量。例如,<0,1>是一个单位向量,因为如果我们使用毕达哥拉斯定理计算欧几里得距离,我们得到的值是 1。
为什么是二维单位向量?
由于存在两种可能的电子自旋测量结果,二维向量空间 ℝ² 足够。我们使用单位向量是因为我们希望将测量结果限制为两个可能的值:0 或 1。正如我们将看到的,我们对量子比特进行的操作将等同于在单位平面上的旋转。然而,可能的结果空间应该包括所有可能的三维球面旋转,其底层空间仍然是二维的,表示两个可能的测量结果。为此,我们将向量表示为复数,而不是实数,表示为复向量空间 ℂ²。(复数是涉及实数和虚数的运算,其中虚数 i 等于 √-1)为了简便起见,我们现在将坚持使用 ℝ²,暂时避开复数。
为了将结果限制为两个可能的值,我们不仅需要单位向量。我们需要一对对正交的单位向量。两个向量相互正交当且仅当它们的 内积或点积 等于零。当单位向量的组合彼此正交时,我们称这些为 正交规范基,结合了表示单位向量的 "normal" 和表示正交的 "orthogonal"。
正交性:两个向量是正交的,当且仅当它们的乘积等于零:<a|b> = 0。
我们可以检查任意 n 维 ket 是否为正交规范基,如果矩阵 A 和其转置 A^T 的乘积等于 单位矩阵 In。
括号符号
在描述这些基之前,我们先说几句关于在线性代数中使用的标准符号的内容,以便你能适当地解释符号。
列向量 被称为 bras,而 行向量 被称为 kets,表示为:<a| & |b>,其中:
bras 是行向量,而 kets 是列向量。
它们一起形成 bra-kets。根据点积规则(即向量乘法),我们只能在每个维度相等的情况下将 bras 与 kets 相乘。
上述 bra-ket 的内积表示为 <a|b>,并表示:
两个单位向量的点积。
然而,相同维度的相同类型(bra 或 ket)的向量可以相加。
我们可以使用箭头来表示与电子自旋测量相关的正交规范对,而不是使用实际值。
自旋有三种正交规范基:
用于测量自旋的三种二维正交规范基(Berhardt, 2019)。
当我们乘以具有相同自旋的 bras 和 kets 时,结果为 1:
相等的正交规范 bra-kets 的点积为 1。
相反,当我们乘以具有相反自旋的 bras 和 kets 时,结果为 0:
相对的正交 bra-kets 的点积为 0。
如你所见,正交括号积给出模拟二进制结果的测量值。
第一个基,由上下箭头表示,称为标准基,对应于自旋的垂直测量,即沿 y 轴的测量。第二个基,由左右箭头表示,对应于自旋的水平测量,即沿 x 轴的测量。通常,排序的正交基代表沿某个方向的自旋测量。实际上,我们可以在任何角度或方向𝛳上测量自旋,结果会坍缩为该方向上自旋向上或向下的离散结果,因为自旋状态只能是离散的。
单个或多个量子位的量子态 将由这些基的线性组合给出。因此,基向量将代表量子态的可能结果。如前所述,量子位的状态可以通过电子或光子的自旋进行建模。在测量之前,粒子或量子态将处于叠加态,由基 |b1> 和 |b2> 的线性组合表示,形式为 c1|b1> + c2|b2>,其中 c1 和 c2 代表概率振幅。由于概率振幅可以为负值,仅使用它们的平方值来表示结果的概率,其中 c1² + c2² = 1。在叠加态中,因此 c1² 和 c2² 各自将具有 0.5 的概率。
测量后,自旋状态坍缩为正交基 |b1> 或 |b2> 之一。坍缩到 |b1> 的概率由 c1² 给出,而坍缩到 |b2> 的概率由 c2² 给出。如果测量坍缩到 |b1>,则 c1² 将等于 1,c2² 为 0,反之亦然。更直接地说,与概率值 1 乘积的基向量将代表测量的结果。
在经典计算中,多位通过这些位的张量积表示,记作符号:⊗
我们说过,[1,0] 和 [0,1] kets 代表标准基,因此分别对应经典位中的 0 和 1。我们还提到任何量子态必须满足以下等式:c1² + c2² = 1。我们称其为单位测量约束(概率论的第二公理),这意味着所有 kets 必须是 ℝ² 中的单位向量。然而,由于实际的量子粒子状态是通过复数表示的,因此实际的状态空间由 ℂ² 给出。因此,实际的单位测量约束为:‖𝛼‖² + ‖β‖² = 1,其中 𝛼 和 β 是复数,代表概率振幅。
因此,为了表示多个量子位状态,我们取标准基 |0> 和 |1> 的张量积。注意,无论我们串联多少量子位,乘积都会保持单位测量约束。
两个 1 量子比特的张量积产生上述向量。
两个 0 量子比特的张量积产生上述向量。
由于我们一直在ℝ²中工作,我们可以直观地通过单位圆表示一个量子比特的状态空间在二维(x,y)中。请记住,所有操作都将在我们上面列举的正交基上进行。此外,所有量子逻辑门都将通过单位矩阵来表示,从而也是正交矩阵。为什么?因为正交矩阵的向量乘法通过保持向量的内积来产生旋转。这会在欧几里得空间上产生等距变换。
还要注意,每 180⁰旋转的负号。这些负号有助于区分等效输出,以便每个操作原则上可以是可逆的或可逆的。所有量子计算都需要是可逆的,以利用量子态的计算能力,即在测量之前的叠加态和纠缠态。正如我们稍后会看到的,叠加态(以及纠缠态)赋予量子计算相对于经典计算的优势。在叠加态下,任何任意数量的量子比特 N 同时占据它们所有的可能状态。如果我们有 4 个量子比特,样本将有 2⁴个可能的状态,但在叠加态下,所有这些状态将同时存在。测量时坍缩到其中一个状态的概率将均匀分布在单位向量的线性组合上。
下图中的单位圆中的线表示通过 Hadamard 门的操作从输入到输出的状态变化,该门将量子比特置于叠加态然后再恢复。由于以下等式 ‖𝛼‖² + ‖β‖² = 1,测量将始终将系统坍缩到一个明确的经典状态,在我们的例子中对应于电子或光子自旋。然而,量子门的操作改变了单个或多个量子比特的状态,而不会坍缩波函数。
量子比特单位圆表示。图片来源于 Wikimedia commons。
如果我们应用位翻转操作符,相当于经典计算中的 NOT 门,它将翻转输入状态的值。例如,|1>态将翻转到|0>态,如下图所示,从 (0,1) 变到 (1,0) 的状态过渡。
与此同时,Hadamard 门将 (0,1) 作为输入,通过将输入与以下正交或单位矩阵相乘来输出 (1/√2,−1/√2):
如果你想知道状态如何变化,这里有一个关于 Hadamard 门在|1>输入下的矩阵乘法的详细示例:
Hadamard 门在|1>输入状态向量下的操作。
Hadamard 门是如何工作的,它有什么特别之处?
从单位圆上可以看出,比特翻转操作对应于单位圆上的 90⁰旋转,而 Hadamard 门对应于 180⁰旋转。需要记住的是,所有量子门都通过正交或单位矩阵完成,这些矩阵沿原点产生旋转。特别地,Hadamard 门在 x 轴和 y 轴之间产生半旋转,对应于 0.5 的概率振幅。
还要注意,向量输出是单位向量,因为(1/√2, −1/√2)观察到以下恒等式 ‖𝛼‖² + ‖β‖² = 1。试着通过将输出值代入𝛼和β来进行计算。
想象一下,在单位圆上的某一点,量子比特的状态将从 0 和 1 的类比中分配概率振幅到一组保持单位和的值之间。Hadamard 门将这种分布精确设置为 50/50 的结果。换句话说,测量时有 50/50 的概率会将量子比特坍缩为|0>或|1>状态向量。这是叠加态的数学类比。我们稍后将看到如何在计算中利用叠加态。
最后,到目前为止,我们的演示利用了单位圆作为量子比特在ℝ²中的可能状态空间。由于实际的量子比特状态由ℂ²中的复数表示,因此量子比特状态空间的更准确表示是所谓的Bloch 球体,旨在捕捉下图所示的复杂值数字的可能状态,这是一种三维球体。
量子比特 Bloch 球体表示。图片来自 Wikimedia commons。
量子逻辑门
与传统计算类似,量子计算中的逻辑门由对量子比特或量子比特集合执行某种操作的电路组成。我们早些时候看到,量子门在数学上相当于对量子比特的矩阵乘法。我们还说过,量子比特由单位向量表示,而量子逻辑门由正交或单位矩阵表示。正如我们所看到的,这些在单位球体或圆上产生旋转,以简化为二维空间。
然而,为了对量子比特进行操作,量子门必须是可逆的。可逆性意味着从输入到输出的每个操作也必须从输出到输入是可逆的。原因在于量子态是可逆的,时间反转不变的,并且在叠加态中保留信息。然而,我们所称的测量会将量子态降解为经典态。测量或坍缩是不可逆的,因此不会保留输入信息。换句话说,我们不能将坍缩恢复到之前的叠加态。因此,量子门构成了受控操作,它们在操控量子态的同时也保持了量子态。这些结果所需的电路利用了几纳米大小的半导体颗粒,称为量子点,这些量子点必须保持在接近零开尔文的温度。然而,需要注意的是,量子计算的期望输出只能通过测量来检索。
因此,当涉及量子逻辑门时,最重要的两个属性是:a) 可逆性 和 b) 普遍性。普遍性指的是一种能够计算所有可能位操作的逻辑门。经典计算中最著名的通用门是 NAND(NOT AND),如下面的表格所示:
请注意,NAND 是 AND 的真值函数补集。最多,仅需两个逻辑运算符即可表达所有可能的逻辑语句,包括逻辑定理。这被称为函数完备性。由于 NAND 将 NOT 和 AND 结合为一个操作,因此它也就成为了功能上完备的、普遍的逻辑运算符和门。为了比较,我们来看一下 AND 的真值表:
在经典计算中,大多数操作是不可逆的。例如,如果我们将序列 010011110 输入到大多数逻辑门中并得到另一个二进制序列作为输出,我们将无法仅从输出序列中恢复输入序列。XOR 和 NAND 都是不可逆的。然而,有些门允许我们仅从输出中检索输入,比如 CNOT(等效于 XOR 但可逆)、Hadamard 和 TOFFOLI 门。在这些门中,Hadamard 和 TOFFOLI 都符合可逆和普遍的标准。然而,还有其他门也满足这些要求,如 FRIEDKIN 门。我们将重点讨论前述三者。
现在让我们来看看对大多数量子计算至关重要的两个门:CNOT 和 Hadamard。CNOT 通过纠缠两个或更多量子比特来操作它们。Hadamard 门通过将一个或多个量子比特置于叠加态来操作它们。我们还将关注第三个门,即 TOFFOLI 门,也称为控制-控制 NOT 门,它是 CNOT 门的通用版本。
CNOT 门的目的是什么?它允许我们通过执行可逆的比特翻转操作来纠缠两个输入量子位。CNOT 门由两个输入组成:一个控制输入和一个目标输入。当控制比特等于 1 时,CNOT 门翻转目标输入。当控制比特等于 0 时,CNOT 门什么也不做。这样,每个输出组合都可以追溯到唯一的输入组合。
带有比特的经典 CNOT 门和张量积输出
标准基和张量积输出上的量子 CNOT 门
CNOT 门在什么意义上等同于纠缠?
让我们来看看 CNOT 门对 |1> 和 |0> 量子位的操作。我们取两个量子位的张量积,并将其与 CNOT 门的单位矩阵相乘,将输入张量积从 |10> 转换为 |11>。为什么?因为 CNOT 仅当控制量子位等于 |1> 时,才会翻转目标量子位的值。
从推论上讲,如果我们的目标量子位是 |1> 而不是 |0>,CNOT 门会将值预测性地翻转回 |0>,如下面所示:
换句话说,CNOT 是 XOR(异或)的可逆经典计算等效物。
正如我们之前提到的,哈达玛门产生了一个完美的叠加态。它是如何做到的?请看下面的正交矩阵:
哈达玛门是一个正交矩阵,它将输入量子位置于叠加态中,反之亦然。
如果我们将矩阵与标准基 |0> 相乘,它会输出以下状态:(1/√2, 1/√2),这等于 (|0>+|1>)/√2. 相反,如果我们将其与 |1> 相乘,我们得到 (1/√2,−1/√2),这等于 (|0>−|1>)/√2. 尽管每个输入将输出转换为概率幅度的均匀分布,但负号使我们能够区分输入(是 |0> 还是 |1>),从而确保操作是可逆的。
在经典世界中,每个值都有 50/50 的结果概率。注意在我们的例子中,我们对单个量子位进行了操作。我们如何将多个量子位置于叠加态?
我们需要分别通过哈达玛门处理每个量子位,然后取它们的张量积。正如我们之前提到的,多量子位状态表示为单个量子位状态的张量积。
下面你可以看到通过将 |1> 标准基与哈达玛矩阵相乘,对单个量子位的操作:
最后,让我们来看看 TOFFOLI 门,也称为受控受控非门(CCNOT 门)。TOFFOLI 门与 CNOT 门相同,只是多了一个控制变量。TOFFOLI 门利用一个 8x8 的正交矩阵对三个输入量子比特进行操作。像 CNOT 一样,TOFFOLI 会产生量子纠缠,并可用于纠缠和解纠缠量子比特。
TOFFOLI 门的输入输出表:
经典 Toffoli 门的输入和输出。
量子 Toffoli 门的输入和输出。
为什么我们需要 TOFFOLI 门而不是 CNOT?
因为像 NAND 一样,TOFFOLI 对于经典计算是通用的,因此可以被量子计算机用来模拟可逆的经典计算。然而,TOFFOLI 在量子计算上并不是通用的,因为它不能产生叠加态。
现在我们对量子逻辑门及其执行的操作有所了解,我们如何将它们结合起来以获得量子算法?
由于量子门保留了叠加态,我们可以使用它们执行可逆的单位运算。从物理角度来看,系统在时间上的演化由薛定谔方程描述。然而,要从量子计算机中检索任何信息,我们需要使波函数坍缩。
通常,我们将比特翻转操作和 Hadamard 门结合起来以获得期望的结果。然而,比特翻转是不可逆的。量子计算的挑战在于设计出以可逆方式书写不可逆函数的方法。我们将通过 Deutsch 算法来了解如何实现这一点。
经典比特是量子比特的特殊情况
原则上,量子计算机可以实例化所有经典计算,因为这些计算是量子计算的一个适当子集。
为了通过量子计算机实现经典计算,我们必须将计算限制在我们一直作为例子的标准基底(类似于经典比特)上,并设计利用经典通用可逆门(如 TOFFOLI 门)的电路。由于 TOFFOLI 门是经典计算的通用门,它可以用于实例化经典计算。
然而,到目前为止,我们还没有讨论过量子算法。我们该如何构造一个量子算法呢?
量子算法:Deutsch Oracle
到目前为止,我们所阐述的一切从计算角度来看将是无意义的,如果我们不能构造出比经典算法具有计算优势的量子算法。
第一个显著实现这一点的算法是由 David Deutsch 于 1985 年提出的,称为 Deutsch 的 Oracle。
假设你有四个函数 f₀-f₃。对于每个输入 0 或 1,f₀ 输出 0。对于每个输入 0 或 1,f₁ 如果输入是 0 则输出 0,如果输入是 1 则输出 1。对于每个输入 0 或 1,f₂ 如果输入是 0 则输出 1,如果输入是 1 则输出 0。对于每个输入 0 或 1,f₃ 输出 1。
我们可以称函数 f₀ 和 f₃ 为常量函数,因为它们无论输入是什么都产生相同的输出。而我们称函数 f₁-f₂ 为平衡函数,因为它们以互补的方式分配输出。
我们接下来要问的问题是:如果我们随机得到这些函数中的一个,我们应该查询算法多少次才能确定函数是常量还是平衡的?
答案是经典计算不能在少于两个查询的情况下确定正确答案。让我们看看这是如何实现的。我们可以选择 0 或 1 作为输入。如果我们输入 0,输出可能是 0 或 1。同样地,如果输入 1,输出可能是 0 或 1。在这两种情况下,我们无法知道输出是由常量函数还是平衡函数产生的。因此,我们必须第二次查询算法才能做出正确的判断。
德意志(Deutsch)通过量子算法演示了,我们可以通过一次查询就知道正确答案。为了实现这一点,我们除了输入量子比特外,还利用了 Hadamard 门和一个控制量子比特。我们将输入通过 Hadamard 门。请记住,H 将量子比特置于叠加态。因此,如果我们输入 |0> 和 |1>,我们会得到以下相应的状态:(1/√2, 1/√2) , (1/√2, −1/√2)。然后我们将目标门应用于随机 fₓ。
由于 Hadamard 是可逆的,fₓ 应该将我们的量子比特置于以下状态之一:
(1/√2) (|0>+|1>); (1/√2) (|0>−|1>); (−1/√2) (|0>−|1>); (−1/√2) (|0>+|1>)
我们再次通过 Hadamard 传递控制量子比特以逆转叠加态。由于操作是可逆的,我们会得到以下可能的结果:
f₀ →|0>; f₁ →|1>; f₂ → −|1>; f₃ →−|0>
这意味着当我们在最后测量量子比特时,如果输出是 |0>,函数是常量的;如果输出是 |1>,函数是平衡的。尽管德意志预言机没有实际应用,但它提供了量子计算相比经典计算的优势的强有力例证。
Deutsch-Jozsa 算法
德意志预言机推广到多个变量被称为 Deutsch-Jozsa 算法。下面的图示提供了该算法的示意量子电路。
Deutsch-Jozsa 算法的电路,其中 H 代表 Hadamard,U 代表常量或比特翻转函数,标准基作为输入。仅测量右上角的输出。图片来源:维基百科。
Shor 算法
Shor 算法是一种用于分解大数的量子算法。该算法由两个部分组成,其中第一部分在经典计算机上执行,第二部分在利用量子傅里叶变换的量子计算机上执行。我们不会深入探讨该算法的数学细节,因为这些细节复杂且超出了本文的范围。
Shor 算法需要两个寄存器,分别具有 1024 和 2048 量子位,以分解一个具有 309 位的 1024 位数字。迄今为止被分解的最大数字长度为 48 位,未达到 RSA 100 位半素数的里程碑。目前没有任何量子计算机解决了 RSA 数字挑战,这是一份列出仅有两个素因子的已知大数的清单。RSA 数字用于公钥加密,以确保政府和金融机构的安全数据传输。
使用足够强大的量子计算机,Shor 算法可以用来解码公钥加密,这种加密使用的非常大的素数被经典计算机认为是计算上难以处理的。
量子霸权与未来
正如我们一开始所说,量子霸权的概念指的是量子计算机在合理的时间范围内解决经典上难以处理的问题的能力。原则上,经典计算机可以解决任何理论上可计算的算法。问题在于实践:有限的处理能力使它们无法在有用的时间范围内解决某些问题。这就是量子计算机有望弥合差距的地方。2019 年,谷歌宣布他们通过其 Sycamore 量子计算机实现了量子霸权,该计算机拥有 53 个量子位。在他们题为使用可编程超导处理器实现量子霸权的《自然》论文中,他们声称 Sycamore 花了 200 秒来对一个量子电路实例进行一百万次采样,而他们进一步声称,这项任务需要经典超级计算机花费 10000 年。IBM 反驳了这一说法,表示他们的一台超级计算机可以在 2.5 天内完成这一任务,从而削弱了谷歌对完成时间的声明。
迄今为止最大的量子计算机,IBM 的 Osprey,拥有一个 433 量子位的处理器。目前尝试构建具有足够大处理器的量子计算机的工作受到逐渐增加的噪声困扰,即量子态通过与周围环境(如温度和磁场变化)的互动而发生退相干或崩溃成经典态的潜力。
噪声问题构成了将量子计算机扩展到如因式分解极其大的质数等计算潜力的一项关键挑战。降噪量子比特可能会缓解这些挑战的一部分,但目前量子计算仍处于起步阶段。
参考文献
Bernhardt, Chris。面向所有人的量子计算。麻省理工学院出版社,2020 年。
IBM 推出 400 量子比特以上的量子处理器和下一代 IBM 量子系统二。IBM 新闻。(无日期)。 newsroom.ibm.com/2022-11-09-IBM-Unveils-400-Qubit-Plus-Quantum-Processor-and-Next-Generation-IBM-Quantum-System-Two
Kaye, P., Laflamme, R., & Mosca, M. (2020)。量子计算导论。牛津大学出版社。
“降噪”量子比特可以最小化量子计算机中的错误。芝加哥大学新闻。(无日期)。 news.uchicago.edu/story/noise-cancelling-qubits-can-minimize-errors-quantum-computers#:~:text=As%20existing%20quantum%20computers%20are,to%20high%20rates%20of%20error.
Roush, W. (2020 年 7 月 13 日)。谷歌-IBM“量子霸权”争端。MIT 技术评论。 www.technologyreview.com/2020/02/26/905777/google-ibm-quantum-supremacy-computing-feud/
Zubairy, Muhammad Suhail。初学者的量子力学:量子通信和量子计算的应用。牛津大学出版社,2020 年。
量子计算在优化问题中的应用——解决背包问题
如何使用量子计算解决优化问题,与传统解决方案相比
·
关注 发表在 Towards Data Science ·12 分钟阅读·2023 年 1 月 16 日
--
许多问题涉及找到一个函数的最大值或最小值,这个函数称为目标函数,依赖于几个(或许多)变量;某些约束可能需要或不需要应用于这些变量。这些变量可以是二进制的、整数的、集合中的元素、浮点数等。约束可能应用于单个变量,也可能更复杂。
具体来说,这些问题可能提供优化工业过程的解决方案,或优化工厂和仓库之间的货物流动(按成本、时间、吞吐量),还可以优化驾驶路线,找到满足某些复杂标准的社交网络成员——示例可以继续列举下去。
解决这些问题的算法有很多:线性规划、整数规划、模拟退火、纯模拟技术等等。这些算法通常在你可能熟悉的领域如规范分析、运筹学等中进行研究。一般来说,这些算法运行在传统的 CPU 上,适用时也会使用 GPU 求解器。
量子计算是一个面向未来的话题,近年来越来越受到关注。虽然目前(2023 年 1 月)全面通用的量子计算机尚未普遍可用,但某些专用机器已经可以用来极其快速地解决优化问题。这是 Ars Technica 发布的一篇文章,展示了一些这些早期量子系统的应用。
在这篇文章中,我们将简要描述其中一种早期的量子架构,并将其与解决典型优化问题的经典方法——背包问题进行比较。
但首先,让我们讨论一下模型。
二次模型——二元、离散、约束
有一类模型称为二次二元模型(BQM)。它们的变量是二元的(可以取两个可能值中的一个),并且它们是线性和二次项的组合。Ising 公式提供了一个 BQM 模型的目标函数示例:
Ising 模型,目标函数
变量 s 可以取 {-1, +1} 的值。h 系数是线性偏差。J 系数是二次耦合。目标函数可以被看作是系统的能量——稍后会详细讲解。
描述相同 BQM 模型的等效方式是 QUBO 公式:
QUBO 模型,目标函数
模型变量是二元值向量 x——可能的值为 {0, 1}。上对角矩阵 Q 包含对角线上的线性系数,二次系数则在非对角线位置。
Ising 和 QUBO 方程描述的是相同的模型,它们之间的转换应该是微不足道的。Ising 从物理背景看可能更为熟悉,而 QUBO 更类似于计算机科学模型。解决这些模型意味着找到使目标函数最小化的变量值。
如果变量可以取自一个比二进制更大的离散值集合,则模型称为离散二次模型——DQM。如果变量可以取任意整数或实数值,并且可能受到各种约束,那么模型称为受约束的二次模型——CQM。这些是基本 BQMs 描述的推广。
让我们探讨一种能够高效解决此类模型的硬件架构。
D-Wave 量子退火机
通用量子计算机可能包含通过量子逻辑门连接的一些量子位(量子比特,值为 0 或 1)。这样的机器可以通过像Qiskit这样的 SDK 进行编程,并可以解决几乎任何问题——因此被称为“通用”。目前,没有这样的机器能够解决实际问题,因为硬件仍然受到量子位数量有限和高错误率的严重限制。
但对于专业应用,量子计算机已经可以使用。D-Wave 已经建立了可以用于解决各种优化问题的量子计算机。这是 D-Wave QPU 的架构:
D-Wave Chimera QPU
这些细长的黑色形状,分布在 4+4 网格(单元格)中,是微小的超导环路,交错布置在芯片表面。由于每个环路是超导的,它可以承载电流。由于电流的存在,环路会产生一个小的磁场。磁场的方向可以是向上或向下(即+1 或-1)。把这些环路看作是量子位——量子比特。
绿色点称为内部耦合器——它们是每个 4x4 单元格内部的耦合器。它们通过一定程度耦合两个交叉环路的磁场,因此这些磁场的值是部分相互依赖的。环路之间并不接触,它们只是通过磁力耦合。
蓝色点称为外部耦合器。它们耦合不同单元格中环路的磁场。
在 Chimera QPU 中有许多单元格,总的量子位数量超过每个 QPU 的 2000 个。
还记得 BQMs 中的 Ising 公式吗?环路中的磁场是模型中的变量 s。每个环路可以有一个硬件偏置,用来增强或削弱磁场——这些偏置是 Ising 公式中的线性系数 h。而耦合器(内部或外部)是 Ising 模型中的二次系数 J。
系统的总能量,通过 Ising 公式计算,表示为:
-
每个环路磁场的方向(向上或向下)(在 Ising 模型中的 s 变量,可以是+1 或-1)
-
每个环路的偏置(在 Ising 模型中的 h 系数)
-
耦合(在 Ising 模型中的 J 系数)
换句话说,量子设备的总能量与 Ising 模型中的目标函数是相同的。如果物理系统达到了其可能的最小能量,那么这应该与目标函数的最小值相同。
为了解决 BQM,需要根据问题的具体情况编写 Ising 公式。Ising 公式中的 h 和 J 系数提供了 QPU 内部的偏差和耦合,并被编码到硬件中。最初,所有回路都被设置在量子叠加状态,其中它们的磁场指向上下,这对应于高能量状态。整个设置过程可以通过 D-Wave API 在 Python 中完成。
系统的总能量随后会逐渐降低(一个叫做退火的过程),直到达到给定 h 和 J 系数(偏差和耦合)下的最小可能值。此时,磁场的量子叠加状态被破坏;任何回路的磁场将指向上或下。
总能量的最小值是对应于 BQM 解的目标函数值。回路的磁场方向是解决优化问题的变量值(上=+1,下=-1)。在那个点读取回路即可得到问题的解决方案。
实际上,量子退火只需几分之一秒。无论问题的大小如何,只要 Ising BQM 可以适配到 QPU 芯片上,解决方案(退火后的最终状态)都能非常迅速地得到。D-Wave 利用物理系统趋向低能量状态的倾向,快速解决 NP 难题。Ising 这个名字可能会让你想起——一个世纪前由 Lenz 和 Ising 确定的方程是铁磁性的一种统计模型,在这里用于通过物理过程解决特定类型的模型。
从某种意义上讲,量子退火机(D-Wave QPU)就像一个模拟计算机,通过物理过程模拟数值问题。它也可以被看作与 FPGA 相关,因为它是一个可以为特定问题配置的设备,快速解决问题后,可以重新配置以处理不同的问题。
回到模型
只要你的问题可以表示为 BQM,它就可以通过 D-Wave 量子退火机直接并且非常快速地解决。
如果问题不是二次的呢?它可能是线性的、立方的等等。通常,通过引入虚拟变量及其他技术,模型可以被重新制定,直到它们变为二次模型。D-Wave 问题解决手册描述了几种这样的技术。
如果模型在重新表述后仍然是二次的但不是纯粹的二进制的怎么办?有些可能是离散模型(DQM——每个变量有多个值),其他的可能有约束且变量可以是任何东西(CQM——受约束的二次模型)。DQM 和 CQM 不能直接在量子退火器上运行。然而,你仍然可以使用 D-Wave API 编码问题,并将其提交到云中的求解器。D-Wave 会将问题分解为 BQM(可以由量子机器解决)和其余部分(由 D-Wave 使用常规求解器解决)。可以将其视为分而治之:D-Wave 尝试使用量子退火器解决尽可能多的模型,但某些问题仍有不符合 BQM 方程的部分。可以解决不同类型二次模型的量子和经典求解器的组合称为混合求解器。
在实践中,即使用混合求解器解决 DQM 或 CQM,过程通常仍然非常迅速。问题的硬核部分被解决为 BQM(可能从大参数空间中采样了许多 BQM),而经典算法可以被看作是对大量 BQM 样本集进行总结。
花时间将问题表述为 BQM 是值得的,因为它将在量子退火器上极其迅速地解决。将问题表述为 DQM 或 CQM 仍然是值得的,因为在许多情况下,问题仍将由 D-Wave 混合求解器在非常短的时间内解决。
如果目标函数需要最大化而不是最小化,只需将整个公式的符号反转。如果你需要解一个等式而不是最小化某些表达式,将所有项移到左侧(这样等式就变为零),然后平方——当平方值最小化时,它将达到零,这就是你所期望的。
可以使用许多类似的数学技巧将问题表述为接近于 D-Wave 求解器(量子和混合型)可以非常高效解决的二次模型——最小化一个看起来很像二次模型的函数。
背包问题
我们在这里用来比较量子技术和传统方法的例子叫做背包问题。
背包问题
你有一个背包,具有重量限制——它只能承载一定的最大重量。你有许多物品,每个物品都有不同的重量和货币价值。你想选择一些物品放入背包中,同时最大化整个运输的货币价值,但又不超过最大允许的重量。
目标函数显然是背包中所有物品的货币价值——越大越好。约束是物品的总重量,必须在某个值以下或等于该值。
解决这个问题的方法有很多。如果项的数量较少,蛮力方法是可行的。但即使对于适中的项数,这种方法也很快变得不可行。
我们将下面描述的经典方法称为整数编程。然后我们将其与 D-Wave 混合求解器进行比较。
整数编程
在解决这个问题的众多方法中,我们将选择Pyomo 库来构建整数编程模型,然后使用GLPK 求解器来解决它。当然,存在更快的方法,但这段代码很简单,并且与下一节的量子代码相似。
完整的 Jupyter notebook 链接在文章末尾。这只是整数编程部分:
print(f"solving for {len(values)} items")
print("build the model")
pecm = pe.ConcreteModel(name="Knapsack")
pecm.x = pe.Var(range(0, len(values)), domain=pe.Boolean)
pecm.worth = pe.Objective(
expr=sum(values[j] * pecm.x[j] for j in range(0, len(values))),
sense=pe.maximize,
)
pecm.weight = pe.ConstraintList()
pecm.weight.add(
sum(weights[j] * pecm.x[j] for j in range(0, len(values))) <= max_weight
)
solver_name = "glpk"
print(f"submit model to solver {solver_name}")
solver = pe.SolverFactory(solver_name)
solver.solve(pecm)
print("parse the solver output")
total_value = int(pecm.worth())
total_weight = int(sum(weights[j] * pecm.x[j]() for j in range(0, len(values))))
selected_items = pecm.x
值和权重列表包含所有项的货币价值和重量。目标函数是通过求和所有选定项的货币价值来创建的。约束是通过求和所有选定项的重量来建立的,然后将该总和限制为不超过 max_weight。
完整的模型被发送到本地运行的求解器。当求解器完成后,结果从求解器输出中解析出来。
量子编程
我们使用 D-Wave SDK,它构建模型,然后将其提交给在云端运行的 D-Wave 求解器。
print(f"solving for {len(values)} items")
print("build the model")
cqm = ConstrainedQuadraticModel()
obj = BinaryQuadraticModel(vartype="BINARY")
constraint = QuadraticModel()
for i in range(len(values)):
obj.add_variable(i)
obj.set_linear(i, -values[i])
constraint.add_variable("BINARY", i)
constraint.set_linear(i, weights[i])
cqm.set_objective(obj)
cqm.add_constraint(constraint, sense="<=", rhs=max_weight, label="capacity")
sampler = LeapHybridCQMSampler()
print(f"submit model to solver {sampler.solver.name}")
sampleset = sampler.sample_cqm(cqm, label="knapsack problem")
print("parse the solver output")
feasible_sampleset = sampleset.filter(lambda row: row.is_feasible)
if not len(feasible_sampleset):
raise ValueError("No feasible solution found")
best = feasible_sampleset.first
selected_items = [key for key, val in best.sample.items() if val == 1.0]
total_weight = sum(list(weights.loc[selected_items]))
total_value = sum(list(values.loc[selected_items]))
我们从相同的权重和值列表开始。这个问题的模型不能是纯 BQM,因为我们有一个约束。因此,我们选择了一个混合采样器。
一个有趣的变化是,D-Wave 可能会返回许多样本集作为解决方案的一部分。其中一些不可行(例如,违反约束),我们需要将它们从结果中筛除。在可行的样本集中,一个或多个将具有最低的 Ising 能量——这些就是量子计算机找到的最佳解决方案。
假设存在一个“最佳”解决方案,量子退火机并不总是能收敛到它。具体细节较为复杂,请参见 D-Wave 手册,但本质上有时会使用启发式搜索来找到一个“足够好”的解决方案。在其他情况下,可以引导求解器更强烈地收敛到绝对最佳的解决方案。最后,另一种方法是反复采样解决方案,直到获得绝对最佳的组合。D-Wave 文档非常优秀,并提供了许多微调求解器以寻找最佳解决方案的技巧。
结果
我们测试了两种算法,使用了一个包含 50,000 项的随机生成数据集,这些项的权重和值在 1 到 9999 的离散均匀分布中随机变化。背包的重量限制设置为所有项总重量的 80%——因此一些项总是会被丢弃。
这是来自 Pyomo / GLPK 的输出:
solving for 50000 items
build the model
submit model to solver glpk
parse the solver output
solver time: 129.44616556167603
knapsack max weight: 200052600
items total weight: 200052581
items total value: 241399378
GLPK 解决问题花费了超过 2 分钟。它得出了一种组合,这种组合的重量略低于允许的最大背包重量。
这是来自 D-Wave 的输出:
solving for 50000 items
build the model
submit model to solver hybrid_constrained_quadratic_model_version1
parse the solver output
solver time server side: 18.006079
solver time QPU: 0.016048
knapsack max weight: 200052600
items total weight: 200052560
items total value: 240318952
best sol. energy: -240318952.0
D-Wave 混合求解器总共花费了 18 秒。这还没有达到比整数编程的数量级改进,但已经在接近了。此外,项目数量的轻微增加会迫使整数编程代码花费更长时间来解决问题——这种增加是非线性的。
D-Wave 求解器并非如此。它也需要更长时间来解决更大的问题,但所需时间并不会随着项目数量的增加而急剧增长。
实际 QPU(量子退火机)上花费的时间仅为 16 毫秒。18 秒的总时间中,大部分时间花费在分解问题、采样解决方案空间,然后将结果返回给编码器上。问题在实际 QPU 上运行的部分越大,速度提升越显著。
对于非常大的问题,你的互联网连接可能会成为限制因素。在我的情况下,使用一个中等质量的连接运行这个问题时,我的代码花了几秒钟的额外时间在互联网上推送和拉取数据。
最后,请注意,混合求解器找到的解决方案实际上比整数编程求解器找到的解决方案稍差。混合求解器可以优化以找到更好的解决方案,如上所述,但这超出了本文的范围。
注释
这篇文章中使用的完整代码的笔记本以及其他文件和材料可以在这里找到:
[## misc/quantum_computing_knapsack_article at master · FlorinAndrei/misc
你目前无法执行该操作。你在另一个标签或窗口中登录。你在另一个标签或窗口中注销了…
量子求解器代码基于 D-Wave 自己的示例库,经过了重新整理以便于理解:
D-Wave Ocean 代码示例。D-Wave 系统示例有 50 个可用的仓库。关注他们在 GitHub 上的代码。
github.com](https://github.com/dwave-examples?source=post_page-----274f01e78ed8--------------------------------)
本文中显示的所有图像均由作者创建。
量子比特魔法:用量子计算创建神话生物
原文:
towardsdatascience.com/qubit-magic-creating-mythical-creatures-with-quantum-computing-49bea0fabf4
教程
利用量子比特和大型语言模型生成美丽的图像。
·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 9 月 18 日
--
一个具有量子态向量[1/√2 e^(π/2)i 1/√2]的神话生物。来源:Stable Diffusion。
你能用量子比特做什么?
想想单个量子比特的本质。乍一看,它可能并不那么宏伟。
事实上,单个量子比特似乎没有太多有趣的用途。
然而,凭借一些创造力,我们可以实现许多令人印象深刻的壮举。其中之一是生成图像、音乐,甚至在本文中——神话生物!
在量子层面生成惊人的图像
可视化量子比特可能看起来是一个相当抽象的概念。毕竟,量子比特在微观层面,使其难以理解。
然而,我们可以利用量子比特的惊人特性,将其可视化为富有想象力和美丽的方式。
一种方法是利用量子比特的独特特性,以生成可用于创建图像、声音等描述属性。
叠加的特性
量子计算使得一个量子比特可以同时表示两个值。这一特性称为叠加。
当一个量子比特处于叠加态时,它可以同时表示零和一。这使得计算可以实现二次甚至指数级的增长。
当量子比特处于叠加态时,它具有独特的幅度和相位属性。这些属性指的是量子比特的自旋,也对应于其测量为 0 或 1 的可能性。
可视化一个具有高概率(83%)测量为 1 的 q 球体。来源:理解 Bloch 球体。
我们可以从量子比特中提取这些属性,并将其解码为描述性文本属性,然后可以通过大型语言模型生成图像,例如 StableDiffusion、Huggingface 或 DALL-E。
创建量子电路
让我们从创建一个单量子比特的量子电路开始。下面的示例展示了这一过程。
# Define the initial state vector.
initial_state = [1/np.sqrt(2), 1/np.sqrt(2)]
# Create a quantum circuit with one qubit.
qc = QuantumCircuit(1)
# Initialize the qubit with the initial state vector.
qc.initialize(initial_state, 0)
接下来,我们将允许用户从一系列量子门(X, Y, Z, H, S, T)中选择,以应用于量子比特以改变其属性。每个门对量子比特执行不同的操作,修改其幅度和相位。
例如,X 门翻转量子比特的状态从 0 到 1。相比之下,Z、S 和 T 门修改量子比特的相位。这会导致其围绕 Z 轴旋转。
每个操作稍微改变量子比特的测量属性,正是这种精确的变化可以被用来生成独特的图像!
获取状态向量
一旦用户对量子比特执行了一系列量子操作,我们可以测量结果并获取状态向量。
量子比特的状态向量类似于其指纹。它提供了对量子比特当前持有的幅度和相位属性的独特视图。
# Simulate the circuit using the statevector simulator
simulator = Aer.get_backend("statevector_simulator")
result = execute(qc, simulator).result()
statevector = result.get_statevector()
下面展示了一个量子比特状态向量的示例:查看。
[-0.70710678–4.44089210e-16j -0.70710678–3.33066907e-16j]
使用状态城市图可视化一个量子比特状态向量。来源:作者。
提取颜色、力量和魔法
状态向量由两个分量组成,我们可以称之为 alpha 和 beta。
状态向量中的每个分量可以进一步分解为幅度(r)和相位(theta)。这导致从量子比特中获得四个不同的值(r1、theta1、r2、theta2)。我们可以将这些值直接映射到描述性文本属性。
# Get the first and second components of the state vector.
alpha = state_vector[0]
beta = state_vector[1]
# Get the magnitude and phase of each component.
r1 = np.abs(alpha)
theta1 = np.angle(alpha)
r2 = np.abs(beta)
theta2 = np.angle(beta)
由于我们想生成神奇的神话生物图像,我们使用的属性可以包括大小、颜色、力量和魔法。
# Define a dictionary of words for each attribute.
attribute_words = {
"size": ["tiny", "small", "medium", "large", "huge"],
"color": ["red", "green", "blue", "black", "rainbow"],
"power": ["feeble", "weak", "strong", "powerful", "mighty"],
"magic": ["mundane", "ordinary", "magical", "enchanting", "mystical"]
}
# Map the magnitude and phase values to indices
size_index = map_value_to_index(r1, 0, 1, len(attribute_words["size"]))
color_index = map_value_to_index(r2, 0, 1, len(attribute_words["color"]))
# ...
# Get the attribute words from the dictionary using the indices
size_word = attribute_words["size"][size_index]
color_word = attribute_words["color"][color_index]
# ...
# Concatenate the attribute words into a string separated by commas
attribute_string = f"{size_word}, {color_word}, {power_word}, {magic_word}"
在上述代码示例中,我们定义了一个文本属性列表,用于描述我们的神话生物。每个属性映射到量子比特的状态向量。
结果是一个句子,可以作为大型语言模型图像生成的提示。
生成 LLM 提示
现在我们已经获得了神话生物的属性,我们需要生成 LLM 提示。
将上述代码整合成一个方法,结果是以下这个辅助方法,用于生成所需的文本。
def generate_prompt(statevector):
# Generate an image from Huggingface using the prompt "A mythical creature with the quantum state vector [α β]".
alpha = statevector[0]
beta = statevector[1]
# Get list of attributes from the state vector.
attributes = generate_attribute_words(statevector)
# Create an LLM prompt with the quantum state vector [{alpha} {beta}].
prompt = f"A mythical monster that is " + attributes + " with the quantum state vector " + statevector
return prompt
生成的状态向量现在将被转换为 LLM 提示。
一种神话怪物,体型巨大,蓝色,强大,神秘,具有量子状态向量 [α β]。
让量子比特成为令人惊叹的生物
最后一步是通过程序调用 LLM 图像生成服务,为我们的量子比特状态向量注入一些生命。
惊人的是,一些大型语言模型(例如 Bing 的 DALL-E)能够直接将量子计算状态向量解读为图像。它们可以直接使用这些向量来以创造性的方式进行可视化。
例如,下面是 Bing 的 DALL-E 在被要求根据量子状态向量创建图像后的响应。
这只昆虫有两只翅膀,形状像量子状态向量 [-0.70710678–4.44089210e-16j -0.70710678–3.33066907e-16j],其中翅膀的长度和角度分别代表实部和虚部。
LLM 的这一强大能力在下面通过可视化量子比特的状态向量,并结合我们程序中的描述属性进一步展示。
一种神话生物,体型巨大,橙色,强大,神秘,具有量子状态向量 [-1/√2 -i/√2]。来源:Stable Diffusion。
完整的源代码可以在 这里 找到。
现实世界应用的潜力
通过大型语言模型生成来自量子比特的内容,具有超越图像的潜力。事实上,凭借一点 创造力 和 实践,这项技术也具有现实世界的应用。
考虑从量子计算状态中生成更抽象的产品的想法。示例可能包括:
-
诗歌或书面故事
-
音乐或演讲
-
商业创意
-
图表和图示
-
软件程序
LLM 擅长生成各种内容,并且能够结合量子计算的力量来完成这些任务。
天空是极限
希望这篇文章能激发你对量子计算的兴趣,并打开更多可能性的大门。
通过结合量子计算和 LLM 的两种令人难以置信的新兴技术,并加入一点想象力,我们可以带来令人惊叹的可能性。
现在,轮到你了!
关于作者
如果你喜欢这篇文章,请考虑在 Medium、Twitter 和我的 网站 上关注我,以便收到我未来的帖子和研究工作通知。
在 Azure “Prompt Flow” 中以 GPT 模式查询文档语料库
如何自动向量化内容并创建类似 LangChain 的机制,以高效查询 语料库 文档
·发布于 Towards Data Science ·6 分钟阅读·2023 年 7 月 21 日
--
由 Kenny Eliason 拍摄,发布于 Unsplash
GPT 热潮
全球所有精通技术的人们已经玩了一段时间的 ChatGPT…
-
他们中的许多人将其用作非常聪明的知识数据库🔎,
-
一些人探索了“提示工程”(或“Prompt Engineering”)的艺术,以获得更相关的结果,有时使用他们自己的数据🤖,
-
但只有少数人进一步利用了诸如 LangChain 等解决方案,来构建复杂的工作流并创建实际应用📚。
的确,掌握诸如“嵌入”或“向量存储”等概念,再加上编程要求,可能对许多人来说显得复杂,阻碍他们真正发挥大型语言模型的潜力。
这时“Prompt Flow”来拯救大家!
让我们发现如何在 Azure 中使用低代码构建一个强大的问答工具现在变得可能了!
前提条件
我假设你拥有创建本教程所需资源的必要权限,其中最重要的一项是拥有一个“Azure 机器学习工作室工作区”。
Azure 机器学习工作室登陆页面(图片来源:作者)
“Prompt Flow”功能,以及“模型目录”(允许你部署由 Azure、Hugging Face、Meta 等策划的大型语言模型),目前处于私人或公开预览阶段,因此你需要加入等待列表才能激活和使用。
在 Azure 机器学习工作室中的模型目录和 Prompt Flow(图片由作者提供)
建立向量索引
理解嵌入
为了高效处理大规模的语料库并克服当前模型的标记限制,你需要将每个文档拆分成块(例如,每一页),并将相应的内容转换成“嵌入”。
嵌入是一个数字向量,允许数学上比较不同内容的潜在含义和概念,如下所示:
-
我爱狗: [-0.020993631333112717, …, -0.013046476989984512]
-
我爱动物: [-0.004775667563080788, …, -0.02461422048509121]
-
我的车坏了: -0.029365260154008865, …, -0.032723452895879745
当使用NumPy “dot” 函数比较前两个嵌入的相似性时,我们得到接近 1 的结果(约 0.94),显示它们的含义非常接近。
相反,将第三个嵌入与前两个嵌入进行比较时(约 0.77),结果表明它们对应于不同的概念。
# skipping the code to connect to Azure Open AI API
# and convert the 3 sentences into embeddings
np.dot(emb_results[0], emb_results[1])
0.9428616001086217
np.dot(emb_results[0], emb_results[2])
0.7730676233478093
np.dot(emb_results[1], emb_results[2])
0.7718532811989359
这将成为下游分析的基础:能够将查询中嵌入的概念与之前审阅过的内容进行比较,以找到相似性。
向量存储或索引
从内容到向量存储(图片由作者提供)
正如你可能猜到的,我们现在需要一个地方来收集所有这些向量,并将它们与初始文档链接在一起……这是“向量存储”的核心功能:
[document.pdf]
— 第 1 页
— 第 1 页 => 分割-1 => 嵌入 #1
— 第 1 页 => 分割-2 => 嵌入 #2
— 第 2 页
— 第 2 页 => 分割-1 => 嵌入 #3
— 第 2 页 => 分割-2 => 嵌入 #4
向量存储创建过程可能繁琐,但“Prompt Flow”简化并加速了这一过程。
对于这个例子,我们将使用三个文档作为语料库:
-
《爱丽丝梦游仙境(第一章)》,作者:刘易斯·卡罗尔
-
《迷失的小机器人》,伊萨克·阿西莫夫的短篇小说(我最喜欢的作家!)
-
关于互联网的几段文字,由 GPT-4 生成
我们启动“stories-vector-index”的创建,并选择存储在我计算机上的 3 个文件。
在 Azure Prompt Flow 中创建新的向量索引(图片由作者提供)
在第二步中,你需要选择一个包含“嵌入”模型的 Azure Open AI 资源。
这适用于标准的“Azure Open AI”资源,这些资源附带 GPT-3.5 模型并提供“text-embedding-ada-002”:
Azure Open AI — 部署在资源中的模型(图片作者提供)
然后,只需在“无服务器计算”或“计算集群”之间选择,并耐心等待 🤔
根据我的经验,使用“无服务器计算”处理少量文档大约需要 10 分钟,而处理 1,000 个文档(约 1GB)的时间可能长达一个小时,包括各种格式如 PDF、PPT、DOC、CSV、XLS、TXT 等。
向量索引创建成功过程在 Azure Prompt Flow 中(图片作者提供)
当过程完成后,向量索引在 AZML 中变得可用:
Azure Prompt Flow 中的可用向量索引(图片作者提供)
注意:到目前为止,我遇到的唯一失败案例是当初始文档包含一个或多个“受保护的 PDF”时,因为脚本无法访问文件的内容。
在这种情况下,您需要检查日志以检测哪个文档导致了问题:
向量索引创建失败登录 Azure Prompt Flow(图片作者提供)
向量索引创建失败登录 Azure Prompt Flow(图片作者提供)
创建新流程
从画廊(不断演变中)中,我们将选择并克隆“带来您自己的数据 QnA”:
Prompt Flow 中的流程画廊(图片作者提供)
这创建了一个完整的类似 LangChain 的机制,只需要几个配置步骤:
克隆的问答提示流配置(图片作者提供)
-
更新初始查询/问题
-
选择一个包括嵌入能力的 Azure Open AI 资源
-
选择一个与运行中的计算实例相关联的运行时
-
更新向量索引的位置(使用其“存储 URI”)
克隆的问答提示流配置(图片作者提供)
成功运行后,您可以检查答案的相关性:
"output":"The main character in the first chapter of Alice in Wonderland is
Alice. The only other character mentioned is the White Rabbit with pink eyes
who runs close by her and leads her down the rabbit hole.
(Source: Alice’s Adventures in Wonderland by Lewis Carroll)"
确实,有人可能会认为您无需提供第一章,因为 GPT 模型能够在没有它的情况下回答……但您明白这是仅用于演示目的 😉
运用相同的原则,我们可以为企业文档创建一个强大的问答应用程序:
基于专有文档的流程输出(图片作者提供)
通过 Web-App 利用该端点
经过几次“合理性检查”以确保提供的答案有意义后,我们可以通过相应的按钮在几次点击中部署应用程序。
该流程变成一个 REST Azure 端点,简单的 Python、C# 或 R 代码片段可以使用:
来自 Azure 的端点消费片段(作者提供的图片)
我的一项建议是使用如 Streamlit 这样的包来使模型易于访问,并设计一个简单的界面,允许用户以自然的方式提问,并获得相应来源的答案,如下所示:
使用 Streamlit 创建的可能的 Web 应用程序设计和 UX(作者提供的图片)
就这样!🎉
凭借如此少的编码,我们从原始的文档语料库转变为一个利用 LLM 功能的 web 应用程序,帮助最终用户以简单自然的方式提取见解和知识!
在即将到来的各种生成式 AI 使用场景中,我真的相信它将成为许多公司或组织处理和利用大量 TB 数据或文档的游戏规则改变者。
像往常一样,我尝试确定所有所需的步骤,但如果在本教程中有任何遗漏的说明,请随时联系我!
不要犹豫,浏览我在 Medium 上的其他贡献:
数据科学、机器学习与创新
pl-bescond.medium.com](https://pl-bescond.medium.com/pierre-louis-besconds-articles-on-medium-f6632a6895ad?source=post_page-----3a79ec23f59c--------------------------------)
PowerBI 中的快速而简单的时间序列预测:实用指南
如何仅用几次点击就设置和配置一个预测系统
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 5 月 1 日
--
作者提供的图片。
介绍
时间序列预测已成为商业、政府和个人普遍使用的工具。现在,它几乎被应用于每一个可以想象的领域:在金融领域预测股票和利率,在医疗保健中预测医院床位容量,在交通领域规划路线和交通模式,在能源领域预测电力供应和需求等等。
因此,存在对一个用户友好、易于上手且快速无缝设置的工具的需求。幸运的是,PowerBI 解决了这个问题。凭借其内置工具,用户能够仅需几次点击就能设置和配置预测系统。
本文将为你提供一个逐步指南,教你如何实现这一目标。但在深入实际操作之前,让我们先了解一下在后台运行的算法。
算法:指数平滑
PowerBI 使用指数平滑——一种能够捕捉时间序列数据趋势的强大算法,同时抑制噪声和不必要的变化。
简单来说,指数平滑采用过去观察值的加权平均值,给予最近的观察值更多权重。这意味着权重随着观察值的变老而指数递减。其背后的思想是,最近的观察值比远离的观察值对未来行为的预测更具信息量。
指数平滑有几种变体,每种方法使用不同的水平、趋势和季节性组件的组合来进行预测。
PowerBI 的工具会根据用户提供的历史数据的季节性,自动在ETS 模型系列的两个算法之间进行选择:(1)用于季节性数据的 ETS AAA,和(2)用于非季节性数据的 ETS AAN。
季节性数据:ETS AAA
首先让我们拆解这个相当令人生畏的缩写。有趣的是——或者说令人困惑,取决于你的视角——缩写的第二部分描述了第一部分所代表的组件的性质。
第一部分,ETS,告诉我们时间序列模型所考虑的组件。在这种情况下,它包括误差(E)、趋势(T)和季节性(S)。AAA中的A代表加性。有了这些信息,我们现在可以得出结论,我们的时间序列模型考虑了加性误差、加性趋势和加性季节性。
-
加性误差指的是时间序列中的误差或随机波动被加到预期值上。
-
加性趋势意味着时间序列的预期值随时间以固定的量发生变化。例如,如果时间序列的趋势组件为 2,则在时间t+1时的预期值将比时间t时的预期值高 2 个单位。
-
加性季节性意味着时间序列的季节性组件在每个季节的预期值上增加一个固定量。例如,如果时间序列在七月份的季节性组件为 5,则七月份的预期值将比其他任何月份高 5 个单位。
非季节性数据:ETS AAN
基于前一节对如何解释这个缩写的解释,你可能已经得出结论,ETS AAN也使用了加性误差和加性趋势。AAN末尾的N仅表示非季节性,表明该模型不考虑时间序列中的季节性模式。
在 PowerBI 中的实现
现在让我们来看看在 PowerBI 中设置时间序列预测系统所需遵循的各个步骤。
步骤 1: 将你的时间序列数据加载到 PowerBI 中。
步骤 2: 创建一个包含你的时间序列的折线图,并确保 X 轴类型设置为连续。在本文的插图中,我使用了微软提供的示例数据。
作者提供的图像。数据使用的许可信息:MIT 许可证。
步骤 3: 在可视化窗格中,导航到将更多分析添加到您的视觉效果,然后打开预测。
步骤 4: 在选项下,你可以设置一些参数和自定义配置,如单位、预测长度、忽略最后、季节性和置信区间。
单位参数适用于预测长度和忽略最后(稍后我们将看到这个参数的含义)。在我们的例子中,我将单位设置为“月份”,将预测长度设置为“6”,表示我正在尝试对未来 6 个月进行预测。
如果不提供进一步的输入,结果将如下所示:
作者提供的图片。
看起来不太好,是吗?这是因为季节性的默认值设置为“自动”,而且,可能我们的数据没有足够的季节性周期以便工具准确检测到它。
仅通过查看折线图,我们可以看到我们的数据具有年度季节性。收入通常在夏季达到低点,然后稳步上升直到 4 月中旬,之后开始再次下降,循环重复。
因此,我们可以手动将季节性设置为“365”点。这里的点指的是时间序列数据的粒度。在我们的例子中,我们具有每日粒度——图表中的每个数据点代表一天。因此,365 点意味着 365 天。
按照描述调整参数后,我们得到以下结果:
作者提供的图片。
好多了!模型明显捕捉到了数据中的季节性,并且还预测了过去两年 8 月底的峰值。此外,通过查看每年 8 月中旬的低点,我们可以看到模型还捕捉到了这个时间序列中存在的轻微上升趋势。
当悬停在图表上时,我们还可以检查预测的确切值以及置信区间的上下限,默认情况下设置为 95%。这意味着我们的模型以 95%的概率预测数据将落在这个预测范围内。
作者提供的图片。
我们对数据进行的预测越远,置信区间就越宽。这是因为未来值的预测不确定性增加(想象一下天气预报——我们可以对接下来一两天的天气进行合理的预测,但对于 7 天或 14 天的预报,不确定性会大得多)。
模型评估
一旦你设置了预测系统,你可以通过一种叫做回溯预测的过程来评估模型的表现。回溯预测不是预测未来值,而是对过去的值进行“预测”,并将这些预测与实际值进行比较。
使用这种方法可以让我们了解模型的表现,通过评估它在过去的表现情况来实现。
现在是利用ignore the last参数的时候。将此参数设置为“6”,我们告诉模型对我们时间序列数据的最后 6 个月进行回溯预测。
图片来源于作者。
查看这个回溯预测,我们可以看到,虽然模型不是 100%准确——实际上没有模型是——但总体表现还是相当不错的。平均而言,预测值似乎略低于实际值。这可能是由于模型没有很好地捕捉数据的整体上升趋势,或者上升趋势的斜率在这一周期中实际上在增加。更多的历史数据通常能解决这个问题。
结论
PowerBI 提供了一种真正无缝的方法来设置时间序列预测系统,这只需几分钟的最小调整即可完成。利用这一工具,组织和个人可以利用其历史数据生成对未来趋势和事件的可操作预测。然而,也必须记住,预测需要不断的调整和评估,而预测的准确性在很大程度上依赖于基础历史数据的质量和数量。
喜欢这篇文章吗?
让我们保持联系!你可以在 Twitter、LinkedIn 和 Substack 找到我。
如果你喜欢支持我的写作,可以通过 Medium 会员 来支持我,这样你可以访问我所有的故事以及 Medium 上成千上万其他作家的故事。
## 加入 Medium,通过我的推荐链接 - Thomas A Dorfer
阅读 Thomas A Dorfer 的每个故事(以及 Medium 上成千上万的其他作家的故事)。您的会员费直接支持…
使用 R 进行快速文本情感分析
原文:
towardsdatascience.com/quick-text-sentiment-analysis-with-r-2cc4f04c35c1
使用 TidyText 创建一个简洁且快速的文本分析工具与 R
·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 3 月 10 日
--
Kenny Eliason 拍摄的照片,来源于 Unsplash
介绍
到处都是文本!自从互联网传播到全球以来,我们每天生成的文本数据量巨大。仅每日发送的文本消息,就估计有大约 180 亿条在 日常基础上流通。
现在想象一下生成的新闻量。这个数量如此庞大,以至于围绕新闻剪辑建立了整个业务,分离出关于特定话题的最佳信息,以帮助公司制定营销策略。
人工智能如何提供帮助?显然,自然语言处理(NLP)在其中发挥了重要作用,提供了良好的工具和算法来分析文本信息。作为数据科学家,我们可以利用 tidytext
,这是 R
的一个优秀库,帮助我们构建快速分析工具来检查文本内容。
接下来,让我们在实践中看看这个。
文本分析
准备你的环境
为了准备好与本文一起编写代码,加载下列列出的库。
# Installing libraries
install.packages('tidyverse')
install.packages('tidytext')
# Loading libraries
library(tidyverse)
library(tidytext)
library(textdata)
tidytext
库的工作方式与 tidyverse
类似,利用直观的函数名称,并通过管道符号 %>%
链接它们。
让我们使用 这篇关于 R 语言 的维基百科文本 来创建我们第一个简单的 文本分析器。
text <- "R is a programming language for statistical computing and graphics
supported by the R Core Team and the R Foundation for Statistical Computing.
Created by statisticians Ross Ihaka and Robert Gentleman, R is used among data
miners, bioinformaticians and statisticians for data analysis and developing
statistical software. Users have created packages to augment the functions
of the R language.
According to user surveys and studies of scholarly literature databases,
R is one of the most commonly used programming languages in data mining.[8]
As of December 2022, R ranks 11th in the TIOBE index, a measure of programming
language popularity, in which the language peaked in 8th place in August 2020.
The official R software environment is an open-source free software
environment within the GNU package, available under the GNU General Public
License. It is written primarily in C, Fortran, and R itself
(partially self-hosting).
Precompiled executables are provided for various operating systems. R has a
command line interface.[11] Multiple third-party graphical user interfaces are
also available, such as RStudio, an integrated development environment,
and Jupyter, a notebook interface."
下一步是将这个文本转换为 tibble
对象,可以理解为 data.frame
。
# Transform to tibble
df_text <- tibble(text)
它不会改变太多你的对象,但这是我们能够使用tidytext
函数所必需的,因为这些函数要求数据来自 tibble 或 data.frame 对象。如果你感兴趣的话,这里是转换后的样子。
文本转换为 tibble 对象。图片来源:作者。
频率计数
继续,我们将对文本进行标记化。令牌是文本中最小的有意义单位。大多数项目使用 1 个单词=1 个令牌,但如果你的项目需要,也可以是其他大小。因此,标记化就是将文本分解为这种最小的有意义的片段,以构成信息。要使用tidytext
标记化我们的文本,请使用此函数。
令牌是文本中最小的有意义单位。[1]
# Tokenizing the text
tokens <- df_text %>%
unnest_tokens(input = text, #name of the input column
output = word) #name of the output column
结果如下。
经过标记化的文本。图片来源:作者。
不过,我们可以看到诸如is
、a
、for
之类的令牌不会为信息添加任何内容。认同吗?这些被称为停用词。我们应该有办法去除这些令牌,只留下干净的数据,即文本信息中的实际意义的令牌。
tidytext
已经附带了集成了停用词的数据集。如果我们输入stop_words
并运行代码,我们将能看到它。
# View stop_words
stop_words
# A tibble: 1,149 × 2
word lexicon
<chr> <chr>
1 a SMART
2 a's SMART
3 able SMART
4 about SMART
5 above SMART
6 according SMART
7 accordingly SMART
8 across SMART
9 actually SMART
10 after SMART
# … with 1,139 more rows
注意到包含单词的列命名为word
。这也是我们将标记化列命名为该变量名的原因,这样更容易将两个数据集连接起来。因此,我们现在的工作是将它们连接起来,去除停用词。我们可以使用anti_join()
函数,它只保留数据集 A 中存在但在 B 中不存在的单词。在序列中,我们只需计数并按出现频率排序。
# Removing stopwords and counting frequencies
tokens %>%
anti_join(stop_words) %>%
count(word, sort = TRUE)
# Result
# A tibble: 79 × 2
word n
<chr> <int>
1 language 4
2 data 3
3 environment 3
4 programming 3
5 software 3
6 statistical 3
7 computing 2
8 created 2
9 gnu 2
10 interface 2
# … with 69 more rows
惊人吧?这样我们就很容易了解这段文本的主题。用于统计数据分析的软件或编程语言。
我们可以使用前面的代码创建一个函数,快速给出任何文本的频率计数。
text_freq_counter <- function(text){
# Transform to tibble
df_text <- tibble(text)
# Tokenizing the text
tokens <- df_text %>%
unnest_tokens(input = text, #name of the input column
output = word) #name of the output column
# Removing stopwords and counting frequencies
freq_count <- tokens %>%
anti_join(stop_words) %>%
count(word, sort = TRUE)
# Return
return(freq_count)
}#close function
让我们来测试一下。我会回到这篇文章的第一部分,复制它,让我们的函数计算频率。
text <- "Text everywhere! Since the Internet was spread around the world,
the amount of textual data we generate everyday is ginormous. Only textual
messages sent everyday, it is estimated that there are around 18 Billion of
them circulating on a daily basis*.
Now imagine the amount of news generated as well. It's a so overwhelming
amount that there are whole businesses built around news clipping, separating
the best information about a given topic to help companies in their marketing
strategies.
How is AI helping that? Certainly, NLP plays a huge part on that providing
good tools and algorithms to analyze textual information. As Data Scientists,
we can profit of tidytext, an excellent library from R to help us building
quick analytical tools to check the content of a text.
Let's see that in practice, next."
# Running the function
text_freq_counter(text)
[OUT]
# A tibble: 50 × 2
word n
<chr> <int>
1 amount 3
2 textual 3
3 data 2
4 everyday 2
5 information 2
6 news 2
7 text 2
8 tools 2
9 18 1
10 ai 1
# … with 40 more rows
效果非常好。
额外进阶
我们可以在这里停下来,但这个话题非常有趣,我觉得我们应该再深入一点。让我们现在向我们的文本分析器添加情感分析吧。
tidytext
还准备好了情感分析,因为它提供了几个情感数据集。选项包括“Bing”、“Afinn”和“nrc”。让我们看看它们之间的区别。
Bing情感数据集包含分类为正面或负面的单词。因此,一个选项是检查你的文本中正面与负面单词的数量,从而了解情感。
# Bing sentiments
get_sentiments('bing')
# A tibble: 6,786 × 2
word sentiment
<chr> <chr>
1 2-faces negative
2 abnormal negative
3 abolish negative
4 abominable negative
5 abominably negative
6 abominate negative
7 abomination negative
8 abort negative
9 aborted negative
10 aborts negative
# … with 6,776 more rows
Afinn情感数据集,可能来源于 affinity,将单词与数字分类。数字越正面,单词越正面,反之亦然。它需要加载library(textdata)
。
library(textdata)
# Sentiments Afinn
get_sentiments('afinn')
# A tibble: 2,477 × 2
word value
<chr> <dbl>
1 abandon -2
2 abandoned -2
3 abandons -2
4 abducted -2
5 abduction -2
6 abductions -2
7 abhor -3
8 abhorred -3
9 abhorrent -3
10 abhors -3
# … with 2,467 more rows
最终,NRC 会将单词分类为情感名称,如信任、惊讶等。
# Sentiments Afinn
get_sentiments('nrc')
# A tibble: 13,875 × 2
word sentiment
<chr> <chr>
1 abacus trust
2 abandon fear
3 abandon negative
4 abandon sadness
5 abandoned anger
6 abandoned fear
7 abandoned negative
8 abandoned sadness
9 abandonment anger
10 abandonment fear
# … with 13,865 more rows
所以,我们接下来要做的是使用afinn创建一个评分,然后绘制我们文本中正面和负面单词的结果。
我将使用关于科技行业裁员的网络文本。现在该功能包含一个图形,显示单词根据afinn值和频率的评分。
# Function for frequency count
text_freq_counter <- function(text){
# get sentiments
sentiments <- get_sentiments(‘afinn’)
# Transform to tibble
df_text <- tibble(text)
# Tokenizing the text
tokens <- df_text %>%
unnest_tokens(input = text, #name of the input column
output = word) #name of the output column
# Removing stopwords and counting frequencies
freq_count <- tokens %>% #dataset
inner_join(sentiments, by=’word’) %>% #join the sentiments
count(word, value, sort = TRUE) %>% #count the words by sentiment value
mutate(score = n * value) %>% # create score by multiplying score * value
arrange( desc(score)) # sort
# Plot
g <- freq_count %>%
ggplot( aes(x= score, y= reorder(word, score),
fill= score > 0) ) +
geom_col(show.legend = F) +
labs( x= ‘Sentiment Score’,
y= ‘WORD’,
subtitle = ‘Negative versus positive sentiments’) +
ggtitle(‘Sentiment Score by Word in the Text’)+
theme(plot.subtitle = element_text(color = "gray", face = "italic")) +
theme_light()
# Return
return(list(freq_count, g))
}#close function
#Applying the function
text_freq_counter(text3)
# Resulting table
# A tibble: 16 × 4
word value n score
<chr> <dbl> <int> <dbl>
1 care 2 2 4
2 best 3 1 3
3 feeling 1 2 2
4 hopes 2 1 2
5 robust 2 1 2
6 save 2 1 2
7 true 2 1 2
8 cool 1 1 1
9 fitness 1 1 1
10 shared 1 1 1
11 cutting -1 2 -2
12 recession -2 1 -2
13 cut -1 3 -3
14 losing -3 1 -3
15 lost -3 1 -3
16 cuts -1 7 -7
上述结果表格已显示。这是结果图形。
针对关于科技行业裁员的新闻文本进行情感分析。图片由作者提供。
在我的 GitHub 上,还有另一个函数,你也可以选择要使用的情感包。结果如下所示,随后是代码链接。
# Enhanced Function
text_freq_sentiment(text3, .sentiment = 'nrc')
text_freq_sentiment(text3, .sentiment = 'bing')
使用“nrc”包捕获的单词频率情感。图片由作者提供。
使用“Bing”包捕获的单词频率情感。图片由作者提供。
你可以在这里查看完整代码: GitHub 代码链接。
你也可以使用这个代码创建的 Shiny 应用进行尝试。
[## [ 文本情感分析器 ]
这个项目 - 由 Gustavo R Santos 创建 - 用于快速创建文本分析,以计算单词频率并绘制图形…
gurezende.shinyapps.io](https://gurezende.shinyapps.io/Sentiment_Analysis/?source=post_page-----2cc4f04c35c1--------------------------------)
Shiny 应用:文本情感分析器。图片由作者提供。
在你离开之前
我喜欢研究自然语言处理和文本挖掘的数据科学工具。我们可以从文本中提取出很多信息,这非常有趣。
我建议你查看下方参考部分的链接,并找到资源以深化你的知识。我的书中也包含一些关于处理文本数据的有趣练习,包括文本挖掘。
如果你喜欢这些内容,别忘了关注我的博客。在 LinkedIn 上找到我。
阅读 Gustavo Santos 在 Medium 上的文章。数据科学家。我从数据中提取洞察,以帮助个人和公司…
gustavorsantos.medium.com](https://gustavorsantos.medium.com/?source=post_page-----2cc4f04c35c1--------------------------------)
参考文献
[1] Santos, G. 2023. 用 R 进行数据处理。第 1 版。Packt Publishing.
使用整洁数据原则是使数据处理更简单、更有效的强大方法,这一点也不例外…
Tidytext Mining R (编程语言) - 维基百科 [## R (编程语言) - 维基百科
R 是一种用于统计计算和图形绘制的编程语言,由 R Core Team 和 R Foundation 支持…
数据整理与 R 数据整理与 R: 加载、探索、转换和可视化数据以进行建模,使用 tidyverse 库 [## 数据整理与 R:加载、探索、转换和可视化数据以进行建模,使用 tidyverse 库
数据整理与 R:加载、探索、转换和可视化数据以进行建模,使用 tidyverse 库[Santos…
快速评估你的 RAG,无需手动标注测试数据
原文:
towardsdatascience.com/quickly-evaluate-your-rag-without-manually-labeling-test-data-43ade0ae187a
自动化评估你的检索增强生成应用程序的过程,无需人工干预
·发布于 Towards Data Science ·阅读时间 12 分钟·2023 年 12 月 21 日
--
用户生成的图像
今天的话题是如何在不手动标注测试数据的情况下评估你的 RAG。
测量你的 RAG 的性能是你应该关注的,尤其是当你正在构建这些系统并在生产环境中提供服务时。
除了让你大致了解应用的表现外,评估你的 RAG 还提供了量化反馈,这些反馈指导实验和参数的适当选择(如 LLMs、嵌入模型、分块大小、top K 等)。
评估你的 RAG 对你的客户或利益相关者也很重要,因为他们总是期望有性能指标来验证你的项目。
不再卖关子,这个问题涵盖了以下内容:
-
自动生成一个合成测试集,从你的 RAG 数据中
-
流行的 RAG 指标概述
-
使用 Ragas 包在合成数据集上计算 RAG 指标
PS: 这个问题的某些部分是稍微动手的。它们包括实现数据集生成和评估 RAG 所需的编码材料。
所有内容也会在这个* notebook中提供。
一起来看看 🔎
1 — 生成合成测试集 🧪
假设你刚刚构建了一个 RAG,现在想要评估其性能。
为此,你需要一个具有以下列的评估数据集:
-
question (str):用来评估 RAG 的问题
-
ground_truths (列表):问题的参考(即真实)答案
-
answer (str):RAG 预测的答案
-
contexts (列表):RAG 用于每个问题生成答案的相关上下文列表
→ 前两列表示真实数据,后两列表示 RAG 预测。
作者提供的截图
要构建这样的数据集,我们首先需要生成问题和相应答案的元组。
然后,在下一步中,我们需要对这些问题运行 RAG 以进行预测。
👉 生成问题和真实答案(理论)
为了生成 (问题, 答案) 的元组,我们首先需要准备 RAG 数据,将其拆分成块,并嵌入到向量数据库中。
一旦拆分完成并嵌入,我们将指示一个 LLM 从 N_c
个主题中生成 N_q
个问题,最终得到 N_q x N_c
个问题和答案的元组。
为了从给定上下文中生成问题和答案,我们需要经过以下步骤:
-
随机抽取一个拆分,并将其用作根上下文
-
从向量数据库中提取 K 个相似上下文
-
将根上下文的文本与其 K 个邻居连接起来,以构建更大的上下文
-
使用下列提示模板中的大
context
和num_questions
生成问题和答案
"""\
Your task is to formulate exactly {num_questions} questions from given context and provide the answer to each one.
End each question with a '?' character and then in a newline write the answer to that question using only
the context provided.
Separate each question/answer pair by "XXX"
Each question must start with "question:".
Each answer must start with "answer:".
The question must satisfy the rules given below:
1.The question should make sense to humans even when read without the given context.
2.The question should be fully answered from the given context.
3.The question should be framed from a part of context that contains important information. It can also be from tables,code,etc.
4.The answer to the question should not contain any links.
5.The question should be of moderate difficulty.
6.The question must be reasonable and must be understood and responded by humans.
7.Do no use phrases like 'provided context',etc in the question
8.Avoid framing question using word "and" that can be decomposed into more than one question.
9.The question should not contain more than 10 words, make of use of abbreviation wherever possible.
context: {context}
"""
然后,重复步骤 1 到 4 N_c
次,以每次变化上下文并生成不同的问题
我已经使用这个工作流来生成 Python 编程相关的问题和答案,以下是我得到的一些结果示例。
| | question | ground_truths |
|---:|:---------------------------------------------------|:---------------------------------------------------|
| 8 | What is the difference between lists and tuples in | ['Lists are mutable and cannot be used as |
| | Python? | dictionary keys, while tuples are immutable and |
| | | can be used as dictionary keys if all elements are |
| | | immutable.'] |
| 4 | What is the name of the Python variant optimized | ['MicroPython and CircuitPython'] |
| | for microcontrollers? | |
| 13 | What is the name of the programming language that | ['ABC programming language'] |
| | Python was designed to replace? | |
| 17 | How often do bugfix releases occur? | ['Bugfix releases occur about every 3 months.'] |
| 3 | What is the significance of Python's release | ['Python 2.0 was released in 2000, while Python |
| | history? | 3.0, a major revision with limited backward |
| | | compatibility, was released in 2008.'] |
👉 现在是编码部分 💻
我们首先从构建一个包含 RAG 使用数据的 vectorstore 开始。
我们可以从维基百科加载这些(但如果你有任何有趣的 Python 编程 PDF,你也可以加载)
from langchain.document_loaders import WikipediaLoader
topic = "python programming"
wikipedia_loader = WikipediaLoader(
query=topic,
load_max_docs=1,
doc_content_chars_max=100000,
)
docs = wikipedia_loader.load()
doc = docs[0]
加载数据后,我们将其拆分成块。
from langchain.text_splitter import RecursiveCharacterTextSplitter
CHUNK_SIZE = 512
CHUNK_OVERLAP = 128
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=[". "],
)
splits = splitter.split_documents([doc])
然后,我们在 Pinecone 中创建一个索引:
import pinecone
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
)
index_name = topic.replace(" ", "-")
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
)
if index_name in pinecone.list_indexes():
pinecone.delete_index(index_name)
pinecone.create_index(index_name, dimension=768)
并使用 LangChain 包装器来对拆分的嵌入进行索引。
from langchain.vectorstores import Pinecone
docsearch = Pinecone.from_documents(
splits,
embedding_model,
index_name=index_name,
)
现在进入有趣的部分:生成合成数据集。
为此,我们使用 LLM、文档拆分、嵌入模型和 Pinecone 索引名称初始化一个 TestsetGenerator
类的对象。
from langchain.embeddings import VertexAIEmbeddings
from langchain.llms import VertexAI
from testset_generator import TestsetGenerator
generator_llm = VertexAI(
location="europe-west3",
max_output_tokens=256,
max_retries=20,
)
embedding_model = VertexAIEmbeddings()
testset_generator = TestsetGenerator(
generator_llm=generator_llm,
documents=splits,
embedding_model=embedding_model,
index_name=index_name,
key="text",
)
然后,我们通过传递两个参数来调用 generate
方法:
synthetic_dataset = testset_generator.generate(
num_contexts=10,
num_questions_per_context=2,
)
这会生成以下数据框:
| | question | ground_truths |
|---:|:---------------------------------------------------|:---------------------------------------------------|
| 8 | What is the difference between lists and tuples in | ['Lists are mutable and cannot be used as |
| | Python? | dictionary keys, while tuples are immutable and |
| | | can be used as dictionary keys if all elements are |
| | | immutable.'] |
| 4 | What is the name of the Python variant optimized | ['MicroPython and CircuitPython'] |
| | for microcontrollers? | |
| 13 | What is the name of the programming language that | ['ABC programming language'] |
| | Python was designed to replace? | |
| 17 | How often do bugfix releases occur? | ['Bugfix releases occur about every 3 months.'] |
| 3 | What is the significance of Python's release | ['Python 2.0 was released in 2000, while Python |
| | history? | 3.0, a major revision with limited backward |
| | | compatibility, was released in 2008.'] |
很简单,对吧?
如果你对实现细节感兴趣,可以在 notebook** 中找到它们。**
这只是其中的一半。现在,我们需要使用 RAG 来预测每个问题的答案,并提供用于生成响应的上下文列表。
RAG 定义在 RAG 类中,所以我们首先需要初始化它:
from rag import RAG
rag = RAG(
index_name,
"text-bison",
embedding_model,
"text",
)
然后,我们通过对每个问题调用 predict
方法来迭代合成数据集,并收集预测结果。
rag_answers = []
contexts = []
for i, row in synthetic_dataset.iterrows():
question = row["question"]
prediction = rag.predict(question)
rag_answer = prediction["answer"]
rag_answers.append(rag_answer)
source_documents = prediction["source_documents"]
contexts.append([s.page_content for s in source_documents])
synthetic_dataset_rag = synthetic_dataset.copy()
synthetic_dataset_rag["answer"] = rag_answers
synthetic_dataset_rag["contexts"] = contexts
下面是最终结果的样子:
| | question | ground_truths | answer | contexts |
|---:|:----------------------------------------------------------------------------|:----------------------------|:-----------------------------------------------------------------------------------------------------------|:---------------------------------------------------|
| 7 | What are the two types of classes that Python supported before version 3.0? | ['old-style and new-style'] | Before version 3.0, Python had two kinds of classes (both using the same syntax): old-style and new-style. | ['. New instances of classes are constructed by |
| | | | | calling the class (for example, SpamClass() or |
| | | | | EggsClass()), and the classes are instances of the |
| | | | | metaclass type (itself an instance of itself), |
| | | | | allowing metaprogramming and reflection.\nBefore |
| | | | | version 3.0, Python had two kinds of classes (both |
| | | | | using the same syntax): old-style and new-style, |
| | | | | current Python versions only support the semantics |
| | | | | new style.\nPython supports optio .......... |
🏆 恭喜你完成到这一步,现在你准备好评估你的 RAG 了。
2 — 流行的 RAG 指标 📊
在跳入代码之前,让我们介绍一下用于评估 RAG 的 四 个基本指标。
每个指标考察了不同的方面。因此,在评估你的应用时,考虑多个指标以获得全面的视角至关重要。
1 — 答案相关性:
答案相关性指标旨在评估生成答案与提供的提示的相关性。缺乏完整性或包含冗余信息的答案会获得较低的分数。该指标利用问题和答案,生成 0 到 1 之间的值。得分越高,相关性越好。
示例
❓问题:健康饮食的关键特征是什么?
⬇️ 低相关性答案:健康饮食对整体健康非常重要。
⬆️ 高相关性答案:健康饮食应包括多种水果、蔬菜、全谷物、瘦蛋白和乳制品,提供必需的营养素以促进最佳健康。
2 — 准确性
该指标评估生成答案在提供的上下文中的事实一致性。计算涉及答案和检索到的上下文,答案的评分范围在 0 到 1 之间,得分越高表示一致性越好。
为了使答案被认为是准确的,回答中提出的所有主张必须可以从给定的上下文中推断出来。
示例:
❓问题:玛丽·居里的主要成就是什么?
📑 上下文:玛丽·居里(1867–1934)是开创性的物理学家和化学家,第一位获得诺贝尔奖的女性,也是唯一一位在两个不同领域获得诺贝尔奖的女性。
⬆️ 高准确性答案:玛丽·居里获得了物理学和化学两项诺贝尔奖,使她成为首位实现这一成就的女性。
⬇️ 低准确性答案:玛丽·居里仅获得了物理学方面的诺贝尔奖。
3 — 上下文精确度
上下文精确度是一个指标,评估上下文
中所有真实相关条目是否都被排在更高的位置。理想情况下,所有相关片段应出现在排名靠前的位置。该指标使用问题
和上下文
进行计算,得分范围在 0 到 1 之间,得分越高表示精确度越好。
4 — 答案准确性
该指标衡量生成答案与真实答案之间的准确性。这项评估利用真实答案和生成答案,并在 0 到 1 的范围内进行评分。得分越高,表明生成的答案与真实答案之间的对齐越准确,表示更高的正确性。
示例:
🟢 真实答案:埃菲尔铁塔于 1889 年在法国巴黎竣工。
⬆️ 高准确性答案:埃菲尔铁塔的建设于 1889 年在法国巴黎完成。
⬇️ 低准确性答案:埃菲尔铁塔于 1889 年竣工,位于英国伦敦。
3 — 使用 RAGAS 评估 RAG 📏
为了评估 RAG 并计算四个指标,我们可以使用Ragas。
Ragas(用于 Rag 评估)是一个帮助你评估检索增强生成(RAG)管道的框架。
要在我们的数据集上运行 Ragas,你首先需要导入指标并将合成数据的数据框转换为数据集对象。
from datasets import Dataset
from ragas.llms import LangchainLLM
from ragas.metrics import (
answer_correctness,
answer_relevancy,
answer_similarity,
context_precision,
context_recall,
context_relevancy,
faithfulness,
)
synthetic_ds_rag = Dataset.from_pandas(synthetic_dataset_rag)
然后,我们需要配置 Ragas 以使用 VertexAI LLMs 和嵌入。
这一步很重要,因为 Ragas 默认配置为使用 OpenAI。
metrics = [
answer_relevancy,
context_precision,
faithfulness,
answer_correctness,
answer_similarity,
]
for m in metrics:
m.__setattr__("llm", ragas_vertexai_llm)
if hasattr(m, "embeddings"):
m.__setattr__("embeddings", vertexai_embeddings)
answer_correctness.faithfulness = faithfulness
answer_correctness.answer_similarity = answer_similarity
最后,我们在合成数据集上调用evaluate
函数,并指定我们想要计算的指标:
from ragas import evaluate
results_rag = evaluate(
synthetic_ds_rag,
metrics=[
answer_relevancy,
context_precision,
faithfulness,
answer_correctness,
],
)
评估完成后,你可以直接打印结果。
{
'answer_correctness': 0.86875,
'answer_relevancy': 0.9709101875947284,
'context_precision': 0.8541666666143055,
'faithfulness': 0.9375
}
或者你可以将其转换为数据框,以检查每个问题的这些指标。
| | question | contexts | answer | ground_truths | answer_relevancy | context_precision | faithfulness | answer_correctness |
|---:|:----------------------------------------------------------------------------------|:-----------|:--------------------------------------------------|:--------------------------------------------------|-------------------:|--------------------:|---------------:|---------------------:|
| 1 | What is the difference between lists and tuples in Python? | ... | Lists are mutable, while tuples are immutable. | ['Lists and tuples are both ordered sequences of | 0.987162 | 1 | 1 | 0.75 |
| | | | This means that the elements of a list can be | elements in Python. However, lists are mutable, | | | | |
| | | | changed, while the elements of a tuple cannot. | meaning their elements can be changed, while | | | | |
| | | | Additionally, tuples can be used as keys in | tuples are immutable, meaning their elements | | | | |
| | | | dictionaries, while lists cannot. | cannot be changed.'] | | | | |
| 5 | What is the name of the Python runtime that uses just-in-time compilation? | ... | The name of the Python runtime that uses just-in- | ['Pyston'] | 1 | 1 | 1 | 1 |
| | | | time compilation is Pyston. | | | | | |
| 13 | What is the name of the programming language that Python was designed to replace? | ... | The programming language that Python was designed | ['ABC programming language'] | 0.982582 | 0.416667 | 1 | 0.5 |
| | | | to replace is called ABC. | | | | | |
| 2 | What is Python's approach to type checking? | ... | Python uses a combination of dynamic typing and | ["Python follows dynamic typing, where type | 0.890692 | 0.916667 | 1 | 0.666667 |
| | | | duck typing, with optional static type checking | constraints are not checked at compile time but | | | | |
| | | | available through the use of type annotations and | may result in operational failures if an object's | | | | |
| | | | the mypy type checker. | type is unsuitable."] | | | | |
| 12 | Which programming language has been the most popular since 2003? | ... | According to the TIOBE Programming Community | ['Python'] | 0.899585 | 0.5 | 1 | 0.75 |
| | | | Index, Python has consistently ranked in the top | | | | | |
| | | | ten most popular programming languages since 2003 | | | | | |
| | | | and as of December 2022, it was the most popular | | | | | |
| | | | language. | | | | | |
结论
生成一个合成数据集以评估你的 RAG 是一个好的开始,特别是当你没有访问标注数据时。
然而,这个解决方案也有其问题。
一些生成的答案:
-
可能缺乏多样性
-
是冗余的
-
只是对原始文本的简单改述,需要更多的复杂性来反映需要推理的实际问题
-
可能过于笼统(尤其是在非常技术性的领域)
为了应对这些问题,你可以调整和调整你的提示,过滤无关的问题,创建特定主题的合成问题,并使用 Ragas 进行数据集生成。
R 数据分析:如何为您的孩子找到完美的 Cocomelon 视频
如何使用 R 从头开始构建一个端到端的数据项目,探索新的流行 Cocomelon 视频
·发布于 Towards Data Science ·9 分钟阅读·2023 年 3 月 4 日
--
由 Tony Sebastian 提供的照片,来源于 Unsplash
Cocomelon — Nursery Rhymes 是全球第二大 YouTube 频道(155M+ 订阅者)。这是一个如此受欢迎和有用的频道,对于幼儿和父母来说都是不可或缺的。我喜欢和我的儿子一起观看 Cocomelon。
在观看 Cocomelon 视频一个月后,我注意到 YouTube 上重复推荐相同的视频。像 “The wheel on the bus” 和 “bath song” 这样的热门视频虽然有趣,但它们已经发布多年,孩子们看了会感到厌倦。作为父亲,我希望展示一些较新的高质量 Cocomelon 视频。作为数据专业人士,我也希望深入探索全球第二大 YouTube 频道的数据,以获得更多见解并发现有趣的数据。
YouTube 频道中的所有视频只提供用户两个选项:最近上传(按时间排序)和热门(按观看次数排序)。我可以去最近上传的标签页,逐个点击。然而,Cocomelon 频道有 800 多个视频,这会很耗时间。
好消息是,我是一名工程师,知道如何利用数据构建某些东西。因此,我开始编写代码,收集数据,进行清理、可视化,并获得更多见解。我将分享我使用 R 进行数据分析的历程:从头开始构建一个端到端的解决方案,用于探索流行的 Cocomelon 视频。
注意:虽然我在 R 中编写的示例代码和 Youtube 频道是针对 Cocomelon 的,但它们是我的偏好。你也可以使用 Python 或 Rust 的数据分析工具进行编写,我将展示如何从 Youtube 获取数据适用于其他频道。
如何使用 R 获取 Youtube 数据
数据源总是任何数据项目的起点。我已经进行了几次尝试来达到最终解决方案。
我首先在 Google 上搜索了术语:“Cocomelon 的 Youtube 观看统计”,它显示了一些关于频道的统计数据,但没有覆盖每个视频的更详细数据。这些网站广告泛滥,网络爬虫可能会很困难。
然后我查看了 Kaggle 上的公共数据集,像 CC0 数据集中的 Trending YouTube Video Statistics 可能是一个不错的选择。然而,在探索数据集后,我发现了两个问题:
-
数据集中不包含 Cocomelon
-
内容是几年前获取的,需要我想要搜索的更新视频。
我唯一的选择是直接从 Youtube 拉取最新数据。这里还有两个选项:
-
网络爬虫:我可以设置一个爬虫或在 GitHub 上找到一个项目直接使用。我的担忧是,如果爬虫过于激进,可能会封锁我的 Youtube 账户。而且爬虫对于从众多视频中拉取数据并不是很高效。
-
Youtube API: 我最终找到了这个解决方案。它高效且提供一些基本的视频统计信息:观看次数和点赞数。我们可以进一步利用这些信息来构建我们的数据分析项目。
将 Youtube 数据加载到 R 数据框
获取 Youtube API 密钥以拉取数据
Youtube API 允许你从 Youtube 拉取数据。你首先需要访问 console.cloud.google.com/apis
,然后使用 API 密钥“创建凭据”。默认密钥没有限制;你可以将 API 密钥仅限于 Youtube 使用。
Google Cloud 创建凭据 | 作者图片
使用 R 获取 Youtube 频道播放列表
一旦你有了 API 密钥,请参考 Youtube 数据 API 获取更多关于支持的潜在数据的参考。为了在查询阶段检查 API,我们可以使用 Postman 等工具或直接复制完整 URL。
例如,我们想拉取 Cocomelon 的频道信息;然而,我通过检查其 URL 没有找到其频道 id,但通过一些谷歌搜索找到了它。
https://www.youtube.com/channel/UCbCmjCuTUZos6Inko4u57UQ
现在我们可以使用频道 id 来构建 GET 方法,并将 API 密钥填入密钥字段:
https://www.googleapis.com/youtube/v3/channels?part=snippet,contentDetails,statistics&id=UCbCmjCuTUZos6Inko4u57UQ&key=
从返回的 JSON 中,最关键的信息是播放列表信息,它进一步告诉我们所有视频的情况。
"contentDetails": {
"relatedPlaylists": {
"likes": "",
"uploads": "UUbCmjCuTUZos6Inko4u57UQ"
}
}
由于新采用了分页,每页最多 50 项,调用 playlistItems
将需要时间才能达到最终列表。我们需要使用当前的令牌来检索下一页,直到找不到下一页为止。我们可以在 R 中将所有内容整合在一起。
library(shiny)
library(vroom)
library(dplyr)
library(tidyverse)
library(httr)
library(jsonlite)
library(ggplot2)
library(ggthemes)
library(stringr)
key <- "to_be_replace"
playlist_url <-
paste0(
"https://www.googleapis.com/youtube/v3/playlistItems?part=snippet,contentDetails,status&maxResults=50&playlistId=UUbCmjCuTUZos6Inko4u57UQ&key=",
key
)
api_result <- GET(playlist_url)
json_result <- content(api_result, "text", encoding = "UTF-8")
videos.json <- fromJSON(json_result)
videos.json$nextPageToken
videos.json$totalResults
pages <- list(videos.json$items)
counter <- 0
while (!is.null(videos.json$nextPageToken)) {
next_url <-
paste0(playlist_url, "&pageToken=", videos.json$nextPageToken)
api_result <- GET(next_url)
print(next_url)
message("Retrieving page ", counter)
json_result <- content(api_result, "text", encoding = "UTF-8")
videos.json <- fromJSON(json_result)
counter <- counter + 1
pages[[counter]] <- videos.json$items
}
## Combine all the dataframe into one
all_videos <- rbind_pages(pages)
## Get a list of video
videos <- all_videos$contentDetails$videoId
all_videos
应该会给我们所有视频的字段。我们在这个阶段只关心 videoId,这样我们才能获取每个视频的详细信息。
迭代视频列表并获取每个视频的数据
一旦所有视频都存储在一个向量中,我们可以复制类似于播放列表的处理过程。这次会更容易,因为我们不需要处理分页。
在这个阶段,我们会更关注最终从视频 API 调用中提取的数据。我选择了那些用于后续数据分析和可视化的。为了节省再次提取数据的时间,最好将数据持久化到 CSV 文件中,这样我们就不必多次运行 API 调用了。
videos_df = data.frame()
video_url <-
paste0(
"https://www.googleapis.com/youtube/v3/videos?part=contentDetails,id,liveStreamingDetails,localizations,player,recordingDetails,snippet,statistics,status,topicDetails&key=",
key
)
for (v in videos) {
a_video_url <- paste0(video_url, "&id=", v)
print(v)
print(a_video_url)
api_result <- GET(a_video_url)
json_result <- content(api_result, "text", encoding = "UTF-8")
videos.json <- fromJSON(json_result, flatten = TRUE)
# colnames(videos.json$items)
video_row <- videos.json$items %>%
select(
snippet.title,
snippet.publishedAt,
snippet.channelTitle,
snippet.thumbnails.default.url,
player.embedHtml,
contentDetails.duration,
statistics.viewCount,
statistics.commentCount,
statistics.likeCount,
statistics.favoriteCount,
snippet.tags
)
videos_df <- rbind(videos_df, video_row)
}
write.csv(videos_df, "~/cocomelon.csv", row.names=TRUE)
在 R 中探索 Cocomelon YouTube 视频数据
数据已为我们下一阶段探索 Cocomelon YouTube 视频做好准备。现在是进行一些清理并创建可视化以展示发现的结果的时候了。
默认的对象数据类型在后续排序中效果不佳,因此我们需要将一些对象转换为浮点数或日期类型。
videos_df <- videos_df %>% transform(
statistics.viewCount = as.numeric(statistics.viewCount),
statistics.likeCount = as.numeric(statistics.likeCount),
statistics.favoriteCount = as.numeric(statistics.favoriteCount),
snippet.publishedAt = as.Date(snippet.publishedAt)
)
最受欢迎的 5 个 Cocomelon 视频是什么?
这部分很简单。我们需要选择感兴趣的字段,然后按字段 viewCount
降序排序视频。
videos_df %>%
select(snippet.title, statistics.viewCount) %>%
arrange(desc(statistics.viewCount)) %>% head(5)
# Output:
# snippet.title statistics.viewCount
#1 Bath Song | CoComelon Nursery Rhymes & Kids Songs 6053444903
#2 Wheels on the Bus | CoComelon Nursery Rhymes & Kids Songs 4989894294
#3 Baa Baa Black Sheep | CoComelon Nursery Rhymes & Kids Songs 3532531580
#4 Yes Yes Vegetables Song | CoComelon Nursery Rhymes & Kids Songs 2906268556
#5 Yes Yes Playground Song | CoComelon Nursery Rhymes & Kids Songs 2820997030
对于你之前观看过 Cocomelon 视频的人来说,看到“Bath Song”、“Wheels on the Bus”和“Baa Baa Black Sheep”排名前三并不意外。这与 Cocomelon 在 YouTube 上的 popular
标签相匹配。此外,“Bath Song”的播放次数比第二名“Wheels on the Bus”多 20% 以上。我可以看出许多幼儿在洗澡时遇到困难,让孩子们观看这个视频可以让他们知道如何洗澡,并安慰他们让他们平静下来。
我们还创建了一个包含前 5 个视频的条形图:
ggplot(data = chart_df, mapping = aes(x = reorder(snippet.title, statistics.viewCount), y = statistics.viewCount)) +
geom_bar(stat = "identity",fill="lightgreen") +
scale_x_discrete(labels = function(x) str_wrap(x, width = 16)) +
theme_minimal()
最受欢迎的 5 个 Cocomelon 视频 | 图片作者
观看次数和点赞数之间的相关性是什么?
观看次数和点赞数之间是否存在相关性:视频是否更有可能因观看次数多而获得点赞?
我们可以进一步用数据证明这一点。首先,标准化 viewCount
和 likeCount
以便更好地进行可视化。其次,我们还计算了自视频上传以来的天数,以获取流行视频的创建时间。
chart_df <- videos_df %>%
mutate(
views = statistics.viewCount / 1000000,
likes = statistics.likeCount / 10000,
number_days_since_publish = as.numeric(Sys.Date() - snippet.publishedAt)
)
ggplot(data = chart_df, mapping = aes(x = views, y = likes)) +
geom_point() +
geom_smooth(method = lm) +
theme_minimal()
cor(chart_df$views, chart_df$likes, method = "pearson")
## 0.9867712
Cocomelon 视频观看次数和点赞数的相关性 | 图片作者
相关系数为 0.98,非常高的相关性:视频的观看次数越多,获得点赞的可能性越大。令人着迷的是,只有六个视频的观看次数超过 20 亿:家长和孩子们喜欢这六个视频,并且可能会观看很多次。
我们可以进一步绘制热门视频,并发现最热门的视频,年龄在 1500–2000 天之间,显示这些视频大约在 2018 或 2019 年制作。
按观看次数计算的发布天数 | 作者提供的图片
如何检查新的热门 Cocomelon 视频?
热门视频很容易获取。然而,4、5 年前制作的热门视频由于大量的每日视频仍然可能保持热门。
怎么样找到新的 Cocomelon 视频的观看次数?由于我们只能从 Youtube API 拉取当前状态下的观看次数,我们需要在几天之间从 API 拉取数据,暂时存储数据。
f1 <- read_csv("~/cocomelon_2023_2_28.csv")
df2 <- read_csv("~/cocomelon_2023_3_2.csv")
df1<- df1 %>% transform(
statistics.viewCount = as.numeric(statistics.viewCount)
)
df2<- df2 %>% transform(
statistics.viewCount = as.numeric(statistics.viewCount),
snippet.publishedAt = as.Date(snippet.publishedAt)
)
df1 <- df1 %>% select(snippet.title,
statistics.viewCount)
df2 <- df2 %>% select(snippet.title,
snippet.publishedAt,
statistics.viewCount)
# Join data by snippet.title
joined_df <- inner_join(df1, df2, by = 'snippet.title')
joined_df <- joined_df %>%
mutate(
view_delta = statistics.viewCount.y - statistics.viewCount.x,
number_days_since_publish = as.numeric(Sys.Date() - snippet.publishedAt)
)
# Recent Video uploaded within 200 days and top 5 of them by view delta
chart_df <- joined_df %>%
filter(number_days_since_publish<=200) %>%
select(snippet.title, view_delta) %>%
arrange(desc(view_delta)) %>% head(5)
ggplot(data = chart_df,
mapping = aes(
x = reorder(snippet.title, view_delta),
y = view_delta
)) +
geom_bar(stat = "identity", fill = "lightblue") +
scale_x_discrete(
labels = function(x)
str_wrap(x, width = 16)
) +
theme_minimal()
# Output
# snippet.title view_delta
#1 🔴 CoComelon Songs Live 24/7 - Bath Song + More Nursery Rhymes & Kids Songs 2074257
#2 Yes Yes Fruits Song | CoComelon Nursery Rhymes & Kids Songs 1709434
#3 Airplane Song | CoComelon Nursery Rhymes & Kids Songs 977383
#4 Bingo's Bath Song | CoComelon Nursery Rhymes & Kids Songs 951159
#5 Fire Truck Song - Trucks For Kids | CoComelon Nursery Rhymes & Kids Songs 703467
新的热门 Cocomelon 视频 | 作者提供的图片
顶级热门视频是 🔴 CoComelon Songs Live 24/7。这个视频展示了家长可以让孩子们自动轮播视频而无需明确切换视频。其他视频也展示了潜在的好单曲,值得推荐。
最后的想法
在 Youtube 上有很多适合孩子观看的视频。Cocomelon 有许多视频,我希望在孩子每天允许的观看时间内展示好的视频。寻找这些热门视频对数据专业人士来说是一次迷人的探索。
希望我的帖子对你有帮助。接下来,我将继续我的 R 之旅,并使用 Shiny 构建一个与用户互动的应用程序。
希望这个故事对你有帮助。本文是我工程与数据科学故事系列的一部分,目前包括以下内容:
数据工程与数据科学故事
查看列表 53 个故事!
你也可以 订阅我的新文章 或成为 推荐的 Medium 会员,享受对 Medium 上所有故事的无限访问。
如果有任何问题/评论,请随时在此故事的评论中留言或通过 Linkedin 或 Twitter 直接联系我。
R 工具包用于人力分析:讲述你的员工人数故事
原文:
towardsdatascience.com/r-toolkit-for-people-analytics-telling-your-headcount-story-d872402d4e8b
使用 R 解决的人力分析中的常见挑战
·发布于Towards Data Science ·阅读时间 11 分钟·2023 年 7 月 6 日
--
在人力分析工作中,你经常需要讲述公司员工人数的变化及公司如何演变成今天的样子。我经常看到这被展示为瀑布图,这很好,但在分享逐年变化时,特别是对不太懂技术的观众,可能会变得模糊。
为了满足这个需求,我创建了逐年的迭代图,突出显示每年的一些额外背景信息。这些图表可以添加到 PowerPoint 中逐年展示,或者可以动画化为一个 gif。让我们一起制作吧!
用区域图的 gif 讲述员工人数的变化。图片由作者提供。
挑战:讲述我们的员工人数如何逐年变化,最终达到今天的状态。
步骤:
1. 加载必要的包和数据
2. 计算每月员工人数
3. 为每一年增加相关的背景信息
4. 创建图表
5. 设置自动为每年创建图表
6. 调整主题和图表格式
1. 加载必要的包和数据
对于这个挑战,我们将需要以下包:
-
tidyverse
-
hrbrthemes(用于美化我们的图表)
要创建我们的视觉效果,我们需要一个包含唯一标识符(即员工 ID)、入职日期和离职日期的文件。我将使用模拟数据来进行这个例子(在底部我包含了生成模拟数据的代码,如果你想逐步跟随)。
# load packages
library(tidyverse)
library(hrbrthemes)
# load data
employee_data <- mock_data
# alternatively you could use something like employee_data <- read.csv("input.csv")
顺便提一下,我通常会给我最初读取的数据分配一个变量,然后创建一个新变量用于后续的操作。这并不总是必要的,但在处理大数据集时可以加快速度,这样你就不需要每次修改代码时都重新加载数据。
原始输入数据的概览。图片来源于作者。
为了确保计算正确,我们需要确保 R 知道入职日期和终止日期实际上是日期。一般来说,在 R 中处理日期可能会很麻烦,但为了这个挑战,我们需要将日期列格式化为日期,并确保没有 NA。
df <- employee_data %>%
mutate(Hire.Date = as.Date(Hire.Date, format = "%m/%d/%Y"),
Termination.Date = as.Date(Termination.Date, format = "%m/%d/%Y"))
在我的输入文件中,仍在职的员工终止日期为空,因为他们尚未离职。如果日期列中有空值,R 会变得很挑剔,所以我们要添加一行代码,给这些空值赋一个很远的未来日期。
df <- employee_data %>%
mutate(Hire.Date = as.Date(Hire.Date, format = "%m/%d/%Y"),
Termination.Date = as.Date(Termination.Date, format = "%m/%d/%Y")) %>%
mutate(Termination.Date = if_else(is.na(Termination.Date),
as.Date("2100-12-31"), Termination.Date))
这一行代码表示在终止日期列中只要有 NA/空白,就赋一个很远的未来日期。在这种情况下,选择 2100 年 12 月 31 日。希望到那时我还不在工作。
2. 计算每月人员数量
希望这一步看起来很简单,但我在弄明白这个过程时遇到了不少困难,所以请对自己有耐心。
首先,我们将创建一个包含每个月日期的序列,然后设置一个数据框作为我们每月人员数量的占位符,最后我们将使用sapply
函数计算每个月的人员数量。开始吧!
为每个月创建一个日期序列(例如,2023 年 1 月 1 日、2023 年 2 月 1 日等):
month_seq <- seq(from = min(df$hire_date),
to = max(df$hire_date),
by = "1 month")
这表示从最早的入职日期开始,到最晚的入职日期,按月份生成序列。这给我们每个月的数据留下一个值。它的样子是这样的:
显示每月序列。图片来源于作者。
现在我们要利用这个序列来创建一个起始数据框,然后我们可以在其中添加人员数量。
headcount_data <- data.frame(Date = month_seq)
好的,现在进入棘手的部分。我们将计算headcount_data
数据框中每个日期的在职员工人数。也就是说,计算 2018 年 1 月 1 日、2018 年 2 月 2 日等日期的在职员工数量。
假设我们要计算 2018 年 1 月 1 日的情况。我们需要找出入职日期早于或等于 2018 年 1 月 1 日且终止日期晚于 2018 年 1 月 1 日的员工数量。换句话说,就是已经被雇佣但尚未离职的员工数量。
然后我们只需使用sapply
对headcount_data
中的每个日期进行操作。
headcount_data <- headcount_data %>%
mutate(Active.Employees = sapply(Date, function(x) {
sum(x >= df$hire_date & (is.na(df$termination_date) | x < df$termination_date))
}))
还跟得上吗?如果你已经把所有内容都搞定了,给自己一个大大的鼓励吧!如果你遇到问题,也要为自己走到这一步而感到骄傲,并查看完整代码在这里,看看是否能发现代码中的任何不一致之处。
3. 添加相关背景
这是讲故事部分的开始。根据你对组织的了解程度,你可能需要采访一些主题专家或资深员工。基本上,你希望添加有助于解释 headcount 增减的背景信息。
我想为每一年添加背景(你也可以按月添加),所以我将向 headcount_data 添加一个年份列。
headcount_data <- headcount_data %>%
mutate(year = as.integer(year(Date))
这将为每个日期添加一个年份列:
为每个日期添加了年份列。图片来源:作者。
现在,我们要为每一年添加背景。假设对于 2020 年,我们想添加背景“COVID-19”,并希望在 2020 年的每个月都显示出来。
为此,我们将使用 case_when 来添加一个基于年份的“context”列。
headcount_data <- headcount_data %>%
mutate(context = case_when(
year == 2018 ~ "Context for 2018",
year == 2019 ~ "Context for 2019",
year == 2020 ~ "COVID-19",
TRUE ~ "No additional context"
))
上面的例子中,我们是说对于每一行年份为 2018 的数据,我们希望背景列为“2018 的背景”。你可以为每一年感兴趣的年份添加背景,然后在 TRUE 条件下,可以指定对未在上面指定的年份的背景是什么。
到这个时候,你的 headcount_data 应该看起来像这样:
添加了背景列的数据集。图片来源:作者。
现在进入有趣的部分!我们可以开始绘图了。
4. 创建图表
首先,我们将使用 ggplot 创建一个包含所有数据的基本面积图。我们将把 Date 放在 x 轴上,把 Active.Employees 放在 y 轴上,这样我们就可以看到 headcount 随时间的变化。
headcount_data %>%
ggplot(aes(x = Date, y = Active.Employees)) +
geom_area()
这将给你这个基本图表:
完整数据集的基本面积图。图片来源:作者。
现在我们开始进行一些基本的 zhushing,然后再进行一些更高级的 zhushing:
1. 添加注释
2. 添加标题和副标题
我们将添加包含最终 headcount 和年份的注释(当我们为每一年制作图表时,这会变得更相关)。让我们先将它们分配给变量,以便于每年更新:
# annotations
annotation_ending_year <- max(headcount_data$year)
annotation_ending_headcount <- max(headcount_data$Active.Employees)
# titles
labels_title <- "Our Headcount Story"
labels_subtitle <- last(headcount_data$context)
现在我们将把这些添加到我们的基本图表中:
headcount_data %>%
ggplot(aes(x = Date, y = Active.Employees)) +
geom_area() +
labs(title = labels_title,
subtitle = labels_subtitle) +
annotate("text",
x = max(headcount_data$Date),
y = max(headcount_data$Active.Employees),
label = annotation_ending_headcount,
hjust = -.25)
这将给我们一个基本的图表,并附带一些额外的背景信息:
带有标题和注释的基本图表。图片来源:作者。
既然我们创建了基本图表,我们希望自动为每一年创建一个附加图表。所以将会有一个从 2018 年开始到 2018 年底的图表,一个从 2018 年到 2019 年底的图表,一个从 2018 年到 2020 年底的图表,等等。
5. 自动为每年创建一个图表
我们将使用 for 循环来为数据集中的每一年创建一个图表。
基本上,我们将把数据集中每个独特的年份放到一个名为“years”的向量中。然后对于“years”中的每一年,我们将创建一个数据子集,然后绘制该子集的图表。这听起来可能很混乱,但查看代码可能会更清楚。
首先进行一些设置:
# create a vector for unique years
years <- unique(headcount_data$year)
# empty list for plots to go to
plots <- list()
现在进入循环!这可能看起来很复杂,但一步一步来就好:
# loop over the each year in years and create plots
for (i in 2:length(years)) {
# create subset adding one year at a time
subset_df <- headcount_data %>%
filter(year <= years[i])
# calculations for annotation
annotation_ending_year <- max(subset_df$Date)
annotation_ending_active <- subset_df %>%
filter(Date == ending_year) %>%
select(Active.Employees) %>%
as.numeric()
# create a plot (p) using the subset
p <- subset_df %>%
ggplot(aes(x = Date, y = Active.Employees)) +
geom_area() +
labs(title = labels_title,
subtitle = labels_subtitle) +
annotate("text",
x = max(subset_df$Date),
y = max(subset_df$Active.Employees),
label = ending_active,
hjust = -.25)
# save each plot
ggsave(p,
file = paste("example_plot_", years[i], ".png"),
height = 6, width = 8, units = "in")
}
你现在应该在你的工作目录中有一个名为“example_plot_year”的每年图表。我喜欢每年一个单独的图表,这样我可以把每一个放在幻灯片中,并在大家有问题时暂停。或者,你可以将图表动画化并创建一个 gif,或者使用像ScreenToGif这样的屏幕录制工具,得到这样的效果:
使用 ScreenToGif 合成的图表动画 gif。图片由作者提供。
我们做到了!!!!剩下的就是添加一些样式,以使图表更符合你的品牌,并添加一个矩形来突出显示最近的一年。
6. 调整主题和图表格式
我想做的第一件事是添加一个矩形来突出显示最近的一年。这将帮助观众知道要关注的重点,并且在每个图表中都会更新,这样我们可以在更大的背景下逐年查看。
我们将通过添加另一个“rect”注解层来完成,这看起来会是这样的:
annotate("rect", xmin = , xmax = , ymin = , ymax = )
这是另一个花了我一段时间才调整到我想要的方式,但关键点在于:
X 轴:我希望矩形从给定年份内的第一个(即底部)日期(即我们数据子集中的最大年份)开始,并在给定年份内的最后一个(即顶部)日期(即我们数据子集中的最大年份)结束。因此,对于 2019 年的图表,我们希望矩形从 2019 年 1 月 1 日开始,到 2019 年 12 月 1 日结束。
annotate("rect",
xmin = floor_date(max(subset_df$Date), "year"),
xmax = ceiling_date(max(subset_df$Date), "year")
Y 轴:我希望矩形从 y 轴开始,到该年最终人数以上的位置结束,这样更容易阅读而不会显得拥挤。再看一下 2019 年,我希望矩形从 y 轴开始,并在最终人数 240 之上(+300)的位置结束。
annotate("rect",
xmin = floor_date(max(subset_df$Date), "year"),
xmax = ceiling_date(max(subset_df$Date), "year"),
ymin = -Inf, ymax = ending_active + 300)
样式:最后,我会把框设置为灰色,并将透明度改为 0.1,使其相当透明,你可以看到下面的区域图:
annotate("rect",
xmin = floor_date(max(subset_df$Date), "year"),
xmax = ceiling_date(max(subset_df$Date), "year"),
ymin = -Inf, ymax = ending_active + 300,
alpha = .1, color = "gray", fill = "gray")
限制坐标轴:为了使过渡更平滑,我将对 x 轴和 y 轴设置限制,使每个图表的比例相同。
scale_x_date(breaks = "1 year", date_labels = "%Y",
expand = c(.1,.1),
limits = c(min(headcount_data$Date), max(headcount_data$Date)))
太棒了!我们快完成了,现在我要对主题进行一些更改,然后给自己倒一杯酒。现在是时候发挥你自己的创意了,我的最终效果是这样的:
最终产品!图片由作者提供。
这是我最终的 for 循环代码:
# loop over the each year in years and create plots
for (i in 2:length(years)) {
# create subset adding one year at a time
subset_df <- headcount_data %>%
filter(year <= years[i])
# calculations for annotation
ending_year <- max(subset_df$Date)
ending_active <- subset_df %>%
filter(Date == ending_year) %>%
select(Active.Employees) %>%
as.numeric()
# create a plot (p) using the subset
p <- subset_df %>%
ggplot(aes(x = Date, y = Active.Employees)) +
geom_area(fill = "#457b9d") +
labs(title = "Our Headcount Story",
subtitle = paste(years[i],":", last(subset_df$context)),
x = "", y = "") +
scale_x_date(breaks = "1 year", date_labels = "%Y",
expand = c(.1,.1),
limits = c(min(headcount_data$Date), max(headcount_data$Date))) +
theme_classic(base_family = "Arial") +
theme(plot.title = element_text(size = 24, face = "bold", color = "#457b9d"),
plot.subtitle = element_text(size = 18),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
axis.ticks.y = element_blank(),
axis.text.y = element_blank(),
axis.line.y = element_blank()) +
annotate("text", x = ending_year,
y = ending_active, label = ending_active,
vjust = -1.25, hjust = -.25, color = "#457b9d") +
annotate("rect",
xmin = floor_date(max(subset_df$Date), "year"),
xmax = ceiling_date(max(subset_df$Date), "year"),
ymin = -Inf, ymax = ending_active + 300,
alpha = .1, color = "gray", fill = "gray")
# save each plot
ggsave(p,
file = paste("example_plot_final", years[i], ".png"),
height = 6, width = 8, units = "in")
}
全部完成!
我们现在有了一个动态视图,展示了我们的员工人数如何随着时间的推移而变化,并在副标题中提供了额外的背景。一些未来迭代的想法:使用 gganimate 制作图表,为每年的员工人数添加百分比变化,如果员工人数增加或减少则改变图表颜色,添加增长趋势线的预测,可能性无穷无尽!
你尝试制作了吗?如果是的话,我很想看看你做了什么!
想要更多通用的人员分析资源吗?
我推荐给那些想要开始人员分析职业生涯的人的免费资源。
如果你想要更多这样的资源,并且访问网站上的所有优质内容,可以使用我的链接以每月$5 注册(我将获得少量佣金,但你无需额外支付费用)。
[## 使用我的推荐链接加入 Medium - Jenna Eagleson
阅读 Jenna Eagleson 的每一个故事(以及 Medium 上成千上万其他作者的故事)。你的会员费直接支持…
我的背景是工业组织心理学,我在人员分析领域找到了归属。数据可视化让我的工作生动起来。我喜欢使用 Power BI、R、Tableau 以及我遇到的其他工具进行学习和开发。我很想了解你的经历!可以通过Linkedin或Twitter与我联系。
径向树图:将树图扩展到圆形映射
原文:
towardsdatascience.com/radial-treemaps-extending-treemaps-to-circular-mappings-7b47785191da
了解径向树图并用 Python 创建自己的树图
·发表于 Towards Data Science ·阅读时间 16 分钟·2023 年 12 月 10 日
--
径向树图,作者 Nick Gerend
背景
树图概念
“树图”由 Ben Shneiderman 在 1990 年代初期于马里兰大学提出¹。简而言之,它是一种将层次数据以嵌套矩形的形式高效展示的方法。尽管这一概念很简单,但矩形的排列受美学偏好的影响,因此已经开发出各种排列算法来改善最终布局的外观。
树图原理
给定一个层次结构,树图将层次结构中的每个分支表示为一个矩形,然后用代表子分支的较小矩形进行拼接。树图中的空间是根据数据的特定属性(通常是大小或值)进行划分的,每个矩形的面积对应于该属性的大小,使得比较层次结构中的不同部分变得容易。
树图中 a、b 和 c 组的顺序 -> 每个层级的最大项:(a1)、(a1,b1)、(a1,b1,c1)
为了考虑矩形的排列,以下是一些常见的算法,它们控制树图的构造和最终外观:
-
方形树图² - 通过调整矩形的长宽比,使矩形尽可能接近正方形
-
条形树图³ - 根据数据的层次结构,以水平或垂直方式布置矩形
-
切片和切割⁴ - 交替进行水平和垂直分割,虽然直观,但可能会生成较长的矩形
树图特征
-
比例 - 每个矩形的大小与其代表的数据点成比例,使得快速识别较大和较小的项变得容易
-
线条和颜色 - 边框、边框大小和颜色以及缓冲区的巧妙使用可以划分层次级别,而容器颜色通常用来表示数据的不同维度
-
空间效率 - Treemaps 特别适合可视化大型数据集,因为它们有效利用空间,允许同时显示成千上万的项目,平铺算法确定最终布局
-
顺序 - 当收集的数据独立于已知层次结构时,分组中维度的顺序很重要,因为它决定了层次结构每一级的父子关系
总体而言,Treemap 的优势在于能够以空间高效和视觉吸引的方式展示复杂的层次数据,使其成为任何研究领域中受欢迎的可视化工具。
多态性
层次数据及其作为嵌套多边形和形状的表示已经有各种有用且视觉吸引的实现。我最喜欢的之一是 Voronoi Treemaps⁵,而且我特别喜欢与 Voronoi 形状结合的 3D 映射(如 3D Voronoi Treemap Sphere)的想法。
可能还有许多 Treemap 表现形式等待开发,下面我将介绍我自己实现的径向 Treemap。在撰写这篇文章时,我还发现了一个名为 Krona⁶ 的工具(它似乎有类似的输出风格),这是通过反向图像搜索我自己的径向 Treemap 图表时发现的。
径向 Treemap
灵感
当我创建我的第一个径向 Treemap 时,我并没有考虑实现一种 Treemap 类型,而是一个关于飞行器⁷的数据集。我的目标是将这个数据集呈现成一个喷气发动机的样子,作为一种艺术可视化作品。
“Takeoff” 作者:Nick Gerend (3/11/2021)
我最终得到的东西,最初我称之为“饼图树”图表,但后来我意识到这实际上只是 Treemapping 的一种形式,所以现在我称它为径向 Treemap!
这种类型的 Treemap 特别之处在于其“径向”布局,这种布局在圆形空间中打开了各种有用且有趣的组件。我发现将它与其他类型的径向图表结合以分解层次数据的不同方面特别有趣。
在下一部分,我将讨论构建径向 Treemap 时涉及的组件。
数学、算法与布局
元素
径向 Treemap 的数学基础利用了一些基本元素来确定多边形容器的尺寸:
-
内半径 - 勾勒出容器较短的弯曲边缘
-
外半径 - 勾勒出容器较长的弯曲边缘
-
起始角度 - 勾勒出容器在最小角度处的直边
-
结束角度 - 勾勒出容器在最大角度处的直边
这些容器本质上是沿两个半径分割的饼图或甜甜圈楔形,与传统树图的矩形容器相比显得不规则,但在相对大小比较上保持相同的直观性。
来自树图和径向树图的形状,具有类似的区域
函数
现在我们已经确定了容器的基本要素,接下来我们将进入数学部分,从层次结构第一层的外部父容器的面积开始。假设我们希望对外部容器的形状有一定的灵活性,我们可以为甜甜圈切片实现一个面积函数(func_area):
# area of the outer container:
# > r1: inner radius
# > r2: outer radius
# > start_angle: degrees between 0-360
# > end_angle: degress > start_angle
area = (pi*r2**2-pi*r1**2)*((end_angle-start_angle)/360)
这允许在甜甜圈的中间(原点和内半径之间)和起始角度与结束角度之间留有空隙。这种灵活性鼓励以独特和创造性的方式配置布局,以最大限度地利用当前的用例。可能性无限!
接下来我们需要根据容器的预期方向来确定容器的角点(func_container):
# given 3 of the 4 child container paramters (r1, r2, a1, a2),
# gathered from the parent container, and the area of the child container:
# > r1: inner radius
# > r2: outer radius
# > a1: starting angle
# > a2: ending angle
# > area: area of the child container
# split the current container between two angles at a constant radius
# find the radius:
r2 = sqrt(((area/(a2-a1))+pi*r1**2)/pi)
# split the current container at a specific angle between two radii
# find the angle:
a2 = a1 + (area)/(pi*r2**2-pi*r1**2)
为了实现一个好的默认方向方法,让我们计算两个方向选项的弧长和半径长度,以选择具有最小最大长度的容器方向(我称之为“智能”方法):
# "smart" container orientation method:
# calculate the max lengths between both
# orientation options using the following:
arc_length = (2*pi*r2)*((a2-a1)/360)
radius_length = r2-r1
max_lnegth = max(arc_length, radius_length)
# select the orientation with the smallest max_length
# to avoid skinny polygons
剩下的就是将极坐标转换为笛卡尔坐标进行绘图(func_convertion):
# assuming a starting position of 12 o'clock:
# > ad: angle in degrees
# > ar: angles in radians
# > r: radius
ar = (ad-90)*pi/180
x, y = r*cos(ar), r*sin(ar)
径向树图算法
使用已建立的函数,这里是通用算法:
-
从一组互斥的组开始,具有 1 到多个维度,以及它们的计数或值,这些将决定多边形的面积,例如:({a1,b1,c1}, 12.3), ({a1,b2,c1}, 4.5), ({a2,b1,c2}, 32.3), ({a1,b2,c2}, 2.1), ({a2,b1,c1}, 5.9), ({a3,b1,c1}, 3.5], ({a4,b2,c1}, 3.1)
-
设置配置输入:(位置:{起始角度、结束角度、内半径、外半径、旋转}),(排序:{降序、升序、手动}),(容器方向方法:{交替、向外、绕行、智能、图例}),(分组:{开启、关闭})
-
计算外部容器的面积(func_area)
-
递归计算子容器的位置,利用它们相对于外部容器面积的相对面积百分比作为容器函数(func_container)的输入,配合所选择的容器方向方法
-
用点填充弧段中的多边形边界(更多的点以获得更高的曲线分辨率),并将极坐标转换为笛卡尔坐标以进行绘图
容器方向方法
我开发的原始构建方法恰好在半径分割和同心圆分割之间交替(如上所示的“起飞”信息图),类似于早期的树图,这些树图使用了“切片和切割”算法,在水平和垂直分割之间交替。
这是我目前为径向树图创建的容器方向方法:
-
替代 - 原始的!(类似于切片和骰子的替代方法)
-
向外 - 所有分区都绘制为半径
-
周围 - 所有分区都绘制为同心圆
-
智能 - 为每个选项计算弧长和半径长度,并选择最小的作为防止细长多边形的机制
-
图例 - 层级结构的第一层总是绘制为半径,以对齐相应的图例甜甜圈
还可以选择是否首先对项目进行分组。可以切换扁平化数据(移除层次分组)以按元素最低级的自然顺序排序,为另一个层次的见解提供更多信息(特别是对图例很有用)。
使用上述示例数据的径向树图配置
布局
通过包括中心和楔形空白区域、堆叠和旋转径向树图的灵活性,布局选项是无限的!
外围布局参数:
-
总面积(作为两个或更多径向树图之间的相对度量)
-
布局约束(甜甜圈切片由两个角度和两个半径确定)
-
旋转(围绕中心的旋转)
-
与其他径向树图的相对定位(堆叠等)
径向树图周边示例
图例:
- 补充的径向树图(内部、外部或两侧)(有助于说明不同级别的元素排序)
径向树图图例
可视化扩展(内/外/线性连接):
适用的可视化扩展示例(桑基图,和弦图)
3D 径向树图
径向树图结构可以轻松地在数学上扩展到 3D,并且附带了一个额外的切片和切割平面!
可用于容器化的几何形状从球坐标系统中可以看出:
-
径向距离:r ≥ 0,
-
极角:0° ≤ θ ≤ 180°(0 rad ≤ θ ≤ π rad)
-
方位角:0° ≤ φ < 360°(0 rad ≤ φ < 2π rad)
就 3D 径向树图输入而言,这是每个表面可用空间的投影:
3D 径向树图容器边界
方便的是,3D 的通用算法与 2D 相同,调整面积过渡到体积的效果,并解决半径、极角和方位角的三种可能方向。以下是一个简单的 3D 径向树图:
由 Nick Gerend 渲染的 3D 树图,使用 Autodesk Fusion 360
接下来,我将展示一个生成 2D 径向树图可视化的 Python 实现。3D 版本在开发计划中!
Python 实现
我已经通过我的 vizmath 包在 PyPI 上提供了我径向树图算法的初步实现。以下是一个使用示例:
from vizmath import rad_treemap as rt # pip install vizmath==0.0.9
import pandas as pd
# using the example data from above:
data = [
['a1', 'b1', 'c1', 12.3],
['a1', 'b2', 'c1', 4.5],
['a2', 'b1', 'c2', 32.3],
['a1', 'b2', 'c2', 2.1],
['a2', 'b1', 'c1', 5.9],
['a3', 'b1', 'c1', 3.5],
['a4', 'b2', 'c1', 3.1]]
df = pd.DataFrame(data, columns = ['a', 'b', 'c', 'value'])
# create a rad_treemap object
# > df: DataFrame with 1 or more categorical columns of data
# and an optional 'value' column for the areas
# (otherwise groups counts are used for areas)
# > groupers: group-by columns
# > value: optional value column
# > r1, r2: inner and outer radius positions
# > a1, a2: start and end angle positions
# > rotate_deg: overall rotation around the center
# > mode: container orientation method
# > other options: 'points', 'default_sort', 'default_sort_override',
# 'default_sort_override_reversed', 'mode', 'no_groups', 'full'
rt_1 = rt(df=df, groupers=['a','b','c'], value='value', r1=0.5, r2=1,
a1=0, a2=180, rotate_deg=-90, mode='alternate')
# plot the Radial Treemap
rt_1.plot_levels(level=3, fill='w')
使用组值作为区域的 Radial Treemap 通过 vizmath 渲染,使用 Matplotlib
让我们来看看 Radial Treemap 算法的输出:
-
level - 层级:从 1 到 N 层
-
group - 代表树上的每个节点:例如,组 {a1,b1,c1} 属于组 {a1,b1},而 {a1,b1} 属于
-
count - 组的计数:下方可以看到在第 1 层(最高层),组 {a2} 包含 2 个项目
-
value - 组的值(如果指定):可以使用提供的数字来表示大小,而不是使用组中项目的计数
-
层级排名 - 项目在其组中的排名,按其值(如果值不可用则按计数)从高到低排序:1 到 N
-
总体排名 - 项目在所有组中的总体排名,按其值(如果值不可用则按计数)从高到低排序:1 到 N
-
x, y - 布局中点的笛卡尔 2D 坐标
-
path - 描述封闭多边形的整数有序集合,与 Radial Treemap 中每个 (x, y) 点相结合,用于每个组:1 到 N(由‘points’参数指定)
# sample the Radial Treemap DataFrame
rt_1.to_df()[['level','group','count','value',
'level_rank','overall_rank','x','y','path']].head()
Radial Treemap DataFrame
最后,让我们看看一个忽略组值的基于计数的版本是什么样的。
# set 'value' to None or just leave it out since None is the default
# doing this sets the areas equal to the group counts
# in this case, each count will be one since there are no duplicates
rt_2 = rt(df=df, groupers=['a','b','c'], value=None, r1=0.5, r2=1,
a1=0, a2=180, rotate_deg=-90, mode='alternate')
# plot the Radial Treemap
rt_2.plot_levels(level=3, fill='w')
使用组计数作为区域的 Radial Treemap 通过 vizmath 渲染,使用 Matplotlib
Tableau Public 实现
在这一部分,我将展示如何在 Tableau Public (v 2023.3.0) 中实现我的 Radial Treemap 可视化,并介绍一些有趣的交互功能。
要开始,请向我们之前的示例中添加更多的组和值,并将数据输出到 csv 文件中以供 Tableau Public 使用。首先,创建一个包含 3 个类别列和一个数值列的 DataFrame:
import pandas as pd
data = [
['a1', 'b1', 'c1', 9.3],
['a1', 'b1', 'c2', 6.7],
['a1', 'b1', 'c3', 2.4],
['a1', 'b2', 'c1', 4.5],
['a1', 'b2', 'c2', 3.1],
['a2', 'b1', 'c1', 5.9],
['a2', 'b1', 'c2', 32.3],
['a2', 'b1', 'c3', 12.3],
['a2', 'b1', 'c4', 2.3],
['a2', 'b2', 'c1', 9.1],
['a2', 'b2', 'c2', 17.3],
['a2', 'b2', 'c3', 6.7],
['a2', 'b2', 'c4', 4.4],
['a2', 'b2', 'c5', 11.3],
['a3', 'b1', 'c1', 7.5],
['a3', 'b1', 'c2', 9.5],
['a3', 'b2', 'c3', 17.1],
['a4', 'b2', 'c1', 5.1],
['a4', 'b2', 'c2', 2.1],
['a4', 'b2', 'c3', 11.1],
['a4', 'b2', 'c4', 1.5]]
df = pd.DataFrame(data, columns = ['a', 'b', 'c', 'value'])
接下来,我们将使用 vizmath 创建 Radial Treemap 图表和图例,将两者合并到一个文件中,并将绘图信息输出到 csv:
from vizmath import rad_treemap as rt
import os
# Radial Treemap chart object
rt_obj = rt(df=df, groupers=['a','b','c'], value='value',
r1=0.5, r2=1, a1=0, a2=180, rotate_deg=-90 ,mode='legend')
rt_df = rt_obj.to_df()
rt_df['type'] = 'chart'
# Radial Treemap legend object
rt_legend_obj = rt(df=df, groupers=['a','b','c'], value='value',
r1=1.04, r2=1.09, a1=0, a2=180, rotate_deg=-90 ,mode='legend',
no_groups=True)
rt_legend_df = rt_legend_obj.to_df()
rt_legend_df['type'] = 'legend'
# export the drawing data
df_out = pd.concat([rt_df, rt_legend_df], axis=0)
df_out.to_csv(os.path.dirname(__file__) + '/radial_treemap.csv',
encoding='utf-8', index=False)
使用 文本文件 选项将文件导入 Tableau,导航到 Sheet 1,并创建这些参数和计算字段,我们将使用它们绘制图表和图例:
创建参数(从左侧“数据”标签下的汉堡菜单中选择“创建参数…”):
[Chart Level]: {整数, 范围, 最小值: 1, 最大值: 3, 步长: 3}
[Legend Level]:
创建计算字段(从相同菜单下选择“创建计算字段…”):
[rad_treemap]: 如果 [type] = ‘chart’ 且 [Level] = [Chart Level],则 MAKEPOINT([Y],[X]) 否则为 null 结束
[rad_treemap_legend]: 如果 ([type] = ‘legend’ 且 [Level] = [Legend Level]),则 MAKEPOINT([Y],[X]) 否则为 null 结束
[rad_treemap_lines]: 如果 [type] = ‘chart’ 且 [Level] <= [Chart Level],则 MAKEPOINT([Y],[X]) 否则为 null 结束
首先将[radial_treemap]拖动到标记下的详细信息中,以生成第一个地图层,然后右键点击地图区域,选择背景层来调整这些选项:
-
取消选择所有背景地图层(基础、土地覆盖等)。
-
现在在地图区域右键单击,选择地图选项并取消选择所有选项。
关闭背景层并继续以下步骤:
-
将[Group]拖到标记下的详细信息中。
-
在标记下拉菜单中选择多边形(如果此时看起来有些奇怪也不用担心)。
-
将[Path]拖动到标记下的路径中,右键点击现在的SUM(Path)并选择维度。
-
将[Value]拖动到颜色中,并重复将其转换为维度的过程。
-
在颜色下选择“编辑颜色…”,并配置以下选项:
-
点击确定,然后在颜色下将不透明度调整为 50%。
现在,径向树图的结构应该可见。让我们添加另一个层级,以使用层级的第一级项来增强颜色。首先添加一些新的计算列:
[Label]:replace(replace(replace([Group],”’”,’’),’(‘,’’),’)’,’’)
[Level 1]:split([Label],’,’,1)
[Level 2]:split([Label],’,’,2)
[Level 3]:split([Label],’,’,3)
现在让我们使用[Level 1]进行着色:
-
将[radial_treemap]拖动到地图区域,弹出窗口将显示:添加标记层 - 将该图标拖入此处以创建新的地图层。
-
重复上述步骤,但现在使用[Level 1]作为颜色。
-
在颜色下选择黑色边框,将不透明度设置为 50%。
让我们通过添加一些不同厚度的线条来总结图表部分,以指示层级边界的位置:
-
使用[rad_treemap_lines]作为地图层,线条作为标记下拉菜单中的图表类型,并将颜色设置为中等黑色,重复前面的步骤。
-
将[Level]拖动到标记下的大小,并转换为维度和离散。
-
在图表右侧的大小部分标记为Level,从容器右上角显示的下拉菜单中选择“编辑大小…”。
-
选择反向选项,点击确定,然后右键单击图表右下角的空值图标,选择隐藏指示器以隐藏空值标签。
现在图表部分已就位,应与下图类似:
让我们添加一个图例来补充图表:
- 使用[rad_treemap_legend]添加两个图表层,重复之前的所有步骤。
为了完成可视化,让我们添加一些标签层。首先添加这些参数和计算列来定位标签:
创建参数:
[Show Labels Chart]: {布尔值,别名:(True: Yes,False: No)}
[Show Labels Legend]:
创建计算列:
[point_angle]: atan2([X],[Y])*180/pi() — 90
[group_angle]: {固定 [Type],[Group]:avg([point_angle])}
[point_radius]: [X]/cos([point_angle]*pi()/180)
[group_radius_min]: {固定 [Type],[Group]:min([point_radius])}
[group_radius_max]: {固定 [Type],[Group]:max([point_radius])}
[group_radius]: ([group_radius_max]-[group_radius_min])/2+[group_radius_min]
[chart_group_legend]: 如果 [Type] = ‘chart’ 且 [Level] = [Chart Level] 且 [Show Labels Chart] 则
MAKEPOINT(
-[group_radius]sin(([group_angle])pi()/180),
[group_radius]cos(([group_angle])pi()/180)
) 否则为空结束
[legend_group_legend]: 如果 [Type] = ‘legend’ 且 [Level] = [Legend Level] 且 [Show Labels Legend] 则
MAKEPOINT(
-[group_radius]sin(([group_angle])pi()/180),
[group_radius]cos(([group_angle])pi()/180)
) 否则为空结束
现在我们将添加最后两层以完成 Radial Treemap:
-
将 [chart_group_legend] 作为地图图层添加,并将 [Group] 添加到 Marks 下的 Detail。
-
将图表类型更改为 Circle,并将 [Label] 拖到 Marks 下的 Label。
-
将 Color 调整为 50%不透明的白色,并没有边框或光晕,将滑块拖动到 Size 的中心右侧。
-
在 Label 下,点击 […] 菜单旁的 Text,在对话框中选择文本,将大小更改为 {8,粗体},然后点击 OK。
-
返回主 Label 菜单,选择 Allow labels to overlap other marks,并将 Alignment 调整为 {center,center}。
-
目前将参数 [Show Labels Chart] 切换为 False,并重复上述步骤使用 [legend_group_legend] 向图例添加标签。
要完成 Sheet 1,通过将 [Label] 拖到 Marks 下的 Tooltip 中的 Attribute,并右键点击该 pill 选择 Attribute,将 [ATTR(Label)] 添加到 Tooltips 中。以相同方式添加 [ATTR(Items)] 和 [ATTR(Value)]。
为了帮助交互式探索 Radial Treemap 中的数据,让我们创建一个简单的表格条形图。
-
使用底部面板上的第一个加号创建新工作表,生成 Sheet 2。
-
在新工作表中,将 [Level 1]、[Level 2]、[Level 3] 和 [Label] 拖到 Rows。
-
现在将 [Count] 拖到 Rows 并更改为 Dimension 和 Discrete。
-
对 [Value] 进行相同操作,将图表类型更改为 Bar,并将 [Value] 拖到 Marks 下的 Color 和 Size。
-
对于 [Value] 使用与之前工作表相同的颜色方案,并添加 80%不透明度的黑色边框。
-
右键点击列并选择 Rename,将 [Count] 重命名为 [Items]。
最后,将两个工作表汇总到仪表盘中。在创建仪表盘并添加工作表后,在仪表盘顶部菜单中的操作下设置一个操作。点击添加操作下拉菜单,选择高亮显示。在目标高亮显示下选择选择字段并选择[标签]和[ATTR(标签)]字段。最后在右侧的运行操作于菜单下选择悬停选项,现在当鼠标悬停在表格或图表中的每一层级上时,整个仪表盘将高亮显示!
添加参数到仪表盘并以有序的方式进行定位后,这里是我们在 Tableau Public 上的新仪表盘:
结论
在这篇文章中,我简要介绍了树图的历史以及我称之为“径向树图”的内容,这是一种我开发的可视化工具,用于检查循环布局中的层级关系,提供了在甜甜圈切片、堆叠、图例和与其他径向图表类型的协同方面的灵活性。它可以以多种方式使用,从数据中得出新的见解,希望你发现这种可视化技术充满启发性和潜力!
如果你对其他径向图表类型感兴趣,查看我的多弦图:
背景
towardsdatascience.com
如果你发现任何有趣或专业的使用案例,请告诉我,谢谢阅读!
参考文献
本文中的所有图像均由作者创建,除非另有说明。
[1] 本·施奈德曼, “使用树图的树状可视化:二维空间填充方法” (1992),《ACM 图形学报告》
[2] 马克·布鲁尔斯,凯斯·惠辛,贾尔克·J·范·维克,“方形树图” (2000),《数据可视化 2000:欧洲图形学和 IEEE TCVG 联合会议论文集》,荷兰阿姆斯特丹,2000 年 5 月 29-30 日
[3] 本杰明·贝德森,本·施奈德曼,马丁·瓦滕贝格,“有序和量子树图:有效利用二维空间显示层级结构” (2002),《ACM 图形学报告》
[4] 本·施奈德曼,马丁·瓦滕贝格,“有序树图布局” (2001),《INFOVIS》第 73-78 页
[5] 迈克尔·巴尔泽,奥利弗·德伊森,“Voronoi 树图” (2005),IEEE 信息可视化研讨会
[6] 布莱恩·昂多夫,尼古拉斯·伯格曼,亚当·菲利皮,“在网页浏览器中的交互式宏基因组可视化”(2011 年),BMC 生物信息学
[7] 联邦航空管理局,“飞机登记数据库”(2020 年),美国运输部
RAG:如何与您的数据交流
详细指南:如何使用 ChatGPT 分析客户反馈
·
关注 发表在 Towards Data Science ·21 分钟阅读·2023 年 11 月 11 日
--
由 DALL-E 3 提供的图像
在我的以前的文章中,我们讨论了如何使用 ChatGPT 进行主题建模。我们的任务是分析不同酒店连锁的客户评论,并确定每家酒店提到的主要主题。
通过这样的主题建模,我们知道每个客户评论的主题,可以轻松按主题筛选并深入了解。然而,在现实生活中,拥有能够涵盖所有可能用例的详尽主题集是不可能的。
例如,这是我们从客户反馈中先前识别出的主题列表。
这些主题可以帮助我们获得客户反馈的高层次概述,并进行初步预筛选。但是,假设我们想了解客户对健身房或早餐饮品的看法。在这种情况下,我们将需要自己从“酒店设施”和“早餐”主题中浏览相当多的客户反馈。
幸运的是,LLMs 可以帮助我们进行这种分析,节省大量浏览客户评论的时间(尽管自己倾听客户的声音仍可能是有帮助的)。在本文中,我们将讨论这些方法。
我们将继续使用 LangChain(最流行的 LLM 应用框架之一)。你可以在我之前的文章中找到 LangChain 的基本概述。
幼稚的方法
获取与特定主题相关的评论最直接的方法就是在文本中寻找一些特定的词汇,比如“健身房”或“饮料”。在 ChatGPT 出现之前,我曾多次使用这种方法。
这种方法的问题是相当明显的:
-
你可能会得到很多不相关的关于附近健身房或酒店餐厅酒精饮料的评论。这种过滤器不够具体,不能考虑上下文,因此你会有很多假阳性。
-
另一方面,你可能也无法获得足够好的覆盖范围。人们往往对相同的事物使用略微不同的词汇(例如,饮料、茶点、饮品、果汁等)。可能还会有拼写错误。如果你的客户说不同的语言,这个任务可能会变得更加复杂。
因此,这种方法在精准度和召回率方面都有问题。它会给你对问题的粗略理解,但能力有限。
另一种潜在的解决方案是使用与主题建模相同的方法:将所有客户评论发送给 LLM,并让模型确定它们是否与我们的兴趣主题相关(早餐饮品或健身房)。我们甚至可以要求模型总结所有客户反馈并提供结论。
这种方法可能会工作得很好。然而,它也有其局限性:每次你想深入探讨一个特定话题时,你需要将所有文档发送给 LLM。即使基于我们定义的主题进行高水平过滤,传递给 LLM 的数据量也可能相当大,而且成本也会相当高。
幸运的是,还有另一种解决这个任务的方法,它被称为 RAG。
检索增强生成
我们有一组文档(客户评论),我们希望提出与这些文档内容相关的问题(例如,“客户喜欢早餐的哪些方面?”)。正如我们之前讨论的,我们不想将所有客户评论都发送给 LLM,因此我们需要一种方法来定义最相关的评论。然后,任务将变得非常简单:将用户问题和这些文档作为上下文传递给 LLM,就可以了。
这种方法称为检索增强生成或 RAG。
作者提供的方案
RAG 的流水线包括以下几个阶段:
-
加载文档从我们拥有的数据源。
-
将文档分割为更容易进一步使用的块。
-
存储: 向量存储通常用于此用例,以有效处理数据。
-
检索与问题相关的文档。
-
生成是将问题和相关文档传递给 LLM 并获得最终答案。
您可能已经听说 OpenAI 本周推出了助理 API,它可以为您完成所有这些步骤。但我认为值得通过整个过程来理解它的工作原理及其特殊性。
因此,让我们逐步了解所有这些阶段。
加载文档
第一步是加载我们的文档。LangChain 支持不同类型的文档,例如CSV或JSON。
您可能会想知道使用 LangChain 处理这些基本数据类型的好处是什么。毫无疑问,您可以使用标准 Python 库解析 CSV 或 JSON 文件。但我建议使用 LangChain 数据加载器 API,因为它返回包含内容和元数据的文档对象。稍后使用 LangChain 文档会更容易。
让我们看看一些更复杂的数据类型的例子。
我们经常需要分析网页内容,因此必须处理 HTML。即使您已经掌握了BeautifulSoup库,您可能会发现BSHTMLLoader也很有帮助。
与 LLM 应用相关的 HTML 的有趣之处在于,很可能您需要对其进行大量预处理。如果您使用浏览器检查工具查看任何网站,您会注意到比网站上看到的文本要多得多。它用于指定布局、格式和样式等。
作者提供的图片,LangChain 文档
在大多数实际情况下,我们不需要将所有这些数据传递给 LLM。一个站点的整个 HTML 很容易超过 200K 标记(只有用户看到的文本约为 10-20%),因此将其适应上下文大小将是一项挑战。而且,这些技术信息可能会让模型的工作变得更加困难。
因此,从 HTML 中提取文本并将其用于进一步分析是相当标准的做法。要做到这一点,你可以使用下面的命令。结果,你将得到一个文档对象,其中网页内容在page_content
参数中。
from langchain.document_loaders import BSHTMLLoader
loader = BSHTMLLoader("my_site.html")
data = loader.load()
另一个常用的数据类型是 PDF。我们可以解析 PDF,例如使用 PyPDF 库。让我们从 DALL-E 3 论文中加载文本。
from langchain.document_loaders import PyPDFLoader
loader = PyPDFLoader("https://cdn.openai.com/papers/DALL_E_3_System_Card.pdf")
doc = loader.load()
在输出中,你会得到一组文档 — 每页一个文档。在元数据中,source
和page
字段都会被填充。
因此,正如你所见,LangChain 允许你处理广泛的不同文档类型。
让我们回到我们最初的任务。在我们的数据集中,每个酒店都有一个单独的.txt 文件,其中包含顾客的评论。我们需要解析目录中的所有文件并将它们整合在一起。我们可以使用DirectoryLoader
来完成这个任务。
from langchain.document_loaders import TextLoader, DirectoryLoader
text_loader_kwargs={'autodetect_encoding': True}
loader = DirectoryLoader('./hotels/london', show_progress=True,
loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
docs = loader.load()
len(docs)
82
我们的文本不是标准的 UTF-8 编码,所以我还使用了'autodetect_encoding': True
。
结果,我们得到了文档列表 — 每个文本文件一个文档。我们知道每个文档由独立的客户评论组成。与其处理酒店所有顾客评论的大文本,我们更有效地使用较小的块来处理。因此,我们需要分割我们的文档。让我们继续下一阶段,详细讨论文档分割。
文档分割
下一步是分割文档。也许你会想为什么我们需要这样做。文档通常很长,涵盖多个主题,例如 Confluence 页面或文档。如果我们将这样的长文本传递给 LLMs,我们可能会面临以下问题:要么 LLM 被无关信息分散注意力,要么文本不适合上下文大小。
因此,为了有效地处理 LLMs,值得从我们的知识库(文档集合)中定义最相关的信息,并仅将此信息传递给模型。这就是为什么我们需要将文档分割成较小块的原因。
通常用于一般文本的最常见技术是递归字符分割。在 LangChain 中,它是由RecursiveCharacterTextSplitter
类实现的。
让我们尝试理解它是如何工作的。首先,你需要定义一个优先级列表用于分割器(默认为["\n\n", "\n", " ", ""]
)。然后,分割器会逐个字符地遍历这个列表,并尝试将文档分割成足够小的块。这意味着该方法试图保持语义上紧密相关的部分在一起(段落、句子、单词),直到我们需要分割它们以达到期望的块大小。
让我们使用Python 之禅看看它是如何工作的。这段文字有 824 个字符,139 个单词和 21 个段落。
如果你执行
import this
,你可以看到 Python 之禅。
zen = '''
Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one -- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!
'''
print('Number of characters: %d' % len(zen))
print('Number of words: %d' % len(zen.replace('\n', ' ').split(' ')))
print('Number of paragraphs: %d' % len(zen.split('\n')))
# Number of characters: 825
# Number of words: 140
# Number of paragraphs: 21
让我们使用RecursiveCharacterTextSplitter
,并从相对较大的块大小开始,设为 300。
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 300,
chunk_overlap = 0,
length_function = len,
is_separator_regex = False,
)
text_splitter.split_text(zen)
我们将得到三个块:264、293 和 263 个字符。我们可以看到所有的句子都保持在一起。
以下所有图像均由作者制作。
你可能会注意到一个chunk_overlap
参数,它允许你进行重叠分割。这很重要,因为我们将把一些块和问题一起传递给 LLM,而拥有足够的上下文来仅根据每个块中提供的信息做出决策是至关重要的。
作者方案
让我们尝试添加chunk_overlap
。
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 300,
chunk_overlap = 100,
length_function = len,
is_separator_regex = False,
)
text_splitter.split_text(zen)
现在,我们有四个分割块,字符数分别为 264、232、297 和 263,我们可以看到我们的块有重叠。
让我们把块的大小稍微调小一点。
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 50,
chunk_overlap = 10,
length_function = len,
is_separator_regex = False,
)
text_splitter.split_text(zen)
现在,我们甚至不得不分割一些较长的句子。这就是递归分割的工作原理:由于按段落("\n"
)分割后,块仍然不够小,因此分割器继续按" "
分割。
你可以进一步自定义分割。例如,你可以指定length_function = lambda x: len(x.split("\n"))
来使用段落的数量作为块的长度,而不是字符的数量。按标记分割也很常见,因为 LLM 的上下文大小基于标记的数量。
另一种潜在的自定义方式是使用其他separators
,而不是用","
而是用" "
来分隔。让我们尝试用几句话来使用它。
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 50,
chunk_overlap = 0,
length_function = len,
is_separator_regex = False,
separators=["\n\n", "\n", ", ", " ", ""]
)
text_splitter.split_text('''\
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.''')
它有效,但逗号的位置不对。
为了解决这个问题,我们可以使用带回顾的正则表达式作为分隔符。
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 50,
chunk_overlap = 0,
length_function = len,
is_separator_regex = True,
separators=["\n\n", "\n", "(?<=\, )", " ", ""]
)
text_splitter.split_text('''\
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.''')
现在已修复。
此外,LangChain 提供了处理代码的工具,可以根据编程语言特定的分隔符来分割文本。
然而,在我们的情况下,情况更简单。我们知道每个文件中有用"\n"
分隔的独立评论,我们只需按此分隔即可。不幸的是,LangChain 不支持这种基本用例,因此我们需要进行一些黑客操作以使其按我们想要的方式工作。
from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 1,
chunk_overlap = 0,
length_function = lambda x: 1, # hack - usually len is used
is_separator_regex = False
)
split_docs = text_splitter.split_documents(docs)
len(split_docs)
12890
你可以在我之前关于 LangChain 的文章中找到更多关于我们为什么需要这个 hack 的详细信息。
文档的重要部分是元数据,因为它可以提供有关该块来源的更多上下文。在我们的例子中,LangChain 自动填充了元数据的source
参数,因此我们知道每条评论涉及哪个酒店。
还有其他方法(例如用于HTML或Markdown的方法),它们在拆分文档时添加标题到元数据。如果您正在处理这些数据类型,这些方法可能非常有帮助。
向量存储
现在我们有评论文本,下一步是学习如何有效地存储它们,以便我们可以获得相关的文档来回答我们的问题。
我们可以将评论存储为字符串,但这对我们解决这个任务没有帮助——我们无法过滤与问题相关的客户评论。
更加功能强大的解决方案是存储文档的嵌入。
嵌入是高维向量。嵌入捕捉单词和短语之间的语义含义和关系,因此语义上接近的文本之间的距离较小。
我们将使用OpenAI 嵌入,因为它们非常流行。OpenAI 建议使用text-embedding-ada-002
模型,因为它具有更好的性能、更广泛的上下文和更低的价格。像往常一样,它有其风险和限制:潜在的社会偏见和对最近事件的有限了解。
让我们尝试在玩具示例上使用嵌入来看看它的工作原理。
from langchain.embeddings.openai import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
text1 = 'Our room (standard one) was very clean and large.'
text2 = 'Weather in London was wonderful.'
text3 = 'The room I had was actually larger than those found in other hotels in the area, and was very well appointed.'
emb1 = embedding.embed_query(text1)
emb2 = embedding.embed_query(text2)
emb3 = embedding.embed_query(text3)
print('''
Distance 1 -> 2: %.2f
Distance 1 -> 3: %.2f
Distance 2-> 3: %.2f
''' % (np.dot(emb1, emb2), np.dot(emb1, emb3), np.dot(emb2, emb3)))
我们可以使用
*np.dot*
作为余弦相似度,因为 OpenAI 嵌入已经被归一化。
我们可以看到第一和第三个向量彼此接近,而第二个向量不同。第一和第三个句子有类似的语义含义(它们都是关于房间大小),而第二个句子不接近,讨论天气。因此,嵌入之间的距离实际上反映了文本之间的语义相似性。
现在,我们知道如何将评论转换为数值向量。下一个问题是如何存储这些数据,以便轻松访问。
让我们考虑一下我们的用例。我们的流程将是:
-
获取一个问题,
-
计算其嵌入,
-
找到与此问题相关的最相关的文档块(与此嵌入距离最小的文档块),
-
最后,将找到的块作为上下文与初始问题一起传递给 LLM。
数据存储的常规任务是找到 K 个最近的向量(K 个最相关的文档)。因此,我们需要计算我们问题的嵌入与我们拥有的所有向量之间的距离(在我们的情况下,余弦相似度)。
通用数据库(如 Snowflake 或 Postgres)在这样的任务中表现不佳。但是有些数据库被优化,特别适合这种用例——向量数据库。
我们将使用一个开源嵌入数据库,Chroma。Chroma 是一个轻量级的内存数据库,非常适合原型设计。你可以在这里找到更多的向量存储选项。
首先,我们需要使用 pip 安装 Chroma。
pip install chromadb
我们将使用persist_directory
来将数据本地存储并从磁盘重新加载。
from langchain.vectorstores import Chroma
persist_directory = 'vector_store'
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embedding,
persist_directory=persist_directory
)
为了在下次需要时能够从磁盘加载数据,请执行以下命令。
embedding = OpenAIEmbeddings()
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embedding
)
数据库初始化可能需要几分钟时间,因为 Chroma 需要加载所有文档并使用 OpenAI API 获取它们的嵌入。
我们可以看到所有文档已经加载完毕。
print(vectordb._collection.count())
12890
现在,我们可以使用相似性搜索来查找关于员工礼貌的顶级客户评论。
query_docs = vectordb.similarity_search('politeness of staff', k=3)
文档看起来与问题非常相关。
我们已经以可访问的方式存储了客户评论,现在是时候更详细地讨论检索了。
检索
我们已经使用了vectordb.similarity_search
来检索与问题最相关的块。在大多数情况下,这种方法将对你有效,但可能会有一些细节:
-
多样性缺乏 — 模型可能会返回极其相似的文本(甚至重复),这不会给 LLM 带来多少新信息。
-
未考虑元数据 —
similarity_search
不会考虑我们拥有的元数据。例如,如果我查询问题“Travelodge Farringdon 的早餐”的前五条评论,结果中只有三条评论的来源等于uk_england_london_travelodge_london_farringdon
。 -
上下文大小限制 — 和往常一样,我们有有限的 LLM 上下文大小,需要将文档适配到其中。
让我们讨论一下可以帮助我们解决这些问题的技术。
解决多样性问题 — MMR(最大边际相关性)
相似性搜索返回与你的问题最接近的响应。但为了向模型提供完整的信息,你可能不想只关注最相似的文本。例如,对于问题“Travelodge Farringdon 的早餐”,前五条客户评论可能都关于咖啡。如果我们仅查看这些评论,就会错过其他提到鸡蛋或员工行为的评论,从而对客户反馈有一定的局限性。
我们可以使用 MMR(最大边际相关性)方法来增加客户评论的多样性。它的工作原理非常简单:
-
首先,我们使用
similarity_search
获取fetch_k
与问题最相似的文档。 -
然后,我们选择了
k
中最具多样性的那些。
作者方案
如果我们想使用 MMR,我们应该使用max_marginal_relevance_search
而不是similarity_search
,并指定fetch_k
数量。值得保持fetch_k
相对较小,以便输出中不会有不相关的答案。就这些。
query_docs = vectordb.max_marginal_relevance_search('politeness of staff',
k = 3, fetch_k = 30)
让我们来看一下相同查询的示例。这次我们收到了更多样化的反馈,甚至还有带有负面情绪的评论。
解决特异性问题 — LLM 辅助检索
另一个问题是我们在检索文档时没有考虑元数据。为了解决这个问题,我们可以让 LLM 将初始问题拆分为两部分:
-
基于文档文本的语义过滤器,
-
基于我们拥有的元数据进行过滤,
这种方法被称为“自查询”。
首先,让我们添加一个手动过滤器,指定与 Travelodge Farringdon 酒店相关的source
参数的文件名。
query_docs = vectordb.similarity_search('breakfast in Travelodge Farrigdon',
k=5,
filter = {'source': 'hotels/london/uk_england_london_travelodge_london_farringdon'}
)
现在,让我们尝试使用 LLM 自动生成这样的过滤器。我们需要详细描述所有元数据参数,然后使用SelfQueryRetriever
。
from langchain.llms import OpenAI
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo
metadata_field_info = [
AttributeInfo(
name="source",
description="All sources starts with 'hotels/london/uk_england_london_' \
then goes hotel chain, constant 'london_' and location.",
type="string",
)
]
document_content_description = "Customer reviews for hotels"
llm = OpenAI(temperature=0.1) # low temperature to make model more factual
# by default 'text-davinci-003' is used
retriever = SelfQueryRetriever.from_llm(
llm,
vectordb,
document_content_description,
metadata_field_info,
verbose=True
)
question = "breakfast in Travelodge Farringdon"
docs = retriever.get_relevant_documents(question, k = 5)
我们的情况很棘手,因为元数据中的source
参数包含多个字段:国家、城市、酒店连锁和位置。在这种情况下,将如此复杂的参数拆分为更详细的子参数是值得的,以便模型可以更容易地理解如何使用元数据过滤器。
然而,通过详细的提示,它确实有效,并仅返回了与 Travelodge Farringdon 相关的文档。但我必须承认,这花了我几个迭代才达到这个结果。
让我们开启调试模式看看它的工作情况。要进入调试模式,只需执行下面的代码。
import langchain
langchain.debug = True
完整的提示非常长,所以让我们看看它的主要部分。这是提示的开头,给模型一个我们期望的概述和结果的主要标准。
然后,使用少量示例提示技术,模型提供了两个输入和期望输出的示例。这是其中一个示例。
我们并没有使用像 ChatGPT 这样的聊天模型,而是使用通用的 LLM(没有针对指令进行微调)。它只是训练来预测文本的后续标记。这就是为什么我们以问题和字符串Structured output:
结束提示,期待模型提供答案的原因。
结果,我们从模型那里得到的初始问题被拆分为两部分:语义部分(breakfast
)和元数据过滤器(source = hotels/london/uk_england_london_travelodge_london_farringdon
)
然后,我们使用了这种逻辑从我们的向量存储中检索文档,并仅获取了我们需要的文档。
解决大小限制 — 压缩
另一种可能有用的检索技术是压缩。尽管 GPT 4 Turbo 的上下文大小为 128K 标记,但它仍然有限。因此,我们可能需要预处理文档并仅提取相关部分。
主要优势有:
-
您将能够将更多文档和信息整合到最终提示中,因为它们将被压缩。
-
您将会得到更好、更集中的结果,因为在预处理期间将清除非相关的上下文。
这些好处是有代价的 — 您将需要更多的 LLM 调用来进行压缩,这意味着更慢的速度和更高的价格。
您可以在文档中找到有关此技术的更多信息。
作者提出的方案
实际上,我们甚至可以结合技术并在这里使用 MMR。我们使用ContextualCompressionRetriever
来获取结果。此外,我们指定了我们只想要三个文档作为返回结果。
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
llm = OpenAI(temperature=0)
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=vectordb.as_retriever(search_type = "mmr",
search_kwargs={"k": 3})
)
question = "breakfast in Travelodge Farringdon"
compressed_docs = compression_retriever.get_relevant_documents(question)
像往常一样,了解其内部运作方式是最有趣的部分。如果我们看实际调用,可以看到有三次调用 LLM 来从文本中提取仅相关信息的情况。这里有一个例子。
在输出中,我们只得到了与早餐相关的部分句子,所以压缩有所帮助。
有许多更有利的检索方法,例如经典自然语言处理技术:支持向量机(SVM)或者TF-IDF。不同的检索器可能在不同情况下有所帮助,因此我建议您为您的任务比较不同版本,并选择最适合您使用情况的版本。
生成
最后,我们来到了最后阶段:我们将所有内容合并并生成最终答案。
这里是它们将如何运作的一个方案:
-
我们收到了用户的一个问题,
-
我们从向量存储中使用嵌入检索了此问题的相关文档,
-
我们将初始问题与从嵌入中检索到的相关文档一起传递给 LLM,并获得最终答案。
作者提出的方案
在 LangChain 中,我们可以使用RetrievalQA
链快速实现这一流程。
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name='gpt-4', temperature=0.1)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectordb.as_retriever(search_kwargs={"k": 3})
)
result = qa_chain({"query": "what customers like about staff in the hotel?"})
让我们看一下对 ChatGPT 的调用。正如您所见,我们将检索到的文档与用户查询一起传递。
这里是模型的输出。
我们可以调整模型的行为,定制提示。例如,我们可以要求模型更加简洁。
from langchain.prompts import PromptTemplate
template = """
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try
to make up an answer.
Keep the answer as concise as possible. Use 1 sentence to sum all points up.
______________
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
result = qa_chain({"query": "what customers like about staff in the hotel?"})
这次我们得到了一个更短的答案。此外,由于我们指定了return_source_documents=True
,我们得到了一组返回的文档。这对于调试可能有帮助。
正如我们所见,所有检索到的文档默认都合并在一个提示中。这种方法既优秀又简单,因为只需要一个调用来执行语言模型。唯一的限制是您的文档必须符合上下文大小。如果不符合,您需要应用更复杂的技术。
让我们看看不同的链类型,它们可以让我们处理任意数量的文档。第一个是 MapReduce。
这种方法类似于经典的MapReduce:我们根据每个检索到的文档生成答案(map 阶段),然后将这些答案合并成最终答案(reduce 阶段)。
作者的方案
所有这些方法的局限性在于成本和速度。你需要为每个检索到的文档进行一次调用,而不是一次调用 LLM。
关于代码,我们只需要指定chain_type="map_reduce"
以改变行为。
qa_chain_mr = RetrievalQA.from_chain_type(
llm,
retriever=vectordb.as_retriever(),
chain_type="map_reduce"
)
result = qa_chain_mr({"query": "what customers like about staff in the hotel?"})
结果,我们得到了以下输出。
让我们在调试模式下看看它是如何工作的。由于这是 MapReduce,我们首先将每个文档发送到 LLM,并根据这个块得到答案。以下是其中一个块的提示示例。
然后,我们将所有结果结合起来,并要求 LLM 给出最终答案。
就这样。
MapReduce 方法还有另一个特定的缺点。模型分别看到每个文档,而不是将它们全部放在同一上下文中,这可能导致更差的结果。
我们可以通过 Refine 链类型克服这个缺点。然后,我们将按顺序查看文档,并允许模型在每次迭代时细化答案。
作者的方案
再次,我们只需要更改chain_type
以测试另一种方法。
qa_chain_refine = RetrievalQA.from_chain_type(
llm,
retriever=vectordb.as_retriever(),
chain_type="refine"
)
result = qa_chain_refine({"query": "what customers like about staff in the hotel?"})
使用 Refine 链,我们得到了一个更详细和完整的答案。
让我们使用调试模式看看它是如何工作的。对于第一个块,我们从头开始。
然后,我们传递当前的答案和一个新的块,并给模型一个机会来细化其答案。
然后,我们对每个剩余的检索文档重复细化提示,最终得到结果。
今天我想告诉你的就这些。让我们快速回顾一下。
总结
在本文中,我们详细介绍了检索增强生成的整个过程:
-
我们已经查看了不同的数据加载器。
-
我们讨论了数据拆分的可能方法及其潜在的细微差别。
-
我们了解了什么是嵌入,并建立了一个向量存储库以有效地访问数据。
-
我们找到了检索问题的不同解决方案,并学习了如何增加多样性、克服上下文大小限制以及使用元数据。
-
最后,我们使用了
RetrievalQA
链来生成基于我们数据的答案,并比较了不同的链类型。
这些知识应该足够你开始使用自己的数据构建类似的东西。
非常感谢您阅读本文。希望本文对您有所启发。如果您有任何后续问题或评论,请在评论部分留言。
数据集
*Ganesan, Kavita 和 Zhai, ChengXiang. (2011). OpinRank 评论数据集.
UCI 机器学习库 (CC BY 4.0).* https://doi.org/10.24432/C5QW4W
参考资料
本文基于以下课程信息:
-
“LLM 应用开发的 LangChain”由 DeepLearning.AI 和 LangChain 提供,
-
“LangChain:与您的数据聊天”由 DeepLearning.AI 和 LangChain 提供。
RAG 与微调——哪种是提升你的 LLM 应用的最佳工具?
为你的用例选择正确方法的权威指南
·发表于 Towards Data Science ·19 分钟阅读·2023 年 8 月 24 日
--
作者提供的图片
前言
随着对大型语言模型(LLMs)兴趣的浪潮涌动,许多开发者和组织忙于构建利用其强大功能的应用。然而,当现成的预训练 LLMs 未能如预期般表现时,如何提升 LLM 应用性能的问题就会浮现。最终,我们会问自己:我们应该使用检索增强生成(RAG)还是模型微调来改善结果?
在深入探讨之前,让我们揭开这两种方法的神秘面纱:
RAG:这种方法将检索(或搜索)的力量集成到 LLM 文本生成中。它结合了一个检索系统,该系统从大规模语料库中获取相关文档片段,以及一个 LLM,该 LLM 使用这些片段中的信息生成答案。实质上,RAG 帮助模型“查找”外部信息以改进其响应。
作者提供的图片
微调:这是将预训练 LLM 进一步在一个较小的特定数据集上进行训练的过程,以便将其适应于特定任务或提升其性能。通过微调,我们根据我们的数据调整模型的权重,使其更加符合我们应用的独特需求。
作者提供的图片
RAG 和微调都是提升基于 LLM 的应用性能的强大工具,但它们针对优化过程的不同方面,这在选择其中一个时至关重要。
之前,我经常建议组织在深入微调之前先尝试 RAG。这是基于我对这两种方法虽然实现了类似结果但在复杂性、成本和质量上有所不同的看法。我甚至曾用类似的图表来说明这一点:
作者提供的图片
在这个图表中,各种因素如复杂性、成本和质量沿着单一维度表示。要点是什么?RAG 更简单且成本更低,但其质量可能无法匹配。我的建议通常是:从 RAG 开始,评估其性能,如果发现不足,则转向微调。
然而,我的观点已经有所变化。我认为将 RAG 和微调视为实现相同结果的两种技术,这种看法过于简化,因为一种技术更便宜、复杂度更低而已。它们本质上是不同的——而不是共线,而是正交——且满足 LLM 应用的不同需求。
为了更清楚地说明这一点,可以考虑一个简单的现实世界类比:当被问到“我应该用刀子还是勺子来吃饭?”时,最合逻辑的反问是:“你在吃什么?”我问过朋友和家人,每个人本能地回答了这个反问,这表明他们并不认为刀子和勺子是可以互换的,或者其中一种是另一种的低级变体。
这是什么内容?
在这篇博客中,我们将深入探讨区分 RAG 和微调的各种维度,这些维度在我看来对于确定特定任务的最佳技术至关重要。此外,我们将查看一些最受欢迎的 LLM 应用场景,并使用第一部分确定的维度来识别哪种技术最适合哪些用例。在博客的最后部分,我们将识别在构建 LLM 应用时应考虑的其他方面。每一个方面可能都值得单独撰写博客,因此在本篇博客中我们只能简要提及。
你为什么要关心这个?
选择适合调整大型语言模型的技术对 NLP 应用的成功有重大影响。选择错误的方法可能导致:
-
在你的具体任务上,模型表现不佳,导致输出不准确。
-
如果技术未针对你的使用案例进行优化,模型训练和推理的计算成本会增加。
-
如果需要以后转向不同的技术,则会增加额外的开发和迭代时间。
-
部署应用程序并让其面向用户的延迟。
-
如果选择了过于复杂的适配方法,则可能缺乏模型解释性。
-
由于大小或计算约束,难以将模型部署到生产环境中。
RAG 和微调之间的细微差别涵盖了模型架构、数据需求、计算复杂性等多个方面。忽视这些细节可能会破坏你的项目时间表和预算。
本博客旨在通过清晰地阐述何时使用每种技术来防止浪费努力。通过这些见解,你可以从第一天起就以正确的适应方法迅速开展工作。详细的比较将使你能够做出最佳的技术选择,以实现你的商业和 AI 目标。本指南将帮助你选择合适的工具,为你的项目成功奠定基础。
那么,让我们深入探讨吧!
提升性能的关键考虑因素
在选择 RAG 还是微调之前,我们应该评估 LLM 项目的需求,并在一些维度上提出几个问题。
我们的用例是否需要访问外部数据源?
在选择对 LLM 进行微调还是使用 RAG 时,一个关键的考虑因素是应用程序是否需要访问外部数据源。如果答案是肯定的,RAG 可能是更好的选择。
RAG 系统的定义是通过在生成回应之前从知识来源检索相关信息,以增强 LLM 的能力。这使得这种技术非常适合需要查询数据库、文档或其他结构化/非结构化数据存储的应用程序。检索器和生成器组件可以被优化以利用这些外部来源。
相比之下,虽然可以对 LLM 进行微调以学习一些外部知识,但这需要来自目标领域的大量标注数据集。随着基础数据的变化,这些数据集必须不断更新,使得它不适用于频繁变化的数据源。微调过程也没有明确建模在查询外部知识时涉及的检索和推理步骤。
总结来说,如果我们的应用程序需要利用外部数据源,使用 RAG 系统可能会比单纯依靠微调“内置”所需知识更有效和可扩展。
我们是否需要修改模型的行为、写作风格或领域特定知识?
另一个非常重要的方面是我们需要模型在多大程度上调整其行为、写作风格或为领域特定应用量身定制其回应。
微调在将 LLM 的行为适应于特定的细微差别、语气或术语方面表现出色。如果我们希望模型听起来更像医学专业人员、用诗意的风格写作,或使用特定行业的术语,在领域特定数据上进行微调可以实现这些定制。这种影响模型行为的能力对于需要与特定风格或领域专长一致的应用程序至关重要。
RAG 虽然在整合外部知识方面强大,但主要关注信息检索,并不会根据检索到的信息自我调整语言风格或领域特定性。它将从外部数据源中提取相关内容,但可能不会展现出微调模型所能提供的定制化细微差别或领域专业知识。
因此,如果我们的应用程序需要特定的写作风格或深度对齐于特定领域的术语和惯例,微调提供了实现这种对齐的更直接途径。它提供了深入的定制,使得内容真正与特定受众或专业领域产生共鸣,确保生成的内容感觉真实且信息丰富。
快速回顾
这两个方面是决定使用哪种方法来提升 LLM 应用性能时最重要的方面。有趣的是,我认为它们是正交的,可以独立使用(也可以组合使用)。
图片由作者提供
但在深入使用案例之前,我们还应该考虑几个关键方面来选择方法:
抑制幻觉有多重要?
LLM 的一个缺点是它们有产生幻觉的倾向——编造没有现实依据的事实或细节。这在准确性和真实性至关重要的应用中可能会非常有问题。
微调可以在一定程度上帮助减少幻觉,通过将模型基于特定领域的训练数据。然而,当面对不熟悉的输入时,模型仍可能会编造响应。需要对新数据进行再训练以持续最小化虚假伪造。
相比之下,RAG 系统本质上较不容易产生幻觉,因为它们将每个响应基于检索到的证据。检索器在生成器构造答案之前会从外部知识源中识别相关事实。这一步骤充当了一个事实检查机制,减少了模型虚构的能力。生成器被限制在基于检索到的上下文合成响应。
因此,在需要抑制虚假信息和想象性伪造至关重要的应用中,RAG 系统提供了内建机制来最小化幻觉。生成响应之前的证据检索使 RAG 在确保输出真实准确方面具有优势。
有多少标记好的训练数据可用?
在决定 RAG 和微调之间时,一个关键因素是我们手头上是否有大量的领域或任务特定的标记训练数据。
对 LLM 进行微调以适应特定任务或领域严重依赖于可用标注数据的质量和数量。丰富的数据集可以帮助模型深入理解特定领域的细微差别、复杂性和独特模式,从而生成更准确和上下文相关的回应。然而,如果我们使用的是有限的数据集,微调的改善可能会很小。在某些情况下,稀少的数据集甚至可能导致过拟合,模型在训练数据上表现良好,但在未见或现实世界输入中表现不佳。
相反,RAG 系统独立于训练数据,因为它们利用外部知识源来检索相关信息。即使我们没有广泛的标注数据集,RAG 系统仍然可以通过访问和整合外部数据源的见解来有效地工作。检索和生成的结合确保了系统保持知情,即使在领域特定的训练数据稀缺时。
实质上,如果我们拥有大量标注数据,能够捕捉领域的复杂性,微调可以提供更加定制化和精细化的模型行为。但在数据有限的情况下,RAG 系统提供了一个稳健的替代方案,通过其检索能力确保应用保持数据驱动和上下文相关。
数据的静态/动态特性如何?
在选择 RAG 和微调时,另一个需要考虑的基本方面是我们数据的动态特性。数据更新的频率如何?对于模型保持最新状态的重要性如何?
对特定数据集进行 LLM 微调意味着模型的知识成为训练时数据的静态快照。如果数据经常更新、变化或扩展,这会迅速使模型过时。为了保持 LLM 在这种动态环境中的时效性,我们必须频繁地重新训练模型,这个过程既耗时又资源密集。此外,每次迭代都需要仔细监控,以确保更新后的模型在不同场景中仍表现良好,并且没有出现新的偏见或理解上的漏洞。
相对而言,RAG 系统在动态数据环境中天生具有优势。它们的检索机制不断查询外部源,确保用于生成回应的信息是最新的。随着外部知识库或数据库的更新,RAG 系统无缝地整合这些变化,保持其相关性而无需频繁的模型重新训练。
总结来说,如果我们在应对快速发展的数据环境,RAG 提供了一种与传统微调难以匹敌的灵活性。通过始终连接到最新的数据,RAG 确保生成的回应与当前的信息状态保持一致,使其成为动态数据场景的理想选择。
我们的 LLM 应用需要多透明/可解释?
最后一个要考虑的方面是我们需要对模型决策过程的洞察程度。
微调 LLM 虽然极具威力,但像一个黑箱一样运行,使得其响应背后的推理更为不透明。随着模型从数据集中内化信息,识别每个响应背后的确切来源或推理变得困难。这可能会使开发者或用户难以信任模型的输出,特别是在需要理解答案背后“为什么”的关键应用中。
RAG 系统提供了一种在纯粹微调模型中通常找不到的透明度。鉴于 RAG 的两步性质——检索和生成——用户可以窥见其过程。检索组件允许检查哪些外部文档或数据点被选择为相关。这提供了可以评估的具体证据或参考,以理解响应的基础。追溯模型答案到特定数据源的能力在需要高度责任感或需要验证生成内容准确性的应用中非常宝贵。
本质上,如果透明性和解释模型响应基础的能力是优先事项,RAG 提供了明确的优势。通过将响应生成分解为不同阶段并允许洞察其数据检索,RAG 促进了对其输出的更大信任和理解。
总结
在考虑这些维度时,选择 RAG 和微调变得更为直观。如果我们倾向于访问外部知识并重视透明性,RAG 是我们的首选。另一方面,如果我们处理的是稳定的标注数据,并且目标是更紧密地调整模型以满足特定需求,微调则是更好的选择。
作者提供的图片
在接下来的部分,我们将看到如何根据这些标准评估流行的 LLM 使用案例。
使用案例
让我们看看一些流行的使用案例,以及上述框架如何用于选择正确的方法:
总结(在专业领域和/或特定风格中)
1. 是否需要外部知识? 对于以之前的总结风格进行总结的任务,主要数据来源将是之前的总结本身。如果这些总结包含在静态数据集中,则无需持续的外部数据检索。然而,如果存在一个动态的总结数据库,经常更新且目标是不断地将风格与最新条目对齐,RAG 可能会在这里发挥作用。
2. 是否需要模型适应? 这个用例的核心在于适应专业领域和/或特定写作风格。微调特别擅长捕捉风格细微差别、语调变化和特定领域的词汇,使其成为这一维度的最佳选择。
3. 是否至关重要的是最小化虚构? 在大多数大型语言模型应用中,包括总结,虚构都是一个问题。然而,在这个用例中,被总结的文本通常作为上下文提供。这使得虚构问题相较于其他用例的关注度较低。源文本限制了模型,减少了虚构的可能性。因此,尽管事实准确性总是值得追求,但在总结中压制虚构的优先级较低,因为有上下文作为基础。
4. 是否有训练数据? 如果有大量标记或结构化的先前总结,模型可以从中学习,那么微调将是一个非常有吸引力的选择。另一方面,如果数据集有限,而我们依赖外部数据库来进行风格对齐,RAG 可能会发挥作用,尽管它的主要强项不是风格适应。
5. 数据的动态性如何? 如果先前总结的数据库是静态的或更新不频繁,那么微调模型的知识可能会在较长时间内保持相关。然而,如果总结更新频繁,并且需要模型不断地与最新的风格变化保持一致,RAG 可能由于其动态数据检索能力而具有优势。
6. 是否需要透明性/可解释性? 这里的主要目标是风格上的一致性,因此某种特定总结风格背后的“原因”可能不像其他用例那样关键。不过,如果需要追溯并了解哪些先前的总结影响了特定输出,RAG 提供了更多的透明度。尽管如此,这对该用例来说可能是次要问题。
推荐: 对于这个用例,微调 似乎是更合适的选择。主要目标是风格一致性,这是微调擅长的一个维度。如果有足够多的先前总结用于训练,微调一个大型语言模型将允许对所需风格进行深度适应,捕捉领域的细微差别和复杂性。然而,如果总结数据库极其动态,并且追溯影响具有价值,可以探索混合方法或倾向于 RAG。
关于组织知识的问答系统(即外部数据)
1. 是否需要外部知识? 依赖于组织知识库的问答系统本质上需要访问外部数据,在这种情况下,即组织的内部数据库和文档存储。系统的有效性取决于其从这些来源中提取和检索相关信息的能力。鉴于此,RAG 在这一维度上更为合适,因为它设计用来通过从知识来源中检索相关数据来增强 LLM 的能力。
2. 是否需要模型调整? 根据组织和其领域的不同,模型可能需要与特定术语、语调或惯例保持一致。虽然 RAG 主要关注信息检索,但微调可以帮助 LLM 调整其回应,以适应公司的内部用语或其领域的细微差别。因此,在这一维度上,根据具体需求,微调可能会发挥作用。
3. 是否至关重要以减少幻觉? 幻觉在这种用例中是一个主要关注点,因为 LLM 的知识截止点。如果模型无法根据其训练数据回答问题,它几乎肯定会(部分或完全)编造一个合理但不正确的答案。
4. 是否有可用的训练数据? 如果组织有结构化和标记化的历史问答数据集,这可以增强微调方法。然而,并非所有内部数据库都已标记或结构化用于训练目的。在数据未被整齐标记或主要关注于检索准确相关答案的情况下,RAG 能够访问外部数据源而无需大量标记数据集,使其成为一个令人信服的选择。
5. 数据的动态性如何? 组织内部的数据库和文档存储可以是高度动态的,经常更新、更改或添加。如果这种动态性是组织知识库的特征,RAG 则具有明显的优势。它不断查询外部来源,确保其回答基于最新的数据。微调则需要定期重新训练以跟上这些变化,这可能不切实际。
6. 是否需要透明度/可解释性? 对于内部应用,特别是在金融、医疗或法律等行业,理解答案背后的推理或来源可能至关重要。由于 RAG 提供了检索和生成的两步过程,它本质上提供了更清晰的洞察力,显示了哪些文档或数据点影响了特定的答案。这种可追溯性对内部利益相关者来说非常宝贵,他们可能需要验证或进一步调查某些答案的来源。
建议: 对于这种使用场景,RAG 系统似乎是更合适的选择。鉴于需要动态访问组织不断发展的内部数据库以及可能需要回答过程中的透明度,RAG 提供的功能与这些需求相契合。然而,如果对模型的语言风格或领域特定细节有重大关注,则可以考虑结合微调的元素。
客户支持自动化(即提供即时响应的自动聊天机器人或帮助台解决方案)
1. 是否需要外部知识? 客户支持通常需要访问外部数据,特别是处理产品详细信息、账户特定信息或故障排除数据库时。虽然许多查询可以通过一般知识来解决,但有些可能需要从公司数据库或产品 FAQ 中提取数据。在这里,RAG 从外部来源检索相关信息的能力将非常有益。然而,值得注意的是,许多客户支持互动也基于预定义的脚本或知识,这些可以通过微调的模型有效解决。
2. 是否需要模型适配? 客户互动需要特定的语调、礼貌和清晰度,并且可能还需要公司特有的术语。微调特别有助于确保语言模型适应公司的声音、品牌和特定术语,从而确保一致且符合品牌的客户体验。
3. 是否必须尽量减少幻觉? 对于客户支持聊天机器人来说,避免虚假信息对于维持用户信任至关重要。仅靠微调会使模型在面对不熟悉的查询时容易出现幻觉。相比之下,RAG 系统通过基于检索到的证据来抑制虚构。这种对获取事实的依赖使 RAG 聊天机器人能够减少有害的虚假信息,并在准确性至关重要的情况下为用户提供可靠的信息。
4. 是否有训练数据可用? 如果公司有客户互动的历史记录,这些数据对于微调非常宝贵。以前客户查询及其解决方案的丰富数据集可以用来训练模型,以便未来处理类似的互动。如果这样的数据有限,RAG 可以通过从外部来源(如产品文档)检索答案来提供备用方案。
5. 数据的动态性如何? 客户支持可能需要处理有关新产品、更新的政策或变化的服务条款的查询。在产品阵容、软件版本或公司政策频繁更新的情况下,RAG 动态从最新文档或数据库中提取信息的能力具有优势。另一方面,对于更静态的知识领域,微调可能就足够了。
6. 是否需要透明度/可解释性? 虽然在某些领域透明度很重要,但在客户支持中,主要关注的是准确、快速和礼貌的响应。然而,对于内部监控、质量保证或处理客户争议,了解答案来源的可追溯性可能是有益的。在这种情况下,RAG 的检索机制提供了额外的透明度。
推荐: 对于客户支持自动化,混合方法可能是最佳选择。调整可以确保聊天机器人符合公司的品牌、语气和一般知识,处理大部分典型的客户查询。RAG 可以作为补充系统,处理更动态或特定的查询,确保聊天机器人可以从最新的公司文档或数据库中提取信息,从而减少虚假信息的生成。通过整合这两种方法,公司可以提供全面、及时和品牌一致的客户支持体验。
作者图片
需要考虑的额外方面
如上所述,决定 RAG 和调整(或两者结合)时还有其他因素需要考虑。我们无法深入探讨这些因素,因为它们都具有多面性,并没有像上述某些方面那样明确的答案(例如,如果没有训练数据,调整根本不可能)。但这并不意味着我们应忽视它们:
可扩展性
随着组织的成长及需求的演变,所用方法的可扩展性如何?由于 RAG 系统具有模块化特性,可能提供更直接的可扩展性,尤其是当知识库增长时。另一方面,频繁调整模型以适应扩展的数据集可能计算量巨大。
延迟和实时要求
如果应用程序需要实时或近实时的响应,请考虑每种方法引入的延迟。RAG 系统涉及在生成响应前检索数据,相比于基于内在知识生成响应的调整后的 LLM,可能会引入更多的延迟。
维护和支持
从长远角度考虑。哪个系统更符合组织提供一致维护和支持的能力?RAG 可能需要维护数据库和检索机制,而调整则需要持续的重新训练,特别是数据或需求发生变化时。
鲁棒性和可靠性
每种方法对不同类型输入的鲁棒性如何?尽管 RAG 系统可以从外部知识源中提取信息,可能处理各种问题,而经过良好调整的模型在某些领域可能提供更多的一致性。
伦理和隐私问题
从外部数据库存储和检索数据可能引发隐私问题,特别是当数据敏感时。另一方面,尽管微调模型不查询实时数据库,但它仍可能基于其训练数据产生输出,这可能有其自身的伦理影响。
与现有系统的集成
组织可能已经有某些基础设施到位。RAG 或微调与现有系统(无论是数据库、云基础设施还是用户界面)的兼容性可以影响选择。
用户体验
考虑最终用户及其需求。如果他们需要详细的、基于参考的答案,RAG 可能更为合适。如果他们重视速度和领域特定的专业知识,微调模型可能更适合。
成本
微调可能会变得非常昂贵,特别是对于非常大的模型。然而,在过去几个月中,由于像QLoRA这样的参数高效技术,成本已大幅下降。设置 RAG 可能需要大额的初始投资——包括集成、数据库访问,甚至可能还有许可费用——但还需要考虑对外部知识库的定期维护。
复杂性
微调可能会迅速变得复杂。虽然许多提供商现在提供一键微调,只需提供训练数据,但跟踪模型版本并确保新模型在各方面仍表现良好是具有挑战性的。另一方面,RAG 也可能迅速变得复杂。涉及多个组件的设置,确保数据库保持最新,并确保各个部分——如检索和生成——恰到好处地配合在一起。
结论
正如我们所探讨的,选择 RAG 还是微调需要对 LLM 应用的独特需求和优先级进行细致的评估。没有一种放之四海而皆准的解决方案;成功在于将优化方法与任务的具体要求对齐。通过评估关键标准——对外部数据的需求、模型行为的调整、训练数据的可用性、数据动态、结果透明度等——组织可以做出明智的决策,确定最佳前进路径。在某些情况下,利用 RAG 和微调的混合方法可能是最优的。
关键在于避免假设某种方法在所有情况下都是优越的。像任何工具一样,它们的适用性取决于具体的任务。方法和目标的不匹配可能会阻碍进展,而正确的方法则会加速进展。在组织评估提升 LLM 应用的选项时,必须抵制过度简化的倾向,不应将 RAG 和微调视为可以互换的工具,而是要选择能够使模型充分发挥其能力并与用例需求对齐的工具。这些方法所解锁的可能性令人惊叹,但仅有可能性是不够的——执行才是关键。工具已经在这里——现在让我们开始使用它们。
海科·霍茨
👋 关注我在Medium和LinkedIn上的动态,阅读更多关于生成式 AI、机器学习和自然语言处理的内容。
👥 如果你在伦敦,可以加入我们的NLP London Meetups。
📔 我对 AI 新闻的想法见😇 Naughty Neural。
作者提供的图片
使用 ggplot2 提高对气候变化的意识
原文:
towardsdatascience.com/raise-awareness-about-climate-change-with-ggplot2-f31f0cae3c70
学会有效绘制历史天气数据
·发表于 Towards Data Science ·8 分钟阅读·2023 年 4 月 17 日
--
图片由 Ganapathy Kumar 在 Unsplash 提供
全球变暖不是预测,而是正在发生的现实。
詹姆斯·汉森
有确凿证据表明地球上的温度正在上升。随着气候变化威胁到人类的生存,了解、研究和提高对这一关键问题的认识比以往任何时候都更为重要。
无论你是学生、政府工作人员、非政府组织成员还是私人公司员工,向同事展示你对相关全球问题的关注是非常重要的。
在本教程中,你将学习如何找到可靠的历史温度数据并使用 ggplot2 将其可视化。在你完成这篇文章后,你将:
-
知道在哪里找到精心整理的历史天气数据集;
-
感到舒适地使用 ggplot2 绘制历史天气数据;
-
能够自定义你的 ggplot2 图表以讲述你的故事。
第一步:查找并加载数据
本教程的数据可在 国家环境信息中心 (NCEI)*** 上获取。NCEI 是美国环境数据的权威机构,提供有关气候、生态系统和水资源的高质量数据。全球年度总结(GSOY)数据集提供按城市和站点划分的历史天气数据。在本教程中,我们将使用来自加州伯克利的数据。如果你愿意,你可以选择你喜欢的城市。如果要使用与本教程相同的数据集,请搜索伯克利并选择包含自 1893 年以来的数据的文件。
文件将通过 read_csv
加载。唯一的参数是文件路径。数据框加载后,我们仅选择 DATE
和 TAVG
变量。DATE
包含观察到温度的年份,TAVG
是以摄氏度表示的年均温度。要了解更多可用变量的信息,请参考 数据集说明书。
library(readr)
library(dplyr)
df <- read_csv('USC00040693.csv') %>%
select("DATE", "TAVG")
summary(df)
R 的 summary()
函数告诉我们数据的范围从 1893 年到 2019 年,在此期间观察到的最小年均温度为 12.9 ºC(地点:加州伯克利)。最大年均温度为 15.93 ºC。它还显示有 33 个温度数据缺失。
第 2 步:使用 na_interpolation()
填补缺失值
由于我们正在处理时间序列,我们将使用线性插值填补缺失值。这种方法假设在缺失期间数据线性变化。实际上,当您使用折线图绘制时间序列时,观察间隔也会用连接两个点的直线填补。
要进行线性插值,我们将使用 imputeTS 包。安装并加载库后,您可以使用 na_interpolation()
填补缺失值。传递两个参数:第一个是您希望处理的数据框列,第二个是您希望用来执行插补的方法。
library(imputeTS)
df$TAVG <- na_interpolation(df$TAVG, option ="linear")
第 3 步:编码我们图表的第一个版本
ggplot2 可视化由多个层组成。如下面的图所示,每一层包含一个 geom 对象,即您在图表中看到的一个元素(例如线条和点)。
图片由作者创建
首先,您需要将数据集传递给 ggplot()
函数。其次,您将变量映射到美学属性——geom 对象的视觉属性。例如,美学属性包括 y 轴位置、x 轴位置、颜色或大小。下面,我们还设置了黑白 ggplot2 主题。如果不添加其他 geom 对象,图表将只有两个轴。
library(ggplot2)
theme_set(theme_bw())
axes <- ggplot(data = df, aes(x = DATE, y = TAVG))
axes
图片由作者创建
现在,您可以添加第二层,用点表示时间上的温度。注意,您可以使用“+”符号将此层添加到前一步骤中制作的图中。
axes +
geom_point()
图片由作者创建
最后,您可以添加第三层,其中包含线条。重要的是要指出,一些作者声称这些线条并不代表观察到的数据,应谨慎使用。有关详细讨论,请查看 《数据可视化基础》第十三章*,作者 Claus O. Wilke。
图片由作者创建
第 4 步:自定义您的图表
在这一部分,你将学习如何定制你的图表,使其既清晰又美观。
首先,为了使温度的增加更加明显,我们将点的颜色美学也映射到TAVG
。由于它是一个数值变量,ggplot2 将使用渐变色来表示连续的值。你可以使用scale_color_gradient()
函数选择代表低温和高温的颜色。
此外,你可以分别使用xlab()
和ylab()
设置 x 和 y 轴的标签。标题可以通过ggtitle()
添加。我们还将增加点的大小,并添加透明度以使重叠的数据可见。
爱德华·塔夫特,数据可视化领域的专家,建议最大限度地利用墨水来显示非冗余数据。作者声称,这会使你的图表更清晰,避免分散读者的注意力。
我们正在使用的 ggplot2 主题theme_bw()
已经符合 Tufte 的建议,但我们仍然可以去除图表的面板网格。为了实现这一点,使用theme()
函数并传递两个参数:panel.grid.minor = element_blank()
和panel.grid.major = element_blank()
。
图片由作者创建
第 5 步:为你的可视化创建一个主题
你现在将学习如何创建你自己的 ggplot2 主题。作为示例,我们将创建theme_tds()
。
首先,我们将加载 Google 字体“Source Serif Pro”。它是 Medium 文章中使用的字体。你可以通过showtext
包轻松加载它。如果你没有这个包,请安装它。安装包后,加载它并使用font_add_google()
函数来加载“Source Serif Pro”。我们还告诉 R 使用showtext
来渲染文本,通过showtext_auto()
。
library(showtext)
font_add_google("Source Serif Pro")
showtext_auto()
注意,一些作者建议在图表中只使用无衬线字体。请查看这篇文章,以了解关于这一问题的讨论。
现在我们将使用theme()
来定制图表。下图展示了一些你可以使用的参数。有关完整的列表,请查看这个ggplot2 参考。
图片由作者创建
你可以通过调用包含你自定义规范的 ggplot2 theme()
函数来创建一个新的主题。注意,我们从黑白主题(theme_bw
)开始,然后去除网格,改变背景、面板和文本颜色。为了便于将来的修改,创建了两个参数供用户指定所需的文本、面板和背景颜色。
theme_tds <- function(text_panel_color, background_color) {
theme_bw()+
theme(text=element_text(size=10,
family="Source Serif Pro",
color = text_panel_color),
# Eliminates grids
panel.grid.minor = element_blank(),
panel.grid.major = element_blank(),
# Changes panel, plot and legend background
panel.background = element_rect(fill = background_color),
plot.background = element_rect(fill = background_color),
legend.background = element_rect(fill= background_color),
# Changes legend texts color
legend.title = element_text(color = text_panel_color),
# Changes plot border color and size
panel.border = element_rect(size = 1, color = text_panel_color),
# Changes color of axis texts
axis.text.x = element_text(color = text_panel_color),
axis.text.y = element_text(color = text_panel_color),
axis.title.x = element_text(color= text_panel_color),
axis.title.y = element_text(color= text_panel_color),
# Changes axis ticks color
axis.ticks.y = element_line(color = text_panel_color),
axis.ticks.x = element_line(color = text_panel_color),
)
}
现在你可以简单地将theme_tds()
添加到你的图表中,并指定你喜欢的颜色。以下是一个示例:
ggplot(data = df, aes(x = DATE, y = TAVG, color = TAVG))+
geom_point(size = 4, alpha = 0.7)+
scale_color_gradient(name = "ºC", low = "#4F88EC", high = "#ec4f88")+
ggtitle("Historical air temperature trend in Berkeley, CA")+
xlab("Year")+
ylab("Annual Mean Temperature [ºC]")+
theme_tds(text_panel_color = "white",
background_color = "#252525")
图片由作者创建
另一个例子,背景为白色,字体颜色为 Towards Data Science:
ggplot(data = df, aes(x = DATE, y = TAVG, color = TAVG))+
geom_point(size = 4, alpha = 0.7)+
scale_color_gradient(name = "ºC", low = "#4F88EC", high = "#ec4f88")+
ggtitle("Historical air temperature trend in Berkeley, CA")+
xlab("Year")+
ylab("Annual Mean Temperature [ºC]")+
theme_tds(text_panel_color = "#365A77",
background_color = "white")
图片由作者创建
最后,你可以使用 LOESS(局部回归散点平滑)平滑器来展示温度趋势,正如 Claus O. Wilke 在《数据可视化基础》第十四章中推荐的那样。你可以通过添加一个包含元素 geom_smooth()
的 ggplot2 图层来实现。
ggplot(data = df, aes(x = DATE, y = TAVG, color = TAVG))+
geom_point(size = 4, alpha = 0.7)+
geom_smooth(color = "#365A77", se = FALSE)+
scale_color_gradient(name = "ºC", low = "#4F88EC", high = "#ec4f88")+
ggtitle("Historical air temperature trend in Berkeley, CA")+
xlab("Year")+
ylab("Annual Mean Temperature [ºC]")+
theme_tds(text_panel_color = "#365A77",
background_color = "white")
图片由作者创建
结论
ggplot2 是一个强大的 R 库,允许你创建和自定义引人入胜的可视化。在本文中,你学会了如何使用它创建一个图表,以提高对全球变暖的认识,并使用了来自NCEI 网站的可靠数据。
如果你对研究和可视化气候数据的更多方法感兴趣,可以查看这篇文章,其中我通过回归分析关联了碳排放和空气温度:R 编程在气候数据分析和可视化中的应用
我希望这篇文章能为你提供新的数据可视化视角,使你的图表更具效果和吸引力。
*数据集使用条款
根据国家海洋和大气管理局(NOAA)的网站,“政府网页上的信息属于公共领域,在美国不受版权保护,除非另有特别说明(版权可能在其他地方持有)。”
随机森林与缺失值
这是一个非常引人入胜的实际解决方案
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 6 月 21 日
--
(分布式)随机森林的特点。在这篇文章中:处理缺失值的能力。来源:作者。
除了在网上找到的一些过度清理的数据集之外,缺失值无处不在。实际上,数据集越复杂、规模越大,缺失值出现的可能性就越高。缺失值是统计研究中的一个有趣领域,但在实际应用中,它们常常是一种麻烦。
如果你处理的是预测问题,希望从 p- 维协变量 X=(X_1,…,X_p)* 中预测一个变量 Y,并且你在 X 中面临缺失值,对于基于树的方法有一个有趣的解决方案。这种方法实际上已经存在很久,但在广泛的数据集中表现得非常好。我所说的是“缺失值纳入属性标准”(MIA;[1])。虽然关于缺失值有许多好的文章(如这篇),但这种强大的方法似乎在使用上有所欠缺。特别是,你不需要以任何方式填补、删除或预测缺失值,而是可以像处理完全观测到的数据一样进行预测。
我将快速解释方法本身的工作原理,然后展示一个用分布随机森林(DRF)进行的示例,详细说明见这里。我选择 DRF 是因为它是随机森林的一个非常通用的版本(特别是,它也可以用来预测随机向量Y),并且我在这里有些偏向。MIA 实际上是为广义随机森林(GRF)实现的,涵盖了广泛的森林实现。特别是,由于 CRAN 上的 DRF 实现基于 GRF,经过轻微修改后,它也可以使用 MIA 方法。
当然,要注意这是一个快速修复(据我所知)没有理论保障。根据缺失机制,这可能会严重偏倚分析。另一方面,大多数处理缺失值的常用方法没有任何理论保障,或者明确已知会偏倚分析,至少在经验上,MIA 似乎效果很好。
它是如何工作的
请记住,在 RF 中,拆分的形式为 X_j < S 或 X_j ≥ S,对于维度 j=1,…,p。为了找到这个拆分值 S,它优化某种关于 Y 的标准,例如 CART 标准。因此,观察值通过依赖于 X 的决策规则逐步划分。
RF 中拆分的示意图。图片作者提供。
原始论文解释得有些混乱,但根据我的理解,MIA 的工作原理如下:我们考虑一个样本(Y_1,X_1),……,(Y_n,* X_n),其中*
X_i=(X_i1,…,X_ip)’。*
没有缺失值的拆分只是寻找如上所述的值 S,然后将所有 Y_i,其中 X_ij < S,丢入节点 1,将所有 Y_i,其中 X_ij ≥ S,丢入节点 2。计算每个值 S 的目标标准,例如 CART,我们可以选择最佳的一个。有缺失值时,对于每个候选拆分值 S,需要考虑 3 种选项:
-
对于所有观察值i,如果X_ij被观察到,则使用通常的规则,如果X_ij缺失,则将i发送到节点 1。
-
对于所有观察值 i,如果X_ij被观察到,则使用通常的规则,如果X_ij缺失,则将i发送到节点 2。
-
忽略通常的规则,只要X_ij缺失就将 i 发送到节点 1,若被观察到则发送到节点 2。
遵循这些规则的选择再次根据我们对Y_i的标准来决定。
说明我如何理解 MIA 过程。给定父节点中的观察值,我们寻找最佳分裂值 S。对于每个分裂值,我们考虑 3 个选项并进行尝试,直到找到最小值。左侧的{}表示被发送到左侧或右侧的观察值 i。图片由作者提供。
一个小例子
需要提到的是,CRAN上的 drf 包尚未更新到最新的方法。未来会有一个时间点,所有这些将会在 CRAN 上的一个包中实现!然而,目前有两个版本:
如果你想使用带有缺失值的快速 drf 实现(不带置信区间),可以使用附在这篇文章末尾的“drfown”函数。这段代码改编自
lorismichel/drf: Distributional Random Forests (Cevid et al., 2020) (github.com)
如果你想要使用带有置信区间的参数,请使用这段(较慢的)代码。
drfinference/drf-foo.R at main · JeffNaef/drfinference (github.com)
特别是,drf-foo.R 包含了在后一种情况下所需的一切。
我们将关注带有置信区间的较慢代码,如在这篇文章中所解释,并考虑与该文章中相同的例子:
set.seed(2)
n<-2000
beta1<-1
beta2<--1.8
# Model Simulation
X<-mvrnorm(n = n, mu=c(0,0), Sigma=matrix(c(1,0.7,0.7,1), nrow=2,ncol=2))
u<-rnorm(n=n, sd = sqrt(exp(X[,1])))
Y<- matrix(beta1*X[,1] + beta2*X[,2] + u, ncol=1)
请注意,这是一个异方差线性模型,p=2,误差项的方差取决于X_1的值。现在我们还以“随机缺失”(MAR)的方式向X_1中添加缺失值:
prob_na <- 0.3
X[, 1] <- ifelse(X[, 2] <= -0.2 & runif(n) < prob_na, NA, X[, 1])
这意味着X_1在X_2的值小于-0.2 时,以 0.3 的概率缺失。因此,X_1缺失的概率取决于X_2,这被称为“随机缺失”(Missing at Random)。这已经是一个复杂的情况,通过查看缺失值的模式可以获得信息。也就是说,缺失情况不是“完全随机缺失(MCAR)”,因为X_1的缺失取决于X_2的值。这反过来意味着我们从中抽取的X_2的分布在X_1缺失与否的条件下是不同的。这尤其意味着删除缺失值的行可能会严重偏倚分析结果。
现在我们固定 x 并估计给定 X=x, 的条件期望值和方差,就像在上一篇文章中一样。
# Choose an x that is not too far out
x<-matrix(c(1,1),ncol=2)
# Choose alpha for CIs
alpha<-0.05
我们还拟合了 DRF,并预测测试点 x 的权重(这相当于预测 Y|X=x 的条件分布):
## Fit the new DRF framework
drf_fit <- drfCI(X=X, Y=Y, min.node.size = 5, splitting.rule='FourierMMD', num.features=10, B=100)
## predict weights
DRF = predictdrf(drf_fit, x=x)
weights <- DRF$weights[1,]
示例 1:条件期望
我们首先估计条件期望值 Y|X=x.
# Estimate the conditional expectation at x:
condexpest<- sum(weights*Y)
# Use the distribution of weights, see below
distofcondexpest<-unlist(lapply(DRF$weightsb, function(wb) sum(wb[1,]*Y) ))
# Can either use the above directly to build confidence interval, or can use the normal approximation.
# We will use the latter
varest<-var(distofcondexpest-condexpest)
# build 95%-CI
lower<-condexpest - qnorm(1-alpha/2)*sqrt(varest)
upper<-condexpest + qnorm(1-alpha/2)*sqrt(varest)
round(c(lower, condexpest, upper),2)
# without NAs: (-1.00, -0.69 -0.37)
# with NAs: (-1.15, -0.67, -0.19)
值得注意的是,通过缺失数据(NAs)获得的结果与上一篇文章中首次分析得到的结果非常接近!这让我感到相当惊讶,因为处理缺失机制并不容易。有趣的是,估计量的估计方差也翻倍了,从没有缺失值时的大约 0.025 增加到有缺失值时的大约 0.06。
真相为:
所以我们有一个轻微的误差,但置信区间包含了真相,正如应该的那样。
对于更复杂的目标,例如条件方差,结果看起来相似:
# Estimate the conditional expectation at x:
condvarest<- sum(weights*Y²) - condexpest²
distofcondvarest<-unlist(lapply(DRF$weightsb, function(wb) {
sum(wb[1,]*Y²) - sum(wb[1,]*Y)²
} ))
# Can either use the above directly to build confidence interval, or can use the normal approximation.
# We will use the latter
varest<-var(distofcondvarest-condvarest)
# build 95%-CI
lower<-condvarest - qnorm(1-alpha/2)*sqrt(varest)
upper<-condvarest + qnorm(1-alpha/2)*sqrt(varest)
c(lower, condvarest, upper)
# without NAs: (1.89, 2.65, 3.42)
# with NAs: (1.79, 2.74, 3.69)
在这里,估计值的差异略大一些。真相为
估计值在有缺失值时甚至略微更准确(虽然这当然可能只是随机性)。再次强调,估计方差(方差)在有缺失值时增加,从 0.15(无缺失值)增加到 0.23。
结论
在本文中,我们讨论了 MIA,它是对随机森林中分割方法的适应,以处理缺失值。由于它在 GRF 和 DRF 中实现,因此可以广泛使用,我们观察到的小示例表明它表现得非常好。
然而,我想再次指出,即使对于大量数据点,也没有理论保证一致性或置信区间的合理性。缺失值的原因很多,必须非常小心,以免由于粗心处理此问题而使分析产生偏差。MIA 方法绝不是解决此问题的一个充分理解的修复方法。然而,它似乎是一个合理的快速修复方法,能够在某种程度上利用数据中的缺失模式。如果有人做了/有更广泛的模拟分析,我会对结果感到好奇。
代码
require(drf)
drfown <- function(X, Y,
num.trees = 500,
splitting.rule = "FourierMMD",
num.features = 10,
bandwidth = NULL,
response.scaling = TRUE,
node.scaling = FALSE,
sample.weights = NULL,
sample.fraction = 0.5,
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
min.node.size = 15,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = 0.05,
imbalance.penalty = 0,
compute.oob.predictions = TRUE,
num.threads = NULL,
seed = stats::runif(1, 0, .Machine$integer.max),
compute.variable.importance = FALSE) {
# initial checks for X and Y
if (is.data.frame(X)) {
if (is.null(names(X))) {
stop("the regressor should be named if provided under data.frame format.")
}
if (any(apply(X, 2, class) %in% c("factor", "character"))) {
any.factor.or.character <- TRUE
X.mat <- as.matrix(fastDummies::dummy_cols(X, remove_selected_columns = TRUE))
} else {
any.factor.or.character <- FALSE
X.mat <- as.matrix(X)
}
mat.col.names.df <- names(X)
mat.col.names <- colnames(X.mat)
} else {
X.mat <- X
mat.col.names <- NULL
mat.col.names.df <- NULL
any.factor.or.character <- FALSE
}
if (is.data.frame(Y)) {
if (any(apply(Y, 2, class) %in% c("factor", "character"))) {
stop("Y should only contain numeric variables.")
}
Y <- as.matrix(Y)
}
if (is.vector(Y)) {
Y <- matrix(Y,ncol=1)
}
#validate_X(X.mat)
if (inherits(X, "Matrix") && !(inherits(X, "dgCMatrix"))) {
stop("Currently only sparse data of class 'dgCMatrix' is supported.")
}
drf:::validate_sample_weights(sample.weights, X.mat)
#Y <- validate_observations(Y, X)
# set legacy GRF parameters
clusters <- vector(mode = "numeric", length = 0)
samples.per.cluster <- 0
equalize.cluster.weights <- FALSE
ci.group.size <- 1
num.threads <- drf:::validate_num_threads(num.threads)
all.tunable.params <- c("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
"honesty.prune.leaves", "alpha", "imbalance.penalty")
# should we scale or not the data
if (response.scaling) {
Y.transformed <- scale(Y)
} else {
Y.transformed <- Y
}
data <- drf:::create_data_matrices(X.mat, outcome = Y.transformed, sample.weights = sample.weights)
# bandwidth using median heuristic by default
if (is.null(bandwidth)) {
bandwidth <- drf:::medianHeuristic(Y.transformed)
}
args <- list(num.trees = num.trees,
clusters = clusters,
samples.per.cluster = samples.per.cluster,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = min.node.size,
honesty = honesty,
honesty.fraction = honesty.fraction,
honesty.prune.leaves = honesty.prune.leaves,
alpha = alpha,
imbalance.penalty = imbalance.penalty,
ci.group.size = ci.group.size,
compute.oob.predictions = compute.oob.predictions,
num.threads = num.threads,
seed = seed,
num_features = num.features,
bandwidth = bandwidth,
node_scaling = ifelse(node.scaling, 1, 0))
if (splitting.rule == "CART") {
##forest <- do.call(gini_train, c(data, args))
forest <- drf:::do.call.rcpp(drf:::gini_train, c(data, args))
##forest <- do.call(gini_train, c(data, args))
} else if (splitting.rule == "FourierMMD") {
forest <- drf:::do.call.rcpp(drf:::fourier_train, c(data, args))
} else {
stop("splitting rule not available.")
}
class(forest) <- c("drf")
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X.mat
forest[["is.df.X"]] <- is.data.frame(X)
forest[["Y.orig"]] <- Y
forest[["sample.weights"]] <- sample.weights
forest[["clusters"]] <- clusters
forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
forest[["tunable.params"]] <- args[all.tunable.params]
forest[["mat.col.names"]] <- mat.col.names
forest[["mat.col.names.df"]] <- mat.col.names.df
forest[["any.factor.or.character"]] <- any.factor.or.character
if (compute.variable.importance) {
forest[['variable.importance']] <- variableImportance(forest, h = bandwidth)
}
forest
}
引文
[1] Twala, B. E. T. H., M. C. Jones, 和 David J. Hand. 应对决策树中缺失数据的良好方法。模式识别信函 29, 2008。
2023 年的随机森林:一种强大方法的现代扩展
随机森林取得了长足的发展
·
关注 发布于 Towards Data Science ·12 min read·2023 年 11 月 7 日
--
现代随机森林方法的特点。来源:作者。
在机器学习的时间线中,随机森林(RFs),在 Breimann 的开创性论文中提出([1]),可以算得上古老了。尽管它们已经有些年头,但凭借其卓越的性能仍然令人印象深刻,并且仍是活跃研究的主题。本文旨在突出随机森林方法已成为多功能工具箱的特点,重点关注广义随机森林(GRF)和分布随机森林(DRF)。
简而言之,这两种方法的主要思想是,RF 隐式产生的权重可以用于估计除条件期望之外的其他目标。GRF 的理念是使用一个随机森林,分裂准则适应于所设想的目标(例如条件均值、条件分位数或条件处理效应)。DRF 的理念是调整分裂准则,以便能够估计整个条件分布。从这个对象中,可以在第二步派生出许多不同的目标。事实上,本文主要讨论 DRF,因为我对这种方法更为熟悉,而且它在处理广泛目标时更加简洁(只需为广泛的目标拟合一个森林)。然而,图中指出的所有优点也适用于 GRF,实际上,R 中的 DRF 包是建立在GRF 的专业实现之上的。此外,GRF 森林的分裂准则适应于目标,这意味着它可能比 DRF 表现更好。这对于二元Y尤其如此,此时应使用 probability_forests()。因此,尽管我主要讨论 DRF,但在整篇文章中应始终记住 GRF。
本文的目标是提供一个概述,并提供相应部分的深入阅读链接。我们将按图中的时钟方向逐一介绍每个要点,引用相关的文章,并通过一个小示例进行强调。我首先快速总结了下面的最重要的进一步阅读链接:
多样性/性能: Medium Article 和原始论文(DRF/GRF)
缺失值纳入: Medium Article
不确定性度量: Medium Article
变量重要性: Medium Article
示例
我们取 X_1, X_2, X_4, …, X_10 独立地均匀分布在(-1,1)之间,并通过 X_3=X_1 + uniform error 来创建 X_1 和 X_3 之间的依赖关系。然后我们模拟 Y 为
## Load packages and functions needed
library(drf)
library(mice)
source("drfnew_v2.R")
## The function drfnew_v2.R is available below, or on
## https://github.com/JeffNaef/drfupdate
## Set parameters
set.seed(10)
n<-1000
##Simulate Data that experiences both a mean as well as sd shift
# Simulate from X
x1 <- runif(n,-1,1)
x2 <- runif(n,-1,1)
x3 <- x1+ runif(n,-1,1)
X0 <- matrix(runif(7*n,-1,1), nrow=n, ncol=7)
Xfull <- cbind(x1,x2, x3, X0)
colnames(Xfull)<-paste0("X", 1:10)
# Simulate dependent variable Y
Y <- as.matrix(rnorm(n,mean = 0.8*(x1 > 0), sd = 1 + 1*(x2 > 0)))
##Also add MAR missing values using ampute from the mice package
X<-ampute(Xfull)$amp
head(cbind(Y,X))
Y X1 X2 X3 X4 X5
1 -3.0327466 -0.4689827 0.06161759 0.27462737 NA -0.624463079
2 1.2582824 -0.2557522 0.36972181 NA -0.04100963 0.009518047
3 -0.8781940 0.1457067 -0.23343321 NA -0.64519687 -0.945426305
4 3.1595623 0.8164156 0.90997600 0.69184618 -0.20573331 -0.007404298
5 1.1176545 -0.5966361 NA -1.21276055 0.62845399 0.894703422
6 -0.4377359 0.7967794 -0.92179989 -0.03863182 0.88271415 -0.237635732
X6 X7 X8 X9 X10
1 -0.9290009 0.5401628 0.39735433 -0.7434697 0.8807558
2 -0.2885927 0.3805251 -0.09051334 -0.7446170 0.9935311
3 -0.5022541 0.3009541 0.29424395 0.5554647 -0.5341800
4 0.7583608 -0.8506881 0.22758566 -0.1596993 -0.7161976
5 -0.3640422 0.8051613 -0.46714833 0.4318039 -0.8674060
6 -0.3577590 -0.7341207 0.85504668 -0.6933918 0.4656891
请注意,通过mice 包中的 ampute 函数,我们将缺失值非随机(MAR)设置在X上,以突出 GRF/DRF 处理缺失值的能力。此外,在上述过程中,仅 X_1 和 X_2 对预测 Y 是相关的,所有其他变量都是“噪声”变量。这种“稀疏”设置在实际数据集中可能很常见。
现在我们选择一个测试点作为这个示例,将在整个过程中使用:
x<-matrix(c(0.2, 0.4, runif(8,-1,1)), nrow=1, ncol=10)
print(x)
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
[1,] 0.2 0.4 0.7061058 0.8364877 0.2284314 0.7971179 0.78581 0.5310279
[,9] [,10]
[1,] -0.5067102 0.6918785
多样性
DRF 估计条件分布 P_{Y|X=x} 以简单权重的形式:
从这些权重中,可以计算出广泛的目标,或者可以用来从条件分布中进行模拟。它的多功能性一个很好的参考是原始的 研究文章,其中使用了很多例子,还有相应的 medium 文章。
在这个例子中,我们首先从这个分布中进行模拟:
DRF<-drfCI(X=X, Y=Y, B=50,num.trees=1000, min.node.size = 5)
DRFpred<-predictdrf(DRF, newdata=x)
## Sample from P_{Y| X=x}
Yxs<-Y[sample(1:n, size=n, replace = T, DRFpred$weights[1,])]
hist(Yxs, prob=T)
z<-seq(-6,7,by=0.01)
d<-dnorm(z, mean=0.8 * (x[1] > 0), sd=(1+(x[2] > 0)))
lines(z,d, col="darkred" )
模拟条件分布的直方图与真实密度(红色)叠加。来源:作者。
图表显示了从条件分布中大致抽取的样本,并与真实密度(红色)叠加。我们现在使用此数据来估计 条件期望 和 条件 (0.05, 0.95) 分位数 在 x:
# Calculate quantile prediction as weighted quantiles from Y
qx <- quantile(Yxs, probs = c(0.05,0.95))
# Calculate conditional mean prediction
mux <- mean(Yxs)
# True quantiles
q1<-qnorm(0.05, mean=0.8 * (x[1] > 0), sd=(1+(x[2] > 0)))
q2<-qnorm(0.95, mean=0.8 * (x[1] > 0), sd=(1+(x[2] > 0)))
mu<-0.8 * (x[1] > 0)
hist(Yxs, prob=T)
z<-seq(-6,7,by=0.01)
d<-dnorm(z, mean=0.8 * (x[1] > 0), sd=(1+(x[2] > 0)))
lines(z,d, col="darkred" )
abline(v=q1,col="darkred" )
abline(v=q2, col="darkred" )
abline(v=qx[1], col="darkblue")
abline(v=qx[2], col="darkblue")
abline(v=mu, col="darkred")
abline(v=mux, col="darkblue")
模拟条件分布的直方图与真实密度(红色)叠加。此外,估计的条件期望和条件 (0.05, 0.95) 分位数为蓝色,真实值为红色。来源:作者。
同样,GRF 可以计算许多目标,但在这种情况下,每个目标需要拟合不同的森林。特别是,regression_forest() 用于条件期望,quantile_forest() 用于分位数。
GRF 无法处理多变量目标,而 DRF 可以做到这一点。
性能
尽管有大量关于强大新(非参数)方法的工作,如神经网络,但基于树的方法在表格数据上始终能够击败竞争者。请参见,例如,这篇 引人入胜的论文,或这篇关于 RF 在分类中的强度 的旧论文。
公平地说,通过参数调优,提升树方法,如 XGboost,通常在经典预测(对应于条件期望估计)中占据主导地位。然而,RF 方法在没有任何调优下展现的强大性能是值得注意的。此外,还有关于改进随机森林性能的研究,例如,对冲随机森林方法。
缺失值整合
来自 这篇论文 的“缺失数据合并属性标准”(MIA)是一个非常简单但非常强大的想法,它允许基于树的方法处理缺失数据。这已在 GRF R 包中实现,因此 DRF 中也可以使用。详细信息也在这篇 medium 文章 中解释。尽管这一概念很简单,但在实践中效果非常好:在上述示例中,DRF 在处理训练数据中的大量 MAR 缺失值时毫不费力 X (!)
不确定性度量
作为统计学家,我不仅希望得到点估计(即使是分布的点估计),还希望得到我的参数(即使“参数”是我的整个分布)的 估计不确定性。事实证明,DRF/GRF 中内置的简单额外子抽样可以为大样本量提供原则性的不确定性量化。关于 DRF 的理论在这篇 研究文章 中得以推导,我在这篇 medium 文章 中也进行了解释。GRF 的所有理论在 原始论文 中都有描述。
我们将其应用于上述示例:
# Calculate uncertainty
alpha<-0.05
B<-length(DRFpred$weightsb)
qxb<-matrix(NaN, nrow=B, ncol=2)
muxb<-matrix(NaN, nrow=B, ncol=1)
for (b in 1:B){
Yxsb<-Y[sample(1:n, size=n, replace = T, DRFpred$weightsb[[b]][1,])]
qxb[b,] <- quantile(Yxsb, probs = c(0.05,0.95))
muxb[b] <- mean(Yxsb)
}
CI.lower.q1 <- qx[1] - qnorm(1-alpha/2)*sqrt(var(qxb[,1]))
CI.upper.q1 <- qx[1] + qnorm(1-alpha/2)*sqrt(var(qxb[,1]))
CI.lower.q2 <- qx[2] - qnorm(1-alpha/2)*sqrt(var(qxb[,2]))
CI.upper.q2 <- qx[2] + qnorm(1-alpha/2)*sqrt(var(qxb[,2]))
CI.lower.mu <- mux - qnorm(1-alpha/2)*sqrt(var(muxb))
CI.upper.mu <- mux + qnorm(1-alpha/2)*sqrt(var(muxb))
hist(Yxs, prob=T)
z<-seq(-6,7,by=0.01)
d<-dnorm(z, mean=0.8 * (x[1] > 0), sd=(1+(x[2] > 0)))
lines(z,d, col="darkred" )
abline(v=q1,col="darkred" )
abline(v=q2, col="darkred" )
abline(v=qx[1], col="darkblue")
abline(v=qx[2], col="darkblue")
abline(v=mu, col="darkred")
abline(v=mux, col="darkblue")
abline(v=CI.lower.q1, col="darkblue", lty=2)
abline(v=CI.upper.q1, col="darkblue", lty=2)
abline(v=CI.lower.q2, col="darkblue", lty=2)
abline(v=CI.upper.q2, col="darkblue", lty=2)
abline(v=CI.lower.mu, col="darkblue", lty=2)
abline(v=CI.upper.mu, col="darkblue", lty=2)
模拟条件分布的直方图叠加了真实密度(红色)。此外,估计的条件期望和条件 (0.05, 0.95) 分位数为蓝色,真实值为红色。此外,虚线红色的线条是由 DRF 计算出的估计置信区间。来源:作者。
从上述代码可以看出,我们实际上有 B 个子树可以用来每次计算度量。通过这些 B 个均值和分位数样本,我们可以计算方差,并使用正态近似来获得图中虚线所示的(渐近)置信区间。尽管存在缺失值,这一切依然可以完成 X(!)
变量重要性
随机森林的一个重要方面是有效计算变量重要性度量。尽管传统的度量方法有些随意,但对于传统 RF 和 DRF,现在已经有了有原则的度量方法,如在这篇 medium article 中解释的那样。对于 RF,Sobol-MDA 方法可靠地识别出对条件期望估计最重要的变量,而对于 DRF,MMD-MDA 识别出对整体分布估计最重要的变量。如文章中讨论的那样,使用 投影随机森林 的理念,这些度量可以非常高效地实现。我们在例子中展示了 MMD 变量重要性度量的一个较低效实现:
## Variable importance for conditional Quantile Estimation
## For the conditional quantiles we use a measure that considers the whole distribution,
## i.e. the MMD based measure of DRF.
MMDVimp <- compute_drf_vimp(X=X,Y=Y, print=F)
sort(MMDVimp, decreasing = T)
X2 X1 X8 X6 X3 X10
0.852954299 0.124110913 0.012194176 0.009578300 0.008191663 0.007517931
X9 X7 X5 X4
0.006861688 0.006632175 0.005257195 0.002401974
在这里,X_1 和 X_2 在尝试估计分布时被正确地识别为最相关的变量。值得注意的是,尽管 X_3 和 X_1 之间存在依赖关系,但测量方法正确地量化了 X_3 对于预测 Y 分布的重要性。这是随机森林原始 MDA 测量往往做错的事情,如中等文章所示。此外,请再次注意,X 中的缺失值在这里没有问题。
结论
GRF/DRF 以及传统的随机森林在任何数据科学家的工具箱中都不应缺少。尽管像 XGboost 这样的方法在传统预测中可能表现更好,但现代基于 RF 的方法的许多优点使它们成为一个极其多功能的工具。
当然,需要记住的是这些方法仍然完全是非参数的,并且需要大量的数据点才能使拟合有意义。这一点在不确定性量化中尤其真实,它仅在渐近情况下有效,即对于“大”样本。
文献
[1] Breiman, L. (2001). Random forests. Machine learning, 45(1):5–32.
附录:代码
require(drf)
require(Matrix)
require(kernlab)
drfCI <- function(X, Y, B, sampling = "binomial", ...) {
n <- dim(X)[1]
# compute point estimator and DRF per halfsample
# weightsb: B times n matrix of weights
DRFlist <- lapply(seq_len(B), function(b) {
# half-sample index
indexb <- if (sampling == "binomial") {
seq_len(n)[as.logical(rbinom(n, size = 1, prob = 0.5))]
} else {
sample(seq_len(n), floor(n / 2), replace = FALSE)
}
## Using normal Bootstrap on the data and refitting DRF
DRFb <-
drfown(X = X[indexb, , drop = F], Y = Y[indexb, , drop = F], ...)
return(list(DRF = DRFb, indices = indexb))
})
return(list(DRFlist = DRFlist, X = X, Y = Y))
}
predictdrf <- function(DRF, newdata, functional = NULL, ...) {
##Predict for testpoints in newdata, with B weights for each point, representing
##uncertainty
ntest <- nrow(x)
n <- nrow(DRF$Y)
weightsb <- lapply(DRF$DRFlist, function(l) {
weightsbfinal <- Matrix(0, nrow = ntest, ncol = n, sparse = TRUE)
weightsbfinal[, l$indices] <- predict(l$DRF, x)$weights
return(weightsbfinal)
})
weightsall <- Reduce("+", weightsb) / length(weightsb)
if (!is.null(functional)) {
stopifnot("Not yet implemented for several x" = ntest == 1)
thetahatb <-
lapply(weightsb, function(w)
functional(weights = w, X = DRF$X, Y = DRF$Y, x = x))
thetahatbforvar <- do.call(rbind, thetahatb)
thetahat <- functional(weights = weightsall, X = DRF$X, Y = DRF$Y, x = x)
thetahat <- matrix(thetahat, nrow = dim(x)[1])
var_est <- if (dim(thetahat)[2] > 1) {
a <- sweep(thetahatbforvar, 2, thetahat, FUN = "-")
crossprod(a, a) / B
} else {
mean((c(thetahatbforvar) - c(thetahat)) ^ 2)
}
return(list(weights = weightsall, thetahat = thetahat, weightsb = weightsb,
var_est = var_est))
} else {
return(list(weights = weightsall, weightsb = weightsb))
}
}
drfown <- function(X, Y,
num.trees = 500,
splitting.rule = "FourierMMD",
num.features = 10,
bandwidth = NULL,
response.scaling = TRUE,
node.scaling = FALSE,
sample.weights = NULL,
sample.fraction = 0.5,
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
min.node.size = 15,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = 0.05,
imbalance.penalty = 0,
compute.oob.predictions = TRUE,
num.threads = NULL,
seed = stats::runif(1, 0, .Machine$integer.max),
compute.variable.importance = FALSE) {
# initial checks for X and Y
if (is.data.frame(X)) {
if (is.null(names(X))) {
stop("the regressor should be named if provided under data.frame format.")
}
if (any(apply(X, 2, class) %in% c("factor", "character"))) {
any.factor.or.character <- TRUE
X.mat <- as.matrix(fastDummies::dummy_cols(X, remove_selected_columns = TRUE))
} else {
any.factor.or.character <- FALSE
X.mat <- as.matrix(X)
}
mat.col.names.df <- names(X)
mat.col.names <- colnames(X.mat)
} else {
X.mat <- X
mat.col.names <- NULL
mat.col.names.df <- NULL
any.factor.or.character <- FALSE
}
if (is.data.frame(Y)) {
if (any(apply(Y, 2, class) %in% c("factor", "character"))) {
stop("Y should only contain numeric variables.")
}
Y <- as.matrix(Y)
}
if (is.vector(Y)) {
Y <- matrix(Y,ncol=1)
}
#validate_X(X.mat)
if (inherits(X, "Matrix") && !(inherits(X, "dgCMatrix"))) {
stop("Currently only sparse data of class 'dgCMatrix' is supported.")
}
drf:::validate_sample_weights(sample.weights, X.mat)
#Y <- validate_observations(Y, X)
# set legacy GRF parameters
clusters <- vector(mode = "numeric", length = 0)
samples.per.cluster <- 0
equalize.cluster.weights <- FALSE
ci.group.size <- 1
num.threads <- drf:::validate_num_threads(num.threads)
all.tunable.params <- c("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
"honesty.prune.leaves", "alpha", "imbalance.penalty")
# should we scale or not the data
if (response.scaling) {
Y.transformed <- scale(Y)
} else {
Y.transformed <- Y
}
data <- drf:::create_data_matrices(X.mat, outcome = Y.transformed, sample.weights = sample.weights)
# bandwidth using median heuristic by default
if (is.null(bandwidth)) {
bandwidth <- drf:::medianHeuristic(Y.transformed)
}
args <- list(num.trees = num.trees,
clusters = clusters,
samples.per.cluster = samples.per.cluster,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = min.node.size,
honesty = honesty,
honesty.fraction = honesty.fraction,
honesty.prune.leaves = honesty.prune.leaves,
alpha = alpha,
imbalance.penalty = imbalance.penalty,
ci.group.size = ci.group.size,
compute.oob.predictions = compute.oob.predictions,
num.threads = num.threads,
seed = seed,
num_features = num.features,
bandwidth = bandwidth,
node_scaling = ifelse(node.scaling, 1, 0))
if (splitting.rule == "CART") {
##forest <- do.call(gini_train, c(data, args))
forest <- drf:::do.call.rcpp(drf:::gini_train, c(data, args))
##forest <- do.call(gini_train, c(data, args))
} else if (splitting.rule == "FourierMMD") {
forest <- drf:::do.call.rcpp(drf:::fourier_train, c(data, args))
} else {
stop("splitting rule not available.")
}
class(forest) <- c("drf")
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X.mat
forest[["is.df.X"]] <- is.data.frame(X)
forest[["Y.orig"]] <- Y
forest[["sample.weights"]] <- sample.weights
forest[["clusters"]] <- clusters
forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
forest[["tunable.params"]] <- args[all.tunable.params]
forest[["mat.col.names"]] <- mat.col.names
forest[["mat.col.names.df"]] <- mat.col.names.df
forest[["any.factor.or.character"]] <- any.factor.or.character
if (compute.variable.importance) {
forest[['variable.importance']] <- variableImportance(forest, h = bandwidth)
}
forest
}
#' Variable importance for Distributional Random Forests
#'
#' @param X Matrix with input training data.
#' @param Y Matrix with output training data.
#' @param X_test Matrix with input testing data. If NULL, out-of-bag estimates are used.
#' @param num.trees Number of trees to fit DRF. Default value is 500 trees.
#' @param silent If FALSE, print variable iteration number, otherwise nothing is print. Default is FALSE.
#'
#' @return The list of importance values for all input variables.
#' @export
#'
#' @examples
compute_drf_vimp <- function(X, Y, X_test = NULL, num.trees = 500, silent = FALSE){
# fit initial DRF
bandwidth_Y <- drf:::medianHeuristic(Y)
k_Y <- rbfdot(sigma = bandwidth_Y)
K <- kernelMatrix(k_Y, Y, Y)
DRF <- drfown(X, Y, num.trees = num.trees)
wall <- predict(DRF, X_test)$weights
# compute normalization constant
wbar <- colMeans(wall)
wall_wbar <- sweep(wall, 2, wbar, "-")
I0 <- as.numeric(sum(diag(wall_wbar %*% K %*% t(wall_wbar))))
# compute drf importance dropping variables one by one
I <- sapply(1:ncol(X), function(j) {
if (!silent){print(paste0('Running importance for variable X', j, '...'))}
DRFj <- drfown(X = X[, -j, drop=F], Y = Y, num.trees = num.trees)
DRFpredj <- predict(DRFj, X_test[, -j])
wj <- DRFpredj$weights
Ij <- sum(diag((wj - wall) %*% K %*% t(wj - wall)))/I0
return(Ij)
})
# compute retraining bias
DRF0 <- drfown(X = X, Y = Y, num.trees = num.trees)
DRFpred0 = predict(DRF0, X_test)
w0 <- DRFpred0$weights
vimp0 <- sum(diag((w0 - wall) %*% K %*% t(w0 - wall)))/I0
# compute final importance (remove bias & truncate negative values)
vimp <- sapply(I - vimp0, function(x){max(0,x)})
names(vimp)<-colnames(X)
return(vimp)
}
机器学习中的随机数
原文:
towardsdatascience.com/random-numbers-in-machine-learning-9cb5d83d078e
关于伪随机数、种子设置和可重复性的一切
·发表在 Towards Data Science ·8 分钟阅读·2023 年 10 月 20 日
--
图片来源:Riho Kroll 在 Unsplash
机器学习依赖于统计学,随机数对数据处理和模型训练管道中许多步骤的性能至关重要。现代机器学习框架提供了在后台实现随机性的抽象和函数,而对于我们这些数据科学家和机器学习工程师来说,随机数生成的细节往往仍然模糊不清。
在本文中,我想探讨机器学习中的随机数。你将阅读到:
-
机器学习中随机数使用的 3 个示例
-
生成(伪)随机数
-
通过设定随机种子来固定随机数
-
可重复的机器学习:scikit-learn、tensorflow 和 pytorch 的必要代码行
在本文结束时,你将了解在机器学习管道中使用随机数时发生了什么,并学习确保机器学习算法可重复性的必要代码行。
机器学习中随机数的 3 个使用示例
为了说明随机数的重要性,我们讨论了机器学习管道中与随机数相关的三个示例。
-
创建数据集的训练/测试划分
-
神经网络中的权重初始化
-
训练过程中的小批量选择
训练/测试划分 将数据集分为训练数据和测试数据是评估机器学习算法性能的最重要步骤之一。我们关注的是创建能很好泛化到训练过程中未使用的数据的模型。为此,一组数据样本被划分为至少两个不相交的集合。
训练数据用于训练算法,即迭代地调整模型参数。测试数据用于通过将训练好的模型应用于测试数据并报告适当的指标来验证算法。
数据集的随机划分为训练数据和测试数据。图片:作者。
流行的 scikit-learn 函数sklearn.model_selection.train_test_split
使用随机数。在底层,它创建一个与数据集长度对应的索引数组。这些索引随后被随机分配到训练数据和测试数据中。
权重初始化
神经网络的层包含在训练过程中调整的参数。对于线性层,这些参数是权重和偏置。它们最初被赋予随机值。考虑一个 4 个神经元连接到具有 3 个潜在类别的输出层的线性层的示例:该层对应于一个 4 x 3 权重矩阵W。
4 x 3 矩阵 W 的三种不同随机初始化。图片:作者。
权重的初始化对模型训练的收敛至关重要。如果所有权重都具有相同的值,反向传播算法在更新权重时将没有理由对每个神经元进行不同的处理。
因此,当模型首次实例化时,权重是随机选择的。通常,这些随机数背后有合理的统计分布。Xavier Glorot 在 2010 年发现,当网络权重从均匀分布中初始化时,训练一个层级维度为 n x m 的前馈神经网络会得到改进[1]。
选择训练批次 通常,使用整个训练数据集一次性更新模型参数既不可行也不建议。因此,训练数据集被分成固定大小的小批量。数据加载器创建这些小批量并可以随机打乱数据。这是为了防止数据以有偏的方式进入训练,例如由于收集方式而按照时间顺序。
从包含 12 个样本的数据集中形成的大小为 4 的小批量。顶部:数据集没有打乱。底部:数据集被随机打乱。图片:作者。
随机梯度下降依赖于这些小批量的随机性。通过向算法呈现训练数据的随机子集,每一步反向传播都会强调训练数据的一个略微不同的方面。这避免了在训练过程中陷入局部最小值。
伪随机数
现代编程语言和机器学习框架提供了生成随机数的工具,使开发者无需担心底层算法。要在 Python 中生成 100 个 0 到 1 之间的均匀随机数,只需输入
import numpy as np
np.random.rand(100)
当你执行这一行代码时,实际上会发生什么?这时就引入了随机数生成器(RNGs)。
标准计算机算法产生可预测的结果,这恰恰与我们想要的相反——一个随机的数字序列。我们所寻找的是一个能生成难以预测的数字序列的算法。需要注意的是,难并不意味着不可能!因此,我们需要一个熵源和一个加密算法。对于真正的随机数生成器,熵源可以由环境提供,例如通过测量传感器温度或放射性粒子的衰变来获取输入 [3]。
对于机器学习,我们不需要像加密应用那样高端的随机性。Python — 特别是 numpy 库 — 实现了一个 RNG 来生成伪随机数。
在日常语言中,伪随机数序列足够接近真正的随机序列,以至于不完全的随机性不会影响算法的目的。序列的相关长度必须足够满足应用的需求。
从初始种子开始,应用一个函数 f 生成一个状态。另一个函数 g 被用来生成一个随机数。这一过程会重复进行多次。函数 f, g 应该是不可逆的。此外,不同的状态必须提供足够的变异性,以真正生成大量不同的随机数,这被称为序列长度。
遵循[4]的伪随机数生成器。图像:作者。
Python 中使用的底层算法叫做 Mersenne Twister [4]。它为每个状态提供随机数序列,并且比上面展示的简单实现计算效率更高。虽然它在加密上不安全,但它快速且足够强大。
随机数分布
我们通常感兴趣的是生成符合某种统计分布的随机数。例如,均匀分布在给定区间内为所有数字分配相等的概率。正态分布遵循具有固定均值和方差的高斯曲线。
1000 个随机数的正态分布(左)和均匀分布(右)。图像:作者。
随机数的分布可以通过其累积分布函数(CDF)相互转换。因此,只需要能够生成均匀分布的数字,其他任何分布都可以从中生成。具体细节超出了本文的范围,初步了解我们推荐[5]。
通过设定种子来固定随机数
如上所述,种子是一个整数,用作伪随机数生成器的输入。
对于固定的种子,由该算法生成的伪随机数序列总是相同的。
我们可以利用这个洞察来实现两个不同的目标。
选择新的随机种子 当你选择新的随机种子时,所有依赖于随机数的机器学习流程方面都将使用不同的值进行初始化。虽然底层分布和数据集大小保持不变,但训练过程可能会受到影响。
理想情况下,你的训练过程不应过于依赖随机种子的选择。你希望确保训练能够收敛,而不是依赖于那些神奇地产生良好结果的黄金运行,并且这些结果无法重复。
修复随机种子 修复随机种子在教学中尤为流行。当我为小数据集创建一个简单模型时,结果依赖于随机种子的概率相当高。每当我为学生准备教学材料,如 Jupyter 笔记本时,我都会修复随机种子,以便他们获得相同的结果。这减少了混淆,让学生能够专注于新概念,而不是度量标准中不那么重要的数字。
克隆看起来完全一样。照片由Phil Shaw拍摄,来源于Unsplash
机器学习中的可重复性
我们现在已经了解到,随机数对机器学习过程的性能至关重要。然而,如果我们希望算法的精确可重复性,就需要修复随机种子。
重要的是要意识到,随机种子必须为每个使用随机数的库单独修复。例如,即使我们通过np.random.seed(789123)
修复了numpy
的随机种子,torch
的随机种子也不会受到影响,因此训练将无法重复。
在下面,我们总结了修复不同流行机器学习框架随机种子的必要调用。
Scikit-Learn在整个过程中使用 numpy 的随机数生成器。要修复该框架的随机种子,只需设置
import numpy as np
np.random.seed(789456123)
流行的框架提供了另一种访问随机数生成器的函数,check_random_state。要快速生成 1000 个随机数:
from sklearn.utils.validation import check_random_state
rs = check_random_state(12345)
rs.rand(1000)
Tensorflow 2提供了其自己的随机数生成模块tf.random
。根据文档,在那里修复随机种子可以确保可重复性。然而,这在不同版本的 tensorflow 中并不能保证。为了生成相同的伪随机数,可以使用tf.random
模块中的stateless_XXX
函数,例如,对于遵循正态分布的 1000 个随机数序列:
from tf.random import stateless_normal
shape = 1000
seed = 12344
stateless_normal(shape, seed)
Pytorch 以类似于 numpy 语句的方式控制随机数生成。以下将设置所有依赖于该随机数生成器的过程的种子:
import torch
torch.manual_seed(9870)
然而,他们指出一些函数可能依赖于 numpy 随机数生成器,因此建议也固定该种子。
Pytorch 文档指出了一个关于完全可重复性的问题。不仅在不同版本之间不可能精确重现随机数,而且这些数字可能还会依赖于硬件。使用不同的 GPU 配合不同的 CUDA 工具包或 CPU 可能会导致不同的随机数序列。
总结
随机数对机器学习算法的性能至关重要。它们确保数据集以无偏差的方式进行划分,提高算法的泛化能力,并增加训练过程的收敛性。
在幕后,随机数序列由所谓的 RNGs 生成,这些算法能够提供遵循不同分布的随机数。
随机种子可以用来固定随机数序列。这使得机器学习算法具有可重复性,这对于教学目的和学术论文特别有用。最后,请记住,如果你在项目中使用了多种机器学习框架,你可能需要为所有框架设置随机种子!
参考资料
-
Xavier Glorot 和 Yoshua Bengio, “Understanding the difficulty of training deep feedforward neural networks”, Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, PMLR 9:249–256, 2010. [paper]
-
www.redhat.com/en/blog/understanding-random-number-generators-and-their-limitations-linux
-
详细介绍梅森旋转算法:
www.cryptologie.net/article/331/how-does-the-mersennes-twister-work/
-
Tensorflow:
www.tensorflow.org/guide/random_numbers
随机变量和概率分布
掌握随机变量和概率分布,利用我们统计学备忘单系列的第三部分,破解你的下一次数据科学面试
·
关注 发表在 Towards Data Science ·16 分钟阅读·2023 年 4 月 10 日
--
图片由 Naser Tamimi 提供,来源于 Unsplash
随机变量
从数学角度来看,随机变量是一个数值,其值依赖于随机事件的结果。例如,如果我们定义随机变量 (X) 为投掷两次硬币得到正面的次数。
注:我们用大写字母表示变量,用小写字母表示 X 的单一值。这里,x = 1 表示在两次掷硬币中恰好出现一个正面。
让我们再看一个例子。设 Y 表示掷骰子时出现的数字。随机变量 Y 可以是 1 到 6 之间的任何一个数字。
概率分布
让我们写下结果的概率。X(掷硬币两次时出现的正面数)的概率分布可以写作 Pr(X = x)。我们得到以下概率。
我们可以在直方图中绘制这些值
作者提供的图片
对于随机变量 Y(骰子面上的数字),概率如下所示。
Y 的直方图如下所示。
作者提供的图片
表格和直方图描述了随机变量的概率分布。
概率分布是对随机事件中不同结果的概率(或可能性)的数学描述。
概率质量函数
到目前为止,我们考虑的例子都有有限的结果。在掷硬币两次时,我们不可能得到 1.4 个正面。因此,与这些结果相关的概率也是有限的。随机变量因此是离散随机变量。离散随机变量的概率分布也称为概率质量函数(PMF)。我们将在文章稍后部分讨论连续变量。
累积分布函数
另一种表示与随机变量相关的概率的方法是绘制累积分布概率。随机变量 X 在值 x 处的累积分布概率是 X 取值小于或等于 x 的概率。通常我们用大写字母 F 来表示累积分布概率。因此 F(x = a ) = Pr(X ≤ a)。这个函数称为累积分布函数(CDF)。CDF 的优势在于无论变量是离散的还是连续的,其定义都是相同的。对于上述例子,我们可以通过将 PDF 累加来构建 CDF。
对于我们的掷硬币的例子
CDF 可以这样绘制。
作者提供的图片
对于骰子的掷骰子事件,
CDF 可以如下绘制。
作者提供的图片
离散概率分布
理论上,可以有无限多种概率分布。然而,有一些概率分布出现得非常频繁,以至于值得特别提及。让我们看一些突出的离散概率分布。
均匀分布
离散均匀分布描述了一个可以取有限个离散值的随机变量,所有值的概率相等。一个很常见的例子是抛一次硬币。得到正面或反面是等可能的。另一个例子是掷骰子。六个数字中的每一个都是等可能的。
从数学上讲,离散均匀分布的 PMF 由两个参数 a 和 b 定义,其中 a 是随机变量可以取的最小值,b 是最大值。PMF 由下式给出
其中 x 是随机变量的结果,x 的值从 a 到 b,包含 a 和 b。
将其应用于掷骰子的情况
a = 1, b = 6
二项分布
二项分布描述了在固定次数的试验或重复中成功的次数 (x)。每个试验只有两种可能的结果:成功或失败。此外,在给定的试验中成功的概率始终保持不变。我们称之为 p。每个这样的试验或重复被称为伯努利试验。概率分布由下式给出
其中 x 是 n 次试验中的成功次数,
上述表示由下式给出的二项系数
二项分布的应用
在一个办公室中,80% 的员工在早上 8 点到 9 点之间到达。找出在某一天,至少七名员工在早上 8 点到 9 点之间到达的概率。
上述问题表示一个具有 10 次试验和 80% 或 0.8 成功概率的二项分布。
我们需要计算 X ≥ 7 的概率。因此,我们需要计算
P(X = 7) + P(X = 8) + P(X = 9) + P(X = 10)
应用二项公式,我们得到
类似地
P(X = 8) = 0.302
P(X = 9) = 0.268
P(X = 10) = 0.107
因此,P(X ≥ 7) = 0.201 + 0.302 + 0.268 + 0.107 = 0.878 或 87.8%
几何分布
几何分布建立在二项分布的基础上。在这里,试验(或重复)会继续进行,直到成功发生,而不是固定次数的试验。从理论上讲,试验可以无限期地继续。例如,假设你继续掷骰子,直到你得到数字 6。理论上,我们可能会经历一个极其不幸的周期而得不到六。
几何分布的 PDF 由下式给出
其中 X 表示获得第一次成功所需的重复次数。
这是几何分布的 PDF,其中 p = 0.2。
图片来源于作者
几何分布的应用
在一个城镇中,60% 的人口是女性。找到一个人需要见到三个人才会遇到第一个男性的概率。
让我们用简单的英语写下这个问题。
仅第三个人是男性的概率是
- 第一个遇到的人是女性,第二个遇到的人是女性,第三个遇到的人是男性的概率
现在计算这些值更容易了。遇到女性的概率是 60%,而遇到男性的概率是 (100–60)% = 40%
因此,所需的概率 = 60% x 60% x 40% = 14.44%
或者,我们可以使用公式。这里 x = 3 和 p = 40% 或 0.4
因此
或 14.44%
泊松分布
另一个非常常见的离散分布是等待时间分布或泊松分布。泊松分布描述了在一个固定的时间或空间间隔内发生给定数量事件的概率,前提是这些事件以已知的平均速率发生,并且与上一个事件的时间无关。常见的应用包括确定一个购物中心应该设置多少个排队线,以免系统过载。只需一个参数——间隔期间的平均发生次数。这个参数通常用 λ 表示。
泊松过程的概率分布由以下公式给出
其中 λ 是间隔内的平均发生次数,x 是给定间隔内事件的发生次数(或成功次数),e 是自然对数(大约为 2.718)。
不同 λ 值的泊松分布的概率密度函数(PDF)
作者提供的图像
泊松分布的应用
观察到一个 WhatsApp 群组每天收到 45 条消息。它在下一个半小时内收到超过两条消息的概率是多少?
我们关注的时间框架(或间隔)是半小时。我们首先计算每半小时收到的平均消息数量。
一天有 24 小时,所以 24x2 = 48 个半小时周期。
每半小时收到的平均消息数量(λ)=
我们需要找出 P(x > 2)。我们可以通过以下方式找到这个值
将 x = 0, 1, 2 和 λ = 0.9375 代入公式,我们得到
连续概率分布
迄今为止,我们只看了具有有限值集合的变量(如硬币上的正面或反面、骰子上的数字等)。但在现实生活中,变量可以有无限多的值。我们以一个人在下午 1 点到 2 点之间到达进行午餐预约的情况为例。假设一个人可以随时到达,他在下午 1:10 到 1:30 之间到达的机会可以计算为
有利时间段 = 下午 1:10 到 1:30 = 20 分钟
总时间周期 = 1 小时 = 60 分钟。
因此概率 = ⅓..
但是,如果你想计算他在 1:15 PM 准确到达的概率,这将取决于我们如何准确地测量时间。如果我们按最接近的分钟来测量,它将是 1/60。如果我们测量到最后一秒钟,那么就是 1/3600。如果我们测量到毫秒,则是 1/360 万,如此等等。正如你所见,到达的确切时刻可以被进一步分割成更小的段。它可以取无限多个值。因此,在给定时刻他到达的概率是零。
这样的变量称为连续变量。正式来说,如果你绘制一个连续随机变量的累积分布函数(CDF),它将没有像我们在掷骰子时看到的那样的中断(或阶跃)。
作者提供的图片
连续变量的累积分布函数(CDF)如下所示。
作者提供的图片
对于连续随机变量,我们找到的是区间的概率(如上面描述的到达时间的情况),而不是单个值。虽然数学上有点复杂,因为涉及到微积分,但直观上,我们可以使用离散随机变量的概念来处理连续变量,因为我们可以通过四舍五入将连续随机变量转换为离散变量(就像我们在上述秒到分钟的例子中做的那样)。
概率密度函数
正如我们所看到的,连续随机变量在特定值处的概率总是零。因此,对于连续变量,我们找到的是 CDF 在任何时刻的变化量。这种变化称为概率密度函数(PDF)。从微积分的角度来看,PDF 是 CDF 的微分。连续随机变量的 PDF 将如下所示。
作者提供的图片
要找到两个点之间的概率,我们需要找到曲线下的面积(使用积分微积分在这两个点之间积分 PDF)。幸运的是,我们不需要自己计算这些值。我们可以使用 Python 或任何电子表格软件轻松计算这些值。
注意:对于任何连续概率分布,
这是因为特定值的精确概率总是零。
连续均匀分布
与离散均匀分布类似,对于连续均匀分布,事件发生的概率是均等的。只不过可能的值的范围是无限的。
连续分布的概率密度函数(PDF)由以下公式给出:
对于 a ≤ x ≤ b
其中 a 是最低可能值,b 是最高可能值。
连续均匀分布的应用
从在 Suber 应用程序上预定出租车到出租车到达之间的时间通常在五分钟到十五分钟之间。客户等待出租车超过十二分钟的概率是多少?
让我们尝试在不使用公式的情况下解决这个问题。类似于离散均匀分布的 PMF,我们可以绘制连续分布的 PDF。
作者提供的图片
整个曲线下的面积应该等于 1,因为确定在这个时间段内出租车会来。矩形的宽度是 15–5 = 10 单位。由于面积是 1,因此高度应该低于 0.1。
为了找到 P(x > 12),我们可以找到阴影部分 P(5 ≤ x ≤ 12)的面积,然后从 1 中减去。
作者提供的图片
阴影部分的面积
因此,
指数分布
另一种常见的连续分布是指数分布。它建模了直到发生特定事件所需的时间。例如,服务客户所需的时间、一个部件在更换前的使用寿命等。指数分布的 PDF 由下式给出
其中 m 称为衰减因子。它是历史平均等待时间的倒数。
该分布的 CDF 由下式给出
指数分布的应用
平均而言,如果每天使用,一辆车需要在 120 天后进行维护。服务时间可以用指数分布建模。
-
车辆在 150 天后进行服务的概率是多少?
-
多少比例的车辆将在 90 天后但少于 120 天进行服务?
这里的衰减因子(m) =
首先让我们解决 A 部分。
我们需要找到 P(x > 150)。所需的概率由图中的阴影部分给出。
作者提供的图片
在这种情况下,更容易找到未阴影区域的面积。如果你回忆一下 CDF 的讨论,未阴影区域的面积是 CDF (X = 150)。一旦我们找到它,我们可以简单地从 1 中减去以找到阴影部分的面积。
由于我们已经有了 CDF 函数,我们可以将所需的概率重新写为
在这种情况下,x = 150\。以及衰减因子
插入值,我们得到
这在直观上是正确的——由于大多数车在 120 天后需要维护,因此在 150 天后需要维护的车非常少。
完成这个之后,我们也可以类似地解决 B 部分。
B 部分要求我们找到一辆车在 90 到 120 天之间被服务的概率,即 P(90<X<120)。从图形上看,阴影部分的面积就是所需的概率。
图像由作者提供
我们可以将其分解为
再次,我们将使用累积分布函数(CDF)的属性来找到这些值。
因此
从图形上,这可以用以下方式表示。
图像由作者提供
我们现在可以轻松计算概率。
CDF(X = 120) = 0.6321
CDF(X = 90) = 0.5276
P(90 < X < 120) = 0.6321–0.5276 = 0.1045 或 10.45%
泊松分布和指数分布的比较
指数分布中的衰减因子 m 是泊松分布中平均发生时间λ的倒数。然而,它们之间存在一些差异。在泊松分布中,随机变量 x(给定时间段内的发生次数)是离散的。然而,在指数分布中,随机变量 x(下一个成功的时间)是连续的。
以我们之前 WhatsApp 小组中的示例为例。我们在特定时间段内收到的消息数量是离散变量,只能取正整数值(0, 1, 2, 3, …);而下一个消息的时间是连续变量,可以取分数值(2.3 分钟,4.8 小时等)。这种差异可以通过下图来说明。
图像由作者提供
正态分布
不论你的学术或专业背景如何,你都应该听说过正态分布。它是最常用(也是最滥用)的概率分布之一。理解正态分布对于理解推断统计至关重要。正态分布的概率密度函数(PDF)具有钟形曲线,描述如下:
这里μ是分布的均值,σ是标准差。
正态分布的概率密度函数(PDF)图形如下。图形围绕均值对称。
图像由作者提供
注:你无需记住公式,只需知道如何使用它。
由于曲线下的面积必须等于 1,标准差的变化会导致曲线变得更胖或更高,具体取决于σ。
图像由作者提供
因此,我们可以有无限多的正态分布。
标准正态分布
标准正态分布的均值为 0,标准差为 1。我们可以通过执行简单的变换将任何均值为μ、标准差为σ的正态分布转换为标准正态分布。
z 分数表示 x 距离均值有多少个标准差。
这确保了我们只需一个这样的表来处理所有的概率计算。这是在电子表格和统计程序出现之前的做法,但仍然非常有用。
如何使用 z-scores 来找出概率?
大多数软件能够非常轻松地提供 PDF 和 CDF 值。为了更好地理解,让我们使用这里提供的z-score 表。
正如标题所示,表格提供了 z-score 左侧的面积。
图片由作者提供
换句话说,表格提供了特定 z-score 的 CDF 值。因此,如果我们想找出 P(z < -1.63)的概率,我们可以直接从表中读取,如下图所示。首先,垂直读取前两位有效数字(-1.6),然后水平读取最后一位数字(3)。因此,CDF (z = -1.63) = P (z < -1.63) = 0.05155 或 5.155%。
如果我们想找出 P (-1.8 < z < 1.2),可以用与之前处理指数分布相同的方法。
从表格中读取值,
P (-1.8 < z < 1.2) = 0.88493–0.03593 = 0.849 或 84.9%
正态分布的经验规则
图片由作者提供
对于任何具有均值μ和标准差σ的正态分布,经验规则指出
- 均值两侧一个σ范围内的面积大约为 0.68。换句话说,68%的时间,你可以期待 x 在μ — σ和μ + σ之间
我们可以很容易地使用 z-score 表找到这个。我们从将 x 值转换为 z-scores 开始。x = μ + σ的值会给我们一个 z-score
同样,x = μ — σ的值将转换为 z-score -1
我们现在可以通过使用之前的 CDF 表找到这两个 z-score 之间的面积
图片由作者提供
P (-1 < z < 1) = CDF (z < 1) — CDF (z < -1) = 0.84134–0.15866 = 0.68269 或 68.27%
-
类似地,约 95%的情况下,x 会在μ — 2σ和μ + 2σ之间
-
并且,约 99.7%的时间,x 会在μ — 3σ和μ + 3σ之间
让我们在实际生活中使用正态分布。
正态分布的应用
如果已知欧洲成年男性的平均体重为 183 磅,标准差为 19 磅。则在欧洲有多少成年男性体重大于 160 磅但低于 200 磅?
已知:μ = 183,σ = 19。我们需要找出 P (160 < x < 200)。
这个问题可以通过将 x 值转换为 z-score 来解决,如前所述。
因此,问题简化为
P (160 < x < 200) = P (-0.684 < z < 0.895)
= CDF (0.895) — CDF(-0.684)
= 0.8146–0.2470
= 0.5676 或 56.76%
结论
在这篇文章中,我们探讨了处理随机变量所涉及的基本术语和技术。我们还介绍了离散和连续随机变量的不同概率分布,并通过正态分布进行了总结。我们还推荐了这份全面的 统计学备忘单,其中包含了统计学和概率论的重要术语和方程式。
最初发布于 https://www.stratascratch.com*。
随机化非常大的数据集
考虑这样一个问题:如何随机化一个大到连内存都容纳不下的数据集。本文描述了如何在 Python 中轻松且(相对)快速地完成这一操作。
·
关注 发表在 Towards Data Science ·6 min read·2023 年 8 月 26 日
--
如今,发现以 Gigabytes 甚至 Terabytes 计量的数据集并不罕见。如此大量的数据可以极大地帮助训练过程,创建出强大的机器学习模型。但如何随机化如此庞大的数据集呢?
图片由 Jess Bailey 提供,发布在 Unsplash
想象一下你有一个非常大的数据集,每行一个项目在一个文件中。数据的具体细节对于我们的目标无关紧要。数据集可以是 CSV(逗号分隔值)或 TSV(制表符分隔值)文件中的行,或者每行是一个 JSON 对象,或者是大点云中某一点的 X、Y、Z 值。我们只需要的是数据集按每行一个项目进行格式化。
对于包含较小数据集的文件,可以使用像这样的简单 Python 函数在内存中对文件进行随机化(称为“洗牌”):
import random
def shuffle_in_memory(filename_in, filename_out):
# Shuffle a file, line-by-line
with open(filename_in) as fp:
lines = fp.readlines()
# Randomize them in place:
random.shuffle(lines)
# Write the new order out:
with open(filename_out, "w") as fp:
fp.writelines(lines)
shuffle_in_memory()函数接收一个输入文件名和一个输出文件名,在内存中使用内置的random.shuffle()函数洗牌,并将随机化的数据写出。顾名思义,该函数要求文件的所有行一次性加载到内存中。
要测试这个函数,让我们制作一些测试文件。函数make_file()接收你希望在测试文件中包含的行数。该函数将创建文件并返回文件名。
import os
def make_file(lines):
filename = "test-%s.txt" % lines
print("Making test file '%s'..." % filename)
with open(filename, "w") as fp:
for i in range(lines):
fp.write(f"Line {i}\n")
print("Done!")
return filename
例如,要创建一个名为“test-1000.txt”的文件,其中包含 100 行,如下所示:
filename_in = make_file(1000)
运行此函数后,你应该在当前目录中找到一个名为“test-1000.txt”的文件,包含 1,000 行文本,如下所示:
Line 0
Line 1
Line 2
Line 3
Line 4
Line 5
Line 6
Line 7
Line 8
Line 9
...
要测试我们的shuffle_in_memory()函数,我们将命名一个输出文件,将字符串保存在变量filename_out中,并调用该函数:
filename_out = "test-randomized-1000.txt"
shuffle_in_memory(filename_in, filename_out)
现在,你的目录中应该有一个第二个文件,名为“test-randomized-1000.txt”。它的大小应与“test-1000.txt”完全相同,行数也完全相同,但顺序是随机的:
Line 110
Line 592
Line 887
Line 366
Line 52
Line 22
Line 891
Line 83
Line 931
Line 408
...
好的,现在大问题来了:如果我们有一个非常大的文件怎么办?让我们创建一个中等大小的文件,比如说 1000 万行。(对于大多数计算机来说,这仍然足够小,可以在内存中随机化,但大小足够大,可以进行练习。)如前所述,我们通过调用make_file()来创建输入文件:
filename_in_big = make_file(10_000_000)
这将花费几秒钟。之后,你应该在目录中有一个名为“test-10000000.txt”的文件。它应该与之前的文件一样开始,但将包含 1000 万行。文件大小约为 128 MB。
如何进行随机化?如果我们不想使用所有的 RAM,或者 RAM 不够,我们可以改用硬盘。这里有一个基于类似问题的递归算法,排序。以下函数shuffle()是基于归并排序算法。
首先,它检查一个文件是否足够小,可以在内存中进行洗牌(递归函数术语中的基本情况)。参数memory_limit以字节为单位。如果文件大小小于memory_limit,那么它将在内存中进行洗牌。如果太大,则会随机分割成多个较小的文件,每个文件递归地进行洗牌。最后,将较小的洗牌文件的内容合并回一起。
这是函数:
import tempfile
def shuffle(filename_in, filename_out, memory_limit, file_split_count,
depth=0, debug=False):
if os.path.getsize(filename_in) < memory_limit:
if debug: print(" " * depth, f"Level {depth + 1}",
"Shuffle in memory...")
shuffle_in_memory(filename_in, filename_out)
else:
if debug: print(
" " * depth, f"Level {depth + 1}",
f"{os.path.getsize(filename_in)} is too big;",
f"Split into {file_split_count} files..."
)
# Split the big file into smaller files
temp_files = [tempfile.NamedTemporaryFile('w+', delete=False)
for i in range(file_split_count)]
for line in open(filename_in):
random_index = random.randint(0, len(temp_files) - 1)
temp_files[random_index].write(line)
# Now we shuffle each smaller file
for temp_file in temp_files:
temp_file.close()
shuffle(temp_file.name, temp_file.name, memory_limit,
file_split_count, depth+1, debug)
# And merge back in place of the original
if debug: print(" " * depth, f"Level {depth + 1}",
"Merge files...")
merge_files(temp_files, filename_out)
如果这是一个排序算法,我们将以一种小心的方式将文件合并在一起,以创建一个排序顺序。然而,对于洗牌,我们不关心以特定顺序合并它们,因为我们希望它们是随机的。因此,merge_files()函数看起来像这样:
def merge_files(temp_files, filename_out):
with open(filename_out, "w") as fp_out:
for temp_file in temp_files:
with open(temp_file.name) as fp:
line = fp.readline()
while line:
fp_out.write(line)
line = fp.readline()
请注意,我们会小心地避免一次性将所有文件的行读入内存。我们通过将内存洗牌的限制设置为文件的大小来测试这一点。由于文件大小不小于 128,888,890,它将被分割成若干个较小的文件。对于这个例子,我们将大文件分成两个,每个文件都足够小,可以在内存中进行洗牌:
filename_out_big = "test-randomized-10000000.txt"
shuffle(filename_in_big, filename_out_big, 128_888_890, 2, debug=True)
这个调用的结果如下:
Level 1 128888890 is too big; Split into 2 files...
Level 2 Shuffle in memory...
Level 2 Shuffle in memory...
Level 1 Merge files...
结果文件“test-randomized-10000000.txt”的内容应包含 1000 万行,所有行都是随机的。更好的测试方法是将所需内存缩小到远小于文件大小,并将过大的文件分割成超过 2 个。假设我们只想使用约 1 MB 的 RAM,并将文件分割成 20 个较小的文件:
shuffle(filename_in_big, filename_out_big, 1_000_000, 20, debug=True)
这个例子将使用不超过 1 MB 的 RAM,并递归地处理大于此大小的子文件,每次处理 20 个。
这个算法可以处理任何大小的文件(当然,你需要足够的磁盘空间!)。你为shuffle_in_memory()分配的内存越多,运行速度就会越快。如果较小的文件数量过多,你将花费太多时间打开和关闭文件。你可以尝试不同的memory_limit值,但我发现 20 到 200 之间的值效果很好。初始文件越大,你可能需要更多的子文件。
你还可以使用其他算法。我曾对将所有行写入 SQLite 数据库、以随机顺序 SELECT 它们抱有很高的期望,但它的速度并没有比上面的代码更快。
import sqlite3
def shuffle_sql(filename_in, filename_out, memory_limit, depth=0, debug=False):
if os.path.getsize(filename_in) < memory_limit:
if debug: print(" " * depth, f"Level {depth + 1}",
"Shuffle in memory...")
shuffle_in_memory(filename_in, filename_out)
else:
if debug: print(
" " * depth, f"Level {depth + 1}",
f"{os.path.getsize(filename_in)} is too big;",
f"Writing to SQLite database..."
)
temp_db = tempfile.NamedTemporaryFile(delete=False)
connection = sqlite3.connect(temp_db.name)
cursor = connection.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS lines (
line TEXT
);
""")
with open(filename_in) as fp:
line = fp.readline()
while line:
cursor.execute("INSERT INTO lines (line) VALUES (?);", [line])
line = fp.readline()
connection.commit()
with open(filename_out, "w") as fp:
for line in cursor.execute("""
SELECT line FROM lines ORDER BY random();
"""):
fp.write(line[0])
shuffle_sql(filename_in_big, filename_out_big, 1_000_000, debug=True)
你能在纯 Python 中击败递归洗牌算法吗?如果能,我很想听听你的方法!
对人工智能、机器学习和数据科学感兴趣吗?请点赞并关注。告诉我你感兴趣的内容!
SQL 中的 RANK() 与 DENSE_RANK() 和 ROW_NUMBER()
原文:
towardsdatascience.com/rank-vs-dense-rank-vs-row-number-sql-1b6c09097b21
理解 SQL 中这些窗口函数之间的区别
·发布于 Towards Data Science ·6 分钟阅读·2023 年 3 月 6 日
--
在 SQL 的世界里,窗口函数是一个强大的构造,允许用户以精确的方式对数据进行分段和操作。通过根据特定列和排序标准对数据进行分组,窗口函数使得在分区内进行高级计算成为可能。
在这篇全面的教程中,我们将深入探讨三种最常用的窗口函数:ROW_NUMBER()
、DENSE_RANK()
和 RANK()
。无论你是经验丰富的 SQL 专家还是刚刚入门,本指南将提供你掌握这些关键工具所需的知识和实际示例。
所有包含的代码片段均已在 MySQL 数据库上测试,这些函数在几乎任何 SQL 方言中都应运行,只需最小或无需修改。
订阅 Data Pipeline,这是一个专注于数据工程的新闻通讯
首先,让我们创建一个示例表格,在整个教程中将参考这个表格,以演示几个不同的概念。
CREATE TABLE employees (
id integer,
first_name varchar(20),
last_name varchar(20),
position varchar(20),
salary varchar(20)
)
INSERT INTO employees VALUES
(1, 'Andrew', 'Brown', 'Manager', 100000),
(2, 'Maria', 'Johnson', 'Manager', 105000),
(3, 'John', 'Anderson', 'Senior Manager', 130000),
(4, 'Alex', 'Purple', 'Associate', 50000),
(5, 'George', 'Bull', 'Senior Associate', 65000),
(6, 'Jess', 'Fridman', 'Associate', 48000),
(7, 'Marion', 'White', 'Senior Associate', 65000),
(8, 'Andreea', 'Berton', 'Manager', 102000),
(9, 'Bob', 'Johanson', 'Associate', 45000),
(10, 'Georgia', 'Hoffman', 'Senior Associate', 66000),
(11, 'Johan', 'Peterson', 'Senior Associate', 58000);
这是我们的示例表格的样子:
SELECT * FROM employees;
| id | first_name | last_name | position | salary |
| --- | ---------- | --------- | ---------------- | ------ |
| 1 | Andrew | Brown | Manager | 100000 |
| 2 | Maria | Johnson | Manager | 105000 |
| 3 | John | Anderson | Senior Manager | 130000 |
| 4 | Alex | Purple | Associate | 50000 |
| 5 | George | Bull | Senior Associate | 65000 |
| 6 | Jess | Fridman | Associate | 48000 |
| 7 | Marion | White | Senior Associate | 65000 |
| 8 | Andreea | Berton | Manager | 102000 |
| 9 | Bob | Johanson | Associate | 45000 |
| 10 | Georgia | Hoffman | Senior Associate | 66000 |
| 11 | Johan | Peterson | Senior Associate | 58000 |
什么是窗口函数?
窗口函数用于 SQL 中,对一组行执行计算,并为每一行返回一个值。注意与聚合函数的区别,后者用于返回每组的单个值。窗口函数通常用于计算累计总和、排名以及移动平均值。
窗口函数的有效语法应包括
-
**OVER**
子句在函数之后指定,用于引用窗口 -
和窗口规范,指定了行应如何分组(该规范可以包括
PARTITION BY
和/或ORDER BY
子句)
RANK()窗口函数
RANK()
函数返回指定组或分区内每行的排名。如果在分区内有多行具有相同的值,那么所有这些行将被分配相同的排名。在这种情况下,排名中会出现间隙,因为随后的行将被分配与其实际排名相对应的编号(而不是下一个可用的排名)。
现在我们考虑创建一个employees
记录的排名,使得在相同职位上的员工根据收入进行排名。这可以通过以下查询实现:
SELECT
*,
RANK() OVER (PARTITION BY position ORDER BY salary DESC) AS emp_pos_rank
FROM
employees;
| id | first_name | last_name | position | salary | emp_pos_rank |
| --- | ---------- | --------- | ---------------- | ------ | ------------ |
| 4 | Alex | Purple | Associate | 50000 | 1 |
| 6 | Jess | Fridman | Associate | 48000 | 2 |
| 9 | Bob | Johanson | Associate | 45000 | 3 |
| 2 | Maria | Johnson | Manager | 105000 | 1 |
| 8 | Andreea | Berton | Manager | 102000 | 2 |
| 1 | Andrew | Brown | Manager | 100000 | 3 |
| 10 | Georgia | Hoffman | Senior Associate | 66000 | 1 |
| 5 | George | Bull | Senior Associate | 65000 | 2 |
| 7 | Marion | White | Senior Associate | 65000 | 2 |
| 11 | Johan | Peterson | Senior Associate | 58000 | 4 |
| 3 | John | Anderson | Senior Manager | 130000 | 1 |
注意Senior Associate
职位的排名出现了间隙。两名员工被分配了第二名,这意味着紧随这两名记录的记录将被分配为4
(而不是3
)。
DENSE_RANK()窗口函数
DENSE_RANK()
函数返回指定组或分区内每行的排名。与RANK()
不同,DENSE_RANK()
不会出现间隙:
SELECT
*,
DENSE_RANK() OVER (PARTITION BY position ORDER BY salary DESC) AS emp_pos_rank
FROM
employees;
| id | first_name | last_name | position | salary | emp_pos_rank |
| --- | ---------- | --------- | ---------------- | ------ | ------------ |
| 4 | Alex | Purple | Associate | 50000 | 1 |
| 6 | Jess | Fridman | Associate | 48000 | 2 |
| 9 | Bob | Johanson | Associate | 45000 | 3 |
| 2 | Maria | Johnson | Manager | 105000 | 1 |
| 8 | Andreea | Berton | Manager | 102000 | 2 |
| 1 | Andrew | Brown | Manager | 100000 | 3 |
| 10 | Georgia | Hoffman | Senior Associate | 66000 | 1 |
| 5 | George | Bull | Senior Associate | 65000 | 2 |
| 7 | Marion | White | Senior Associate | 65000 | 2 |
| 11 | Johan | Peterson | Senior Associate | 58000 | 3 |
| 3 | John | Anderson | Senior Manager | 130000 | 1 |
注意第 4 名 Senior Associate 现在会获得 3 名排名,因为其他两名员工共享第二名排名,并且DENSE_RANK
窗口函数不会创建间隙。
ROW_NUMBER()窗口函数
最后,ROW_NUMBER
窗口函数会为每个分区内的每一行分配一个从索引1
开始的编号。
SELECT
*,
ROW_NUMBER() OVER (PARTITION BY position ORDER BY salary DESC) AS emp_pos_rank
FROM
employees;
| id | first_name | last_name | position | salary | emp_pos_rank |
| --- | ---------- | --------- | ---------------- | ------ | ------------ |
| 4 | Alex | Purple | Associate | 50000 | 1 |
| 6 | Jess | Fridman | Associate | 48000 | 2 |
| 9 | Bob | Johanson | Associate | 45000 | 3 |
| 2 | Maria | Johnson | Manager | 105000 | 1 |
| 8 | Andreea | Berton | Manager | 102000 | 2 |
| 1 | Andrew | Brown | Manager | 100000 | 3 |
| 10 | Georgia | Hoffman | Senior Associate | 66000 | 1 |
| 5 | George | Bull | Senior Associate | 65000 | 2 |
| 7 | Marion | White | Senior Associate | 65000 | 3 |
| 11 | Johan | Peterson | Senior Associate | 58000 | 4 |
| 3 | John | Anderson | Senior Manager | 130000 | 1 |
最终思考
总之,窗口函数是 SQL 的一个强大功能,允许在结果集的特定子集上执行复杂的计算。它们可以用来计算累计总数、滚动平均值和其他需要参考邻近行的指标。
在今天的文章中,我们讨论了RANK()
、DENSE_RANK()
和ROW_NUMBER()
函数之间的区别。通过理解这些窗口函数之间的差异,你可以选择最适合你特定用例的函数,并优化你的 SQL 查询性能。
👉 订阅数据管道,一份专注于数据工程的新闻通讯
👇相关的文章你可能也会喜欢 👇
对正在主导数据领域的 dbt 的温和介绍
towardsdatascience.com ## ETL 与 ELT:有什么区别?
数据工程背景下 ETL 与 ELT 的比较
towardsdatascience.com ## 什么是 SQL 中的 CTE
理解 SQL 中的公共表表达式(CTE)
towardsdatascience.com
使用 PCA 在 PySpark 中对钻石进行排名
原文:
towardsdatascience.com/ranking-diamonds-with-pca-in-pyspark-a59cab7f4f1a
在 PySpark 中运行主成分分析的挑战
·发布于 Towards Data Science ·阅读时间 8 分钟·2023 年 12 月 22 日
--
照片由 Edgar Soto 提供,来源于 Unsplash
介绍
这里是另一篇关于 PySpark 的帖子。我很享受写关于这个主题的文章,因为我感觉我们缺乏关于 PySpark 的优质博客,特别是当我们谈论 MLlib 中的机器学习时——顺便说一下,MLlib 是 Spark 的Machine Learning Library,旨在与大数据在并行环境中一起工作。
我可以说 Spark 文档非常出色。它组织得很有条理,示例也很容易跟随。但是,在 Spark 中进行机器学习并不是最友好的事情。
在这篇文章中,我使用 PCA 模型来帮助我创建钻石的排名,并且遇到了一些挑战,我们将在接下来的内容中讨论这些挑战。
我之前已经 写过关于 PCA 的文章,以及它在降维方面的作用,还有 创建排名 的方法。然而,这是我第一次使用 Spark 进行这个操作,目的是在大数据环境中复现这一技术。
让我们看看结果。
编码
让我们从导入模块开始我们的编码。
from pyspark.sql.functions import col, sum, when, mean, countDistinct
from pyspark.sql import functions as F
from pyspark.ml.feature import PCA
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.window import Window
数据集
本练习中使用的数据集是 Diamonds,来自 ggplot2 包,并且按照 Creative Commons 4.0 许可协议使用。
在这里,我从 Databricks 示例数据集中加载数据,并从一个变量中删除两个已知的异常值。PCA 受到异常值的影响。它们由于非常大的扭曲方差而倾向于主导一个组件。
# Point file path
path = '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv'
# Load Data
df = spark.read.csv(path, header=True, inferSchema= True)
df = df.filter( col('y') < 30)
接下来,由于 PCA 是一种用于数值值的技术,我选择使用数据中的carat
、table
和depth
变量。我没有使用price
,因为它完全主导了组件的方差,使得模型失效。
此外,由于尺度不同,我还将数据转换为对数值,以便它们在相同的尺度上。
# Get only our selected columns
cols = [ 'carat', 'table', 'depth']
df_num = df.select(F.log1p('carat').alias('carat'),
F.log1p('table').alias('table'),
F.log1p('depth').alias('depth'))
这就是它的样子。
选择变量的对数值。图片来源于作者。
下一步是对我们的数据集进行向量化。
向量化
当我们在 Spark 中处理大数据时面临的挑战之一是机器学习算法要求数据以向量形式输入。这可能更直接,但不幸的是,它给我们的代码增加了几个步骤。
如果我们在向量化之前还有其他步骤可以执行,我们可以创建一个Pipeline
,可以一次运行之前编写的步骤。在我们的案例中,这一步是不必要的,因为我们只有一个步骤。
所以,接下来我们将创建一个VectorAssembler
。顾名思义,它会获取数字并将它们汇集成一个单一的向量,通过观察,将所有这些数字放在一个名为features
的列中。
assembler = VectorAssembler(inputCols= cols, outputCol="features")
prepared_df = assembler.transform(df_num)
结果将如下所示。
向量化的数据。图片来源于作者。
PCA
使用 MLlib 运行算法非常简单。我们实例化 PCA 并拟合数据。
# Run PCA it with all the possible components
pca = PCA(k= len(df_num.columns), inputCol="features")
pca.setOutputCol("pca_features")
# Fit
model = pca.fit(prepared_df)
如果我们想查看 PCA 运行后每个主成分的解释方差,我们需要 Pandas 来简化我们的工作,能够快速创建一个数据框架并从model.explainedVariance
中插入一个新列,列名为 PC 名称。
# See explained variance of the PCs
df_var = ps.DataFrame(model.explainedVariance, columns=['explained_var'])
df_var.insert(0, 'Component',
value= ['PC'+str(n) for n in range(1,len(df_num.columns)+1)])
在 Databricks 中,创建这个图表只需点击几下按钮。
PC1 解释了 96%的方差。图片来源于作者。
获取转换后的数据
鉴于 Spark 是为大数据设计的,它缺乏对主成分的更完整的输出。像特征向量、特征值和载荷这样的数字不容易找到作为 PCA 模型的属性。
我们得到的输出只有 PC 解释的方差和转换后数据的表。这些数字可以给我们以下信息:(1)方向,对于那些在 PC 相同方向上变化的数据点为正值,而当其变化方向相反时为负值;(2)值,数值越高,表示在该组件中变化越大。
这是获取转换后数据的代码片段。
# Get transformed Output
model.setOutputCol("output")
pca_transformed = model.transform(prepared_df).select('output')
以及生成的向量。
PCA 结果。图片来源于作者。
看看观察值#5 是与 PC1 方向相反的方差最大,同时也是 PC3 中方差最大的。
这些数字可以为我们提供排名观察值的基础。由于我们知道它们在每个组件中的变异程度以及每个组件解释的方差,我们可以通过将每个 PC 中的值乘以相应的解释方差来快速创建分数。
让我们看看怎么做。
清理变换后的数据
这里还有另一个挑战。这些向量并不太容易操作。所以我们必须绕过这个问题。
首先,我们只需收集变换后的值。
# Collect the transformed results
temp = spark.createDataFrame(pca_transformed.collect())
好的。但这些数据还不能直接使用。每行仍然是一个DenseVector
。我找到的处理方法是将临时数据集转换为字符串,以便将其拆分成 3 列。
# Cast data to text
temp = temp.select(col('output').cast('string'))
[OUT]
output
[-0.31666817904639416,-2.8476048469770308,5.090764334358518]
[-0.32446613567339955,-2.833791365368773,5.079981934273924]
[-0.3312692998472932,-2.777878254082261,5.094343264160385]
[-0.33271613063654204,-2.830862619169142,5.089337919087877]
由于值现在是字符串,因此必须去除首尾的[]
。使用split
函数,我们可以用,
作为分隔符拆分行。要去除方括号,只需切除 PC1 的第一个字符和 PC3 的最后一个字符。所有列都cast('float')
以再次转换为数值。
# split columns
df_transformed = (
temp
.select(
F.split('output', ',')[0][2:15].alias('PC1').cast('float'),
F.split('output', ',')[1].cast('float').alias('PC2'),
F.split('output', ',')[2][0:14].alias('PC3').cast('float') )
.fillna(0)
)
瞧。我们的数据准备好了用于排名。
清理后的数据。图像由作者提供。
计算分数
分数可以通过将每个 PC 的变换数据乘以相应的解释方差然后相加来计算。
PC1 * 解释的方差 PC1 +
PC2 * 解释的方差 PC2 +
PC3 * 解释的方差 PC3
然后,我们将创建一个带有索引的列,使用row_number
和Window
,以便将结果与原始数据连接。最后创建一个dense_rank
列,这就是我们的钻石排名。
# Explained variance array
expl_var = model.explainedVariance
df_transformed = (
df_transformed
.withColumn('_c0', F.row_number().over(Window.partitionBy().orderBy(F.lit(1))))
.withColumn('score', (col('PC1') * expl_var[0]) + (col('PC2') * expl_var[1]) + (col('PC3') * expl_var[2]) )
.withColumn('rank', F.dense_rank().over(Window.partitionBy().orderBy('score')) )
.sort('_c0')
)
让我们展示一下。
排名前 10 名。图像由作者提供。
前面的表格是基于carat
、table
和depth
变量的钻石排名。效果相当不错。我们可以看到,至少在这前 10 个结果中,即使没有考虑分类列,我们仍然看到更好的颜色和更高的清晰度排名。颜色从 D(最好)到 J(最差),清晰度的顺序是:(I1(最差)、SI2、SI1、VS2、VS1、VVS2、VVS1、IF(最好))。
例如:排名 1 的颜色比排名 2 好,但克拉更小且清晰度较差。排名 5 的颜色和清晰度都比排名 6 和 7 好,价格更好,但克拉较小。depth
、table
的确在其中起了作用。
离开前
我们已经结束了这篇文章。我在 Spark 中执行 PCA 时发现了一些挑战。这不是我工作过的最友好的方法之一。
我认为目前的功能还远不如统计工具如 R,后者能提供更优的算法实现。另一方面,Spark 并不是为了做统计分析而设计的。它最初是一个为 ETL 创建的大数据工具,后来不断发展并添加了更多功能。
此外,PCA 可以用于降维,当我们更改k
值来提取组件时。对于大数据应用,它可以作为一个辅助工具来减少数据集的维度,并作为分析中其他步骤的输入。
# Reduce to k components
pca = PCA(k = k, inputCol="features")
如果你喜欢这些内容,请关注我的博客获取更多信息。
阅读 Gustavo Santos 在 Medium 上的文章。数据科学家。我从数据中提取洞察,以帮助个人和公司……
gustavorsantos.medium.com](https://gustavorsantos.medium.com/?source=post_page-----a59cab7f4f1a--------------------------------)
你也可以在LinkedIn找到我。
GitHub 上的完整代码:
[## Studying/PySpark/PCA at master · gurezende/Studying
这是一个包含我测试和研究新包的代码库 - Studying/PySpark/PCA at master · gurezende/Studying
有兴趣了解更多关于 PySpark 的内容吗?
点击这里注册:Master Data Wrangling With PySpark。图片由作者提供。
这是我的课程链接,已应用折扣优惠券:Mater Data Wrangling With PySpark 在 Udemy 上。
参考资料
## Creating Scores and Rankings with PCA
使用 R 语言根据多个变量为观察值创建评分
towardsdatascience.com ## PCA: Beyond Dimensionality Reduction
学习如何使用 PCA 算法来找到共同变化的变量
towardsdatascience.com [## PCA - PySpark 3.5.0 documentation
pyspark.sql.SparkSession.builder.getOrCreate
伴我阅读:因果关系书籍俱乐部
原文:
towardsdatascience.com/read-with-me-a-causality-book-club-edd7085d6ae6
图片由Humberto Arellano提供,在Unsplash上
从一个猫咪的故事开始……
·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 10 月 26 日
--
我有三只猫。我爱它们所有的猫,但我不得不承认它们的智力水平不同。最聪明的一只是毛毛。最近,我注意到毛毛养成了一个新习惯。每当我吃午餐时,它就坐在我旁边。我愿意相信它爱我并想和我待在一起,但它冷漠的性格暗示着相反的事情。我还注意到它在我旁边坐着时很警觉,不断寻找什么。经过一些观察,我意识到它在等待一个光线反射去追逐。你知道猫咪对激光、光线反射等是多么狂热。餐桌在正午时能接受到充足的阳光,有时我手表的反射会出现在附近的墙壁或天花板上。
作者提供的图片
毛毛已经发现了一个规律——每当我坐在桌旁时,就会有光线吸引它。我的另外两只猫还没有发现这个规律,但它们有自己的策略。每当它们听到毛毛追逐光线的声音时,它们知道光线来了,很快就会加入追逐。我每次吃午餐时都会看着它们玩,这成了从忙碌的早晨恢复过来的最放松的活动。
你可能会问,这是一篇关于因果关系的技术博客还是猫的故事?尽管研究表明,看猫玩耍或做傻事能显著减轻压力,但这并不是这篇文章的主要目的。这个故事只是我为了更易理解地提出因果关系重要性的一种方式。此外,我很高兴地介绍我的第一个‘跟我读’系列,在这个系列中,我邀请你和我一起读一本书。我希望能提供一个平台,让我们深入探讨共同兴趣并分享我们的学习。我要开始的书是《因果之书》由朱迪亚·珀尔著,这本书“彻底改变了对因果关系的理解。”
我对因果关系深感兴趣,不仅因为我在工作中广泛使用它,还因为我相信它是真正的科学。你可能经常听到‘相关性不是因果关系’。然而,真正的问题是,为什么要追逐因果关系,而相关性已经能为我们做很多事情?这就是猫的故事发挥作用的地方。
基于相关性的模型 vs. 因果结构
在毛毛的脑海中,他对光线反射感兴趣,而且信不信由你,他脑中有一个预测模型。这个模型是:每当我坐在餐桌旁,光线就会出现。这个模型可以用数学方程表示:
P(Light) = 1, if I sit at the dining table
P(Light) = 0, otherwise
因此,每当他观察到我坐在餐桌旁,他就准备去追逐一些光线。很多时候,他是对的。如果我计算他的预测准确率,他在晴天时 99%的时间是对的。1%的不准确可能来自于我没有佩戴手表的情况。其他跟随毛毛行为的猫咪用这些方程表示:
P(Light) = 1, if MaoMao is chasing it
P(Light) = 0, otherwise
它们的准确性也类似。然而,它们动作较慢,总是错过光线争夺中的先机。
这两种猫咪预测模型可能总结了所有基于相关性的预测模型。模型从历史数据中学习模式,并找到最能预测目标变量的变量。在我猫的脑海中,这两种模型比人类聪明开发的机器学习和深度学习模型要简单得多。然而,它们遵循相同的原则,只是复杂性不同。
系统有效! 很多时候都是! 只要我佩戴手表并坐在合适的位置、在合适的天气下,我的猫咪们总是在午餐时间进行日常锻炼。然而,在阴天、晚餐时间太阳下山时,或者当我不佩戴手表时,它们的预测模型准确率都会降到 0。为什么?因为它们没有通过学习因果结构来正确解决问题:
image by author
它们脑海中的模型和心理模型是:
image by author
尽管我们人类很容易判断结构是否错误,我的猫们依然坚持使用它们的预测模型。MaoMao 每当我吃饭时,无论是白天还是晚上,总是和我待在一起。这就是在做预测时不了解因果关系的问题。当一切保持不变,或者没有制度变更时,基于相关性的模型效果非常好,以至于人们怀疑搞清楚因果结构的必要性。然后,当像 COVID、地缘政治摩擦和商业调整等事情发生时,模型表现突然大幅下降,你会摸不着头脑。在这种情况下,你唯一的选择就是重新训练和重建你的模型,以适应新的制度,通常可用的训练数据有限。另一种更糟的情况是你像我的其他猫一样,始终做出滞后的模型预测。在这种情况下,你只是被动反应,而不是主动,这从来不是预测模型的目的,而且你总是迟到。
因果关系如何帮助?
如果 MaoMao 理解了太阳和我的手表之间的因果关系,导致光线出现在墙上,它就会知道在晴天应该去哪里寻找光线。此外,MaoMao 还会知道在雨天或我没有戴手表的时候等待光线是没有意义的。遵循相同的原则,如果我坐在有阳光的沙发上,光线也会出现在对面的墙上;如果我不戴手表但手上拿着手机,手机也会在墙上反射光线。此外,一旦它们搞清楚了结构,它们可以通过它们的可爱来影响我的行为,促使我坐在有阳光的地方。它们可以充分发挥因果关系的潜力,享受光线的乐趣。理解因果关系的猫难道不好吗?
当然,我们的预测模型比猫的模型复杂得多。我们需要面对众多特征来构建和构建更复杂的因果结构。然而,回报也远超玩耍。这是每个公司都想知道的百万美元问题 —— 如果我做 X,我的销售会增加吗?我的客户会减少流失吗?我的利润会更高吗?
这就是 Judea Pearl 希望通过《为什么的书》帮助我们回答的问题。在介绍中,他解释了因果关系的阶梯。
参考 Judea Pearl 的 “因果关系的阶梯”
第一个层级基于被动观察做出预测,它回答类似“如果我看到我的顾客买了牙膏,那么我看到他们买牙线的可能性有多大?”的问题。
第二层级基于干预,即超越观察并改变现状。这个层级的一个典型问题是,“如果我们将牙膏价格翻倍,牙线的销售会发生什么?”或者更直接地,“我们应该设置什么价格才能卖出最多的牙线?”
第三层级涉及反事实。它位于顶层,因为它涉及想象力,这是在认知革命期间展现出的,区别于动物的能力。这里的一个典型问题是,“如果我没有提高牙膏的价格,牙线销售会发生什么?”
类人智能还是类动物能力
图片由Phillip Glickman拍摄,来源于Unsplash
从直立行走到揭示火的力量,从皮尔逊相关系数到复杂的深度神经网络再到 LLM,人类不断进行创新。我们在构建类似人类智能的预测 AI 模型方面取得了许多进展,这些模型赋予了机器令人印象深刻的能力。然而,没有因果关系的嵌入,就不会实现智能。我们可以创建更复杂的具有多个隐藏层和数学方程的模型,但基本原理与我猫的简单模型类似。它们都是基于相关性的初级解决方案,用于复杂问题。
有了这些,我将启动因果关系书友会,邀请你和我一起阅读“因果关系之书”。我计划每周阅读一章,并每两周更新一篇关于两章内容的文章。文章可以是我阅读章节后的笔记或与内容相关的有用信息。如果时间允许,我强烈鼓励你阅读这本书。如果不行,请关注我的总结和学习。这本书共有十章,按照这个速度,我们可能在年底前读完。以下是我迄今为止写的文章:
-
附加内容:因果推断在学术界和工业界的区别是什么?
与我一起阅读 通过 订阅我的邮件列表。欢迎在下方评论中分享你的学习体会。同时,这也是我第一次尝试与我一起阅读系列,请告诉我你是否有任何建议。我也非常鼓励你开设自己的博客,记录你在这个过程中的学习或想法。正如我在最新的 YouTube 视频中提到的,写作让我获得了比预期更丰富的经历。这也可能是你尝试一下的机会。
感谢阅读。希望我的猫的故事能激励你更多地了解因果关系,并与我一起踏上这段旅程。如果你喜欢这篇文章,不要忘记:
-
查看我最近的文章,内容包括 数据讲故事中的 4D:将科学变为艺术;数据科学中的持续学习; 我如何成为数据科学家;
-
订阅 我的邮件列表;
-
或者在 YouTube 上关注我,并观看我最新的 YouTube 视频,了解我的工作日常:
-
我读过的其他书籍:
-
观看我参与的其他创作:
参考资料
《为什么的书》作者:Judea Pearl
因果关系的阶梯照片:
[1] 机器人照片由 Rock’n Roll Monkey 提供,发布在 Unsplash上;
[2] 猫的照片由 Raoul Droog 提供,发布在 Unsplash上;
[3] 干预照片由 British Library 提供,发布在 Unsplash上;
[4] 人类下棋照片由 JESHOOTS.COM 提供,发布在 Unsplash上;
实时火车乘客拥挤度预测
使用无服务器的 Azure 技术为我们的旅行规划应用提供流式预测
·
关注 发布于 Towards Data Science ·11 分钟阅读·2023 年 7 月 7 日
--
图像来源:vecstock 于 Freepik
荷兰铁路公司(Dutch Railways)的旅行者可以使用荷兰铁路公司提供的应用程序来规划他们的行程。在规划行程时,应用程序会显示相关火车的拥挤程度预测。这被分为三个类别:低占用、中等或高。旅行者可以使用这些信息来决定是否要乘坐其他可能稍微不那么拥挤的火车。
图 1:带有预测乘客拥挤程度的旅行应用(使用 1、2 或 3 人表示)。图像由作者提供。
这些预测是通过批处理过程进行的。一个机器学习模型会定期在历史数据上进行训练,每天早晨一个过程会运行以预测未来几天火车的拥挤程度。这是通过预测预期的乘客数量并将其与为该线路计划的火车的容量结合起来完成的。
然而,在白天可能会发生事故,导致火车被取消、改道,或者可能原计划是双层火车,但只提供了单层火车。因此,旅行者将看到过时的拥挤信息。大约 20% 的出发火车在旅行当天会更改容量,且往往是在出发前不久。
在这篇博客中,我们解释了如何构建一个流式处理管道,该管道接收有关计划线路的火车长度和类型的实时信息,并更新应用程序中的预期拥挤程度。我们遵循 Lambda 架构,其中我们的夜间预测实现了批处理层,而更新过程实现了流处理层。该管道目前在生产环境中运行,为使用我们应用程序的荷兰所有火车旅行者提供更实时的预期拥挤程度视图。
图 2:使用 Lambda 架构的应用程序架构简化视图。图像由 draw.io 创建。
我们描述了实现此架构的方法。我们的第一次实现是使用 Spark Structured Streaming,但结果并未如我们预期的那样。基于我们的经验(我们将讨论这一点),我们决定采用不同的方法,即使用 Azure 云中的无服务器资源。
初步尝试:Spark Structured Streaming
我们的每日拥挤度预测在 Databricks 上使用 Spark 进行数据处理。由于 Spark 支持流式数据处理,因此在 Spark Structured Streaming 中实现我们预测的实时更新似乎是一个合乎逻辑的选择。这个决定的好处是平台已经可用,我们可以使用我们已有经验的 DataFrame 方式实现逻辑。
我们从一个批量版本的模型开始实现,然后将其转换为纯 Spark Structured Streaming 实现。最终,我们得到一个小笔记本以引导流作业和一个包含所需逻辑的自定义 Python 包。
在开发过程中,我们学到了一些关于使用 Structured Streaming 编程的知识。首先,SQL DataFrames 和 Structured Streaming DataFrames 的编程接口是不一样的。Structured Streaming 在可做的事情上限制更多,这意味着我们不能将批量模型逐一实现为流模式,因此不得不多次修订算法以使其有效。Structured Streaming 接口的有限表达能力导致代码变得难以阅读,因此也难以维护。
一个简单的例子是,我们希望在两个数据流上基于时间窗口执行外连接。然而,Spark Structured Streaming 需要在连接条件中有等式,而我们没有两个相同数据的列。我们尝试向两个数据流中添加两个具有相同值的文字字段以实现等式,但 Spark 并不那么容易被欺骗。最终我们创建了一个“千年”字段,因为我们的时间戳都在第三千年:这样做有效,但实际上我们制造了一个“Y3K”错误。
此外,我们不得不将算法拆分成多个步骤,因为模型的不同部分有不同的时间限制,而这些限制无法在一个单一的结构化流作业中实现。我们选择将模型拆分成几个部分,使用 Azure Event Hubs 作为持久存储层将它们连接起来。这种方法的好处是处理的每一部分都有明确的目标,并且可以单独测试。
图 3:使用 Spark Structured Streaming 处理列车容量更新的流处理概述。图像使用 draw.io 创建。
我们用两种方式测试了我们的流程。对于单元测试,我们会简单地使用手工制作的批量 Spark SQL DataFrames 来测试流逻辑。这意味着我们可以测试流动流程的部分而无需实际启动流作业。这种方法捕获了许多功能需求,但无法捕获任何时间问题。第二步测试使用了 Spark Structured Streaming 内存接收器,以流模式运行查询来捕获一些时间效应。
最终,我们部署了代码,并且发现我们的云账单急剧增加。我们确定了两个原因:首先,Databricks 是一个很好的批处理分析作业解决方案,但它持续运行以处理流式作业的成本非常高。其次,我们雇主的信息安全政策要求我们记录数据访问。由于结构化流的状态存储可能包含数据,我们也必须记录这些数据。然而,状态存储更新非常频繁,包含许多小文件,这导致产生大量日志,捕获这些日志的成本很高。
最终,我们决定放弃这种方法。由于我们尝试解决的问题导致的云成本过高,再加上 Spark Structured Streaming 表达能力有限,使得模型实现非常难以理解和维护,我们得出结论认为我们不愿进一步投资改进这种方法,而是看看是否可以以不同的方式解决这个问题。
使用无服务器技术重新设计
注意到流程中的许多部分不需要状态,我们选择了一个使用 Azure Functions 作为计算平台的系统,以便每条消息可以单独处理。需要状态的地方我们使用 Stream Analytics。这使我们能够比较消息、重放消息或将其与另一个流连接。为了快速访问辅助数据,我们使用 Cosmos 数据库。我们仍然使用 Azure Event Hubs 将所有部分连接在一起。
图 4:使用无服务器技术的最终架构。图像使用 draw.io 创建。
Azure Functions
Azure Functions 是对事件流应用操作的简单方法。它们对流中的每个事件分别调用,使得业务逻辑容易理解。它们原生支持 Python,使编写可维护的操作变得简单。由于平台管理所有的云连接样板代码,它们可以很容易地在本地开发和测试。我们在流程的各个部分使用它们:
-
一些函数仅仅过滤传入的消息,从而减少后续步骤的计算负载,降低容量和成本;
-
一些函数通过将消息与例如 Cosmos DB 中的其他数据源进行连接来丰富消息;
-
其他功能将消息从一种格式转换为例如最终输出格式;
-
最终,我们使用 Azure Functions 将数据从批处理层摄取到流处理层。
过滤、丰富和转换
执行这些步骤的函数是直接的 Python 代码。例如,过滤函数的主要部分仅仅是几行代码:
def main(event: func.EventHubEvent, evh: func.Out[bytes]) -> None:
"""
Filter messages to only send relevant messages for our streaming flow.
"""
message = json.loads(event.get_body().decode("utf-8"))
if _is_ns_operator(message):
message = _remove_keys(message)
message = _add_build_id(message)
evh.set(str.encode(json.dumps(message)))
列表 1:示例 Azure Function 代码,用于过滤和转换消息。
在这里,我们处理每条消息,过滤出与我们公司运营的列车相关的消息,并从消息中移除我们不感兴趣的键(数据字段)。最后,我们向消息中添加构建 ID 元数据,以便我们在调试时有一些追踪信息。对于感兴趣的读者,JSON 字符串使用 str.encode() 编码为 Bytes 对象。如果将普通字符串发送到 Event Hub,它会自动被美化显示,这会在消息中引入大量空白。Bytes 对象则不会发生任何更改。
数据摄取到快速 Cosmos 数据库
为了重新计算列车拥挤度,需要快速访问预测的乘客数量、新的滚动库存的容量以及低、中和高分类的边界。这些数据是我们批处理流程的一部分,每天生成,并以 parquet 格式写入到我们的数据湖中。每次重新计算时从数据湖加载这些数据太慢。我们利用 Azure Cosmos 数据库的键值存储,以低延迟提供所需的静态数据,用于 Azure Functions 重新计算列车拥挤度。
理想的情况是,我们从每晚的批处理流程中触发数据摄取,并且能够接收到摄取是否成功或失败的反馈。摄取过程还需要能够读取包含复杂类型的 parquet 文件,这些文件在 Azure 数据工厂复制活动中不再受支持。我们的解决方案是利用 Azure Durable Functions。这是标准 Azure Functions 平台的扩展,支持有状态的长期运行函数。具体而言,Durable Functions 支持 webhooks,这使我们能够向协调器反馈摄取是否成功。
数据摄取的工作流程如下。我们的每晚批处理流程触发一个 Durable Function。该 Durable Function 选择正确的 Activity Function 以处理需要摄取的数据源,并针对每个可用的 parquet 文件触发该 Activity。然后,我们使用 pandas 读取每个文件,执行一些简单的转换,并将记录批量插入到 Cosmos 数据库中。Durable Function 会自动跟踪任何失败情况,并重新尝试该函数。
Azure Stream Analytics
有些操作我们无法通过 Azure Functions 轻松完成。这主要适用于有状态的操作,或涉及在时间窗口内合并消息的操作。
我们的每日拥挤度预测是在一个批处理流程中完成的,该流程不会立即计算预测结果。这个过程需要时间,在这段时间内可能会有新的列车容量更新。如果发生这种情况,我们希望更新拥挤度两次:首先是在最新的上一个预测上,然后是在新的预测结果变得可用时。我们在这里使用 Azure Stream Analytics 来保持更新消息的状态,并在新批次预测可用时从特定时间戳重新播放这些消息。
Azure Stream Analytics 查询是用 SQL 方言编写的。实现转换相对简单。然而,当消息吞吐量需要很高时,需要特别注意。在我们的案例中,直接实现无法跟上输入流,我们必须确保流分析查询能够以尴尬并行的方式运行。
尴尬并行查询有一些要求和限制。它们需要处理分区数据,并且需要在分区内执行有状态操作(例如连接)。这意味着在连接两个 Event Hub 流时,它们必须具有相同数量的分区,且第一个 Event Hub 的分区 1 上的数据只能与第二个 Event Hub 的分区 1 上的数据进行连接。
为了解决这个问题,我们在多个 Event Hub 分区上复制了一些数据,并基本上实现了广播连接操作。我们在下面的查询中展示了这一点。在这里,我们的每个拥挤度预测都会被分配一个批次 ID 和一个批次开始时间,用于决定哪个火车容量更新消息适用于哪个预测。如果消息在计算新一组预测期间到达,它可以适用于多个预测。在这种情况下,会输出多个消息。每个批次 ID 会在多个 Event Hub 分区上复制。
SELECT
batchid.batch_id,
batchid.batch_start_time,
event.message,
event.message_timestamp INTO [Target]
FROM
[SourceData] event TIMESTAMP BY event.message_timestamp PARTITION BY PartitionId
JOIN [BatchId] batchid TIMESTAMP BY batchid.EventEnqueuedUtcTime PARTITION BY PartitionId
ON
-- Join if the batch id message was received before the message (positive DATEDIFF) and
-- when replay when a batch id message was received after the message (negative DATEDIFF),
-- but only if the message was enqueued after the batch start time.
-- To allow fast re-ingesting data we discard messages which are no longer valid for the batch
DATEDIFF(HOUR, batchid, event) BETWEEN - 24 AND 24
AND CAST(batchid.batch_start_time AS datetime) <= CAST(event.message_timestamp AS datetime)
AND CAST(event.message.valid_until AS datetime) >= CAST(batchid.batch_start_time AS datetime)
AND event.PartitionId = batchid.PartitionId
清单 2:将相应的批次 ID 添加到每条消息中的示例 Azure Stream Analytics 查询。
端到端集成测试
从项目的初始提交开始,我们决定对流处理流执行自动化端到端集成测试。这种测试的形式是用我们生成的示例消息对入口 Event Hubs 进行种子填充,然后验证在输出 Event Hubs 上创建的消息。我们还将 Cosmos 数据库摄取包含在这个集成测试流程中。将这些测试作为我们持续部署的一部分,让我们在更改时充满信心,因为流中的组件数量增加,并且复杂性也随之增加。
图 5:用于集成测试流的 CI/CD 流程步骤概述。如果需要,我们会删除早期测试中的任何遗留数据,上传新数据并启动三个数据源的数据摄取功能。然后,我们向 Event Hub 提供事件消息,并检查它们是否正确地从另一端输出。最后,我们对 Cosmos 数据库摄取进行额外检查。图片由作者提供。
结论和关键学习
在我们努力为火车旅行者提供最新的乘客拥挤度信息,即使在火车服务发生变化的情况下,我们也采用了 lambda 架构来更新我们的预测,以应对火车容量的变化。
我们最初使用 Spark Structured Streaming 的实现未能达到预期效果,因此我们转向了使用 Azure Event Hubs、Azure Functions、Azure Stream Analytics 和 Azure Cosmos DB 的无服务器架构。
这种方法的关键好处包括:
-
作为开发者,你掌控全局:哪些部分性能不佳,哪些部分产生最高成本都很清楚;
-
与 Spark Structured Streaming 相比,Azure Functions 中的纯 Python 代码可读性强、可维护性好且表达力丰富;
-
对于无状态操作,Azure Functions 成本低廉;
-
Azure 流分析是最昂贵的部分,只能在重要的地方使用(有状态或时间窗口操作);
-
新解决方案显著降低了云基础设施成本。
关键的缺点:
-
使用解耦组件如 Azure Functions 和 Azure Cosmos DB 可能导致竞争条件,前提是设计不够完善;
-
需要管理许多基础设施和小块代码:逻辑没有集中在一个地方,需要更多的测试。
如何检测 LLM 中的幻觉
原文:
towardsdatascience.com/real-time-llm-hallucination-detection-9a68bb292698
教导聊天机器人说“我不知道”
·发表于数据科学前沿 ·10 分钟阅读·2023 年 12 月 31 日
--
伊夫琳·哈特维尔是谁?
伊夫琳·哈特维尔是一位美国作家、演讲者和生活教练…
伊夫琳·哈特维尔是一位加拿大芭蕾舞演员以及创始艺术总监…
伊夫琳·哈特维尔是一位美国演员,以其在…中的角色而闻名
不,伊夫琳·哈特维尔并不是一个拥有多个虚假身份的诈骗犯,她过着一种伪装的三重生活,拥有各种职业。实际上,她根本不存在,但模型并没有告诉我它不知道,而是开始编造事实。我们正面临 LLM 幻觉的问题。
长篇详细的输出可能看起来非常可信,即使是虚构的。这是否意味着我们不能信任聊天机器人,每次都需要手动核实输出的真实性?幸运的是,通过正确的保护措施,可能会有方法使聊天机器人不太容易编造信息。
text-davinci-003 在虚构人物上的提示完成。图片由作者提供。
对于上述输出,我设置了较高的温度 0.7。我允许 LLM 更改句子的结构,以避免每次生成相同的文本。输出之间的差异应仅为语义上的,而非事实上的。
这个简单的想法引入了一种基于样本的幻觉检测机制。如果 LLM 对相同的提示产生的输出相互矛盾,那么这些输出很可能是幻觉。如果它们相互包含,这意味着信息是事实的。[2]
对于这种类型的评估,我们只需要 LLM 的文本输出。这被称为黑箱评估。此外,由于我们不需要任何外部知识,这也被称为零资源。[5]
句子嵌入的余弦距离
让我们从一种非常基本的相似性度量方法开始。我们将计算嵌入句子对应对的成对余弦相似性。我们对其进行归一化,因为我们只需要关注向量的方向,而不是大小。下面的函数以原始生成的句子 output 和一个包含 3 个样本输出的列表 sampled_passages 作为输入。所有的完成都在文章开头的图片中找到。
为生成嵌入,我使用了轻量级模型 all-MiniLM-L6-v2。嵌入一个句子将其转换为向量表示。
output = "Evelyn Hartwell is a Canadian dancer, actor, and choreographer."
output_embeddings= model.encode(output)
array([ 6.09108340e-03, -8.73148292e-02, -5.30637987e-02, -4.41815751e-03,
1.45469820e-02, 4.20340300e-02, 1.99541822e-02, -7.29453489e-02,
…
-4.08893749e-02, -5.41420840e-02, 2.05906332e-02, 9.94611382e-02,
-2.24501686e-03, 2.29083393e-02, 7.80007839e-02, -9.53456461e-02],
dtype=float32)
我们为 LLM 的每个输出生成嵌入;然后,使用 pairwise_cos_sim 函数计算成对余弦相似性。我们将原始响应与每个新样本响应进行比较,然后计算平均值。
from sentence_transformers.util import pairwise_cos_sim
from sentence_transformers import SentenceTransformer
def get_cos_sim(output,sampled_passages):
model = SentenceTransformer('all-MiniLM-L6-v2')
sentence_embeddings = model.encode(output).reshape(1, -1)
sample1_embeddings = model.encode(sampled_passages[0]).reshape(1, -1)
sample2_embeddings = model.encode(sampled_passages[1]).reshape(1, -1)
sample3_embeddings = model.encode(sampled_passages[2]).reshape(1, -1)
cos_sim_with_sample1 = pairwise_cos_sim(
sentence_embeddings, sample1_embeddings
)
cos_sim_with_sample2 = pairwise_cos_sim(
sentence_embeddings, sample2_embeddings
)
cos_sim_with_sample3 = pairwise_cos_sim(
sentence_embeddings, sample3_embeddings
)
cos_sim_mean = (cos_sim_with_sample1 + cos_sim_with_sample2 + cos_sim_with_sample3) / 3
cos_sim_mean = cos_sim_mean.item()
return round(cos_sim_mean,2)
这是函数如何在二维笛卡尔空间中使用一对非常简单的向量的直观解释。A 和 B 是原始向量,而 Â 和 B̂ 是归一化后的向量。
成对余弦相似性计算。图片由作者提供。
从上面的图片中,我们可以看到向量之间的角度大约是 30⁰,所以它们彼此接近。余弦值大约是 0.87。余弦值越接近 1,向量之间就越接近。
cos_sim_score = get_cos_sim(output, [sample1,sample2,sample3])
我们的嵌入输出的 cos_sim_score 平均值为 0.52。
为了理解如何解读这个数字,让我们将其与一些有效输出的余弦相似性分数进行比较,这些输出涉及关于现有人物的信息。
作者提供的图片 — text-davinci-003 在尼古拉斯·凯奇上的提示完成
在这种情况下,成对余弦相似性分数是 0.93。这看起来很有前景,特别是作为一种非常快速的方法来评估输出之间的相似性。
余弦相似性计算的持续时间。图片由作者提供。
SelfCheckGPT- BERTScore
BERTScore 基于我们之前实现的成对余弦相似性思想。
[1]
用于计算上下文嵌入的默认分词器是 RobertaTokenizer。上下文嵌入不同于静态嵌入,因为它们考虑了词汇周围的上下文。例如,“bat” 这个词会根据上下文是指“飞行的哺乳动物”还是“棒球棒”而对应不同的标记值。
def get_bertscore(output, sampled_passages):
# spacy sentence tokenization
sentences = [sent.text.strip() for sent in nlp(output).sents]
selfcheck_bertscore = SelfCheckBERTScore(rescale_with_baseline=True)
sent_scores_bertscore = selfcheck_bertscore.predict(
sentences = sentences, # list of sentences
sampled_passages = sampled_passages, # list of sampled passages
)
df = pd.DataFrame({
'Sentence Number': range(1, len(sent_scores_bertscore) + 1),
'Hallucination Score': sent_scores_bertscore
})
return df
让我们深入了解一下自检 _bert 评分.predict函数。我们没有将完整的原始输出作为参数传递,而是将其拆分为单独的句子。
['Evelyn Hartwell is an American author, speaker, and life coach.',
'She is best known for her book, The Miracle of You: How to Live an Extraordinary Life, which was published in 2007.',
'She is a motivational speaker and has been featured on TV, radio, and in many magazines.',
'She has authored several books, including How to Make an Impact and The Power of Choice.']
这个步骤很重要,因为selfcheck_bertscore.predict*函数计算了每个原始响应中的句子与样本中的每个句子的 BERTScore。首先,它创建一个数组,其中行数等于原始输出中的句子数,列数等于样本的数量。
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
用于计算候选句子和参考句子之间 BERTScore 的模型是具有 17 层的 RoBERTa large。我们的原始输出有 4 个句子,我将其称为 r1、r2、r3 和 r4。第一个样本有两个句子:c1 和 c2。我们计算原始输出中每个句子与第一个样本中每个句子的 F1 BERTScore。然后,我们进行基于基线张量 b = tensor([0.8315, 0.8315, 0.8312])的基础重标定。基线b是通过使用来自 Common Crawl 单语数据集的 100 万对随机句子计算的。他们计算了每对句子的 BERTScore 并取其平均值。这代表了一个下界,因为随机对的语义重叠很少。[1]
F1 BERTScore。图片来源:作者。
我们保留原始响应中每个句子与每个抽样样本中最相似句子的 BERTScore。逻辑是,如果某条信息出现在从相同提示生成的多个样本中,那么该信息很可能是事实。如果某个声明仅出现在一个样本中,而在相同提示的其他样本中没有出现,则更可能是虚构的。
让我们为第一个样本添加数组中的最大相似度:
bertscore_array
array([[0.43343216, 0\. , 0\. ],
[0.12838356, 0\. , 0\. ],
[0.2571277 , 0\. , 0\. ],
[0.21805632, 0\. , 0\. ]])
现在我们对另外两个样本重复这个过程:
array([[0.43343216, 0.34562832, 0.65371764],
[0.12838356, 0.28202596, 0.2576825 ],
[0.2571277 , 0.48610589, 0.2253703 ],
[0.21805632, 0.34698656, 0.28309497]])
然后我们计算每一行的平均值,从而得到原始响应中每个句子与每个后续样本的相似度评分。
array([0.47759271, 0.22269734, 0.32286796, 0.28271262])
每个句子的幻觉评分通过从 1 中减去上述每个值来获得。
伊夫琳·哈特维尔的幻觉评分。图片来源:作者。
将结果与尼古拉斯·凯奇的答案进行比较。
尼古拉斯·凯奇的幻觉评分。图片来源:作者。
看起来很合理;有效输出的幻觉评分较低,而虚构输出的幻觉评分较高。不幸的是,计算 BERTScore 的过程非常耗时,这使得它不适合实时幻觉检测。
BERTScore 计算的持续时间。图片来源:作者。
SelfCheckGPT-NLI
自然语言推断(NLI)涉及确定一个假设是否从给定的前提中逻辑推导出来,或与之矛盾。关系被分类为蕴含、矛盾或中立。对于 SelfCheck-NLI,我们使用经过 MNLI 数据集微调的 DeBERTa-v3-large 模型来执行 NLI。
NLI 流程图 [5]
以下是一些前提—假设对及其标签的示例。
上下文—假设对的示例 [4]
def get_self_check_nli(output, sampled_passages):
# spacy sentence tokenization
sentences = [sent.text.strip() for sent in nlp(output).sents]
selfcheck_nli = SelfCheckNLI(device=mps_device) # set device to 'cuda' if GPU is available
sent_scores_nli = selfcheck_nli.predict(
sentences = sentences, # list of sentences
sampled_passages = sampled_passages, # list of sampled passages
)
df = pd.DataFrame({
'Sentence Number': range(1, len(sent_scores_nli) + 1),
'Probability of Contradiction': sent_scores_nli
})
return df
在 selfcheck_nli.predict 函数中,将原始响应中的每个句子与三个样本中的每一个配对。
logits = model(**inputs).logits # neutral is already removed
probs = torch.softmax(logits, dim=-1)
prob_ = probs[0][1].item() # prob(contradiction)
相对于第一个句子和每个样本的矛盾概率。图像由作者提供。
现在我们对每个四个句子重复这个过程。
Evelyn Hartwell 的 SelfCheck-NLI。图像由作者提供。
我们可以看到模型输出了一个极高的矛盾概率。现在我们与实际输出进行比较。
Nicolas Cage 的 SelfCheck-NLI。图像由作者提供。
模型表现得非常好!不幸的是,NLI 检查时间有点长。
NLI 计算的持续时间。图像由作者提供。
SelfCheckGPT-Prompt
更新的方法已经开始使用 LLMs 自身来评估生成的文本。我们将输出和三个样本一起发送到 gpt-3.5-turbo,而不是使用公式来计算评分。模型将决定原始输出与生成的其他三个样本的相符程度。 [3]
def llm_evaluate(sentences,sampled_passages):
prompt = f"""You will be provided with a text passage \
and your task is to rate the consistency of that text to \
that of the provided context. Your answer must be only \
a number between 0.0 and 1.0 rounded to the nearest two \
decimal places where 0.0 represents no consistency and \
1.0 represents perfect consistency and similarity. \n\n \
Text passage: {sentences}. \n\n \
Context: {sampled_passages[0]} \n\n \
{sampled_passages[1]} \n\n \
{sampled_passages[2]}."""
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content
对于 Evelyn Hartwell,返回的自相似度评分是0。与此同时,与 Nicolas Cage 相关的输出评分是0.95。获取评分所需的时间也相当短。
LLM 自相似度评分计算的持续时间。图像由作者提供。
这似乎是我们案例的最佳解决方案,因为我们从源论文的比较分析中也预期到了这一点 [2]。SelfCheckGPTPrompt 显著优于所有其他方法,NLI 是表现第二好的方法。
幻觉检测评估结果 [6]
评估数据集是通过使用 WikiBio 数据集和 GPT-3 生成合成的维基百科文章创建的。为了避免模糊的概念,从最长的文章的前 20%中随机抽取了 238 个文章主题。GPT-3 被提示以维基百科风格为每个概念生成首段。
评估数据集创建 [5]
接下来,这些生成的段落被人工注释以评估其准确性。根据预定义的指南,每个句子被标记为主要不准确、次要不准确或准确。总共标注了 1908 个句子,其中约 40%为主要不准确,33%为次要不准确,27%为准确。
为了评估标注者的一致性,200 个句子有双重标注。如果标注者达成一致,则使用该标签;否则,选择最坏的标签。通过 Cohen’s kappa 测量的标注者间一致性在选择准确、轻微不准确和主要不准确之间为 0.595,而在将轻微/主要不准确合并为一个标签时为 0.748。
评估指标 AUC-PR 指的是精确度-召回曲线下的面积,这是用于评估分类模型的指标。
实时幻觉检测
作为最终应用,我们来构建一个 Streamlit 应用以实现实时幻觉检测。如前所述,最佳指标是 LLM 自相似性分数。我们将使用 0.5 的阈值来决定是显示生成的输出还是免责声明。
import streamlit as st
import utils
import pandas as pd
# Streamlit app layout
st.title('Anti-Hallucination Chatbot')
# Text input
user_input = st.text_input("Enter your text:")
if user_input:
prompt = user_input
output, sampled_passages = utils.get_output_and_samples(prompt)
# LLM score
self_similarity_score = utils.llm_evaluate(output,sampled_passages)
# Display the output
st.write("**LLM output:**")
if float(self_similarity_score) > 0.5:
st.write(output)
else:
st.write("I'm sorry, but I don't have the specific information required to answer your question accurately. ")
现在,我们可以可视化最终结果。
Streamlit 应用演示。图片由作者提供。
结论
结果非常令人鼓舞!聊天机器人中的幻觉检测一直是一个长期讨论的质量问题。
这里所述技术之所以令人兴奋,是因为采用了使用 LLM 评估其他 LLM 输出的新颖方法。具体做法是生成多个对同一提示的响应并比较其一致性。
仍然需要做更多的工作,但与其依赖人工评估或手工制定规则,不如让模型本身捕捉不一致性,这似乎是一个不错的方向。
. . .
如果你喜欢这篇文章,加入 文本生成 ——我们的新闻通讯每周发布两篇文章,提供有关生成 AI 和大语言模型的最新见解。
你可以在 GitHub上找到这个项目的完整代码。
你也可以在 LinkedIn上找到我。
. . .
参考文献:
-
[广覆盖挑战语料库用于
通过推断进行句子理解](https://aclanthology.org/N18-1101)
-
drive.google.com/file/d/13LUBPUm4y1nlKigZxXHn7Cl2lw5KuGbc/view
-
drive.google.com/file/d/1EzQ3MdmrF0gM-83UV2OQ6_QR1RuvhJ9h/view
现实世界的问题以及数据如何帮助我们解决这些问题
·
关注 发表在 数据科学前沿 ·发送至 通讯 ·阅读时间 3 分钟 ·2023 年 11 月 23 日
--
在不断关注新工具和前沿模型的喧嚣中,很容易忽视一个基本的真理:利用数据的真正价值在于其带来切实的积极变化。无论是在复杂的商业决策还是我们日常的生活中,数据驱动的解决方案只有在现实世界中产生实际影响时才算真正有效。
为了帮助你更有效地将点连接起来,并激励你尝试新的方法,我们收集了一些最近最受欢迎的文章,这些文章具有强烈的问题解决角度。它们涵盖了从战略到极具个人化的广泛应用场景,但都分享了对数据在我们生活中作用的务实和注重细节的观点。祝阅读愉快!
-
通过生存分析探索时间事件在生存分析的一个易于理解的介绍中,奥利维亚·塔努维贾贾涵盖了一些基本概念和技术,并展示了这种方法如何应用于从医疗领域到维护预测和客户分析等各种问题。
-
决策树如何从数据中知道下一个最佳提问是什么?一些机器学习从业者可能认为二分类任务很基础,但即使近年来出现了更复杂的技术,它们的有用性依然保持不变。为了帮助你入门,古尔金德·考尔最近分享了一个适合初学者的决策树入门指南;该指南详细解释了决策树如何在一个训练来预测给定鱼类是金枪鱼还是鲑鱼的模型中运作。
-
我的生活统计:我追踪了一年的习惯,这就是我学到的在他的 TDS 首秀中,保罗·布拉斯科·伊·罗卡展示了一个为期一年的项目,这个项目处于数据分析和所谓的量化自我之间的交汇点。保罗已经跟踪了他 332 天的日常活动,并展示了即使是看似微不足道的数据,也可以从中提取出有意义的见解。
-
客户终生价值建模方法:精华与陷阱对于行业数据科学家而言,计算客户终生价值是一个常见的目标——而且这个目标在你深入挖掘业务操作时会变得越来越复杂。Katherine Munro对 CLV 的全面实用指南提供了对此话题的急需清晰度,并概述了各种建模选项,包括它们各自的优势和局限性。
-
改进 Strava 训练日志如果你是马拉松跑者——即使你不是——你也不想错过barrysmyth的最新深度探讨,他带我们完成了下载、分析和可视化 Strava 训练日志的整个过程。这篇文章特别有帮助,因为它专注于从“这里有很多跑步数据!”到“这是如何利用数据跑得更好”。
无论你是否计划在这个周末切火鸡,我们希望你能找到一些空闲时间去探索我们作者最近探讨的其他几个迷人话题:
-
E-值是否是 p-hacking 的答案?Hennie de Harder继续剖析数据操控的棘手话题。
-
对于加速数据分析师求职过程的实用方法,可以参考Natassha Selvaraj的详细路线图。
-
检索增强生成后的下一步是什么?Gadi Singer分享了对解释性检索中心生成(RCG)模型的全面概述及其克服 RAG 缺陷的潜力。
-
如果你的团队考虑转向无服务器技术,Sheen Brisals对主要概念和注意事项的详细解析是绝对必读的。
-
了解如何在无法随机化处理时估计因果效应:Matteo Courthoud最新的解释文稿在将理论和实际元素结合方面做得非常出色。
感谢你支持我们作者的工作!如果你喜欢在 TDS 上阅读的文章,考虑 成为 Medium 会员 — 这将解锁我们整个档案(以及 Medium 上的所有其他帖子)。
直到下一个变量,
TDS 编辑
RecList 2.0:开源系统化测试 ML 模型
一个新的 RecList 以提供更多灵活性和更好的评估支持
·
关注 发布于 Towards Data Science ·7 分钟阅读·2023 年 8 月 8 日
--
介绍
评估是一个复杂的事项。管理撰写评估管道中涉及的不同组件往往很困难,你需要在某个地方拥有模型,需要加载它,然后获取测试,运行测试等等。
然后呢?嗯,你需要将结果保存到某个地方,也许还需要在线记录输出,以便你可以跟踪它们。
由于这总是一个艰难的过程,我们最近尝试提供一种更结构化的测试方法。在这篇博客文章中,我们介绍并展示了如何使用 RecList beta,我们的开源评估包;RecList 是一种通用的即插即用方法,用于扩展测试,具有易于扩展的自定义用例接口。RecList 是一个自由开放的开源项目,您可以在 GitHub 上自由获取。
RecList 允许你将代码的评估部分分离出来,并封装在一个类中,这个类会自动处理其他几个方面的事情(例如,存储和日志记录)。
RecList 提供了一种简单的方法来系统化测试,并在你训练自己的模型后保存所有需要的信息。
我们在几年前开始开发 RecList,并且 RecList 的 alpha 版本在一年前左右发布。从那时起,RecList 已经收获了超过 400 个 GitHub stars。
我们已经使用 RecList 并进行了压力测试,在 2022 年的 CIKM 上举办了 RecSys 挑战,目前正在准备 2023 年的 KDD 挑战。RecList 使我们能够系统化所有参与者的评估。我们的想法是,一旦每个人都提供相同的 RecList,比较不同的评估将变得简单。我们的经验总结出现在我们的 Nature Machine Intelligence 评论文章中。
RecList 最初是在一篇学术论文中介绍的,但我们也有一个在 Towards Data Science 出版的概述,你可以在这里阅读:
使用 RecList 对推荐系统进行行为测试
towardsdatascience.com
Chia, P. J., Tagliabue, J., Bianchi, F., He, C., & Ko, B. (2022 年 4 月)。超越 nDCG:使用 RecList 对推荐系统进行行为测试。见 Web Conference 2022 附录(第 99–104 页)。
虽然我们最初设计 RecList 是为了推荐系统的测试,但没有什么能阻止你将 RecList 用于测试其他机器学习模型。那么,为什么会有一篇新的博客文章呢?好吧,在开发了第一个版本后,我们意识到它需要一些更新。
我们学到了什么:重新思考 API
通常只有在你构建了某样东西之后,你才会意识到如何改进它。
对于那些使用了 RecList 1.0 的用户,我们对 RecList API 进行了重大更新。最初,我们对代码结构和输入/输出对有更严格的约束。
事实上,当我们实现 RecList 时,我们的目标是提供一个更通用的 API,用于评估推荐系统,并提供了几种开箱即用的功能。然而,为了做到这一点,我们不得不创建多个抽象接口,用户需要实现这些接口。
例如,原始的 RecList 1.0 要求用户将自己的模型和数据集包装到预定义的抽象类中(即 RecModel 和 RecDataset)。这使我们能够实现一组由这些抽象连接的通用行为。然而,我们很快意识到,这可能会经常使流程复杂化,并需要额外的工作,这可能让一些人不喜欢。
在 RecList 2.0 中,我们决定让这些约束成为可选的:我们使测试更加灵活。用户定义自己的评估用例,将其包装在一个便捷的装饰器中,并且他们已经实现了元数据存储和日志记录。用户可以将测试接口分享给其他人,他们可以运行相同的实验。
总结:我们意识到在构建需要其他人使用的软件时,灵活性是多么重要。
RecList 2.0 实践
现在,让我们探索一个如何使用 RecList 来编写和运行评估流程的简单用例。我们将使用非常简单的模型,这些模型随机输出数字,以减少机器学习项目中涉及的复杂性。
一个简单的用例
让我们创建一个非常简单的用例和一个非常简单的数据集。假设我们有一个整数目标序列,每个都有一个关联的类别。我们只是生成一些随机数据。
n = 10000
target = [randint(0, 1) for _ in range(n)]
metadata = {"categories": [choice(["red", "blue", "yellow"])
for _ in range(n)]}
我们的非常简单的数据集应该看起来像这样:
>>> target
[0, 1, 0, 1, 1, 0]
>>> metadata
{"categories" : ["red", "blue", "yellow", "blue", "yellow", "yellow"]}
一个简单的模型
现在假设我们有一个 DummyModel,它会随机输出整数。当然,正如我们所说,这不是一个“好”的模型,但它是一个很好的抽象,可以用来查看整个评估流程。
class DummyModel:
def __init__(self, n):
self.n = n
def predict(self):
from random import randint
return [randint(0, 1) for _ in range(self.n)]
simple_model = DummyModel(n)
# let's run some predictions
predictions = simple_model.predict()
现在,我们如何运行评估?
一个简单的 RecList
RecList 是一个 Python 类,继承了我们 RecList 抽象类的功能。RecList 实现了 RecTests,这是一些简单的抽象,允许你系统化评估。例如,这可能是一个可能的准确性测试。
@rec_test(test_type="Accuracy", display_type=CHART_TYPE.SCALAR)
def accuracy(self):
"""
Compute the accuracy
"""
from sklearn.metrics import accuracy_score
return accuracy_score(self.target, self.predictions)
我们正在采用 sklearn 准确性指标,并将其包装到另一个方法中。这与简单的准确性函数有何不同?好吧,装饰器允许我们引入一些额外的功能:例如,rectest 现在会自动将信息存储到本地文件夹中。此外,定义一种图表类型使我们能够为这些结果创建一些可视化。
如果我们想要更复杂的测试会怎样?例如,我们想要查看在不同类别中我们的准确性有多稳定(例如,计算红色物体的准确性是否高于黄色物体的准确性?)
@rec_test(test_type="SlicedAccuracy", display_type=CHART_TYPE.SCALAR)
def sliced_accuracy_deviation(self):
"""
Compute the accuracy by slice
"""
from reclist.metrics.standard_metrics import accuracy_per_slice
return accuracy_per_slice(
self.target, self.predictions, self.metadata["categories"])
现在让我们看看一个完整的 RecList 示例!
class BasicRecList(RecList):
def __init__(self, target, metadata, predictions, model_name, **kwargs):
super().__init__(model_name, **kwargs)
self.target = target
self.metadata = metadata
self.predictions = predictions
@rec_test(test_type="SlicedAccuracy", display_type=CHART_TYPE.SCALAR)
def sliced_accuracy_deviation(self):
"""
Compute the accuracy by slice
"""
from reclist.metrics.standard_metrics import accuracy_per_slice
return accuracy_per_slice(
self.target, self.predictions, self.metadata["categories"]
)
@rec_test(test_type="Accuracy", display_type=CHART_TYPE.SCALAR)
def accuracy(self):
"""
Compute the accuracy
"""
from sklearn.metrics import accuracy_score
return accuracy_score(self.target, self.predictions)
@rec_test(test_type="AccuracyByCountry", display_type=CHART_TYPE.BARS)
def accuracy_by_country(self):
"""
Compute the accuracy by country
"""
# TODO: note that is a static test,
# used to showcase the bin display
from random import randint
return {"US": randint(0, 100),
"CA": randint(0, 100),
"FR": randint(0, 100)}
几行代码即可将我们需要的内容集中在一个地方。我们可以重用这段代码用于新的模型,或添加测试并重新运行过去的模型。
只要你的指标返回了一些值,你就可以以任何你喜欢的方式实现它们。例如,这个 BasicRecList 在特定的上下文中评估特定的模型。但没有什么能阻止你生成更多模型特定的 reclists(例如,GPT-RecList)或数据集特定的 reclists(例如,IMDB-Reclist)。如果你想查看 RecList 上深度模型的示例,可以 查看这个 Colab。
运行并获取输出
让我们运行 RecList。我们需要目标数据、元数据和预测。我们还可以指定一个日志记录器和一个元数据存储。
rlist = BasicRecList(
target=target,
metadata=metadata,
predictions=predictions,
model_name="myRandomModel",
)
# run reclist
rlist(verbose=True)
这个过程的输出是什么?我们将在命令行中看到以下结果:对于每个测试,我们都有一个实际的得分。
指标也会自动绘制。例如,AccuracyByCountry 应该显示如下内容:
RecTest 生成的图示示例。
除此之外,RecList 还会保存一个 JSON 文件,其中包含我们刚刚运行的实验的所有信息:
{
"metadata": {
"model_name": "myRandomModel",
"reclist": "BasicRecList",
"tests": [
"sliced_accuracy",
"accuracy",
"accuracy_by_country"
]
},
"data": [
{
"name": "SlicedAccuracy",
"description": "Compute the accuracy by slice",
"result": 0.00107123176804103,
"display_type": "CHART_TYPE.SCALAR"
},
...
}
好的一点是,只需几行额外的代码,大部分日志记录工作就会自动处理!
使用在线日志记录器和元数据存储
默认情况下,RecList 运行器将使用以下日志记录器和元数据设置。
logger=LOGGER.LOCAL,
metadata_store= METADATA_STORE.LOCAL,
然而,没有什么阻止我们使用在线和云解决方案。例如,我们将 CometML 和 Neptune API 封装起来,以便你可以直接在评估管道中使用它们。我们还提供 S3 数据存储的支持。
例如,向 BasicRecList 添加几个参数将允许我们在 Neptune 上记录信息(我们对 Comet.ml 也提供类似支持)!
rlist = BasicRecList(
target=target,
model_name="myRandomModel",
predictions=predictions,
metadata=metadata,
logger=LOGGER.NEPTUNE,
metadata_store= METADATA_STORE.LOCAL,
NEPTUNE_KEY=os.environ["NEPTUNE_KEY"],
NEPTUNE_PROJECT_NAME=os.environ["NEPTUNE_PROJECT_NAME"],
)
# run reclist
rlist(verbose=True)
以类似的方式,添加以下内容:
bucket=os.environ["S3_BUCKET"]
将允许我们使用 S3 桶来存储元数据(当然,你还需要设置一些环境密钥)。
结论
就这些!我们创建 RecList 是为了使推荐系统的评估更加系统化和有序。我们希望这一大规模的 API 重构能帮助人们构建更可靠的评估管道!
致谢
在 2022 年 6 月至 12 月期间,我们的 beta 版开发得到了 Comet、Neptune、Gantry 的出色支持,并在 Unnati Patel 的帮助下完成。
推荐系统:基于矩阵分解的协同过滤
原文:
towardsdatascience.com/recommendation-system-with-matrix-factorization-ebc4736869e4
通过矩阵分解解释推荐
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 4 月 26 日
--
图片来源:freestocks 于 Unsplash
Netflix 是一个受欢迎的在线流媒体平台,为其订阅者提供了各种电影、纪录片和电视节目。为了提升用户体验,Netflix 开发了一个复杂的推荐系统,根据你的观看历史、评分和偏好来推荐电影。
推荐系统使用复杂的算法分析大量数据,以预测用户最可能喜欢的内容。Netflix 的推荐系统是其成功的关键因素之一,并为流媒体行业设立了标准,全球拥有超过 2 亿用户。以下是 Netflix 如何通过个性化实现 80%观看时间的来源 link。
那么,什么是推荐系统?
推荐系统是一种无监督学习,通过信息过滤向用户推荐产品或内容,基于他们的偏好、兴趣和行为。这些系统广泛应用于电子商务和在线流媒体设置,以及其他应用程序中,以帮助发现可能感兴趣的新产品和内容。
推荐系统经过训练,以了解用户 和 产品的偏好、过去的决策和特征,使用关于用户-产品互动的数据收集。
推荐系统有两种类型,如下:
基于内容的过滤
推荐是基于用户或物品属性作为算法的输入。然后,使用共享属性空间的内容来创建用户和物品档案。
例如,蜘蛛侠:英雄无归 和 蚁人和黄蜂女:量子狂潮 有类似的属性,因为这两部电影都属于动作/冒险类型。不仅如此,它们都是漫威的一部分。因此,如果爱丽丝观看了蜘蛛侠电影,基于内容的推荐系统可能会推荐具有类似属性的电影,比如动作/漫威电影。
协同过滤
基于多个具有类似历史互动的用户。这种方法的关键思想是利用协作概念来生成新的推荐。
例如,爱丽丝和鲍勃对特定电影类型有类似兴趣。协同过滤推荐系统可能会向爱丽丝推荐鲍勃之前看过但爱丽丝没看过的项目,因为他们的偏好非常相似。鲍勃也适用相同的情况。
更多关于协同过滤的信息
推荐系统模型类型范围广泛,如下图所示,但今天本文将重点讨论协同过滤 (CF) 与矩阵分解。
推荐系统类型 - 作者插图
基于矩阵分解的协同过滤
简言之,矩阵分解是一种数学过程,将复杂的矩阵转换为低维空间。在推荐系统中,最流行的矩阵分解技术包括奇异值分解 (SVD)、非负矩阵分解 (NMF) 和概率矩阵分解。
以下是矩阵分解概念如何预测用户-电影评分的示意图
作者插图
阶段 1: 矩阵分解会随机初始化数字,并设置因子数量 (K)。在此示例中,我们将 K 设为 5。
-
用户矩阵(绿色框)表示每个用户与特征之间的关联
-
项目矩阵(橙色框)表示每个项目与特征之间的关联
在这里,例如,我们创建了 5 个特征 (k=5) 来表示 m_1 电影的特性:喜剧 为 2.10,恐怖 为 0.88,动作 为 0.04,家长指南为 0.02,适合家庭观看为 0.04。用户矩阵的情况也是类似的。用户矩阵表示用户的特征,如喜欢的演员或导演、最喜欢的电影制作等。
阶段 2: 评分预测是通过 用户矩阵 和 项目矩阵 的点积计算的。
其中 R 为真实评分,P 为用户矩阵,Q 为项目矩阵,结果 R’ 为预测评分。
作者插图
在更好的数学符号表示中,预测评分 R’ 可以在以下方程中表示:
阶段 3: 使用平方误差来计算真实评分与预测评分之间的差异
一旦这些步骤到位,我们可以使用随机梯度下降来优化我们的参数。它将计算此值的导数。
在每次迭代中,优化器将通过使用点积计算每部电影和每个用户之间的匹配,然后将其与用户给出的实际评分进行比较。然后,它将计算此值的导数,并通过将其乘以学习率⍺来更新权重。随着我们多次重复这一过程,损失将得到改善,从而提供更好的推荐。
广泛用于推荐系统的矩阵分解模型之一是奇异值分解(SVD)。SVD 本身具有广泛的应用,包括图像压缩和信号处理中的噪声减少。此外,SVD 通常用于推荐系统,因为它擅长解决大型用户-项目矩阵中固有的稀疏性问题。
本文还将提供SVD 实现的概述,使用 Surprise 包。
所以让我们开始实现吧!!
实现内容
-
数据导入
-
数据预处理
-
实现#1:从头开始用 Python 实现矩阵分解
-
实现#2:使用 Surprise 包的矩阵分解
完整的矩阵分解实现笔记本可以在这里找到。
数据导入
由于我们正在开发一个类似于 Netflix 的推荐系统,但可能没有访问其大数据的权限,我们将使用来自MovieLens的优秀数据集进行实践[1]已获许可。此外,你可以阅读并查看他们的README文件以了解使用许可和其他细节。该数据集包含数百万部电影、用户和用户过去的互动评分。
提取 zip 文件后,将会提供 4 个 csv 文件,如下所示:
数据快照 - 作者提供的图片
数据预处理
顺便提一下,协同过滤有 用户冷启动 的问题。冷启动问题指的是系统或算法无法对没有先前信息的新用户、项目或实体做出准确的预测或推荐。这可能发生在对新用户或项目几乎没有历史数据时,使得系统难以了解他们的偏好或特征。
冷启动问题是推荐系统中的一个常见挑战,系统需要为与用户交互历史有限或没有交互历史的用户提供个性化推荐。
在这个阶段,我们将选择与至少 2000 部电影互动过的用户以及被 1000 用户评分的电影(这可以是减少数据规模和空数据的一种好方法。此外,我的 RAM 无法处理大规模表格)
我的 RAM 状况 - 来源:KC Green 的 2013 年网页漫画
实际上,你也可以使用 MovieLens 提供的 100k 评分的小子集。我只是想尽可能优化我的计算机资源,减少空数据。
数据预处理后的输出 - 作者图片
按照惯例,我们将数据分为两组:训练集和测试集——通过使用 train_test_split 方法。
尽管我们需要的信息存在,但它的呈现方式并不利于人类理解。不过,我创建了一个表格,以一种更容易理解的格式呈现相同的数据。
原始数据 - 作者图片
实现 #1:从头开始用 Python 实现矩阵分解
这里是实现矩阵分解与梯度下降的 Python 代码片段。matrix_factorization
函数返回两个矩阵:nP (用户矩阵) 和 nQ (项目矩阵)。
然后,将训练数据集拟合到模型中,这里我设置了 n_factor K = 5。接下来,可以通过 使用点积方法将 nP 与 nQ 的转置相乘 来计算预测,如下代码片段所示。
结果是,矩阵分解产生了最终的预测
训练集中的新预测评分 - 作者图片
测试集上的预测
以下代码片段利用给定的 nP (用户矩阵) 和 nQ (电影矩阵) 对测试集进行预测
测试集的评分和预测评分输出 - 作者图片
评估预测性能
尽管推荐系统有多种评估指标,如 Precision@K、Recall@K、MAP@K 等,但在这个练习中,我将使用一个基本的准确度指标,即 RMSE。我可能会在后续的文章中更详细地描述其他评估指标。
结果显示,测试集上的 RMSE 为0.829,即使在超参数调整之前也相当不错。显然,我们可以调整一些参数,如学习率、n_factor、epochs 步数,以获得更好的结果。
实现 #2: 使用 Surprise 包的矩阵分解
在这一部分,我们选择了名为surprise package的 Python 库。一个surprise package是用于构建和评估推荐系统的 Python 库。它提供了一个简单易用的接口,用于加载和处理数据集,以及实现和评估不同的推荐算法。
数据导入和模型训练
Top-N 推荐生成器
对于 UserId: 231832
,以下是前 10 个电影推荐列表:
m_912, m_260, m_1198, m_110, m_60069, m_1172, m_919, m_2324, m_1204, m_3095
前 10 推荐结果 - 图片由作者提供
总结
在现代娱乐如 Netflix 中使用矩阵分解有助于理解用户偏好。这些信息随后用于向最终用户推荐最相关的项目/产品/电影。
这是我创建的矩阵分解示例的总结,以备将来需要向我的孙子们解释时使用……
图片由作者提供
参考
[1] Maxwell Harper 和 Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Transactions on Interactive Intelligent Systems (TiiS) 5, 4: 19:1–19:19. https://doi.org/10.1145/2827872
## Collaborative Filtering Deep Dive
使用 Kaggle Notebooks 探索和运行机器学习代码 | 无附加数据源
## Recommender Systems in Python 101
使用 Kaggle Notebooks 探索和运行机器学习代码 | 使用来自 CI&T 的文章分享和阅读数据……
现在的网络上信息量过大,这点可能无需多言。搜索引擎帮助我们获取…
使用 TensorFlow 推荐系统的隐式反馈推荐系统
推荐系统
当客户没有明确告诉你他们想要什么时
·发表于 Towards Data Science ·11 分钟阅读·2023 年 11 月 5 日
--
图片由Noom Peerapong提供,来源于Unsplash
提供推荐实际上并没有那么难。你只需要检查你的客户如何评分你的产品,例如使用 1 到 5 颗星,然后在其基础上训练一个回归模型,对吗?
一个典型的数据集,你会希望拥有的。图片由作者提供。
好吧,如果我们没有任何数值的用户或电影特征,我们可能需要处理嵌入,但我们已经在我之前的文章中见过如何做到这一点。
学习在 TensorFlow 中构建一个简单的矩阵分解推荐系统
towardsdatascience.com
在本文中我们也需要嵌入,因此我建议在继续之前阅读上面的文章。
隐式反馈
然而,有时候我们没有显式的用户反馈,例如星级、点赞或点踩,或类似的反馈。这在零售中很常见,我们知道哪个客户购买了哪个商品,但不知道他们是否真的喜欢它。我们从客户那里得到的唯一信息是关于他们对该产品兴趣的隐式信号。
如果他们购买了(观看、消费……)这个产品,他们就表现出兴趣。如果没有,他们可能不感兴趣,但也可能只是还不知道。我们无法确定。
这听起来像是我们可以将其视为分类问题。感兴趣 = 1,不感兴趣 = 0。然而,这里有一个小问题,我们不能确定 0(不感兴趣)是否真的是零。也可能是客户只是从未有机会购买,但实际上是想要的。
让我们回到电影,假设我们没有任何评分。我们只知道哪个用户观看了哪个电影。
比如说,Alice 观看了 Gaußzilla 和 The Markov Chainsaw Massacre。图片由作者提供。
从这里我们可以有至少两种方法继续。
-
只需将所有缺失值视为零,然后训练一个二分类器。
-
使用成对损失函数来确保用户与他们观看过的电影之间的相似性高于同一用户与他们没有观看过的电影之间的相似性。
将所有缺失值视为零
这是最简单的解决方案。从上面的不完整表格,你可以创建以下数据集:
注意: A = Alice, B = Bob, C = Charlie, G = Gauß, E = Euler, M = Markov
图片由作者提供。
你可以将观看列解释为用户是否对电影感兴趣的标签。从这个表格中,你可以推断出,例如,用户A喜欢电影G,但不喜欢电影E,这在数据上是一个大胆的声明。也许A还不知道E。或者更糟的是,它实际上在A的观看列表中,但没有时间观看。
这个方法在技术方面的问题是,模型学会对几乎任何(用户,电影)输入都返回 0,因为大多数观看值通常为零。想象一下你有一个包含 1,000,000 个用户和 100,000 部电影的数据集。平均每个用户观看多少部不同的电影?也许是 1000 部?那你1%的观看标签是 1。因此,你有一个严重不平衡的数据集,这本身并不是坏事。然而,由于我们人工创建这些零值,可能会导致性能较差。
一个计算问题是这个数据集变得非常庞大。1,000,000 个用户乘以 100,000 部电影意味着你有一个包含 100,000,000,000 行的数据集。并且通常,你的数据库中还有更多的电影和项目。在这种情况下,你不会将所有的零目标行都放入数据集中,而是进行子采样,也称为负采样。例如,如果你有 1,000,000,000 行目标为 1 的数据(= 发生的交易),你也可以子采样 1,000,000,000 个负样本(= 从未发生的交易)。这样你就有了一个可以训练的良好数据集。
这种方法有效,但通常不是最优的,因为你把问题变得比实际需要的更复杂。你不需要完美预测观看标签。你只想对每个用户对电影进行排序,即你想能够说“用户A更喜欢电影G而不是电影E”。第二种方法正好满足这个需求。
使用成对损失函数
在这种方法中,我们不会告诉模型用户是否喜欢某个特定的电影。我们更谨慎地表述:
如果用户 A 观看了电影 G,但没有观看电影 E,我们仅仅说 A 对 G 的兴趣大于对 E 的兴趣。
这使我们能够解决一个更简单的目标。现在,让我们开始一些公式,以便更好地理解这种直觉如何转化为算法。
我们将再次训练一个处理嵌入的模型。假设我们有用户A、电影G和电影E的嵌入。如果A观看了G,但没有观看E,我们仅希望得到
图片由作者提供。
其中的e是嵌入,·是点积。这意味着对于用户A来说,电影G在某种程度上是更好的,优于电影E。但这比说“A喜欢G但不喜欢E”要温和,因为后者是二元分类的情况。
训练这样的模型听起来比训练二元分类器复杂得多,但有几个库可以帮助我们。我将展示如何使用TensorFlow Recommenders,因为这是我知道的最灵活的库。另一个值得一提的是implicit,它易于使用但不够灵活。
使用 TensorFlow Recommenders 训练
我们现在将看到将之前描述的逻辑放入代码中是多么简单。为了开诚相见,我遵循了官方 TFRS 网站的指南。我只是试图让它更简洁。
准备和数据生成
首先,让我们做一个
pip install tensorflow tensorflow-recommenders tensorflow-datasets
然后我们可以通过以下方式加载一些数据
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
import tensorflow as tf
ratings = (
tfds.load("movielens/100k-ratings", split="train")
.map(lambda x: {
"movie_title": x["movie_title"],
"user_id": x["user_id"],
})
.shuffle(10000)
)
ratings_df = tfds.as_dataframe(ratings)
ratings
是一个 TensorFlow 数据集,处理起来总是有点繁琐。然而,对于大数据集的内存高效训练,你必须使用它。但对于我们的小示例,我尽量停留在友好的数据框世界中,因此将数据集转换为数据框ratings_df
。数据如下:
图片由作者提供。
模型定义
我们将构建一个有两个部分的模型:
-
一个用户模型
-
一个电影模型
这些模型应分别接收用户或电影,并将其转换为嵌入,即一组浮点数。
embedding_dimension = 32
user_model = tf.keras.Sequential([
tf.keras.layers.StringLookup(vocabulary=ratings_df["user_id"].unique()),
tf.keras.layers.Embedding(ratings_df["user_id"].nunique() + 1, embedding_dimension)
])
movie_model = tf.keras.Sequential([
tf.keras.layers.StringLookup(vocabulary=ratings_df["movie_title"].unique()),
tf.keras.layers.Embedding(ratings_df["movie_title"].nunique() + 1, embedding_dimension)
])
使用这两个组件,我们可以这样定义完整的模型:
class MovielensModel(tfrs.Model):
def __init__(self, user_model, movie_model, task):
super().__init__()
self.movie_model = movie_model
self.user_model = user_model
self.task = task
def compute_loss(self, features, training=False):
user_embeddings = self.user_model(features["user_id"])
positive_movie_embeddings = self.movie_model(features["movie_title"])
return self.task(user_embeddings, positive_movie_embeddings)
我们可以看到模型由user_model
和movie_model
组成。我们将这个task
属性设置为检索任务,这正是我们所需要的实现。还有另一种任务类型,叫做排名任务,当你有明确的反馈如评分时可以使用。本文不会进一步探讨这个。
你还可以看到一些损失值被计算出来。输入是一个名为features
的字典,它应该看起来像这样:
features = {
"user_id": ["A", "B", "C"],
"movie_title": ["G", "E", "M"],
}
它包含了一些用户 ID 以及一系列电影标题。在这个例子中,用户A观看了G,用户B观看了E,用户C观看了M。我们这里只有正例,即过去发生的电影会话。
用户和电影被转化为嵌入,然后计算一些损失。我稍后会详细说明,但请放心,它正在做我们希望它做的事。
模型的架构如同我在其他文章中描述的那样:
图片由作者提供。
训练模型
最后,我们可以使用一个很好的 TFRS 预测类,但为了使其正常工作,我们需要一个独特的电影列表作为 TensorFlow 数据集。
# a TensorFlow dataset
unique_movies = tf.data.Dataset.from_tensor_slices(ratings_df["movie_title"].unique())
使用这个数据集,我们可以定义之前提到的任务:
task = tfrs.tasks.Retrieval()
现在,我们可以训练模型了!
model = MovielensModel(user_model, movie_model, task)
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))
model.fit(ratings.batch(10000).cache(), epochs=5)
你最终会看到类似这样的结果:
图片由作者提供。
很难给loss
赋予具体的意义,但越小越好。我会进一步详细说明,但让我们先用我们的模型预测一些电影!
预测时间
首先,你需要定义一个叫做索引的东西。然后你可以使用这个索引来获取预测结果。
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(
tf.data.Dataset.zip((unique_movies.batch(100), unique_movies.batch(100).map(model.movie_model)))
)
注意: 你不再使用model
了。你只是用它来调整user_model
和movie_model
的参数,现在直接使用这两个子模型。此时你基本上可以抛弃 TFRS 模型model
。
现在,你可以将用户传递给index
。索引将会
-
将这个用户 ID 转换为嵌入 —— 这是因为你传递了它给
user_model
—— 然后 -
计算每部电影的嵌入 —— 这是通过执行
index_from_dataset
函数实现的 —— 然后 -
输出与用户嵌入最接近的电影标题。
注意: 它在后台进行的是精确的最近邻搜索,这可能会很慢。它还支持使用ScaNN进行近似最近邻搜索。你可以通过输入ScaNN
代替BruteForce
来使用它。
它的工作原理如下:
_, titles = index(tf.constant(["99"]))
print(f"Recommendations for user 99: {titles[0, :3]}")
# Output:
# Recommendations for user 99: [b'Sunset Park (1996)' b'Happy Gilmore (1996)' b'High School High (1996)']
很好!你现在已经准备好使用这个模型了。
收尾工作
还有一些我承诺要深入探讨的内容:这个task
和它输出的损失。内部发生的事情文档(尚未)非常好,但我查看了源代码以了解发生了什么。你可以在这里找到我提到的源代码。
我将通过一个小例子批次来解释。
进入模型的批量。图片由作者提供。
它进入模型,然后使用user_model
和movie_model
创建每个用户和电影的嵌入。这些嵌入进入检索任务对象。
在task
中,所有用户嵌入与所有电影嵌入进行相乘(点积)。这可以通过简单的矩阵乘法完成,其中电影矩阵首先转置。假设我们使用二维嵌入来节省空间。
图片由作者提供。
矩阵乘积是
图片由作者提供。
现在,推理过程如下:从数据中我们知道
-
爱丽丝观看了高斯,
-
鲍勃观看了欧拉,并且
-
查理观看了马尔可夫。
这就是为什么我们希望在单元格中的对应数字(A,G),(B,E),(C,M)——这就是主对角线——具有最高的数字。在这个小例子中,我们差得很远。
为了量化这一点,他们进行另一步:TFRS 的作者进行逐行 softmax。
经过逐行 softmax 处理。注意每行的和是 1。图片由作者提供。
现在,观察到的是:如果前一个矩阵中主对角线上的元素远高于其他数字,那么“softmax 处理过”的矩阵接近于单位矩阵。
最优的单位矩阵。图片由作者提供。
这是因为如果你对一个数组进行 softmax 处理,如果一个数字远高于其他数字,这个数字将接近 1。所以其他数字必须接近零。可以尝试一下:
x = np.array([1, 2, 10])
np.exp(x) / np.exp(x).sum() # softmax
# Output:
# array([1.23353201e-04, 3.35308764e-04, 9.99541338e-01])
所以,损失来自于将上面的矩阵与单位矩阵进行比较。具体来说,使用的是分类交叉熵损失。但是不是平均值,而是总和,见此处。这就是为什么损失值总是这么高的原因。批量越大,损失会越大。所以不要对损失突然变得非常低感到困惑,仅仅因为你把批量大小从 10,000 改成 1,000 或类似的数值。
结论
在这篇文章中,我们学习了如何利用隐式反馈数据构建推荐系统。为此,我们使用了 TensorFlow Recommenders,因为它扩展性强且非常表达清晰:你可以使用任何子模型——只要它输出一个嵌入——然后将它们组合在一起,使用 tfrs.Model
类进行联合训练。
经过训练后,你可以使用一个方便的类来进行实际的预测。如果你使用 ScaNN,这应该会非常快,但如果你需要更强大的搜索功能,你可以使用像 Qdrant 这样的专用向量数据库。你将训练模型中的用户和电影嵌入提供给它,它会为你进行搜索。
我们还查看了库的内部结构,以了解要最小化的损失来自哪里,因此这个库不再是纯粹的魔法。
如果你想了解如何评估隐式反馈推荐系统的质量,请参考我的另一篇文章:
离线评估推荐系统可能会很棘手
towardsdatascience.com
我希望你今天学到了一些新的、有趣的和有价值的东西。感谢你的阅读!
如果你有任何问题,可以在 LinkedIn上写信给我!
如果你想更深入地了解算法的世界,可以尝试我的新出版物《全面了解算法》!我还在寻找作者!
从直观的解释到深入的分析,算法通过示例、代码和精彩的内容变得生动起来…
medium.com](https://medium.com/all-about-algorithms?source=post_page-----8ba36a976c57--------------------------------)
重新创建 Andrej Karpathy 的周末项目 — 电影搜索引擎
使用 OpenAI 嵌入和向量数据库构建电影推荐系统
·
关注 发布于 Towards Data Science ·9 分钟阅读·2023 年 11 月 7 日
--
最终电影推荐演示 的风格化截图(图片由作者提供)
在 2023 年 4 月,Andrej Karpathy,OpenAI 的创始成员之一及前特斯拉人工智能总监,分享了这个有趣的周末项目,一个 电影搜索和推荐引擎:
用户界面非常简单,主要有两个关键功能。首先,你有一个搜索框,可以通过电影标题进行搜索。然后点击任何电影,你会得到一份该电影的 40 部最相似电影的推荐列表。
演示网站:awesome-movies.life/
尽管很受欢迎,Karpathy 不幸地没有公开分享该项目的源代码。
原推文下的 评论截图(截图由作者提供)
所以,拿上一些爆米花,让我们自己动手重建吧!
前提条件
该项目基于四个主要组件:
-
OpenAI 嵌入模型生成嵌入
-
使用 Weaviate 向量数据库 来存储嵌入,数据通过 Python 脚本填充
-
前端:HTML、CSS、Js
-
后端:NodeJs
因此,要跟随本教程,你需要以下内容:
-
Python 用于数据处理和填充向量数据库
-
Docker 和 Docker-Compose 用于在本地运行向量数据库。
-
Node.js 和 npm 用于在本地运行应用程序。
-
OpenAI API 密钥 用于访问 OpenAI 嵌入模型
实现一个电影搜索引擎
本节分析 Karpathy 的周末黑客活动,并旨在以一些小变化重建它。要构建一个简单的电影搜索引擎,请按照以下步骤操作:
-
准备工作:电影数据集
-
步骤 1:生成和存储嵌入
-
步骤 2:搜索电影
-
步骤 3:获取类似的电影推荐
-
步骤 4:运行演示
完整代码是开源的,你可以在 GitHub 上找到它。
准备工作:电影数据集
Karpathy 的项目索引了自 1970 年以来的 11,762 部电影,包括来自 Wikipedia 的情节和摘要。
若想在不手动抓取 Wikipedia 的情况下实现类似功能,你可以使用以下两个来自 Kaggle 的数据集:
-
48,000+ 电影数据集(许可:CC0: 公共领域)包括
'id'
、'name'
、'PosterLink'
、'Genres'
、'Actors'
、'Director'
、'Description'
、'Keywords'
和'DatePublished'
列。 -
Wikipedia 电影情节(许可:CC BY-SA 4.0),用于
'plot'
列。
两个数据集在电影标题和发行年份上合并,然后筛选出 1970 年后发行的电影。你可以在[add_data.py](https://github.com/weaviate-tutorials/awesome-moviate/blob/main/add_data.py)
文件中找到详细的预处理步骤。结果数据框包含大约 35,000 部电影,其中约 8,500 部电影有情节描述,数据框如下所示:
预处理后的电影数据框(截图由作者提供)
步骤 1:生成并存储嵌入
这个演示项目的核心是电影数据对象的嵌入,这些嵌入主要用于通过情节相似度推荐电影。在 Karpathy 的项目中,为电影摘要和情节生成了向量嵌入。生成向量嵌入有两个选项:
-
术语频率-逆文档频率(TF-IDF),这是一种简单的二元组,应该用于单个词的使用。
-
[text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
OpenAI 的嵌入模型,应该用于语义相似度。
此外,相似度是基于每部电影的维基百科摘要和情节计算的,具有两种相似度排序选择:
-
k-最近邻(kNN)使用余弦相似度
-
支持向量机
Karpathy 建议将[text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
和 kNN 结合使用,以获得一个好的快速默认设置。
最后但同样重要的是,如这条臭名昭著的回应所述,向量嵌入存储在np.array
中:
原推文下的评论的截图(截图由作者提供)
在这个项目中,我们还将使用来自 OpenAI 的[text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
嵌入模型,但将向量嵌入存储在向量数据库中。
也就是说,我们将使用Weaviate*,一个开源的向量数据库。虽然我可以争论说,向量数据库比将嵌入存储在np.array
中要快得多,因为它们使用向量索引,但坦率地说:在这个规模(数千个)下,你不会注意到速度上的任何差异。我使用向量数据库的主要原因是 Weaviate 有许多方便的内置功能可以立即使用,比如使用嵌入模型的自动向量化。
首先,如[add_data.py](https://github.com/weaviate-tutorials/awesome-moviate/blob/main/add_data.py)
文件所示,你需要设置你的 Weaviate 客户端,它连接到本地的 Weaviate 数据库实例,如下所示。此外,你将定义你的 OpenAI API 密钥,以启用集成 OpenAI 模块的使用。
# pip weaviate-client
import weaviate
import os
openai_key = os.environ.get("OPENAI_API_KEY", "")
# Setting up client
client = weaviate.Client(
url = "http://localhost:8080",
additional_headers={
"X-OpenAI-Api-Key": openai_key,
})
接下来,你将定义一个名为Movies
的数据集合来存储电影数据对象,这类似于在关系数据库中创建一个表。在这一步中,你定义text2vec-openai
模块作为向量化器,它允许在导入和查询时自动进行数据向量化,并在模块设置中定义使用[text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
嵌入模型。此外,你可以定义余弦距离作为相似性度量。
movie_class_schema = {
"class": "Movies",
"description": "A collection of movies since 1970.",
"vectorizer": "text2vec-openai",
"moduleConfig": {
"text2vec-openai": {
"vectorizeClassName": False,
"model": "ada",
"modelVersion": "002",
"type": "text"
},
},
"vectorIndexConfig": {"distance" : "cosine"},
}
接下来,你定义电影数据对象的属性以及为哪些属性生成向量嵌入。在下面的简化代码片段中,你可以看到,对于属性movie_id
和title
,没有生成向量嵌入,因为向量化模块的设置为"skip" : True
。这是因为,我们只希望为description
和plot
生成向量嵌入。
movie_class_schema["properties"] = [
{
"name": "movie_id",
"dataType": ["number"],
"description": "The id of the movie",
"moduleConfig": {
"text2vec-openai": {
"skip" : True,
"vectorizePropertyName" : False
}
}
},
{
"name": "title",
"dataType": ["text"],
"description": "The name of the movie",
"moduleConfig": {
"text2vec-openai": {
"skip" : True,
"vectorizePropertyName" : False
}
}
},
# shortened for brevity ...
{
"name": "description",
"dataType": ["text"],
"description": "overview of the movie",
},
{
"name": "Plot",
"dataType": ["text"],
"description": "Plot of the movie from Wikipedia",
},
]
# Create class
client.schema.create_class(movie_class_schema)
最后,你定义一个批处理过程来填充向量数据库:
# Configure batch process - for faster imports
client.batch.configure(batch_size=10)
# Importing the data
for i in range(len(df)):
item = df.iloc[i]
movie_object = {
'movie_id':float(item['id']),
'title': str(item['Name']).lower(),
# shortened for brevity ...
'description':str(item['Description']),
'plot': str(item['Plot']),
}
client.batch.add_data_object(movie_object, "Movies")
第二步:搜索电影
在 Karpathy 的项目中,搜索栏是一个简单的基于关键字的搜索,尝试逐字匹配你的确切查询与电影标题。当一些人表示他们希望搜索能够进行语义搜索时,Karpathy 同意这可能是项目的一个很好的扩展:
原始推文下的评论截图(作者截图)
在这个项目中,你将在[queries.js](https://github.com/weaviate-tutorials/awesome-moviate/blob/main/queries.js)
文件中启用三种类型的搜索:
每个搜索都会返回num_movies = 20
部电影,属性为['title', 'poster_link', 'genres', 'year', 'director', 'movie_id']
。
要启用基于关键字的搜索,你可以使用.withBm25()
搜索查询,涉及的属性为['title', 'director', 'genres', 'actors', 'keywords', 'description', 'plot']
。你可以通过指定'title³'
来给属性'title'
更大的权重。
async function get_keyword_results(text) {
let data = await client.graphql
.get()
.withClassName('Movies')
.withBm25({query: text,
properties: ['title³', 'director', 'genres', 'actors', 'keywords', 'description', 'plot'],
})
.withFields(['title', 'poster_link', 'genres', 'year', 'director', 'movie_id'])
.withLimit(num_movies)
.do()
.then(info => {
return info
})
.catch(err => {
console.error(err)
})
return data;
}
要启用语义搜索,你可以使用.withNearText()
搜索查询。这将自动将搜索查询向量化,并检索其在向量空间中最接近的电影。
async function get_semantic_results(text) {
let data = await client.graphql
.get()
.withClassName('Movies')
.withFields(['title', 'poster_link', 'genres', 'year', 'director', 'movie_id'])
.withNearText({concepts: [text]})
.withLimit(num_movies)
.do()
.then(info => {
return info
})
.catch(err => {
console.error(err)
});
return data;
}
要启用混合搜索,你可以使用.withHybrid()
搜索查询。alpha : 0.5
表示关键字搜索和语义搜索的权重相等。
async function get_hybrid_results(text) {
let data = await client.graphql
.get()
.withClassName('Movies')
.withFields(['title', 'poster_link', 'genres', 'year', 'director', 'movie_id'])
.withHybrid({query: text, alpha: 0.5})
.withLimit(num_movies)
.do()
.then(info => {
return info
})
.catch(err => {
console.error(err)
});
return data;
}
第三步:获取类似电影推荐
要获取类似的电影推荐,你可以执行 .withNearObject()
搜索查询,如 [queries.js](https://github.com/weaviate-tutorials/awesome-moviate/blob/main/queries.js)
文件中所示。通过传递电影的 id
,查询返回向量空间中与给定电影最接近的 num_movies = 20
部分电影。
async function get_recommended_movies(mov_id) {
let data = await client.graphql
.get()
.withClassName('Movies')
.withFields(['title', 'genres', 'year', 'poster_link', 'movie_id'])
.withNearObject({id: mov_id})
.withLimit(20)
.do()
.then(info => {
return info;
})
.catch(err => {
console.error(err)
});
return data;
}
第 4 步:运行演示
最后,将一切包装成一个具有 2000 年代标志性 GeoCities 美学的网页应用(我不打算让你厌倦前端的内容),然后就完成了!你已准备好!
要在本地运行演示,克隆 GitHub 仓库。
git clone git@github.com:weaviate-tutorials/awesome-moviate.git
导航到演示的目录并设置一个虚拟环境。
python -m venv .venv
source .venv/bin/activate
确保 在你的虚拟环境中设置环境变量 $OPENAI_API_KEY。 此外,在目录中运行以下命令,以安装虚拟环境中所需的所有依赖项。
pip install -r requirements.txt
接下来,在 docker-compose.yml
文件中设置你的 OPENAI_API_KEY,并运行以下命令通过 Docker 本地运行 Weaviate。
docker compose up -d
一旦你的 Weaviate 实例启动并运行,运行 add_data.py
文件以填充你的向量数据库。
python add_data.py
在你运行应用程序之前,安装所有必需的节点模块。
npm install
最后,运行以下命令在本地启动你的电影搜索引擎应用程序。
npm run start
现在,导航到 localhost:3000/
并开始尝试你的应用程序。
摘要
本文重新创建了 Andrej Karpathy 的有趣周末项目,即电影搜索引擎/推荐系统。下面,你可以看到完成的 实时演示 的简短视频:
演示网址 awesome-moviate.weaviate.io/
与原始项目相比,该项目使用向量数据库来存储嵌入。此外,搜索功能扩展到了语义和混合搜索。
如果你进行尝试,你会注意到它并不完美,但正如 Karpathy 所说:
“它运行得~还不错,哈哈,需要再调试一下。”
你可以在 GitHub 上找到项目的开源代码,并根据需要进行修改。一些进一步改进的建议包括尝试向量化不同的属性、调整关键词搜索与语义搜索之间的权重,或用开源替代品替换嵌入模型。
喜欢这个故事吗?
免费订阅 以便在我发布新故事时收到通知。
[## 每当 Leonie Monigatti 发布时获取电子邮件。
每当 Leonie Monigatti 发布新内容时,会收到一封电子邮件。通过注册,如果你还没有 Medium 账户,你将创建一个…
在 LinkedIn、Twitter 和 Kaggle 上找到我!
免责声明
-
这个项目并非原创想法:该项目的灵感来源于Andrej Karpathy 的周末项目,实现基于一个旧的电影搜索引擎演示项目。
从基础到高层次讲解和可视化的递归神经网络
机器翻译的应用
·发表于 Towards Data Science ·阅读时间 23 分钟·2023 年 6 月 22 日
--
递归神经网络(RNNs)是可以顺序操作的神经网络。虽然它们不像几年前那样流行,但它们在深度学习的发展中代表了一个重要的进步,并且是前馈网络的自然扩展。
在这篇文章中,我们将涵盖以下内容:
-
从前馈网络到递归网络的过渡
-
多层递归网络
-
长短期记忆网络(LSTM)
-
顺序输出(‘文本输出’)
-
双向性
-
自回归生成
-
机器翻译的应用(对 Google Translate 2016 模型架构的高层次理解)
本文的目的是不仅解释 RNN 的工作原理(已有许多文章对此进行了说明),还通过插图探索其设计选择和高层次直观逻辑。希望这篇文章不仅能为你对这一技术话题的理解提供独特的价值,还能更广泛地提升对深度学习设计灵活性的理解。
递归神经网络(1985)的设计基于两个观察,即理想模型(如人类阅读文本)如何处理顺序信息:
-
它应该跟踪“已学习”的信息,以便能够将新信息与之前看到的信息相关联。为了理解句子“the quick brown fox jumped over the lazy dog”,我需要跟踪词‘quick’和‘brown’,以便后续理解这些词适用于‘fox’。如果我没有在‘短期记忆’中保留这些信息,我将无法理解信息的顺序意义。当我在‘lazy dog’处完成句子时,我会将这个名词与之前遇到的‘quick brown fox’联系起来阅读。
-
尽管后来的信息总是会在早期信息的背景下被阅读,我们希望以类似的方式处理每一个词(标记),无论它的位置如何。我们不应该出于某种原因将第三个位置的词系统地转换成与第一个位置的词不同,即使我们可能会根据后者来阅读前者。请注意,之前提出的方法——将所有标记的嵌入并排堆叠并同时呈现给模型——不具备这一特性,因为没有保证第一个词的嵌入与第三个词的嵌入遵循相同的规则。这种通用特性也称为位置不变性。
循环神经网络的核心由循环层组成。一个循环层,像前馈层一样,是一组可学习的数学变换。事实证明,我们可以通过多层感知机来大致理解循环层。
循环层的‘短期记忆’被称为它的隐藏状态。这是一个向量——只是一组数字——它传达了网络迄今为止学到的重要信息。然后,对于标准化文本中的每个标记,我们将新信息融入隐藏状态。我们使用两个 MLP 来完成这项工作:一个 MLP 转换当前的嵌入,另一个转换当前的隐藏状态。这两个 MLP 的输出相加,形成更新后的隐藏状态,或称为‘更新后的短期记忆’。
然后我们对下一个标记重复这一过程——嵌入被传递到一个 MLP 中,而更新后的隐藏状态被传递到另一个 MLP 中;两个 MLP 的输出相加。这对序列中的每个标记重复:一个 MLP 将输入转换为可以融入短期记忆(隐藏状态)的形式,而另一个准备更新短期记忆(隐藏状态)。这满足了我们的第一个要求——我们希望在旧信息的背景下读取新信息。此外,这两个 MLP在每个时间步都是相同的。也就是说,我们使用相同的规则来合并当前隐藏状态与新信息。这满足了我们的第二个要求——我们必须对每个时间步使用相同的规则。
这两种 MLP 通常仅实现为单层深度:也就是说,它们只是一个大的逻辑回归堆栈。例如,下面的图示展示了 MLP A 的架构是如何工作的,假设每个嵌入长度为八个数字,隐藏状态也由八个数字组成。这是一种简单但有效的转换,将嵌入向量映射到适合与隐藏状态合并的向量。
当我们完成将最后一个标记融入隐藏状态后,循环层的任务也就完成了。它生成了一个向量——一个数字列表——代表了通过按顺序读取标记所累积的信息。然后,我们可以将这个向量传递通过第三个 MLP,MLP 学习“当前记忆状态”与预测任务之间的关系(在这个案例中,是股票价格是上涨还是下跌)。
更新权重的机制过于复杂,无法在本书中详细讨论,但其逻辑类似于反向传播算法。额外的复杂性在于追踪每个参数对其自身输出重复作用的累积效应(因此模型具有“递归”特性),这可以通过一种称为“时间反向传播”的修改算法在数学上解决。
递归神经网络是一种相当直观的方法来处理序列数据建模。它是线性回归模型复杂排列的另一种情况,但它非常强大:它允许我们系统地处理如语言这样的困难序列学习问题。
为了方便图示和简洁性,你通常会看到递归层被简单地表示为一个块,而不是作为一个扩展的单元,顺序处理一系列输入。
这是用于文本的最简单的递归神经网络变种:标准化的输入标记被映射到嵌入中,然后输入到循环层;循环层的输出(“最新的记忆状态”)由 MLP 处理,并映射到预测目标。
循环网络的复杂变种
循环层使得网络能够处理序列问题。然而,我们当前的递归神经网络模型存在一些问题。为了理解递归神经网络如何在实际应用中用于建模复杂问题,我们需要添加一些附加功能。
其中一个问题是深度不足:一个递归层只是简单地遍历一次文本,因此仅获得了内容的表面级、粗略的阅读。考虑哲学家伊曼纽尔·康德的句子“幸福不是理性的理想,而是想象的理想”。要真正理解这句话的深度,我们不能仅仅遍历一次文字。相反,我们需要读过这些文字,然后——这是关键的一步——我们阅读我们的思考。我们评估对句子的即时解释是否合理,并可能对其进行修改以使其更有深度。我们甚至可能阅读我们对我们思考的思考。这一切发生得非常快,通常在我们没有意识的情况下,但这是一个使我们能够从文本内容中提取多个层次深度的过程。
相应地,我们可以添加多个递归层来增加理解的深度。当第一个递归层获取文本的表面级信息时,第二个递归层则读取第一个递归层的“思考”。然后,将第二层的双重信息“最近记忆状态”作为输入传递给最终决策的 MLP。或者,我们可以添加多于两个递归层。
为了具体说明这种堆叠机制如何工作,请参阅下图:我们不仅将每个隐藏状态传递给下一个递归层进行更新,还将该输入状态提供给下一个递归层。虽然第一个递归层的第一个输入是嵌入,第一个递归层的第二个输入是“第一个递归层对第一个输入的思考”。
几乎所有用于现实世界语言建模问题的递归神经网络都使用递归层的堆栈而不是单个递归层,因为这样可以增加理解深度和语言推理。对于大堆栈的递归层,我们通常使用递归残差连接。回忆一下残差连接的概念,其中早期版本的信息被添加到后来的信息版本中。类似地,我们可以在每层的隐藏状态之间放置残差连接,使得各层可以参考不同的“思维深度”。
虽然递归模型在短小简单的句子如“美联储宣布经济衰退”中表现良好,但金融文档和新闻文章通常远比几个词要长。对于较长的序列,标准递归模型会遇到持续的长期记忆丧失问题:序列中较早的单词的信号或重要性常常被后面的单词稀释和掩盖。由于每个时间步都会对隐藏状态产生影响,它会部分破坏早期信息。因此,在序列结束时,大部分起始处的信息变得不可恢复。递归模型的关注/记忆窗口很窄。如果我们希望创建一个能够以类似人类的理解和深度来查看和分析文档的模型,我们需要解决这个记忆问题。
长短期记忆(LSTM)(1997)层是一个更复杂的递归层。其具体机制过于复杂,无法在本书中准确或完全讨论,但我们可以大致理解为试图将“长期记忆”与“短期记忆”分开。两个组件在“读取”序列时都很重要:我们需要长期记忆来追踪跨越时间的大量信息,同时也需要短期记忆来专注于特定的、本地化的信息。因此,LSTM 层不仅存储单一的隐藏状态,还使用“细胞状态”(代表“长期记忆”)。
每一步中,输入与隐藏状态以与标准递归层相同的方式结合。然而,之后会有三个步骤:
-
长期记忆清除。长期记忆是宝贵的;它保存了我们会持久保留的信息。当前的短期记忆状态用于确定长期记忆中哪些部分不再需要,并“剪切”这些部分以腾出空间给新的记忆。
-
长期记忆更新。现在长期记忆中的空间已被清除,短期记忆被用来更新(添加到)长期记忆,从而将新信息记录到长期记忆中。
-
短期记忆通知。此时,长期记忆状态已经根据当前时间步完全更新。因为我们希望长期记忆能够指导短期记忆的功能,长期记忆帮助剪切和修改短期记忆。理想情况下,长期记忆对什么是重要的、什么是不重要的有更大的监督。
因此,短期记忆和长期记忆——请记住,这两者都是数字列表——在每个时间步与彼此及输入进行交互,以一种允许仔细阅读而不会发生灾难性遗忘的方式读取输入序列。这个三步过程在下图中以图形方式描述。+
表示信息添加,而 x
表示信息移除或清理。(加法和乘法是实现这些思想的数学运算。例如,假设当前隐藏状态的值为 10。如果我将其乘以 0.1,它变成 1——因此,我已经‘减少’了隐藏状态中的信息。)
使用带有残差连接的 LSTM 堆叠,我们可以构建强大的语言理解模型,这些模型能够阅读(如果你愿意,也可以说是‘理解’)段落甚至整篇文章。除了用于金融分析以浏览大量金融和新闻报告外,这些模型还可以用于预测社交媒体帖子和消息中的潜在自杀或恐怖分子,推荐客户可能会购买的新产品(根据他们以前的产品评论),以及检测在线平台上的有毒或骚扰性评论和帖子。
这样的应用迫使我们批判性地思考其材料的哲学意义。政府对检测潜在恐怖分子有强烈的兴趣,而最近大屠杀的枪手往往有令人担忧的公共社交媒体记录——但悲剧在于,他们在海量的互联网信息中未被发现。语言模型,如你所见,纯粹是数学性的:它们试图找到最能建模输入文本和输出文本之间关系的权重和偏置。但在这些权重和偏置有意义的程度上,它们可以以有效且极其迅速的方式‘阅读’信息——比人类读者更快,甚至可能更有效。这些模型可能使政府能够在潜在恐怖分子行动之前检测、追踪和阻止他们。当然,这可能会以隐私为代价。此外,我们已经看到,尽管语言模型能够机械地追踪数据中的模式和关系,但它们实际上只是能够犯错的数学算法。如何调和模型对个体的错误标记为潜在恐怖分子的情况呢?
社交媒体平台在用户和政府的压力下,想要减少在线论坛中的骚扰和有害言论。从概念上看,这可能看似一个简单的任务:将社交媒体评论的语料库标记为有害或无害,然后训练语言模型预测特定文本样本的有害程度。然而,直接的问题是,数字话语因为依赖于迅速变化的参考(如表情包)、内部笑话、隐晦的讽刺以及必需的背景知识而极具挑战性。然而,更有趣的哲学问题是,是否可以且是否应该真正训练一个数学模型(一个‘客观’模型)来预测一个看似‘主观’的目标,如有害言论。毕竟,对一个人来说有害的东西,对另一个人可能并无害。
随着我们进入处理越来越个人化数据的模型——语言是我们沟通和吸收几乎所有知识的媒介——我们发现思考和致力于回答这些问题的重要性不断增加。如果你对这一研究方向感兴趣,你可能需要了解对齐、陪审学习、宪法 AI、RLHF 和价值多元化。
神经机器翻译
概念:多输出递归模型,双向性,注意力机制
机器翻译是一项令人惊叹的技术:它使得那些以前无法进行有效沟通的个体能够自由对话。一个讲印地语的人可以通过点击‘翻译此页面’按钮阅读用西班牙语写的网站,反之亦然。一个讲英语的人观看俄语电影时,可以启用实时翻译字幕。一个在法国的中国游客可以通过获取菜单的照片翻译来点餐。从非常字面意义上讲,机器翻译将语言和文化融合在一起。
在深度学习兴起之前,机器翻译的主流方法是基于查找表。例如,在中文中,‘I’翻译为‘我’,‘drive’翻译为‘开’,而‘car’翻译为‘车’。因此,‘I drive car’会被逐字翻译为‘我开车’。然而,任何双语者都知道这种系统的缺陷。许多拼写相同的词具有不同的含义。一种语言中的多个词可能在另一种语言中被翻译为一个词。此外,不同的语言有不同的语法结构,因此翻译后的词本身需要重新排列。英语中的冠词在像西班牙语和法语这样的有性别语言中有多种不同的上下文依赖翻译。许多试图通过巧妙的语言解决方案来调和这些问题的方法已经被提出,但其效果仅限于短小简单的句子。
深度学习则为我们提供了构建更深层次理解语言的模型的机会——甚至可能接近人类理解语言的方式——从而更有效地执行翻译这一重要任务。在这一部分,我们将介绍语言深度建模的多个额外想法,并最终深入探索 Google 翻译的工作原理。
文本输出递归模型
目前,构建可行的递归模型的最大障碍是无法输出文本。之前讨论的递归模型可以‘读取’但不能‘写入’——输出的结果是一个单一的数字(或一组数字,向量)。为了解决这个问题,我们需要赋予语言模型输出整个文本序列的能力。
幸运的是,我们不需要做太多工作。回顾之前介绍的递归层堆叠概念:我们不是仅在递归层处理完整个序列后收集‘记忆状态’,而是在每个时间步收集‘记忆状态’。因此,为了输出一个序列,我们可以在每个时间步收集一个记忆状态的输出。然后,我们将每个记忆状态传递给指定的 MLP,它预测给定记忆状态下输出词汇中的哪个词(标记为‘MLP C’)。概率最高的词被选为输出。
为了绝对清楚地了解每个记忆状态是如何转化为输出预测的,请考虑以下图示的进展。
在第一个图中,第一个输出的隐藏状态(这是在层读取了第一个词‘the’后得到的隐藏状态)被传递到 MLP C。MLP C 输出一个输出词汇的概率分布;即,它为输出词汇中的每个词提供一个概率,指示该词在该时间点作为翻译被选择的可能性。这是一个前馈网络:我们本质上是在对隐藏状态进行逻辑回归,以确定给定词的可能性。理想情况下,概率最大的词应该是‘les’,因为这是‘the’的法语翻译。
下一个隐藏状态是在递归层读取了‘the’和‘machines’之后得到的,再次传递到 MLP C。这次,概率最高的词应该是‘machine’(这是‘machines’在法语中的复数形式)。
在最后一个时间步中最可能选择的词应该是‘gagnent’,它是‘win’在特定时态中的翻译。模型应该选择‘gagnent’,而不是‘gagner’或其他不同的时态,基于它之前阅读的信息。这就是使用深度学习模型进行翻译的优势所在:能够掌握贯穿整句话的语法规则。
实际上,我们通常希望将多个递归层堆叠在一起,而不仅仅是一个递归层。这允许我们发展多层次的理解,首先是‘理解’输入文本的含义,然后以输出语言重新表达输入文本的‘意义’。
双向性
请注意,递归层是顺序进行的。当它读取文本“the machines win”时,它首先读取“the”,然后是“machines”,再是“win”。虽然最后一个词“win”是根据前面的词“the”和“machines”来读取的,但这种反向关系并不成立:第一个词“the”不是在后面的词“machines”和“win”上下文中读取的。这是一个问题,因为语言通常是为了预期我们将要说的内容而被说出的。在像法语这样的性别语言中,像“the”这样的冠词可以有许多不同的形式——“la”用于阴性物体,“le”用于阳性物体,而“les”用于复数物体。我们不知道要翻译哪个版本的“the”。当然,一旦我们阅读了句子的其余部分——“the machines”——我们就知道对象是复数的,应该使用“les”。这是文本的早期部分受到后期部分影响的一个例子。更一般地说,当我们重新阅读一个句子时——我们常常在不自觉中这样做——我们是在以句子的开始部分为背景来阅读开始部分。尽管语言是按顺序读取的,但它往往需要以‘非顺序’的方式解释(即,不严格从开始到结束的单向顺序)。
为了解决这个问题,我们可以使用双向性——这是一种对递归模型的简单修改,使得层能够‘前向’和‘后向’读取。双向递归层实际上是两个不同的递归层。一个层向前读取时间,而另一个层向后读取。两个层都读取完成后,它们在每个时间步的输出会被加在一起。
双向性使模型能够以一种方式读取文本,使得过去的内容在未来的上下文中被读取,同时也在过去的上下文中读取未来的内容(递归层的默认功能)。请注意,双向递归层在每个时间步的输出是由整个序列而不仅仅是所有之前的时间步所提供的信息。例如,在一个 10 时间步的序列中,时间步 t = 3 是由一个已经读取了序列 [t = 0] → [t = 1] → [t = 2] → [t = 3] 以及 另一个已经读取了序列 [t = 9] → [t = 8] → [t = 7] → [t = 6] → [t = 5] → [t = 4] → [t = 3] 的“记忆状态”所提供的信息。
这个简单的修改使得语言理解的深度显著丰富。
自回归生成
我们当前的翻译模型工作模型是一个大型(双向)递归层的堆栈。然而,存在一个问题:当我们将一些文本 A 翻译成其他文本 B 时,我们不仅仅是根据 A 来写 B,还要根据 B 自身 来写 B。
我们不能直接将复杂的俄语句子“Грузовик внезапно остановился потому что дорогу переходила курица”翻译成英语“ The truck suddenly stopped because a chicken was crossing the road”,因为直接从俄语逐字翻译会得到“Truck suddenly stopped because road was crossed by chicken”。在俄语中,宾语位于名词之后,但在英语中保留这种形式虽然可读但不流畅,也不是“最佳”的。因此,关键思想是:为了获得易于理解和使用的翻译,我们不仅需要确保翻译忠实于原文,还需要“忠实于自身”(自洽)。
为了实现这一点,我们需要一种不同的文本生成方法,称为 自回归生成。这使得模型能够不仅根据原始文本,还根据模型已经翻译过的内容来翻译每个单词。自回归生成不仅是神经翻译模型的主流范式,也是各种现代文本生成模型,包括高级聊天机器人和内容生成器的主流。
我们从一个“编码器”模型开始。在这种情况下,编码器模型可以表示为一个递归层的堆栈。编码器读取输入序列并生成一个输出,即编码后的表示。这一串数字表示输入文本序列的“本质”以定量形式呈现——如果你愿意,可以称其为“普遍/真实意义”。编码器的目标是将输入序列提炼成这个基本的意义包。
一旦获得了这个编码表示,我们就开始解码的任务。解码器的结构与编码器类似——我们可以将其视为另一个接受序列并生成输出的递归层堆栈。在这种情况下,解码器接受编码表示(即编码器的输出)和一个特殊的‘开始标记’(表示为)。开始标记表示句子的开头。解码器的任务是预测给定句子中的下一个词;在这种情况下,它给定了一个‘零词句子’,因此必须预测第一个词。在这种情况下,没有以前的翻译内容,因此解码器完全依赖于编码表示:它预测第一个词‘The’。
接下来是关键的自回归步骤:我们将解码器的先前输出重新输入到解码器中。现在我们有了一个‘单词句子’(开始标记后跟词‘The’)。两个标记都被传入解码器,以及编码表示——和之前一样,由编码器输出——现在解码器预测下一个词“truck”。
这个标记随后被视作另一个输入。在这里,我们可以更清楚地意识到为什么自回归生成是文本生成的一个有用算法框架:给定当前工作的句子是“The truck”,这限制了我们如何完成它。在这种情况下,下一个词很可能是动词或副词,这是我们作为语法结构“知道”的。另一方面,如果解码器仅能访问原始俄语文本,它将无法有效地限制可能性的集合。在这种情况下,解码器能够参考先前翻译的内容和原始俄语句子的含义,从而正确预测下一个词为“突然”。
这个自回归生成过程继续进行:
最后,为了结束一个句子,解码器模型预测一个指定的‘结束标记’(表示为)。在这种情况下,解码器将‘匹配’当前翻译句子与编码表示,以确定翻译是否令人满意并停止句子生成过程。
2016 Google Translate
到现在为止,我们已经覆盖了很多内容。现在,我们掌握了开发对 Google 翻译模型的较全面理解所需的大部分要素。我不需要多说这种模型的重要性:即使粗略,一个准确且易于使用的神经机器翻译系统也能打破许多语言障碍。对我们来说,这个特定的模型有助于将我们讨论的许多概念统一到一个连贯的应用中。
这些信息摘自 2016 年的Google 神经机器翻译论文,该论文介绍了 Google 的深度学习机器翻译系统。尽管在这些年里使用的模型几乎肯定已经发生了变化,但该系统仍然为神经机器翻译系统提供了有趣的案例研究。为了清楚起见,我们将该系统称为“Google 翻译”,并承认它可能不是最新的。
Google 翻译使用了编码器-解码器自回归模型。也就是说,该模型由编码器组件和解码器组件组成;解码器是自回归的(回顾前述:它在接受其他信息的同时,还接受之前生成的输出作为输入,这里是编码器的输出)。
编码器是由七层长短期记忆(LSTM)层组成的堆叠。第一层是双向的(因此技术上有 8 层,因为双向层“算作两个”),这使得它能够捕捉输入文本中双向的重要模式(底部图像,左侧)。此外,该架构在每一层之间采用了残差连接(底部图像,右侧)。回顾之前的讨论,残差连接在递归神经网络中可以通过将输入添加到递归层的输出上来实现,从而使递归层学习到对输入施加的最佳差异。
解码器也是由八层 LSTM 组成的堆叠。它以自回归的方式接收之前生成的序列,从开始标记</s>
开始。然而,Google 神经机器翻译架构使用了自回归生成和注意力机制。
注意力分数是为每个原始文本单词计算的(由编码器中的隐藏状态表示,编码器迭代地转换文本但仍保持位置表示)。我们可以把注意力看作是解码器和编码器之间的对话。解码器说:“我已经生成了[sentence],我想预测下一个翻译词。原始句子中的哪些词与下一个翻译词最相关?”编码器回答:“让我看看你在想什么,我会将其与我对原始输入中每个词的了解进行匹配……啊,你应该关注[word A],但不要过多关注[word B]和[word C],它们与预测下一个特定词的关系较小。”解码器感谢编码器:“我会考虑这些信息来确定如何生成,以便确实关注[word A]。”关于注意力的信息被发送到每个 LSTM 层,使得这种注意力信息在生成的所有层级中都可知。
这代表了谷歌神经机器翻译系统的主要部分。该模型在一个大型翻译任务数据集上进行训练:给定英文输入,预测西班牙文输出。模型学习读取的最佳方式(即编码器中的参数)、关注输入的最佳方式(即注意力计算)以及将关注的输入与西班牙文输出关联的最佳方式(即解码器中的参数)。
后续工作扩展了神经机器翻译系统的多语言能力,在这种能力下,一个模型可以用于在多个语言对之间进行翻译。这不仅从实际角度来看是必要的——为每对语言训练和存储一个模型是不可行的——而且也已显示出能改善任何两对语言之间的翻译。此外,GNMT 论文提供了训练的细节——这是一个受硬件限制的非常深的架构——以及实际部署——大型模型不仅训练慢,预测也慢,但谷歌翻译用户不希望等待超过几秒钟来翻译文本。
虽然 GNMT 系统确实是计算语言理解的一个里程碑,但仅仅几年后,一种在某些方面极简化的新方法将彻底改变语言建模——并完全摆脱曾经常见的递归层,我们曾如此费力地去理解。请关注关于 Transformers 的第二篇文章!
感谢阅读!
在这篇文章中,我们对递归神经网络进行了深入的调查:它们的设计逻辑、更多复杂特性和应用。
一些关键点:
-
RNN(循环神经网络)是前馈网络的自然扩展,可以用于机器翻译。RNN 旨在跟踪迄今为止学习到的信息,并将新信息与先前见过的信息相关联,类似于人类处理顺序信息的方式。
-
RNN 使用递归层,其中隐藏状态代表短期记忆。
-
可以堆叠递归层,以增加网络对理解和推理的深度。
-
长短期记忆(LSTM)网络是更复杂的递归层类型,它将长期记忆与短期记忆分开。LSTM 具有清除、更新和通知长期记忆和短期记忆的机制。
-
RNN 在金融分析、社交媒体分析、推荐系统和语言翻译等多个领域都有应用。
-
RNN 在现实世界应用中的使用引发了关于隐私、模型错误以及在标记有毒或有害内容时的主观性等哲学和伦理问题。
-
神经机器翻译是 RNN 的一个强大应用,它使不同语言之间的翻译成为可能,促进了沟通和文化交流。
所有照片由作者提供。
递归化学反应
使用 RDKit 对化学结构进行算法分析
·
关注 发表在Towards Data Science ·8 分钟阅读·2023 年 3 月 7 日
--
图片来源于Manuel Darío Fuentes Hernández来自Pixabay
RDKit 是一个用 C++ 编写的开源化学信息学工具包,也可以在 Java、Python 和 KNIME 中使用。它提供了广泛的化学信息学功能,例如读取和写入分子、处理原子、键和环、生成 2D 或 3D 坐标、搜索子结构、应用化学转换以及计算指纹和描述符。RDKit 还为 PostgreSQL 提供了高性能的 cartridge 数据库。RDKit 理解 SMARTS 语言(用于描述分子模式),以及 SMIRKS 语言(用于应用反应转换),这两者都以 SMILES 为基础,SMILES 是用于输入和表示分子及反应的著名线性符号。
RDKit 可以应用反应转换,并结合 Python 的递归功能,支持如本文所述的特殊用例。特别是,我们将递归地应用化学反应,以检查输入的分子结构是否为肽,即氨基酸的线性序列。我发现递归应用反应是算法分析化学结构的一个好方法,例如,通过反复分割结构、去除定义明确的片段并分类剩余基团,当不再能应用更多的转换时。
· 介绍
· 构建模块:氨基酸
· 识别肽键
· 打破肽键
· 结论
介绍
下图是一个由四个氨基酸组成的线性寡肽,即精氨酸、丙氨酸、苏氨酸和蛋氨酸,最近在我分析的数据集中发现。我从事工业化学工作,看到我们的数据库中有这样的结构感到惊讶,因为我们与生物分子无关,肽也不常见。进一步挖掘发现,我们通过将化学名称解释为化学名称到结构的算法错误地生成了这些结构。除了全名,氨基酸还可以用一个或三个字符表示,这些字符有时出现在文本中,并被错误地解释为化学结构。这种数据质量问题可能会影响公开可用的数据集,因此我认为找到检测这些情况并在其意外出现时加以消除的方法可能会很有用。
一小段含有四个氨基酸的肽。图像来源:作者。
解决这个问题有很多方法,我认为所有线性肽序列都可以通过复杂的 SMARTS 查询来匹配。开发这个方法对我来说似乎很具有挑战性,因此我认为通过将问题拆分成更小的问题来解决可能更为可行。在化学术语中,这意味着我们需要通过逐一水解肽键,将初始分子分解成更小的片段,直到没有更多的键可以断裂。如果所有获得的片段都是氨基酸,那么起始结构必须是肽。这样做还可以确定精确的氨基酸序列。我们需要的是一种递归应用化学反应的方法。如果你有兴趣了解如何使用 RDKit 实现这一点,并在过程中发现这个丰富的化学信息学库的一些功能,请继续阅读。
本文的代码可以在我的 GitHub 博客 仓库 中找到。我们将使用 RDKit 版本 2022.9.4 和 Python 3.9.13。该仓库还包括了提供的要求 文件 中的所有依赖项。
基础单元:氨基酸
组成蛋白质的氨基酸共有 20 种。此外,还有两种额外的氨基酸,这些氨基酸在某些物种中由通常解释为终止密码子的密码子编码。所有这些都是α-氨基酸,即氨基直接连接到α-碳上,即连接到羧基的碳原子。为了方便起见,我为这 22 种氨基酸创建了 SMILES,这些 SMILES 可以在 仓库 中找到。
第一步是读取 22 种氨基酸,并同时导入所有必要的 RDKit 模块。
在倒数第二行中,微笑表情被转换为本地的 RDKit Mol 对象。最后一行可能看起来有些神秘,但它的作用只是将氨基酸名称包含在分子中,具体可以通过
print(Chem.MolToMolBlock(amino_acids['mol'].iloc[0]))
打印出来
L-丙氨酸 Mol 块,其中包含了化学名称。图片由作者提供。
氨基酸可以使用 matplotlib 在定制的网格中进行可视化,如下所示。
最后一行将下方的图像保存为 PNG 文件,你可以在仓库中找到这个文件以及本文中所有其他 图像。
出现在遗传密码中的 22 种氨基酸。图片来源于作者。
RDKit 可以无缝地在 jupyter notebooks 中使用,其中 Draw 模块允许通过Draw.MolsToGridImage()
函数轻松可视化分子结构。但我发现使用 matplotlib 可以提供更多灵活性,特别是如果遵循这些出色的 建议 来调整图表组件。所有氨基酸共享以下对映体骨架。
α氨基酸骨架。图片来源于作者。
在这一阶段,我们可以引入一个有用的 RDKit 功能,即所谓的 R 基团分解。在下面的代码中,我们定义了氨基酸骨架核心,其 smiles 为 [*:1]C@HC(O)=O,具有两个明确的 R 基团标签。使用两个 R 基团标签的原因是 L-脯氨酸中的呋喃环。通过明确设置 R 基团分解,我们将其约束为仅匹配明确指定的 R 基团。
其余的代码创建了必要的输入数组,包含分子和图例,用于生成图像,使用的是与之前相同的实用函数。如果你仔细观察,你会注意到甘氨酸未能被分解为 R 基团。原因是它不是手性分子,而用于分解的核心结构是手性的。如果我们从核心中移除手性中心,甘氨酸的分解会成功,但 R 基团分解会失去一些特异性,这可能是不希望的。
氨基酸 R 基团分解。图片来源于作者。
R 基团分解在氨基酸骨架需要进一步处理的情况下很有用。本文不再对此深入探讨。
识别肽键
在我们尝试断裂肽键之前,我们可以检查是否能够识别这些肽键。我们使用一个由苏氨酸、蛋氨酸和精氨酸形成的线性三肽作为例子。肽键被定义为一种结构模式,用来定位原始分子中的匹配原子和键。
这产生了下面的图像。
苏氨酸、精氨酸、蛋氨酸三肽,突出显示了两个肽键。图片来源于作者。
使用 rdkit.Chem.Draw.rdMolDraw2D
,肽键被很好地定向。子结构搜索返回了一个原子索引的元组,其中的元组被展平,并用于在展平列表中找到所有原子对之间的键索引。在分子结构被保存为 PNG 图像之前,使用浅灰色突出显示了肽链中原子和键的索引。这里的关键消息是,通过使用 RDKit,可以控制到原子和键的级别,这些原子和键实质上形成了适合任何想象中的算法的图形。
打破肽键
打破肽键需要使用 SMIRKS 定义一个反应。之前定义的肽模式用于子结构匹配,成为 SMIRKS 反应的反应物部分。产品只是水解产物,其中两个氨基酸使用点符号分隔。苏氨酸、精氨酸、蛋氨酸三肽有两个肽键,因此可以应用两次反应,导致两个反应结果,这两个结果都包括一个氨基酸和一个二肽。
上述代码生成了一个 PNG 图像,其中包含两行的两个反应可能性。
苏氨酸、精氨酸、蛋氨酸三肽的水解。图片由作者提供。
如果我们想将一个肽完全分解成其组成的氨基酸,我们需要反复应用反应,直到不能再打破更多肽键。在每个阶段不必列举所有可能的反应。我们可以简单地应用一次反应,然后分别对两个反应物进行水解。RDKit 允许通过将 maxProducts
参数 设置为 1 来控制产品数量。什么使得起始结构成为肽?如果在应用所有可能的肽水解反应后我们只产生了一种或多种已知的氨基酸,那么起始结构就是一个肽。相反,如果在某个时点无法应用任何肽水解反应,并且结构不是已知氨基酸之一,那么起始结构就不是肽。
上述代码使用了两个实用函数,一个用于检查两个结构是否通过彼此为子结构的方式在化学上等效,另一个用于检查结构是否可以在结构列表中找到。递归肽键水解在第三个也是最后一个函数中实现。
使用这个递归函数,我们检查一组九个示例分子是否是肽。
该算法正确地将前 8 个结构分类为肽,将最后两个分类为非肽。值得注意的是,我们使用了氨基酸本身就是一个肽的约定,这严格来说可能并不准确,但这对于应用递归来说是方便的。
通过递归水解检查结构是否为肽。图像由作者提供。
通过将反应物结果添加到图中,可以增强算法,例如使用NetworkX并通过在每个节点上绘制结构来可视化反应进程。叶子节点将是可以进一步分析的氨基酸,以获取肽中氨基酸的确切序列。可能性无穷无尽;RDKit 已经完成了它的部分工作,然后可以依赖 Python 的表现力完成其余部分。
结论
RDKit 是一个丰富的化学信息学库。现在可以通过 pip 轻松部署,并打开了在化学应用中使用 Python 及其数据分析和数据科学生态系统的可能性。RDKit 的文档虽说并非最好,但现在有很多教程和博客可以参考。该库不断发展,新增了许多功能。希望这篇文章对展示 RDKit 的一些功能和潜力有所帮助。
通过大型语言模型重新定义对话式人工智能
实现对话式人工智能以提供统一用户体验的指南
·
关注 发表在 Towards Data Science · 21 分钟阅读 · 2023 年 9 月 28 日
--
来源: rawpixel.com
对话式人工智能是大型语言模型(LLMs)的一种应用,因其在多个行业和用例中的可扩展性而引发了大量关注。尽管对话式系统已经存在了几十年,但 LLMs 带来了大规模应用所需的质量提升。在本文中,我们将使用图 1 所示的心理模型来剖析对话式人工智能应用(参见 用整体心理模型构建 AI 产品以了解心理模型)。在考虑了对话式人工智能系统的市场机会和商业价值之后,我们将解释需要设置的额外“机制”,包括数据、LLM 微调和对话设计,以使对话不仅成为可能,而且有用和愉悦。
图 1:AI 系统的心理模型(参见 用整体心理模型构建 AI 产品)
1. 机会、价值和局限
传统的用户体验(UX)设计围绕着大量的人工 UX 元素、滑动、点击和触碰展开,这要求每个新应用都有一个学习曲线。通过使用对话式人工智能,我们可以摆脱这些繁琐的操作,代之以自然流畅对话的优雅体验,这样我们可以忘记不同应用、窗口和设备之间的过渡。我们使用语言——我们通用且熟悉的沟通协议——与不同的虚拟助手(VAs)互动并完成任务。
对话式用户界面(UI)并不是全新的热门事物。交互式语音响应系统(IVRs)和聊天机器人自 1990 年代以来就已经存在,而自然语言处理(NLP)的重大进展也一直伴随着语音和聊天界面的希望和发展浪潮。然而,在大型语言模型(LLMs)出现之前,大多数系统都是以符号范式实现的,依赖于规则、关键字和对话模式。它们还局限于特定的、预定义的“能力”领域,用户如果超出这些领域很快就会陷入困境。总的来说,这些系统充满了潜在的失败点,在经历了几次令人沮丧的尝试后,许多用户再也没有回到这些系统中。下图展示了一个对话示例。一个希望为特定音乐会订票的用户耐心地经过了一系列详细的询问流程,结果在最后发现音乐会已经售罄。
图 2:糟糕对话流程的示例
作为一种赋能技术,LLMs 可以将对话界面提升到新的质量和用户满意度水平。对话系统现在能够展示更广泛的世界知识、语言能力和对话能力。利用预训练模型,它们也可以在更短的时间内开发完成,因为编写规则、关键词和对话流程的繁琐工作现在被 LLM 的统计知识所取代。让我们来看看两个对话式 AI 可以大规模提供价值的突出的应用场景:
-
客户支持,更一般来说,是那些由大量用户使用的应用,这些用户经常提出类似的请求。在这里,提供客户支持的公司在信息上相对于用户具有明显的优势,可以利用这一点创造出更直观和愉悦的用户体验。以重新预订航班为例。对于我这样一个经常乘坐飞机的人来说,这是一年中会发生 1 到 2 次的事情。在这段时间内,我往往会忘记过程的细节,更不用说特定航空公司的用户界面了。相比之下,航空公司的客户支持将重新预订请求置于其操作的核心。与其通过复杂的图形界面展示重新预订过程,不如将其逻辑“隐藏”在联系支持的客户面前,他们可以使用自然语言作为通道来完成重新预订。当然,仍然会有一些较少见的请求。例如,想象一下一个突发的情绪波动使得一位商务客户决定将她心爱的狗作为超额行李添加到已预订的航班中。这些更具个性化的请求可以交给人工客服处理,或通过连接到虚拟助手的内部知识管理系统来解决。
-
知识管理依赖于大量的数据。对于许多现代公司而言,他们在运营、迭代和学习过程中积累的内部知识是一项核心资产和差异化因素——前提是这些知识以高效的方式存储、管理和访问。虽然公司拥有隐藏在协作工具、内部维基、知识库等中的大量数据,但他们常常未能将其转化为可操作的知识。随着员工离职、新员工加入,以及你永远无法完成三个月前开始的文档页面,有价值的知识会受到熵的影响。找到内部数据迷宫中的出路并获取特定业务情境所需的信息变得越来越困难。这导致了知识工作者的巨大效率损失。为了解决这个问题,我们可以通过在内部数据源上增强 LLM 的语义搜索功能来应对。LLM 允许使用自然语言问题而不是复杂的正式查询来对数据库进行提问。这样,用户可以专注于他们的信息需求,而不是知识库的结构或查询语言(如 SQL)的语法。由于这些系统是基于文本的,它们在丰富的语义空间中处理数据,在“幕后”进行有意义的连接。
除了这些主要的应用领域,还有许多其他应用,例如远程医疗、心理健康助手和教育聊天机器人,它们可以以更快、更高效的方式优化用户体验并为用户带来价值。
2. 数据
大型语言模型(LLMs)最初并不是为了进行流畅的小谈话或更深入的对话而训练的。相反,它们在每一步推理中学习生成下一个标记,最终形成连贯的文本。这种低级别的目标与人类对话的挑战不同。对人类来说,对话非常直观,但当你想教会机器做到这一点时,它就变得极其复杂和微妙。例如,让我们来看一下意图的基本概念。当我们使用语言时,我们是为了特定的目的,这就是我们的沟通意图——可能是传达信息、社交或要求别人做某事。前两种目的对 LLM 来说相对直接(只要它在数据中见过所需的信息),而后者则更具挑战性。LLM 不仅需要以连贯的方式组合和组织相关信息,还需要在正式性、创造力、幽默感等软性标准方面设定正确的情感基调。这是对话设计(参见第五部分)的挑战,与创建微调数据的任务紧密相关。
从传统语言生成转向识别和响应特定的交流意图是提升对话系统可用性和接受度的重要步骤。与所有微调工作一样,这从编制适当的数据集开始。
微调数据应尽可能接近(未来的)实际数据分布。首先,它应该是对话(对话)数据。其次,如果你的虚拟助手将专注于特定领域,你应该尝试组装反映必要领域知识的微调数据。第三,如果你的应用程序中有典型的流程和请求会频繁出现,例如客户支持中的情况,请尽量将这些的多样化示例纳入你的训练数据中。下表显示了来自3K Conversations Dataset for ChatBot的对话微调数据示例,该数据集在 Kaggle 上免费提供:
表 1:来自3K Conversations Dataset for ChatBot的对话微调数据示例
手动创建对话数据可能是一项昂贵的工作——众包和使用 LLMs 来帮助生成数据是扩展的两种方式。一旦对话数据收集完成,就需要对对话进行评估和注释。这使你能够向模型展示正面和负面的示例,并推动其捕捉到“正确”对话的特征。评估可以通过绝对分数或不同选项之间的排名来进行。后一种方法能够提供更准确的微调数据,因为人类通常更擅长对多个选项进行排名,而不是单独评估它们。
在数据就绪后,你可以对模型进行微调,并为其增添额外的功能。在下一节中,我们将探讨微调、从记忆和语义搜索中集成额外信息,以及将代理连接到你的对话系统以使其能够执行特定任务。
3. 组装对话系统
一个典型的对话系统由一个协调和组织系统组件和能力的对话代理构建,例如 LLM、内存和外部数据源。对话 AI 系统的开发是一个高度实验性和经验性的任务,你的开发人员将不断在优化数据、改进微调策略、尝试额外组件和增强功能以及测试结果之间反复进行。非技术团队成员,包括产品经理和用户体验设计师,也将持续测试产品。根据他们的客户发现活动,他们能够很好地预测未来用户的对话风格和内容,并应积极贡献这一知识。
3.1 教授你的 LLM 对话技能
对于微调,你需要你的微调数据(参见第二部分)和一个预训练的 LLM。LLM 已经对语言和世界有很多了解,我们的挑战是教会它们对话的原则。在微调中,目标输出是文本,模型将被优化以生成尽可能与目标相似的文本。对于监督微调,你首先需要清晰定义你希望模型执行的对话 AI 任务,收集数据,然后运行和迭代微调过程。
随着对 LLM 的炒作,各种微调方法应运而生。对于一个相对传统的对话微调示例,你可以参考 LaMDA 模型的描述。[1] LaMDA 通过两个步骤进行微调。首先,使用对话数据教会模型对话技能(“生成”微调)。然后,使用评估数据时由注释员产生的标签来训练分类器,这些分类器可以评估模型在所需属性(包括合理性、具体性、趣味性和安全性)上的输出。这些分类器随后用于引导模型的行为朝向这些属性。
图 3:LaMDA 通过两个步骤进行微调
此外,事实基础性——将其输出建立在可靠的外部信息上的能力——是 LLM 的一个重要属性。为了确保事实基础性并最小化幻觉,LaMDA 通过一个数据集进行了微调,该数据集涉及在需要外部知识时调用外部信息检索系统。因此,模型学会了在用户提出需要新知识的查询时,首先检索事实信息。
另一种流行的微调技术是来自人类反馈的强化学习(RLHF)[2]。RLHF “重定向”了大型语言模型(LLM)的学习过程,从简单但人为的下一个词预测任务转向在特定交流情境中学习人类偏好。这些人类偏好直接编码在训练数据中。在标注过程中,人类会收到提示,或者写下期望的回应,或者对一系列现有回应进行排序。然后,LLM 的行为被优化以反映人类的偏好。
3.2 添加外部数据和语义搜索
除了编译对话用于微调模型之外,你可能还想用可以在对话中利用的专门数据来增强你的系统。例如,你的系统可能需要访问外部数据,如专利或科学论文,或内部数据,如客户档案或你的技术文档。这通常通过语义搜索(也称为检索增强生成,或 RAG)[3] 完成。额外的数据以语义嵌入的形式保存在数据库中(参见这篇文章以了解嵌入的解释和进一步参考)。当用户请求到来时,它会被预处理并转换为语义嵌入。语义搜索然后识别与请求最相关的文档,并将其作为提示的上下文。通过将额外数据与语义搜索结合,你可以减少幻觉并提供更有用、基于事实的回应。通过不断更新嵌入数据库,你还可以保持系统的知识和回应的最新状态,而不需要不断重新运行微调过程。
3.3 记忆与上下文意识
想象一下你去参加一个派对,遇到了彼得,一位律师。你很兴奋,开始介绍你正在计划构建的法律聊天机器人。彼得看起来很感兴趣,朝你倾斜身体,嗯嗯作声并点头。在某个时刻,你想知道他是否愿意使用你的应用程序。你听到的不是一个能弥补你口才的有信息性的声明,而是:“嗯……这个应用程序来做什么的?”
人类之间的未书面沟通契约假定我们在倾听对话伙伴,并在我们共同创建的上下文基础上构建自己的言语行为。在社交环境中,这种共同理解的出现标志着一次富有成效、充实的对话。在更平凡的场景中,比如预订餐厅桌位或购买火车票,这是完成任务并向用户提供期望价值的绝对必要条件。这要求您的助理了解当前对话的历史,也包括过去对话的历史——例如,它不应该在用户每次发起对话时重复询问用户的姓名和其他个人信息。
维护上下文意识的挑战之一是共指消解,即理解代词指代的对象。人类在解释语言时直观地使用了很多上下文线索——例如,您可以问一个小孩:“请把红色盒子里的绿球拿出来给我,”孩子会知道您指的是球,而不是盒子。对于虚拟助手来说,这项任务可能相当具有挑战性,如以下对话所示:
助理:谢谢,我现在将为您预订航班。您是否还想为航班订餐?
用户:嗯……我可以稍后决定是否需要吗?
助理:抱歉,这个航班不能更改或取消。
在这里,助理未能识别用户的代词it指的不是航班,而是餐食,因此需要进行另一次迭代以纠正这一误解。
3.4 附加保护措施
即使是最好的 LLM,也会偶尔出现行为不当和幻觉。在许多情况下,幻觉只是简单的准确性问题——而且,您需要接受没有任何 AI 是 100% 准确的。与其他 AI 系统相比,用户与 AI 之间的“距离”相对较小。简单的准确性问题很快会转变为被认为是有害的、歧视性的或一般性的有害内容。此外,由于 LLM 对隐私没有固有的理解,它们还可能泄露诸如个人身份信息(PII)之类的敏感数据。您可以通过使用额外的保护措施来抵制这些行为。工具如 Guardrails AI、Rebuff、NeMo Guardrails 和 Microsoft Guidance 允许您通过对 LLM 输出制定额外要求并阻止不良输出来降低系统的风险。
对话 AI 中可能有多种架构。以下示例展示了如何通过对话代理将微调的 LLM、外部数据和记忆集成在一起,代理还负责提示构建和保护措施。
图 4:包含微调 LLM、语义搜索数据库和记忆组件的对话 AI 系统示意图
4. 用户体验与对话设计
对话界面的魅力在于它们在不同应用程序中保持的简单性和一致性。如果用户界面的未来是所有应用看起来或多或少都一样,那么 UX 设计师的工作就注定要失败了吗?绝对不是——对话是一门艺术,需要教给你的 LLM,以便它能够进行对用户有帮助、自然且舒适的对话。良好的对话设计是在结合我们对人类心理学、语言学和 UX 设计的知识时产生的。接下来,我们将首先考虑在构建对话系统时的两个基本选择,即你是否会使用语音和/或聊天,以及你系统的更大背景。然后,我们将看看对话本身,并了解如何设计你的助手的个性,同时教会它进行有帮助和合作的对话。
4.1 语音与聊天
对话界面可以通过聊天或语音实现。简而言之,语音更快,而聊天允许用户保持隐私,并受益于丰富的 UI 功能。让我们深入了解这两种选项,因为这是构建对话应用时你将面临的第一个也是最重要的决策之一。
要在这两种选择之间做出决定,首先要考虑你的应用将使用的物理环境。例如,为什么几乎所有汽车中的对话系统,如 Nuance Communications 提供的那些,都是基于语音的?因为司机的双手已经忙碌,他们不能不断在方向盘和键盘之间切换。这同样适用于其他活动,如烹饪,在这些活动中,用户希望在使用你的应用时保持活动的流畅。汽车和厨房通常是私人环境,因此用户可以享受语音交互的乐趣,而无需担心隐私问题或打扰他人。相比之下,如果你的应用将在办公室、图书馆或火车站等公共场所使用,语音可能不是你的首选。
在了解了物理环境后,考虑情感方面。语音可以有意地传达语调、情绪和个性——这在你的背景中是否增添了价值?如果你为休闲目的构建应用,语音可能会增加趣味性,而一个用于心理健康的助手可以更具同情心,并允许潜在的困扰用户有更大的表达范围。相比之下,如果你的应用将帮助用户在专业环境中,如交易或客户服务,基于文本的更匿名交互可能有助于做出更客观的决策,并免去设计过于情感化体验的麻烦。
下一步,考虑功能性。基于文本的界面允许你通过其他媒体(如图片)和图形用户界面元素(如按钮)来丰富对话。例如,在一个电子商务助手中,通过展示产品图片和结构化描述来推荐产品的应用程序将比通过语音描述产品并可能提供其标识符的应用程序更具用户友好性。
最后,让我们谈谈构建语音用户界面的额外设计和开发挑战:
-
在用户输入可以通过 LLM 和自然语言处理(NLP)处理之前,还有一个额外的语音识别步骤。
-
语音是一种更具个人化和情感化的沟通媒介——因此,为虚拟助手设计一致、合适且愉快的个性化要求更高,你需要考虑“语音设计”的额外因素,如音色、重音、语调和语速。
-
用户期望你的语音对话速度与人类对话速度相同。为了通过语音提供自然的互动,你需要比聊天更短的延迟。在人类对话中,转折之间的典型间隔为 200 毫秒——这种快速响应是可能的,因为我们在听对方讲话时开始构建我们的发言。你的语音助手需要达到这种流畅度。相比之下,对于聊天机器人,你需要与几秒钟的时间跨度竞争,一些开发者甚至引入额外的延迟,使对话感觉像是人类之间的输入聊天。
-
语音沟通是线性的、一发即成的事业——如果用户没有听懂你说的话,你将进入一个冗长且容易出错的澄清循环。因此,你的发言需要尽可能简洁、清晰和信息丰富。
如果你选择语音解决方案,确保你不仅清楚了解相对于聊天的优势,还具备应对这些额外挑战的技能和资源。
4.2 你的对话 AI 将生活在哪里?
现在,让我们考虑你可以集成对话 AI 的更大背景。我们都熟悉公司网站上的聊天机器人——当我们打开商业网站时,屏幕右侧弹出的那些小部件。就我个人而言,更多时候,我的直觉反应是寻找“关闭”按钮。这是为什么呢?通过最初尝试“与这些机器人对话”,我了解到它们无法满足更具体的信息需求,最终我仍然需要浏览网站。故事的寓意?不要因为聊天机器人很酷和时尚而构建它——而是因为你确定它能为用户创造额外的价值。
除了公司网站上有争议的小部件,还有几个令人兴奋的背景,可以集成那些随着 LLMs 变得可能的更通用的聊天机器人:
-
副驾驶:这些助手在特定的流程和任务中为你提供指导和建议,比如用于编程的 GitHub CoPilot。通常,副驾驶是“绑定”到特定应用程序(或一小套相关应用程序)的。
-
合成人物(也称数字人物):这些生物在数字世界中“模仿”真实人类。它们看起来、行动和说话都像人类,因此也需要丰富的对话能力。合成人物通常用于沉浸式应用,如游戏、增强现实和虚拟现实。
-
数字双胞胎:数字双胞胎是现实世界过程和物体的数字“副本”,例如工厂、汽车或引擎。它们用于模拟、分析和优化真实物体的设计和行为。与数字双胞胎的自然语言交互使得对数据和模型的访问更加顺畅和多样化。
-
数据库:如今,任何主题的数据都是可用的,无论是投资建议、代码片段还是教育材料。通常难的是找到用户在特定情况下需要的非常具体的数据。图形界面的数据库要么过于粗糙,要么布满了无尽的搜索和过滤小部件。诸如 SQL 和 GraphQL 等多功能查询语言仅对具备相应技能的用户开放。对话式解决方案允许用户以自然语言查询数据,而处理请求的 LLM 会自动将其转换为相应的查询语言(参见这篇文章以了解 Text2SQL 的解释)。
4.3 在你的助手上印刻个性
作为人类,我们天生倾向于拟人化,即在看到某些与人类略微相似的事物时附加额外的人类特征。语言是人类最独特和迷人的能力之一,对话产品将自动与人类相关联。人们会想象屏幕或设备背后有一个人——而且最好不要让这个特定的人物依赖于用户的想象,而是赋予其与您的产品和品牌一致的个性。这个过程被称为“角色设计”。
角色设计的第一步是理解你希望角色展示的特质。理想情况下,这应该在训练数据层面上完成——例如,在使用 RLHF 时,你可以要求标注者根据有用性、礼貌性、趣味性等特质对数据进行排名,从而使模型偏向于所需的特征。这些特征可以与您的品牌属性相匹配,以创建一个一致的形象,通过产品体验不断地促进品牌建设。
除了一般特征外,你还应该考虑你的虚拟助手如何处理“幸福路径”之外的特定情况。例如,它将如何回应超出其范围的用户请求,回答关于自身的问题,以及处理辱骂或粗俗的语言?
重要的是要制定明确的内部指南,以供数据标注员和对话设计师使用。这将使你能够以有目的的方式设计你的角色,并在团队内及随着应用程序经过多个迭代和优化的过程中保持一致性。
4.4 使用“合作原则”使对话更有帮助
你是否曾经有过和砖墙对话的感觉,即便你实际上是在和一个人交谈?有时,我们发现我们的对话伙伴根本不愿意将对话引向成功。幸运的是,在大多数情况下,事情会更顺利,人们会直观地遵循由语言哲学家保罗·格赖斯提出的“合作原则”。根据这一原则,成功沟通的人遵循四个准则,即量、质量、相关性和方式。
量的准则
量的准则要求说话者提供信息,并使他们的贡献尽可能有信息量。在虚拟助手的角度,这也意味着积极推动对话。例如,请考虑这个来自电子商务时尚应用的片段:
助理:你在寻找什么类型的衣物?
用户:我在找一件橙色的连衣裙。
助理:不要:对不起,我们目前没有橙色的连衣裙。
做:对不起,我们没有橙色的连衣裙,但我们有这件很棒且非常舒适的黄色连衣裙:…
用户希望从你的应用中找到合适的物品。由于没有符合确切描述的物品而停止对话会使成功的可能性丧失。然而,如果你的应用能提出替代物品的建议,它会显得更有帮助,并留住成功互动的可能性。
尤其是在语音交互中,找到提供用户可能需要的所有信息与不使他们被不必要的信息淹没之间的正确平衡非常重要,以免干扰对话。
质量的准则
质量的准则要求说话者要真实,不说他们认为是虚假的或缺乏充分证据的事物。在这里涉及大量主观性和个人信念——因此,这一准则不能直接投射到大型语言模型上。正如我们所见,大型语言模型倾向于产生虚假信息,这种风险在对话进入不熟悉的领域时尤其高。可以使用额外的保护措施来避免误导性陈述,并将对话引回正轨。
相关性准则
相关原则要求我们要——确切地说,要相关。作为前提,虚拟助理必须理解用户的真实意图。以下是一个违反这一原则的航空公司助理的例子:
用户:我需要一张明天早上从柏林到苏黎世的航班。
助理:好的,我找到了一些明天早上从柏林到苏黎世的航班。航班 BZ36294 在早上 6:30 从柏林起飞,8 点到达苏黎世。航班 BZ96294 在早上 8 点从柏林起飞,9:30 到达。航班 BZ96294 在早上 8 点从柏林起飞,9:30 到达苏黎世。航班 BZ95614 在早上 11 点从柏林起飞,12:30 到达苏黎世。此外,还有一个中转航班的可能性,即你将在 11:45 离开柏林,3 点到达苏黎世。
用户:对不起……你能再重复一遍吗?
助理用大量可能不相关的详细信息让用户感到不知所措。通过将这些信息分解并引导用户缩小搜索范围,对话会变得更易于理解:
用户:我需要一张明天早上从柏林到苏黎世的航班。
助理:好的,我找到了一些明天早上从柏林出发的航班,分别是 6:30、8 点和 11 点。你想选择哪个时间出发?
用户:我需要在早上 9 点前到达苏黎世。
助理:好的,你可以选择航班 BZ36294。它在 6:30 出发,8 点到达。我要为你购买机票吗?
用户:是的,谢谢。
表述原则
最后,表述原则指出我们的言语行为应当清晰、简洁、有序,避免歧义和模糊表达。你的虚拟助理应避免技术或内部术语,倾向于使用简单、普遍易懂的表述。
虽然格赖斯的原则适用于所有对话,无论具体领域如何,但没有专门针对对话进行训练的 LLMs 往往无法满足这些原则。因此,在编制训练数据时,重要的是要有足够的对话样本,以便模型能够学习这些原则。
对话设计领域发展迅速。无论你是已经在构建人工智能产品还是考虑在人工智能领域的职业道路,我鼓励你深入探讨这个话题(参见[5]和[6]中的优秀介绍)。随着人工智能逐渐成为一种商品,良好的设计和可辩护的数据策略将成为人工智能产品的重要差异化因素。
摘要
让我们总结一下文章的要点。此外,图 5 提供了一个“备忘单”,其中包含主要内容,你可以下载作为参考。
-
大型语言模型(LLMs)提升了对话型人工智能的质量和可扩展性:大型语言模型(LLMs)显著提高了各个行业和应用场景中对话型人工智能应用的质量和可扩展性。
-
对话型人工智能可以为处理大量类似用户请求(例如客服)或需要访问大量非结构化数据(例如知识管理)的应用增加很多价值。
-
数据:为对话任务微调 LLMs 需要高质量的对话数据,这些数据应尽可能贴近现实世界的互动。众包和 LLM 生成的数据可以是扩大数据收集的宝贵资源。
-
组建系统:开发对话型 AI 系统是一个迭代和实验的过程,涉及对数据、微调策略和组件集成的持续优化。
-
教授对话技巧给 LLMs:微调 LLMs 涉及训练它们识别和响应特定的交流意图和情境。
-
使用语义搜索添加外部数据:通过使用语义搜索整合外部和内部数据源,可以通过提供更多上下文相关的信息来增强 AI 的响应。
-
记忆和上下文意识:有效的对话系统必须保持上下文意识,包括跟踪当前对话和过去互动的历史,以提供有意义和连贯的回应。
-
设置保护措施:为了确保负责任的行为,对话型 AI 系统应采用保护措施,以防止不准确、虚假信息和隐私泄露。
-
人物设定:为你的对话助手设计一个一致的人物设定对创建连贯且具有品牌特色的用户体验至关重要。人物特征应与产品和品牌属性相一致。
-
语音与聊天:选择语音还是聊天界面取决于物理环境、情感背景、功能和设计挑战等因素。在决定对话型 AI 的界面时,请考虑这些因素。
-
在各种环境中的集成:对话型 AI 可以在不同的环境中集成,包括副驾驶、合成人物、数字双胞胎和数据库,每种情况都有特定的应用场景和需求。
-
观察合作原则:遵循对话中的数量、质量、相关性和方式原则可以使与对话型 AI 的互动更加有帮助和用户友好。
图 5:对话型 AI 的关键要点和最佳实践
参考文献
[1] Heng-Tze Chen 等. 2022. LaMDA: 迈向安全、扎根且高质量的对话模型。
[2] OpenAI. 2022. ChatGPT: 优化对话模型。2022 年 1 月 13 日检索。
[3] Patrick Lewis 等. 2020. 检索增强生成用于知识密集型 NLP 任务。
[4] Paul Grice. 1989. 《言语的方式研究》。
[5] Cathy Pearl. 2016. 《语音用户界面设计》。
[6] Michael Cohen 等. 2004. 《语音用户界面设计》。
注意:除非另有说明,所有图片均由作者提供。
减少你的 Cloud Composer 账单(第一部分)
原文:
towardsdatascience.com/reduce-your-cloud-composer-bills-f03e112df689
使用定时 CICD 流水线关闭环境并恢复到之前的状态
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 3 月 24 日
--
图片由 Sasun Bughdaryan 提供,来源于 Unsplash
Cloud Composer 是一个托管的、可扩展的流行复杂作业调度器 Airflow 的安装版本。该服务在 Google Cloud Platform (GCP) 上提供两种版本:Cloud Composer 1 和 Cloud Composer 2,主要区别在于只有 Cloud Composer 2 才提供 Workers Autoscaling。
由于我已经使用该服务多年,我可以肯定地说它值得一试。然而,一些公司会避免使用这项服务,这个原因可能不会让你感到太意外。钱。
在这篇文章中,我将分享一个减少 Cloud Composer 账单的有效方法。尽管代码片段仅适用于 Cloud Composer 2,但提倡的策略仍适用于 Cloud Composer 1 用户。
请注意,这是一系列两部分中的第一部分。第二篇文章可以在 这里 查阅。
以下是将要涵盖的主要主题:
理解 Cloud Composer 2 定价(第一部分)
使用快照关闭 Composer 并保留其状态(第一部分)
使用快照创建 Composer 环境(第一部分)
总结(第一部分)
销毁 Composer 环境以节省费用(第二部分)
更新 Composer 环境(第二部分)
自动化 Composer 环境的创建与销毁(第二部分)
总结(第二部分)
理解 Cloud Composer 2 定价
Cloud Composer 的主要概念是环境。基本上,一个环境对应于一个 Airflow 实例,具有一个名称和版本。每个环境由一组 Google Cloud 服务组成,这些服务的使用会产生一些费用。例如,Airflow 元数据数据库在每个环境中作为Cloud SQL 实例提供,而 Airflow Scheduler 在每个环境中作为Google Kubernetes Engine Pod进行部署。
Cloud Composer 定价有 3 个主要部分:
-
计算成本:这是运行 Airflow Scheduler、Airflow Workers、Airflow Triggerers、Airflow Web Servers 和其他 Cloud Composer 组件的 Google Kubernetes Engine 节点的成本。
-
数据库存储成本:这对应于 Airflow 元数据数据库存储的 Cloud SQL 存储成本。
-
环境规模成本:环境规模与某些完全由 Google Cloud 控制的 Cloud Composer 组件相关。这些组件的规模根据为环境大小参数(小型 | 中型 | 大型)设置的值自动调整。Airflow 元数据数据库的 Cloud SQL 实例、Cloud SQL 代理和 Redis 任务队列就是这些组件的例子。
关于 Cloud Composer 2 定价的更详细文档可在这里找到。
图片来源:作者,Cloud Composer 定价
快照作为关闭 Composer 并仍然保留其状态的一种方式
在降低使用 Cloud Composer 的成本方面,尝试的方案并不多。人们通常会尝试调整环境的规模,以避免资源浪费。这意味着根据将要部署在 Composer 环境上的实际工作负载,使用最少数量的 GKE 节点和最小可能的 Cloud SQL 实例。
坦率地说,这比说起来容易。凭借其自动扩缩功能,Cloud Composer 2 通过使 Airflow Workers 数量在环境需求不大时自动减少到 1,从而简化了我们的生活。然而,不支持缩减到零,并且没有办法停止或禁用 Cloud Composer 环境。
这对许多 Cloud Composer 用户来说是一个严重的痛点,他们可能会觉得在 Cloud Composer 上花费的费用超过了他们实际获得的好处。实际上,在大多数情况下,开发和测试 Composer 环境不需要在夜间和周末保持开启。我所说的是,保持非生产环境全天候运行是不经济的,但由于缺乏原生的启动和停止功能,Cloud Composer 用户最终会这样做。
Google 于 2022 年 4 月推出了环境快照作为预览功能,并于 2022 年 12 月正式推出。顾名思义,环境快照会创建 Cloud Composer 环境的快照,这些快照可以加载以将环境恢复到创建快照时的状态。
使用此功能,可以模拟启动与停止功能,因为环境可以被销毁并重建而不会丢失其状态。
注意: 请注意,Cloud Composer 快照不会保留 Airflow 任务日志
以下是如何在非生产环境中大幅削减 Cloud Composer 费用的 3 步秘诀:
-
创建环境并加载最新的快照(首次创建环境时将没有快照可加载)
-
在环境上执行所需的任何更新
-
保存快照并销毁环境,当你不再需要它时
在专业环境中,上述步骤将通过 CICD 管道执行,这正是接下来部分将涵盖的内容。
图片由作者提供,Cloud Composer 的启动与停止功能通过快照实现
使用快照创建 Composer 环境
假设我们想要每天(周末除外)在早上 7 点创建 Composer 开发环境,并在每天晚上 9 点销毁它。我们需要按照以下步骤进行:
-
创建 Cloud Storage 备份桶和环境服务帐户
-
在 Cloud Source Repositories 中创建一个仓库来存放环境创建管道
-
配置一个 Cloud Build 触发器来运行环境创建管道
注意: 在继续之前,你需要安装 gcloud。如果没有,请参考gcloud 安装指南
步骤 1:创建 Cloud Storage 备份桶和环境服务帐户
你在问为什么吗?Cloud Storage 桶将存储环境快照和任务日志,而这些日志不会作为快照的一部分保存
gsutil mb gs://<PROJECT_ID>-europe-west1-backup
至于环境服务帐户,使用具有最小权限的用户创建的服务帐户是一个良好的实践,遵循所谓的最小权限原则。该帐户将被授予 Composer Worker 角色。
# Enable the Composer Service
gcloud services enable composer.googleapis.com
# Create the Environment service account. Name it "sac-cmp"
gcloud iam service-accounts create sac-cmp
# Add the role Composer Worker to the sac-cmp service account
gcloud projects add-iam-policy-binding <PROJECT_ID> \
--member serviceAccount:sac-cmp@<PROJECT_ID>.iam.gserviceaccount.com \
--role roles/composer.worker
# Add the role Composer ServiceAgentV2Ext to the Composer Agent
# Watch out, do not confuse the Project ID with the Project Number
gcloud iam service-accounts add-iam-policy-binding sac-cmp@<PROJECT_ID>.iam.gserviceaccount.com \
--member serviceAccount:service-<PROJECT_NUMBER>@cloudcomposer-accounts.iam.gserviceaccount.com \
--role roles/composer.ServiceAgentV2Ext
注意: 此外,Cloud Composer 服务代理在用户创建的环境服务帐户上被授予 Service Agent V2 Ext 角色
这就是所有前提条件。下一步是在 Cloud Source Repositories 中创建一个仓库来存放环境创建管道。
步骤 2:在 Cloud Source Repositories 中创建一个存放环境创建管道的仓库
让我们绕个弯,介绍一下 Cloud Source Repositories(CSR),这是 Google Cloud 提供的私有 Git 仓库托管服务。实际上,为了运行环境创建 CICD 管道,我们需要创建一个 Cloud Build Trigger,它通过克隆 Git 仓库的内容来工作。Cloud Build 支持许多流行的 Git 仓库托管服务,如 BitBucket、Github 和 Gitlab。为了简化起见,本文使用 CSR 作为 Cloud Build Git 仓库的来源。
在能够在 CSR 内创建任何 Git 仓库之前,必须处理一些先决条件。简而言之,我们需要启用 CSR API 并配置 Git 以便与 CSR 进行交互。
# Enable the CSR API
gcloud services enable sourcerepo.googleapis.com
# Configure Git. Make sure git is installed before
gcloud init && git config --global credential.https://source.developers.google.com.helper gcloud.sh
现在,我们可以继续创建 CSR Git 仓库 reduce_composer_bill,该仓库将包含环境创建 CICD 管道。为此,我们需要访问 source.cloud.google.com,点击 开始使用,然后点击 创建仓库 按钮。
作者提供的图片,创建 CSR 仓库
然后选择 创建新仓库 并点击 继续
作者提供的图片,创建 CSR 仓库
然后将仓库命名为 reduce_composer_bill 并选择你希望创建 CSR Git 仓库的 GCP 项目。
警告:请不要使用 composer-environment-mgmt,因为这对你不起作用。使用你自己的 GCP 项目。
作者提供的图片,创建 CSR 仓库
接下来的步骤是将这个 Gitlab 仓库 克隆到你的计算机上,并将其内容推送到 CSR 仓库 reduce_composer_bill。在推送到 CSR 仓库之前,编辑 3 个文件 create_environment.yaml、destroy_environment.yaml 和 update_environment.yaml,将 PROJECT_ID 和 ENV_NAME 变量分别替换为 GCP 项目名称和你希望给 Composer 环境起的名称。
注意:变量 PROJECT_ID 和 ENV_NAME 可能在 3 个文件中出现多次。
# Clone the Git repository
git clone git@gitlab.com:marcdjoh/reduce_composer_bill.git
# Push the edited files into the CSR repository reduce_composer_bill
# To do that, follow the instructions in the CSR console
步骤 3:配置云构建触发器以运行环境创建管道
Cloud Build 是 Google Cloud 提供的持续集成和持续部署(CICD)服务。环境创建管道完成 3 件事:
-
它创建一个环境
-
它加载最新的快照(如果有的话)
-
它恢复环境任务的日志
steps:
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Create environment'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
gcloud composer environments create ${env_name} --location europe-west1 \
--project ${project_id} --image-version=composer-2.1.10-airflow-2.4.3 \
--service-account sac-cmp@${project_id}.iam.gserviceaccount.com
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Load Snapshot'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
if gsutil ls gs://${project_id}-europe-west1-backup/snapshots/* ; then
snap_folder=$(gsutil ls gs://${project_id}-europe-west1-backup/snapshots)
gcloud composer environments snapshots load ${env_name} --project ${project_id} \
--location europe-west1 \
--snapshot-path ${snap_folder}
else
echo "There is no snapshot to load"
fi
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Restore Tasks Logs'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
if gsutil ls gs://${project_id}-europe-west1-backup/tasks-logs/* ; then
dags_folder=$(gcloud composer environments describe ${env_name} --project ${project_id} \
--location europe-west1 --format="get(config.dagGcsPrefix)")
logs_folder=$(echo $dags_folder | cut -d / -f-3)/logs
gsutil -m cp -r gs://${project_id}-europe-west1-backup/tasks-logs/* ${logs_folder}/
else
echo "There is no task logs to restore"
fi
构建由 Cloud Build 服务帐户触发。因此,我们将 Project Editor 角色添加到 Cloud Build 服务帐户,以便它可以创建、(也可以销毁和更新)Composer 环境,并将文件复制到 Cloud Storage 存储桶中。最后,我们创建了一个 Cloud Build 触发器来运行环境创建管道。
# Add the project editor role to the Cloud Build service account
gcloud projects add-iam-policy-binding <PROJECT_ID> \
--member serviceAccount:<PROJECT_NUMBER>@cloudbuild.gserviceaccount.com \
--role roles/editor
# Create a Cloud Build trigger for the Environment creation CICD pipeline
gcloud builds triggers create manual --name trg-environment-creator \
--build-config create_environment.yaml --repo reduce_composer_bill \
--branch main --repo-type CLOUD_SOURCE_REPOSITORIES
总结
本文是一个两部分系列的第一部分,旨在详细讲解一种高效的方式,以减少所有 Cloud Composer 用户的费用。该策略主要依靠保存和加载环境快照,以便在关闭非生产环境时不会丢失其状态。
CICD 管道代码可在此 Gitlab repository 中找到。请随意查看。此外,系列的第二部分可以在这里查阅。
感谢您的时间,敬请关注更多内容。
减少 Cloud Composer 账单(第二部分)
原文:
towardsdatascience.com/reduce-your-cloud-composer-bills-part-2-af63da03806f
使用定时 CICD 管道关闭环境,并将其恢复到先前的状态
·发布于 Towards Data Science ·阅读时间 5 分钟·2023 年 3 月 31 日
--
图片来源:Sasun Bughdaryan 在 Unsplash
这是 2 部分系列的第二部分,旨在介绍一种在使用 Cloud Composer 进行作业编排时节省开支的有效方法。因此,如果还没有的话,我强烈推荐查看第一部分。
以下是将要涉及的主要主题:
了解 Cloud Composer 2 定价(第一部分)
使用快照作为关闭 Composer 的方式,同时仍然保留其状态(第一部分)
使用快照创建 Composer 环境(第一部分)
总结(第一部分)
为了节省开支销毁 Composer 环境(第二部分)
更新 Composer 环境(第二部分)
自动化 Composer 环境的创建和销毁(第二部分)
总结(第二部分)
为了节省开支销毁 Composer 环境
这里是如何减少 Cloud Composer 账单的方法:在环境未使用时关闭它。
成本降低策略的关键是,在销毁 Composer 环境之前保存其快照或状态。请注意,这一策略不适用于那些 24 小时、7 天不间断运行编排作业的环境,例如生产环境。
销毁 Composer 环境的管道包含 3 个步骤:
-
保存环境的快照
-
将任务日志复制到备份 Cloud Storage 存储桶
-
删除环境及其存储桶
steps:
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Save snapshot'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
snap_folder=$(gsutil ls gs://${project_id}-europe-west1-backup/snapshots) || snap_folder=empty
gcloud composer environments snapshots save ${env_name} \
--location europe-west1 --project ${project_id} \
--snapshot-location gs://${project_id}-europe-west1-backup/snapshots
if [[ $snap_folder != empty ]]
then
gsutil -m rm -r $snap_folder
fi
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Save Tasks Logs'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
dags_folder=$(gcloud composer environments describe ${env_name} --project ${project_id} \
--location europe-west1 --format="get(config.dagGcsPrefix)")
logs_folder=$(echo $dags_folder | cut -d / -f-3)/logs
gsutil -m cp -r ${logs_folder}/* gs://${project_id}-europe-west1-backup/tasks-logs/
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Delete Composer Environment'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
dags_folder=$(gcloud composer environments describe ${env_name} --project ${project_id} \
--location europe-west1 --format="get(config.dagGcsPrefix)")
gcloud composer environments delete --project ${project_id} --quiet \
${env_name} --location europe-west1
dags_bucket=$(echo $dags_folder | cut -d / -f-3)
gsutil -m rm -r $dags_bucket
类似于环境创建,我们创建一个触发器来运行环境销毁 CICD 管道
gcloud builds triggers create manual --name trg-environment-destroyer \
--build-config destroy_environment.yaml --repo reduce_composer_bill \
--branch main --repo-type CLOUD_SOURCE_REPOSITORIES
更新 Composer 环境
随着时间的推移,需要更新 Cloud Composer 环境。更新环境会作为一个独立的 CICD 流水线处理。
预计随着更新操作数量的增加,该流水线的步骤数量也会增加。以下是一个简单的更新流水线示例,其中在环境创建后向 Cloud Composer 环境添加了一个环境变量。
steps:
- name: gcr.io/cloud-builders/gcloud
entrypoint: /bin/bash
id: 'Add environment variables'
args:
- -c
- |
set -e
# This is an example project_id and env_name. Use your own
project_id=reduce-composer-bill
env_name=my-basic-environment
gcloud composer environments update ${env_name} --location europe-west1 \
--project ${project_id} --update-env-variables ENV=dev
关联的 CICD 流水线触发器需要按照环境创建和销毁的相同原则进行创建。
gcloud builds triggers create manual --name trg-environment-updater \
--build-config update_environment.yaml --repo reduce_composer_bill \
--branch main --repo-type CLOUD_SOURCE_REPOSITORIES
自动化 Composer 环境的创建、更新和销毁
目前,有 3 个流水线来实施 Cloud Composer 成本削减策略:
-
创建流水线从最新快照中创建 Composer 环境,以保持编排作业执行的状态。
-
销毁流水线在保存环境状态到快照后销毁 Composer 环境。
-
更新流水线用于不时地对 Composer 环境进行更新。
虽然保持更新流水线的手动运行是可以的,但成本削减策略要求自动化创建和销毁流水线。这可以借助Cloud Scheduler来实现。
Cloud Scheduler,顾名思义,是一个由 Google Cloud 管理的服务,可用于计划不同的任务。在我们的案例中,需要按计划运行的任务是 Composer 环境创建和销毁的 Cloud Build 触发器。
从 Cloud Build 界面,可以为触发器添加计划。首先,我们启用 Cloud Scheduler API。
gcloud services enable cloudscheduler.googleapis.com
然后,我们打开环境创建触发器,点击按计划运行
作者提供的图片,添加计划到 Cloud Build 触发器
然后我们配置计划。最重要的参数是计划名称、频率以及用于运行计划的服务账户。例如,将计划命名为trg-environment-creator-schedule,设置为每天早上 7 点运行(周末除外),并使用sac-cmp 服务账户来运行计划任务。
作者提供的图片,添加计划到 Cloud Build 触发器
环境创建就完成了。现在每天早上 7 点,环境将会从最新快照中自动重建。
最后,我们为环境销毁流水线添加一个计划。我们可以使用名称trg-environment-destroyer-schedule和频率0 21 * * 1–5,意味着 Composer 环境将会在每天晚上 9 点自动销毁。
总结
Cloud Composer 不允许启动和停止环境,这使得该服务的成本相对较高。在我看来,显著降低 Cloud Composer 成本的最佳方法是销毁并重建环境。
然而,当环境被销毁时,环境的状态会丢失。这就是快照发挥作用的地方。在提出的成本减少策略中,利用快照来保存环境的状态。此外,任务日志也会保存到 Cloud Storage 备份桶中,因为它们不会作为快照的一部分被原生保存。
通过 Cloud Build 和 Cloud Scheduler 的帮助,销毁和重建过程对用户是透明的。环境更新通过一个单独的管道进行处理,该管道旨在根据需要手动触发,每当需要更新时。
CICD 管道的代码可以在这个 Gitlab 仓库 中找到。欢迎查看。感谢您的时间,敬请关注更多更新。
反思十年数据科学实践
旅程、思考与兴趣,以及可能的未来
·
关注 发表在 Towards Data Science ·7 min read·2023 年 1 月 11 日
--
Betty’s Bay(图片由作者提供)
它是如何开始的?
我第一次在 2012 年通过与教授和同事在我们研究实验室的讨论发现了机器学习,当时我正在攻读电气和电子工程硕士学位。我对所学的内容感到非常着迷,因此决定参加一些在线课程,其中包括了在 coursera.org 上的热门 Andrew Ng 和 Daphne Koller 材料。Andrew 的机器学习课程脱颖而出,是最吸引人、最实用和最有趣的。这为 X 到 Y 映射打开了一个新的、令人兴奋的无限可能的世界。
在参加在线课程后,我还有更多的问题,比如这是否是大脑实际运作的方式,是否存在一种能实现类人智能的普遍学习算法,学习和推理到底是什么,它们如何用数学表达。我也对这些算法的创造灵感以及它们在心理学上的潜在应用感兴趣。为了获得更多的见解,我阅读了 Jeff Hawkins、Geoffrey Hinton、Daniel Kahneman、Dan Ariely、V. S. Ramachandran、Claude Shannon、Pedro Domingos、Steven Pinker 和 Ian Goodfellow 等作者的书籍和研究论文。
到了 2017/2018 年,我对基本学习算法,尤其是神经网络有了很好的理解,但在与我的主管广泛讨论和深思熟虑之后,我决定在深度学习趋势的高峰期将重点转向贝叶斯方法。这一转变使我从权重、神经元和激活函数转向了不确定性,而且是大量的不确定性!我阅读了 David Barber、Kevin Murphy、Christopher Bishop、Tom Minka、John Winn、Judea Pearl、Norman Fenton、David MacKay、Daphne Koller 和 Edwin Thompson Jaynes 等作者的更多书籍和研究论文。我发现这个领域在智力上仍然非常有刺激性,因为它要求与其他机器学习方法不同的思维方式——常常引发自我反思。
过去十年有什么有趣的事情?
作为一名行业专业人士,过去十年充满了实践数据科学和在各行各业实施机器学习算法的机会:电力公用事业、数字和社交媒体、金融科技和移动通信。我的目标一直是以更少的资源取得更多的成果,并分享所学到的经验。我应用过的一些最有影响力的算法(及其变种)包括 PCA、逻辑回归、朴素贝叶斯、高斯混合模型、高斯过程、潜在狄利克雷分配和简单的神经网络(序列和前馈)。高斯分布在无监督、半监督和监督任务中的多功能性和广泛适用性依然让我印象深刻。
我特别喜欢解决工程类型的问题,而不是面向客户的问题,部分原因是数据生成过程中的物理原理。分析和解释人类行为的数据或元数据非常复杂,超出了我的专业范围,而 Apophenia 是我不愿意使用的“工具”。
在我工作于数字和社交媒体公司期间,我发现深度学习非常适用,这在 Ian Goodfellow 的《深度学习》第十一章中有详细阐述。根据我的经验,这是一种“安全有效”的数据科学形式,特别适用于文本和图像问题。现在有大量的工具、示例代码、在线课程和研究论文,除了繁琐的数据准备过程,它通常就像在电影院前挑选糖果一样。然而,我必须对使用的工具的基本理论有所了解,特别是当预构建的解决方案不够时。Nando de Freitas 曾将其比作玩乐高,这个比喻很准确,也是我对深度学习经验中最宝贵的收获。将复杂问题拆解成更小、更易管理的部分,这是一种古老的智慧。
在我工作过的所有行业中,我发现贝叶斯建模适用于比其他方法更广泛的问题。这些问题通常涉及有限的数据、缺失的数据、需要纳入领域知识、需要定制模型、需要处理不确定性以及需要概率和因果解释。然而,说服管理层投资这种方法并不总是容易的。采用的障碍包括缺乏专业技能、相关在线资源有限以及模型开发时间较长。开发贝叶斯模型意味着要熬夜。尽管面临这些挑战,我设法让贝叶斯建模在一些问题上发挥作用,但对于大多数问题,我不得不保持严肃的面孔,同时“捶打”数据。自从发现了 Infer.NET,一个用于开发定制贝叶斯模型的框架(在 2018 年开源)后,情况发生了变化。熟悉这个框架花了不少功夫,但这是非常值得的!
贝叶斯建模领域在智力上是令人满足的,因为它经常揭示我们自身推理和感知中的隐秘缺陷,这很惊讶,考虑到其核心的基本原则(和哲学)。获得的结果通常是违反直觉的、意外的,并且具有显著的影响。值得注意的是,许多广泛使用的机器学习算法可以理解为从更一般的贝叶斯视角衍生出来的特例。为什么要局限于自己,当可以拥有一切(大致上)呢?
在过去十年中,我观察到了许多由 AI 研究人员和开源社区的贡献驱动的技术进步,以及行业在基础设施方面提供的支持。这使得模型开发和数据管理变得更快、更便宜、更易于获得,这反过来使得标记更多数据和训练高计算成本的模型几乎可以“瞬时”完成。虽然我们中的许多人专注于连接 Pandas 数据框,但像 OpenAI、DeepMind 等组织则创造了突破性的文本到图像、文本到视频和语言模型,这些模型已被广泛使用。谁知道 ChatGPT 的极限在哪里?此外,值得一提的是 AlphaFold 取得的惊人准确性和历史性的(苦乐参半的)AlphaGo 冠军。这让我想知道哪个研究领域将最终提供对人类智能和意识的最准确理解。
像许多人一样,我遇到过不寻常的职位名称,如“大数据工程师”。我曾听到一位招聘人员问“候选人应该有多大?”这句话虽然是比喻的,但突显了这些新兴角色定义不清、责任持续演变和重叠的情况。这使得招聘和申请这些职位变得具有挑战性。曾经广受欢迎的 Drew Conway Venn 图当前似乎是数据科学角色复杂性的一个恰当描述。然而,我仍然认为乌洛波罗斯会是更合适的象征,因为它象征着该领域的循环和自指性质。
什么有效?
数据科学已成为许多组织运营的重要组成部分,但根据我的经验,监督此类项目的有效领导者却很少。要为数据科学项目开发一个强大的支持结构和商业模式,需要对涉及的各种组件有深刻的理解,包括平衡业务需求和最新技术趋势。根据我的观察,领域内最成功的领导者具有实际的方法,并且对行业和科学都有很好的了解。数据科学项目的成功通常取决于平衡项目意图、必要技能、主题专业知识以及系统和数据的能力。不幸的是,那些过于雄心勃勃或过于关注最新技术趋势的人往往会遇到困难,而那些遵循如奥卡姆剃刀和高尔法则等原则的人则更有可能取得成功,往往还需要一点运气!
看到自己或他人开发的模型或解决方案对组织产生的财务影响是一种非常令人满意的经历。对我来说,一种实现这一目标的方法是从小处着手,并在整个过程中保持对端到端产品的关注,同时留出足够的时间进行开发和改进周期。这包括保持良好的文档记录,以清晰易懂的方式呈现信息,寻求同行的反馈,积极与不同业务部门的主题专家合作,从不同角度思考问题,尝试不同的技术,提出许多问题,并认真听取用户反馈。开发原型并将其推向生产本身就是一个挑战,但重要的是要记住,即使是最好的解决方案,如果人们不知道如何使用它,也会毫无用处。这是一个没有容易解决的妥协问题。
下一步是什么?
在未来几年,我相信人工智能的重点将从创建人工通用智能转向开发能够在特定任务上表现良好,并在处理边缘情况时保持高度准确性的系统,从而使它们对人类使用更具实用性和可靠性。我认为我们将看到对人工智能公司的更现实估值,对企业级数据作为产品的更多关注,使用贝叶斯方法进行 A/B 测试的更多关注,区块链技术的整合以允许客户货币化他们的数据,以及多模态模型的更多进展。
就我个人而言,我认为现在比以往任何时候都更需要让贝叶斯建模(特别是因果建模)领域对学生、研究人员、科学家和行业从业者更具实用性和吸引力。可以说是一个“安德鲁·吴风格的贝叶斯推断课程”。特别是在 2020-2022 年期间发生的事情之后,我特别强调这一点,这段时间似乎是人类历史上最大的生物学失误。似乎许多 Covid 研究受到统计学谬误的困扰,如辛普森悖论和伯克森悖论,以及其他统计操控影响了 Covid 叙事(见这里)。看到这种情况发生确实令人沮丧,特别是当统计学和概率的正确使用对理解和应对大流行至关重要时。我是说“你怎么敢?”
尽管数据科学技能和数据科学领袖非常抢手且稀缺,我希望看到更多组织对该领域专业发展的支持。我想解决这个问题的一种方式是鼓励行业与学术界之间更多的合作,以及更多的举措来培养和留住技术娴熟的专业人士。
标记数据将继续重要,因此我们可能会看到组织提供激励措施来捕获和管理干净可靠的数据。这应该会很有趣!
这是我在过去十年中实践数据科学的个人经验。这些反思、观点和“预测”都是我自己的。请联系我,让我知道你的经历!
如何在 Pandas 中使用正则表达式模式处理复杂字符串
原文:
towardsdatascience.com/regex-patterns-in-pandas-api-afe70178f9e9
正则表达式简化了大规模文本的模式匹配任务——Pandas 使它变得优雅
·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 2 月 27 日
--
图片由Chris Moore拍摄,发布在Unsplash
正则表达式是清理和提取数据的最强大技术。如果你曾经处理过大型文本数据集,你会知道这是一项多么耗时和耗力的工作。
我经常使用正则表达式来清理电话号码和电子邮件,并标准化地址。但也有复杂的用例。
我们在最近的数据管道中注意到了来自特定数据源的办公栏位不一致。我们只需要这个栏位中的办公代码。它由两个或三个字母后跟一个冒号和一个两位数字组成。早期,我们使用简单的替换操作将该栏位映射到我们期望的值。但由于新数据与我们的假设不一致,我们不得不改变策略。由于我们可以确保模式的一致性,我们使用正则表达式来清理它们。这样,我们就不必担心更改栏位值的问题。
但如果你的数据集非常大,并且你需要将提取的值存储在每行旁边的新列中,你可能会倾向于使用 Pandas 中的 map 或 apply 方法。但 Pandas 原生提供了优秀的字符串操作 API。几乎所有的 API 都支持正则表达式。
# With apply
import re
df.office.apply(lambda x: re.search(r'\d+', x).group(0))
# With native string operation
df.office.str.extract(r'(\d+)')
在我们讨论这些原生 API 相较于 map/apply 方法的优势之前,这里是我所说的意思。
比较 Pandas 字符串提取与常规正则表达式
以下代码使用 Faker 生成一个合成数据集。它生成 10 万条虚假地址,并将它们存储在 Pandas Series 中。你可以通过将 n 更改为计算机支持的更大值来调整大小。
import pandas as pd
from faker import Faker
faker = Faker()
n = 100000
address= pd.Series([faker.address() for i in range(n)])
"""
| | address |
|---:|:-----------------------------|
| 0 | 548 Small Parkways Suite 832 |
| | South Brianborough, DC 50474 |
| 1 | 94291 Jerry Pass Suite 992 |
| | East Rebecca, PR 87881 |
| 2 | 3973 Wise Spring |
| | Grantfort, AS 52002 |
| 3 | 62589 David Island |
| | East Kathleenville, OH 45208 |
| 4 | 0415 Jimenez Hill Apt. 642 |
| | Gambleland, WA 99356 |
"""
我们的目标是将状态和邮政编码提取到不同的列。这项任务在电子表格软件中足够简单。但让我们将其保留用于讨论。假设我们使用正则表达式。
这是我通常的正则表达式应用或映射方法。
(
(address)
.map(lambda x: (re.search(r"\w{2} \d{5}", x).group()))
.str.split(" ", expand=True)
.rename(columns={0: "state", 1: "zip_code"})
)
"""
| | state | zip_code |
|---:|:--------|-----------:|
| 0 | DC | 50474 |
| 1 | PR | 87881 |
| 2 | AS | 52002 |
| 3 | OH | 45208 |
| 4 | WA | 99356 |
"""
上面的代码很容易理解。我们匹配所有两个字母后跟空格和五位数字。然后我们进行字符串分割,并将其展开到不同的列。最后,我们将列命名为“state”和“zip_code”。
但这是 Pandas 的处理方式。
address.str.extract(r"(?P<state>\w{2}) (?P<zip_code>\d{5})")
"""
| | state | zip_code |
|---:|:--------|-----------:|
| 0 | DC | 50474 |
| 1 | PR | 87881 |
| 2 | AS | 52002 |
| 3 | OH | 45208 |
| 4 | WA | 99356 |
"""
这无疑比之前的代码更优雅。我们在正则表达式模式中使用了命名组,这些组成为了列名。
另外,你可以通过将正则表达式模式的一部分用括号括起来,使其成为一个组。你可以通过在描述模式之前添加?P<group_name>
来命名每个组。
好吧,原生方法在可读性方面表现出色。但性能怎么样呢?
我在 Jupyter notebook 中使用 timit
工具记录了执行时间。我不认为原生方法在性能上占有优势。映射更快。
但我们期望的输出并没有通过单一的 map 函数完成。我们需要做额外的步骤才能实现。整个操作集所花费的时间略多于提取方法。
将映射/应用方法性能与提取方法进行比较 — 作者截图。
除了可读性,两个方法之间没有太大区别。但如果你处理的是大规模数据集,差异就会变得显著。
## 使用 Black 和 GitHub Actions 保持 Python 代码整洁
没有人想要一个混乱的代码库;很少有人有耐心去清理它。
## 使用 Black 和 GitHub Actions 保持 Python 代码整洁
此外,如果你的代码运行在资源受限的环境中,你必须仔细决策。这通常是我的情况,因为我主要构建数据管道。我需要确保我使用的是最优代码,以便更快、更便宜地处理实时数据。
我们知道在 Pandas 中有多种处理方式。如果你预见到将来会重新运行这些脚本,你可能需要花时间尝试不同的替代方案。
Pandas 的默认设置并不理想。稍微调整一下配置,可以压缩你的 dataframe 以适应内存。
有用的 Pandas 字符串方法与正则表达式
现在我们知道如何直接使用正则表达式操作,而无需映射或应用函数,下面是一些我经常使用的方法。
在上一节中我们已经看到过使用 extract
API 的一个示例。它在处理正则表达式模式时非常方便,也许这是我使用最频繁的一个。
这里有另外三种我经常使用的技巧。
1. 将文本分割为带有复杂模式的单独列。
假设一个列包含州和邮政编码。我们需要将它们分离成单独的列。由于这些数据来自自由格式的输入表单,因此分隔符不总是空格或逗号。
import pandas as pd
# Create a sample dataframe
df = pd.DataFrame(
{
"location": [
"New York 10001",
"California 90001",
"Texas-75201",
"Georgia 30301",
"Oregon97205",
"Arizona 85001",
"Illinois 60601",
"Florida 33101",
"Ohio 44101",
"Pennsylvania-19104",
]
}
)
df["location"].str.split(r"[\s|-]*(\d{5})", expand=True)
'''
| | 0 | 1 |
|---:|:-------------|------:|
| 0 | New York | 10001 |
| 1 | California | 90001 |
| 2 | Texas | 75201 |
| 3 | Georgia | 30301 |
| 4 | Oregon | 97205 |
| 5 | Arizona | 85001 |
| 6 | Illinois | 60601 |
| 7 | Florida | 33101 |
| 8 | Ohio | 44101 |
| 9 | Pennsylvania | 19104 |
'''
2. 筛选包含文本模式的记录。
我曾经有一个数据集,其中有一个办公室序列号。这个号码有一个模式。前两个字母表示国家的简码。紧接着是位置代码,包含三位数字。然后是一个连字符和一个部门编号,也是一组三位数字。
假设我们需要筛选与国家的财务部门相关的记录,包括英国、印度和澳大利亚。我们可以这样做。
import pandas as pd
# Create a sample dataframe
data = {
"office_serial_number": [
"US101-001",
"UK201-006",
"CA301-003",
"AU401-004",
"UK202-005",
"IN302-006",
"IR102-007",
"AU402-006",
"SL303-009",
"UK203-010",
"FR403-011",
"US103-012",
]
}
df = pd.DataFrame(data)
df[df.office_serial_number.str.contains("^(UK|IN|AU)\d{3}-006")]
'''
| | office_serial_number |
|---:|:-----------------------|
| 1 | UK201-006 |
| 5 | IN302-006 |
| 7 | AU402-006 |
'''
如果没有正则表达式,这可能是个棘手的任务。而且它也不是一个容易阅读的任务。
3. 用新字符串替换模式
替换是一个常见的字符串操作。即使在 Excel 中,我们也经常这样做。但有些替换操作比简单的查找和替换更复杂。
我们需要找到模式并将其替换为新的字符串。
以电话号码列为例。你需要从列中删除国家代码。有些记录有国家代码,有些没有。即使有国家代码的记录也有不同的格式。
下面是一个简单的正则表达式示例。
import pandas as pd
import re
# create a sample dataframe (dummy)
df = pd.DataFrame({'phone': ["+1 555-555-5555", "+44 20 7123 4567", "+81 3-1234-5678", "0049 30 12345678", "+61 2 1234 5678", "+33 1 23 45 67 89", "+86 10 1234 5678", "011 52 55 1234 5678", "+971 4 123 4567", "+49 911 12345678", "(+81) 3-1234-5678"]})
# define a regular expression pattern to match the country code
pattern = r'^\+?\d{1,3}[\s-]?' # match + or digit(s) followed by space/hyphen
# apply the pattern to the 'phone' column and replace the matches with an empty string
df['phone'] = df['phone'].apply(lambda x: re.sub(pattern, '', x))
print(df)
phone
0 555-555-5555
1 20 7123 4567
2 3-1234-5678
3 30 12345678
4 2 1234 5678
5 1 23 45 67 89
6 10 1234 5678
7 52 55 1234 5678
8 4 123 4567
9 911 12345678
10 3-1234-5678
结论
对于基本的字符串操作,你的电子表格软件已经足够了。但对于更复杂的用例,编程语言可以节省你大量的时间。
一些操作即使使用编程语言的基本 API 也很复杂,特别是涉及模式的操作。这时候我们需要正则表达式的帮助。
如果你是 Pandas 用户,可以直接在其原生 API 中使用正则表达式。这具有代码简洁、行数更少的优势。这也是本文的重点。
我已经讨论了一些我喜欢的 Pandas 中的正则表达式技巧。虽然没有显著的性能提升,但我仍然偏爱这些方法,因为它们简单明了。
感谢你的阅读,朋友!在 LinkedIn、Twitter 和 Medium 上向我打个招呼吧!
还不是 Medium 会员?请使用这个链接来 成为会员,因为在不增加你额外费用的情况下,我可以通过推荐你赚取少量佣金。
现代偏好引导中的回归与贝叶斯方法
简单果昔制作的应用
·
关注 发表在 Towards Data Science ·14 min read·2023 年 8 月 29 日
--
摄影:Denis Tuksar 在 Unsplash
线性回归通常被视为预测建模的主力,但其应用超越了简单的预测任务。本文旨在通过引入 Probit 线性回归作为建模偏好的有效工具来丰富关于回归技术的对话。此外,我们使用贝叶斯框架从经典线性回归过渡到贝叶斯线性回归,阐明了基于成本的优化——特别是二元交叉熵(BCE)损失最小化——与最大似然估计之间的内在关系。
通过这样做,我们旨在展示正则化可以被视为一种贝叶斯先验选择形式,从而将成本函数方法与概率推理连接起来。
最后,我们将讨论贝叶斯线性回归如何不仅提供点估计,还提供这些预测的分布,提供一种更丰富、考虑不确定性的视角。
贝叶斯框架
贝叶斯框架识别出两个主要组成部分:数据 D 和模型 w。通过指定似然 P(D∣w) 和模型的先验 P(w),我们旨在找到最大化后验 P(w∣D) 的模型,该后验通过贝叶斯定理得出,如下:
在偏好学习中,对 w 的分布提供了捕捉人类偏好固有不确定性的优势,从而不仅提供一个‘最佳猜测’,而是一个合理模型的范围。
偏好引导问题
偏好引导是决策理论中的关键组成部分,旨在根据可用数据识别决策者的选择。在这项研究中,我们通过拟合一个模型到部分偏好集来处理偏好引导问题。在我们的案例中,偏好以最简单的形式表达:成对比较。为了说明这一概念,考虑一组水果,记作 F,包括苹果、香蕉、橙子、荔枝和芒果。
在我们的背景下,备选集 A 包含所有可以使用来自集合 F 的一个或多个成分制作的可能的冰沙。
用户通过一组有序对(A,B)表达他们的偏好,其中 A 被严格偏好于 B。
本文的下一部分将介绍专门选择的用于捕捉用户偏好的函数家族:加性函数。这些数学构造提供了一个直接而强大的框架,用于理解不同因素如何影响个人偏好,从而有效建模通过成对比较表达的选择。
加性模型
线性加性模型是最简单的模型,可以用来捕捉用户的偏好。
加性线性模型
加性效用模型是一个将特定权重分配给我们集合中每个成分的模型。果昔的整体效用或“受欢迎程度”通过汇总其组成成分的权重来计算。正式来说,给定一个权重向量
从成分子集 A 制成的果昔的效用是:
其中 I 是测试是否在 A 中的单位函数。
带有二元交互的加性模型
2-additive 模型在 1-additive 模型的基础上引入了额外的复杂度。权重向量不仅包含每个单独成分的权重,还包括每对成分的权重。这使得模型能够捕捉成分对之间的协同效应,有效识别两个成分的组合如何影响整体效用。正式来说,权重向量w被扩展为包括每对(i,j)的权重,除了单一成分:
使用 2-additive 线性模型,果昔的效用由以下公式给出:
其中 F²是单一成分和成对的集合。
n-additive 模型
进一步扩展概念,n-additive 模型提供了一个高度灵活的效用框架。在这个模型中,权重向量不仅考虑单个成分和成对,还扩展到包括任何最多n成分的子集的权重。这种概括使得模型能够同时捕捉多个成分之间的复杂关系和协同效应。
正式来说,权重向量w被扩展为包括所有可能的最多n成分的组合的权重:
这个n-additive 模型可以捕捉成分之间的全部交互,成为理解复杂偏好结构的强大工具。
为了这项分析,我们将限制在 2-additive 模型上,因为我们认为成分之间的偏好关系复杂性不太可能超过成对交互。
通过解决 Probit 回归问题来学习偏好
虽然传统回归模型输出实值预测,但我们的目标是预测二元偏好关系。
为了实现这一点,我们修改回归模型以输出选项A优于选项B的概率。然后我们推导出一个合适的成本函数,以有效地将这个概率模型拟合到我们的数据中。
一种经典的将值压缩到 0 和 1 之间的方法是使用 Probit 函数。Probit 函数定义如下
下图展示了其形状
概率单位函数的形状(由我绘制)
通过将此函数应用于 f(A) 和 f(B) 之间的差异,我们的模型将产生一个接近 1 的概率,如果 f(A) 显著超过 f(B)。相反,如果 f(A) 大约等于 f(B),它将产生一个接近 0.5 的概率。
因此,偏好引导问题可以重新表述为寻找一个最优的权重向量 w,使得:
二元交叉熵 (BCE) 损失
二元交叉熵 (BCE) 损失,也称为对数损失,是一种用于输出范围为 0 到 1 概率的分类模型的性能指标,通常用于二元分类任务。数学上,给定真实标签 y(0 或 1)和预测概率 p,BCE 定义为:
玩具数据生成
为了验证我们的方法,我们引入了一个生成合成数据的协议。
该过程首先通过随机采样一个权重向量 w 开始。然后我们将其某些参数设置为零,以引入现实性和简单性。
在假设用户的偏好与这个采样函数一致的前提下,我们可以将其作为基准来评估我们预测模型的准确性。
def singletons_and_pairs(lst):
singletons = [(x,) for x in lst]
pairs = list(combinations(lst, 2))
return singletons + pairs
ingredients = ["o", "a", "b","l", "m"]
model = singletons_and_pairs(ingredients)
w = np.random.normal(0, 1, size = (len(model),))
p = np.random.randint(0, 2, size = (len(model),))
w = w * p
然后我们使用以下函数将每个备选方案编码为一个二进制向量,其中组件的顺序与模型参数中的顺序相同
def vectorize_smoothie(smoothie):
arr = np.zeros(len(model))
print(list(smoothie))
for i in range(arr.shape[0]):
if all(j in list(smoothie) for j in model[i]):
arr[i] = 1
return arr
然后,为了评估特定的平滑度,我们使用乘积
vectorize_smoothie("oa") @ w
# Return w_a + w_o + w_oa
为了构建我们的数据集,我们首先采样一个权重向量 w。接下来,我们生成一组平滑度,并根据采样的权重对每个平滑度进行评估。对于每对平滑度 A 和 B,其中 f(A) > f(B),我们将相应的偏好添加到我们的数据集中。每个 A 和 B 之间的偏好都被捕捉在一个向量中,定义如下:
对于每对 A 和 B,其中 f(A) > f(B),我们添加了两行 v(A,B) 和 v(B,A),第一行标记为类别 1,第二行标记为类别 0。
以下代码为我们提供了一个关于 n 个平滑度的数据集。
def sample_dataset(n):
ingredients = ["o", "a", "b","l", "m"]
model = singletons_and_pairs(ingredients)
X = []
y = []
w = sample_w(model)
subsets = set()
while len(subsets) != n:
s = random_subset(ingredients)
subsets.add(s)
subsets = list(subsets)
for i in range(len(subsets)-1):
x_i = vectorize_smoothie(subsets[i])
for j in range(i+1, len(subsets)):
x_j = vectorize_smoothie(subsets[j])
x1 = x_i - x_j
x2 = x_j - x_i
if f(subsets[i], w) == f(subsets[j], w):
continue
if f(subsets[i], w) > f(subsets[j], w):
X.append(x1)
X.append(x2)
y.append(1)
y.append(0)
continue
if f(subsets[i], w) < f(subsets[j], w):
X.append(x1)
X.append(x2)
y.append(0)
y.append(1)
continue
X = np.array(X)
y = np.array(y)
return X,y,w,model
基于成本的分辨率
解决这个问题的一种方法是利用 BCE 损失的凸性以及类似 Torch 的库。
我们首先将生成的数据包装到 PyTorch 提供的适当数据集加载器中。
X,y,w,model = sample_dataset(30)
X_tensor = torch.FloatTensor(X)
y_tensor = torch.FloatTensor(y)
dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.3 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
现在,我们创建一个简单的线性模型
class BinaryClassifier(nn.Module):
def __init__(self, input_dim):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, 1)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
return x
并使用 PyTorch 的自动梯度功能对其进行训练。
input_dim = X.shape[1]
model = BinaryClassifier(input_dim)
# Loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
losses = []
# Train the model
epochs = 200
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data).squeeze()
loss = criterion(output, target)
loss.backward()
optimizer.step()
然后,我们使用测试数据集来测试获得的模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for data, target in test_loader:
output = model(data).squeeze()
predicted = (output > 0.5).float()
total += target.size(0)
correct += (predicted == target).sum().item()
acc = correct / total
accuracy.append(acc)
if (epoch+1) % 50 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
print(f'Test Accuracy: {100 * correct / total:.2f}%')
使用 20% 的数据进行训练,我们获得了约 98.32% 的准确率,这已经相当不错了。
最大似然估计 (MLE)
解决 probit 回归挑战的另一种方法是明确地制定给定权重向量 w 的数据的似然函数。
我们首先假设模型产生一个概率 p,表示 A 相对于 B 的偏好。该情景的预测分布表示如下:
给定一对(x, y)的权重向量的可能性可表示为:
数据集的概率为
似然值可能非常小,这在将许多概率相乘时尤为重要。这可能导致数值下溢(非常小的浮点数被四舍五入为零)。对这些值取对数会将其转化为更可管理的数字,这些数字通常为负值且具有更大的量级。
对数似然因此给出为
你可能会注意到,这个损失是 BCE 损失的负值,这就是最大化似然等同于最小化 BCE 损失的原因。
正则化技术
正则化是机器学习中的一项关键技术,用以对抗过拟合,即模型过度适应训练数据,包括其噪声,从而影响在新数据上的表现。它通过在损失函数中添加惩罚项来限制模型参数的复杂性。这促进了更简单的模型,平衡了对训练数据的拟合与保持模型简洁性之间的关系。
L1(Lasso)和 L2(Ridge)是常见的正则化形式,每种形式都向模型的目标函数中引入了独特的惩罚项。
L1 通过基于参数的绝对值添加惩罚,导致稀疏模型,其中一些权重为零。
相比之下,L2 惩罚参数的平方大小,缩小权重但不将其设为零。
L1(Lasso)和 L2(Ridge)正则化技术在惩罚模型参数的方式上有所不同。L1 应用一个与绝对值成比例的惩罚,导致一些权重完全为零,从而便于特征选择。相反,L2 惩罚权重的平方大小,确保权重保持较小但通常不为零,从而保留所有特征并减少其影响。
最大后验
如前所述,贝叶斯定理允许我们通过利用似然函数和选定的参数先验分布P(w)来估计模型参数的后验分布P(w∣X,y)。
从本质上讲,先验包含了我们在观察到任何数据之前对参数的初步信念或假设,而似然量化了参数解释观察到的数据的效果。贝叶斯定理将这些元素结合起来,生成一个后验分布,代表了在先验和数据的基础上我们对参数的更新信念。
两个非常著名的先验是拉普拉斯先验和高斯先验。
拉普拉斯先验假设权重 w 服从位置参数 μ=0 和尺度参数 b 的拉普拉斯分布。
换句话说,它假设权重的分布围绕零中心,并且随着值偏离这一中心点而指数衰减,反映了对稀疏模型的偏好,在这些模型中,许多权重可能被设为零。
高斯先验假设权重 w 服从均值 μ=0 和方差 σ 的高斯(或正态)分布。
本质上,它假设权重的分布对称地围绕零中心,具有钟形曲线轮廓,表明权重最有可能接近均值,而随着远离均值的距离增加,可能性逐渐降低。这导致对平滑正则化的模型有偏好,确保权重保持较小的幅度,而不必精确地设为零。
对数后验分布由以下公式给出
通过优化我们的模型,我们发现最大化对数后验分布在本质上等同于最小化特定的正则化损失。
值得注意的是,L1 和 L2 正则化的区别在于所考虑的先验分布形式。
使用 MCMC 方法中的后验分布
在贝叶斯框架中,一切都以概率的方式对待。因此,与经典线性回归中估计回归系数的固定值不同,贝叶斯线性回归估计的是可能系数值的分布。
使用后验分布的一种方法是从分布 P(w|X,y) 中抽取一组权重。
一种简单的方法是使用 MCMC 方法,理解 MCMC 方法的起点是 Metropolis-Hasting 方法。
Metropolis-Hasting 方法
Metropolis-Hastings (M-H) 算法是贝叶斯统计中用于从复杂概率分布中采样的一种方法。
它使用一个更简单的“提议分布”来探索目标分布,根据计算出的概率接受或拒绝样本。值得注意的是,M-H 算法不需要确切的目标分布知识;只需具有与之成比例的分布即可。
我们不会使用它,因为其他方法更可靠、更高效,但我们仍将简要说明它的工作原理,因为 M-H 算法是一个基础的 MCMC 方法。
-
选择一个初始猜测
-
设定一个提议分布,通常是一个以当前值 w 为中心的高斯分布。
然后对于每次迭代,我们按如下步骤进行:
-
从提议分布 P(w’|w) 中抽取一个新的 w’。
-
计算接受概率
- 从[0,1]的均匀分布中抽取一个随机数 u。如果 u ≤ α,则接受 w'作为新样本;否则,保留 w。
NUTS 采样器和 PyMC3
Metropolis-Hastings 方法涉及在参数空间中提议一个新点,然后根据新点的似然与当前点的似然进行比较来决定是否接受这个新点。其效率在很大程度上依赖于提议分布的选择,并且在高维空间中可能会出现随机游走行为,从而导致收敛速度缓慢。
NUTS(无转弯采样器)是哈密顿蒙特卡罗(HMC)方法的扩展。NUTS 不依赖随机游走,而是利用目标分布的梯度信息来提议跃迁步伐,从而使得其能够更有效地遍历分布。它的一个主要优点是自动确定跃迁步伐的最优数量,从而避免了随机游走问题以及手动调整的繁琐任务。
PyMC3 是一个流行的概率编程框架,它无缝集成了这些方法(以及其他方法),使用户能够轻松拟合复杂的贝叶斯模型,而不会被底层算法的复杂性所困扰。
在我们的案例中,这段代码将从后验分布 P(w|X,y) 中采样一系列权重。
import pymc3 as pm
with pm.Model() as probit_model:
# Priors for weights and bias
weights = pm.Normal('weights', mu=0, sd=4, shape=X.shape[1])
bias = pm.Normal('bias', mu=0, sd=4)
# Probit link function
mu = pm.math.dot(X, weights) + bias
phi = pm.math.invprobit(mu) # Inverse probit link function
# Likelihood
y_obs = pm.Bernoulli('y_obs', p=phi, observed=y)
# Sample from the posterior
trace = pm.sample(5000, tune=1000, chains=5, target_accept = 0.90)
我们可以绘制每个权重的不同分布。
权重的分布及其收敛图(由我提供)
我们看到每个权重都收敛到高斯分布。因此,现在每个预测都可以以概率的方式进行,并且预测的分布也将是高斯分布。
例如,我们虚拟决策者对橙子冰沙、橙子-苹果冰沙和香蕉-苹果冰沙的偏好分别由以下高斯分布给出。
使用生成数据的模型,我们可以看到三种冰沙的真实效用分别为 -0.66、-0.24 和 0.79,因此高斯分布实际上很好地反映了偏好及其之间的差距。
结论
在这篇博客文章中,我们从偏好引导的复杂性到贝叶斯线性回归模型的复杂性进行了探讨。我们的讨论从 2-additive 模型的探索开始,这些模型作为捕捉用户偏好的现实但计算上可处理的方式。通过从基础线性回归过渡到更高级的 probit 模型,我们提供了一个理解偏好数据的新视角。
我们还探讨了基于成本的视角与概率视角之间的等价性,阐明了最小化二元交叉熵损失如何类似于最大化似然,以及正则化如何作为对先验的隐式选择。
最后,我们讨论了贝叶斯框架在生成不仅仅是点估计而是整个预测分布的实用性。这种方法为我们的模型提供了更高的信心和可解释性,特别是在偏好学习这一微妙任务中十分有用。
在奠定了这些基础后,未来的研究可以更深入地探讨将这些复杂模型应用于越来越复杂的大规模偏好数据中。
致谢
特别感谢我的同事/朋友Anissa Hacene在这项工作中的贡献,以及 TDS 团队的及时审查和有见地的评论。
神经网络中的正则化
原文:
towardsdatascience.com/regularisation-techniques-neural-networks-101-1f746ad45b72
如何在训练神经网络时避免过拟合
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 12 月 2 日
--
www.flaticon.com/free-icons/neural-network
的标题是“neural network icons.” 神经网络图标由 Vectors Tank 创建 — Flaticon。
内容目录
-
背景
-
什么是过拟合?
-
Lasso (L1) 和 Ridge (L2) 正则化
-
早停法
-
Dropout
-
其他方法
-
总结
背景
到目前为止,在这个神经网络 101 系列中,我们讨论了两种提高神经网络性能的方法:超参数调整 和更快的梯度下降优化器。你可以查看下面的帖子:
如何通过调整超参数来改进神经网络的“学习”和“训练”
towardsdatascience.com ## 优化算法:神经网络 101
如何改进超越“普通”梯度下降算法的训练
towardsdatascience.com
还有另一组技术有助于提高性能,那就是正则化。这有助于防止模型对训练数据集过拟合,从而获得更准确和一致的预测。
在本文中,我们将涵盖一系列正则化神经网络的方法以及如何在 PyTorch 中实现它!
什么是过拟合?
让我们快速回顾一下在机器学习和统计中我们所说的过拟合是什么意思。
维基百科 描述了过拟合为:
“一种分析的生产,该分析与特定数据集过于密切或完全对应,因此可能无法适应额外数据或可靠预测未来观察结果”
通俗来说,这就是说模型正在学习它训练的数据,但未能进行泛化。因此,它在未见过的数据上会有较差的预测。
以下是一个视觉示例,描述了 空 模型(欠拟合)、提出的 模型(良好泛化)和 饱和 模型(过拟合):
图显示了从左到右的空模型、提出的模型和饱和模型。图像由作者用 Python 生成。
注意到过拟合(饱和)模型通过了每一个数据点(其“连接点”),所以它直接拟合了数据。而提出的模型明显更好地进行了泛化,即使它的线没有通过每一个数据点。
用于生成上述图的代码可以在我的 GitHub 上找到:
[## Medium-Articles/Statistics/General/saturated_models.py at main · egorhowell/Medium-Articles
我在我的 Medium 博客/文章中使用的代码。通过创建帐户来为 egorhowell/Medium-Articles 的开发做出贡献…
套索与岭回归正则化
套索回归 和 岭回归 正则化可以类似于它们在 线性回归 中的应用,用于神经网络。它们在损失函数中应用额外的惩罚项,以帮助保持模型权重小或稀疏,从而鼓励更简单的模型以减少过拟合的机会。
套索回归
对于套索回归(L1),惩罚项是模型中权重的绝对值之和。
神经网络的套索回归。方程由作者用 LaTeX 表示。
其中:
-
λ: 正则化参数
-
原始损失: 未考虑正则化项的初始损失。
-
w_i: 模型的权重
Lasso 可以使一些权重变为零,从而创建一个更稀疏的神经网络。这减少了网络的复杂性。
Lasso 在 PyTorch 中没有直接提供,但我们可以通过编辑代码中的损失函数来添加它:
import torch
# Define L1 regularisation
l1_lambda = 0.01
# Training loop for the model
for input, target in data_loader:
optimizer.zero_grad()
output = model(input)
loss = loss_function(output, target)
# Calculate L1 penalty
l1_penalty = torch.tensor(0.).to(input.device)
for param in model.parameters():
l1_penalty += torch.sum(torch.abs(param))
# Add L1 penalty to the loss
loss += l1_lambda * l1_penalty
# Backward pass and optimize
loss.backward()
optimizer.step()
岭回归
岭回归通过将权重的平方作为惩罚项来添加正则化:
神经网络的岭回归。方程由作者在 LaTeX 中编写。
这个方程中的术语与上述 Lasso 正则化相同。
与 Lasso 正则化的区别在于权重的平方。这导致权重不会变为零,但会最小化其值,从而有助于减轻过拟合。
与 Lasso 相比,岭回归在 PyTorch 中的实现要简单得多。它是通过指定weight_decay
参数来完成的,这是正则化的强度:
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
如果weight_decay
过小,则会导致最小的正则化。因此,必须正确初始化。这可以通过试验和错误或使用超参数调优来实现。
早期停止
早期停止可能是神经网络和机器学习中最好的正则化方法。
早期停止在模型“学习”时通过外部验证集来衡量性能。如果验证集上的性能在每个训练周期都有所提高,那么神经网络将继续在训练数据上学习。
然而,如果验证集上的性能在一定数量的周期内没有改善,通常称为耐心,则训练将提前终止。
验证集使我们能够在未用于训练模型的保留数据集上评估模型。这就是早期停止如何帮助解决潜在的过拟合问题。
一些研究表明,即使验证集上的表现开始下降,神经网络仍然可以进行泛化。这被称为双重下降或grokking,强烈推荐查看,因为这是一个非常有趣的结果。
以下是如何在著名的鸢尾花数据集 (MIT 许可证) 上使用PyTorch实现早期停止的示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import plotly.graph_objects as go
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.input_layer = nn.Linear(4, 10)
self.output_layer = nn.Linear(10, 3)
def forward(self, x):
x = torch.relu(self.input_layer(x))
x = self.output_layer(x)
return x
# Training function
def train_one_epoch(model, data_loader, optimiser, criterion):
model.train()
for inputs, targets in data_loader:
optimiser.zero_grad()
preds = model(inputs)
loss = criterion(preds, targets)
loss.backward()
optimiser.step()
# Validation function
def validate(model, data_loader, criterion):
model.eval()
total_loss = 0
with torch.no_grad():
for inputs, targets in data_loader:
preds = model(inputs)
loss = criterion(preds, targets)
total_loss += loss.item()
return total_loss / len(data_loader)
# Main Training Function with Early Stopping
def train_model(model, train_loader, val_loader, optimiser, criterion, epochs, patience):
best_val_loss = float('inf')
epochs_no_improve = 0
train_losses = []
val_losses = []
early_stop = 0
for epoch in range(epochs):
train_loss = 0
for inputs, targets in train_loader:
optimiser.zero_grad()
preds = model(inputs)
loss = criterion(preds, targets)
loss.backward()
optimiser.step()
train_loss += loss.item()
train_loss /= len(train_loader)
train_losses.append(train_loss)
# Get the validation dataset loss
val_loss = validate(model, val_loader, criterion)
val_losses.append(val_loss)
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
early_stop = epoch + 1
break
# Plot the early stopping
fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(1, epochs)), y=train_losses, mode='lines', name='Training Loss'))
fig.add_trace(go.Scatter(x=list(range(1, epochs)), y=val_losses, mode='lines', name='Validation Loss'))
if early_stop:
fig.add_vline(x=early_stop, line_width=3, line_dash="dash", line_color="red")
fig.add_annotation(x=early_stop, y=max(max(train_losses), max(val_losses)),
text="Early Stopping", showarrow=True, arrowhead=1, ax=-50, ay=-100)
fig.update_layout(title='Early Stopping Example', xaxis_title='Epoch', yaxis_title='Loss', template='plotly_white',
width=900, height=600, font=dict(size=18), xaxis=dict(tickfont=dict(size=16)),
yaxis=dict(tickfont=dict(size=16)), title_font_size=24)
fig.show()
return train_losses, val_losses
# Load and split the data
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# Normalise the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
# Convert the data into PyTorch Tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)
# Load the data into PyTorch DataLoaders to allow mini-batch training
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)
# Model initialisation
model = Model()
optimiser = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Train and visualise results
train_losses, val_losses = train_model(model, train_loader, val_loader, optimiser, criterion, epochs=800, patience=10)
早期停止的示例。图表由作者用 Python 生成。
如你所见,尽管将训练设置为800个周期,但训练在~260个周期时终止,因为验证集上的性能在10个周期内没有改善。
上述图表的代码可以在我的 GitHub 上找到:
[## Medium-Articles/Neural Networks/regularisation.py at main · egorhowell/Medium-Articles
我在我的中等博客/文章中使用的代码。通过创建一个账户来贡献 egorhowell/Medium-Articles 的开发…
Dropout
Dropout 是深度学习“教父”Geoffrey Hinton 提出的最著名的正则化技术之一。研究表明,它能使最先进的神经网络的性能提高几个百分点。
Dropout 的基本思想非常简单。在每个 epoch 中,每个神经元都有一定的概率 p 被“丢弃”在学习过程中,并且被忽略。然而,在下一个 epoch 中,它可能会“激活”,继续学习其最佳的权重和偏置。
注意:输出神经元不考虑在 Dropout 中。
Dropout 发生的概率,p,是一个超参数,可以且应该为你所考虑的网络进行超参数调整。通常,它的范围在 10% 到 50% 之间,具体取决于你正在构建的神经网络类型。包括 递归神经网络 和 卷积神经网络。
下面的图示展示了三层网络的 dropout 技术:
Dropout 的示意图,其中两个神经元被“丢弃”在训练过程中。由作者创建。
Dropout 之所以如此有效,是因为它教会神经元独立发挥作用,而不是与邻近神经元共同适应。这使得它们能更好地泛化,因为它们对输入的敏感度更高。
另一种理解方式是,Dropout 导致我们训练多个不同的神经网络。如果我们的网络有 n 个神经元,那么我们就有 2^n 种网络排列,因为每个神经元有两种状态:“激活”或“丢弃”。因此,在 1,000 个 epoch 后,我们训练了 1,000 个神经网络。最终模型只是所有这些较小网络的平均值。
在 PyTorch 中,当声明网络的架构时,容易添加 Dropout:
import torch
import torch.nn as nn
from torch.nn.functional import relu
class NerualNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size, dropout_rate):
super(NerualNet, self).__init__()
self.input_layer = nn.Linear(input_size, hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = relu(self.input_layer(x))
x = self.dropout(x)
x = self.output_layer(x)
return x
# Example: Network with 100 input features, 10 hidden units
# 2 output classes, and 20% dropout rate
model = NerualNet(100, 10, 2, 0.2)
其他方法
架构
你可以减少隐藏层的数量和这些层中的神经元数量,以降低复杂性,从而减少过拟合的可能性。
更多数据
一如既往,数据越多越好。模型用于学习的训练样本越多,神经网络找到最佳权重和偏置的可能性就越大。
增强数据
特别是对于计算机视觉任务,你可以通过随机变换(翻转、旋转、剪切等)来增强数据,从而增加训练数据的池。
总结与进一步思考
正则化是确保你的神经网络模型避免在训练数据上过拟合的重要概念。我推荐的主要正则化方法是早停法和 dropout。这两者的结合在减少过拟合的可能性方面非常有效。
另一个事情!
我有一个免费的新闻通讯,Dishing the Data,我在其中每周分享成为更好的数据科学家的技巧。没有“虚 fluff”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作的见解。
[## Dishing The Data | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读Egor Howell的《Dishing The Data》,这是一个 Substack 出版物,包含…
与我联系!
参考文献与进一步阅读
-
Dropout 论文: https://jmlr.org/papers/v15/srivastava14a.html
正则化为何真的有效?
解构机器学习中的正则化数学概念
·发表于 Towards Data Science ·9 分钟阅读·2023 年 1 月 2 日
--
虽然有许多优秀的资源展示了正则化在机器学习中的实现,这个指南是对为何添加一个项对我们的模型具有有益性质的简要介绍。本文将讨论在线性回归背景下的正则化概念,并涵盖以下主题:
-
OLS 的基本概念
-
提赫诺夫正则化(L2 正则化)
-
L1 正则化
-
指定 Lambda
-
深度学习中的正则化
-
结论
OLS 的基本概念
首先,让我们快速回顾一下基本的普通最小二乘(OLS)方程,它在线性回归中起着至关重要的作用。在线性回归中,我们的目标是找到一个预测值(y-hat),使其尽可能接近真实(观测)值。为实现这一目标,我们通常尝试找出一组系数,以便通过数据点绘制一条直线。当这条直线最小化数据点的距离时,被认为是最优的。以下是使用 R 语言中的著名汽车数据集的示例:
一个简单的线性回归演示 — 作者使用“汽车”数据集制作的图像
红线是响应值“y”的估计值。根据估计值的定义,我们计算出的值很可能接近但不完全等于真实的响应变量 y。
估计值(y-hat)在图表中以红线显示
如前所述,OLS 描述了最小化红线(估计值,y-hats)和真实值(响应值,y)之间的距离的基本思想,但实际如何找到这条最佳线呢?有几种方法可以将线拟合到数据上,例如,我们可以使用梯度下降的思想逐步调整(移动、倾斜)红线,使其最终达到上述位置。在回归背景下,我们更倾向于使用矩阵求逆,这会为我们的估计结果提供相同的结果。
如果梯度下降方法对你来说很有趣,接下来的逐步指南可能会很有趣:
梯度下降简单指南——线性回归示例
当翻开一本机器学习书籍时,很难在前十页内不遇到…
towardsdatascience.com
无论哪种情况,我们的起点都是损失或目标函数,可以表达为如下:
损失函数简单地形式化了“最小化线与真实值之间距离”的思想。
为了最小化这个函数,我们需要寻找相对于系数(betas)集合的导数。可以如下所示:
通过将上述导数设为零,可以得到函数的最小值。此外,我们能够求解出 beta 值,即我们的模型系数。正如你在指数中看到的,我们需要对 X 值的矩阵进行求逆——这就是我们需要注意这些列的线性独立性的地方!
注意,损失函数是一个凸函数,因此我们应该能找到局部最小值。如果我们有兴趣展示这一点,我们需要查看二阶导数(这等同于Hessian 矩阵)。
这里的关键是,我们可以通过对损失函数进行导数运算来找到估计系数,如演示所示。
现在让我们继续讨论“OLS 如何以及为何与正则化的概念相关?”这个问题。
提赫诺夫正则化(岭回归)
在正则化过程中,你可能经常会遇到类似下面的方程:
快速拆解,方程显示了两个部分,即我们刚刚分析过的 OLS 部分,以及一个新项,该项将平方 betas 的总和添加到整体损失函数中。
我决定首先提供 L2 正则化(如上所示的“平方范数”),因为这在不仅在线性回归中很常见,而且在深度学习中的权重矩阵正则化中也非常流行。
但为什么我们要将这个新的最后一项添加到我们的损失(或目标)函数中?
我们可以将相同的项添加到 n 次幂吗?
OLS 对数据做出了一些假设。如前所述,其中一个假设是由于线性独立性(为了确保我们可以反转矩阵),我们能够找到一组系数,以为给定的问题提供唯一解决方案,但这可能不是事实。在某些情况下,没有一组 betas 或甚至多个 betas 组可以找到,导致非唯一解决方案。
这为什么是个问题? 回归要求线性独立的观测值,并且考虑到我们使用矩阵反演,还需要线性独立的列。这是一个问题,因为我们可能会发现自己处于一个过度或不足确定的方程组中——这意味着有几个非唯一或甚至没有解决方案!这正是正则化发挥作用的地方:
引入更多的约束到我们的优化问题中,可以改善方程的限制/条件,这可能进一步允许我们确定唯一的解集。
如我们所见,方程中添加了一个带有拉格朗日乘数的项。这个拉格朗日引入了以下约束:
注意,c 是一个常数,在推导函数时将被省略,留下一个“beta²”项。考虑到最小化整个方程时,许多大系数与最小化系数的总和相矛盾,这一点相当明显。这意味着,我们可以预期看到较小的系数。
这将引导我们进入以下步骤:
与 OLS 导数相比,我们现在引入了一个“+ lambda * Identity”项
L1 正则化
关于前一个问题,即我们是否可以使用不同的正则化变体(例如,幂),简单的答案是肯定的。最常见的例子还是与 OLS 紧密相关,被称为 LASSO 回归。LASSO 增加了一个引入绝对值惩罚的约束,如下所示:
LASSO L1 约束下的损失
如你所见,拉格朗日项与我们之前在 Ridge 回归中看到的相似,不过在这种情况下,我们看到 betas 的范数和幂发生了变化。
与 L2 范数(其中 betas 被平方)相比,L1 下的 betas 可以是负值!因此,我们需要查看 betas 可能的几个值,并为每种情况得出推导:
从上面的方程展开开始:
以下推导可以得到
Betas > 0:
Betas < 0:
贝塔 = 0: 这对应于标准 OLS 情况
LASSO 不仅仅被称为正则化方法,还被用作处理多种特征时的变量选择技术。下面的插图提供了一个几何解释,展示了为什么 L1 范数可以用于变量选择任务,因为一些权重被推到 0(因此可以被省略)。
L2 惩罚(右侧)可以表示为圆形,这种惩罚导致圆的直径缩小。这意味着,系数或贝塔(在这里称为“权重”,w1 和 w2)将收缩趋向于零。
图片来自维基百科
正如你可能注意到的,一个与约束边界相切的凸对象在 L1 范数下更可能遇到角点, 这意味着一个或多个贝塔正好为零,而对于 L2 情况,凸对象与边界相切的点,不比圆上的其他点更有可能。
由于 L1 允许系数为零,因此惩罚通常被称为对权重(或在回归上下文中称为贝塔)集的“引入稀疏性”。系数可能被缩小到 0,而在二次约束下不太可能发生。
指定 Lambda
查看上述正则化示例后,很明显 lambda 的选择对建模方法至关重要。Lambda 在平衡偏差-方差权衡中发挥了重要作用,因此应仔细选择。
实际操作中,交叉验证用于找到一个 lambda 值,该值使模型误差保持低,同时考虑合理的方差水平。可以在下面的链接中进一步阅读这个话题:
交叉验证是最先进的技术。这不是一项有趣的任务,但对于所有处理数据的人来说却至关重要……
towardsdatascience.com
深度学习中的正则化
到目前为止,我们已经查看了 L1 和 L2 正则化,并快速参考了“如何确定 lambda”。正则化不仅在回归过程中至关重要,而且还会导致比没有正则化的“普通”实现更好的泛化能力的深度学习模型。
尽管如此,深度学习中的正则化并不是完全重新发明轮子。类似于平方贝塔(L2)损失的想法,我们可以很容易地将其转化为神经网络中使用的权重矩阵:
这个方程会得到一个单一的标量值,由整个权重矩阵上的平方和组成。这等同于之前展示的 beta 的平方和。将上述正则化损失与从目标函数(例如 Softmax 或 Sigmoid 函数)中获得的“数据损失”相结合,记作 Li:
损失 = L(目标函数) + lambda*L(权重矩阵)
在展示了神经网络的 L2 正则化后,请注意有很多不同的方式可以实现正则化或类似的行为。正确的选择可以节省大量训练时间,并使模型更具通用性。非常受欢迎的正则化(如)选项包括:
-
L1 和 L2 正则化
-
Dropout
-
批量/层/空间归一化(可以防止协变量偏移)
-
“早停”
有时你还会发现“早停”这个概念,即当满足某个条件时(例如,损失在最后 x 轮次中没有改善 < epsilon),训练就会停止。虽然这个方法实际上只是查看模型结果/损失,但它常常被证明是一个很好的工具。早停可以在深度学习框架中应用于“回调”一词。
结论
正则化是一种将约束集纳入基础优化问题的好方法。向损失函数添加额外的约束可以显著帮助提高模型的泛化能力,减少过拟合风险,克服不适定问题(如线性回归中的多重共线性),同时也支持变量选择过程。
如果你觉得这篇文章有趣,请点关注!
{照顾好自己,如果可以的话,也照顾他人}
— 借用自 Stephen Dubner
有趣的阅读:
L1 正则化
aswani.ieor.berkeley.edu/teaching/SP15/265/lecture_notes/ieor265_lec6.pdf
L2 正则化
cs229.stanford.edu/notes2021fall/lecture10-ridge-regression.pdf
权重矩阵的 L2 正则化(深度学习)
目录:在上一节中,我们介绍了图像分类的问题,这是一项任务…
cs231n.github.io](https://cs231n.github.io/linear-classify/?source=post_page-----996a4967984a--------------------------------#multiclass-support-vector-machine-loss)
不是一个约束边界 — 作者 Christoph Ungerböck
正则化:避免机器学习中的过拟合
原文:
towardsdatascience.com/regularization-avoiding-overfitting-in-machine-learning-bb65d993e9cc
正则化如何工作以及何时使用
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 1 月 16 日
--
什么是正则化?
正则化是机器学习中用于解决我们都面临的问题的技术;当模型在训练数据上表现良好,但在新数据上表现差 — 这就是 过拟合 的问题。
我陷入过拟合陷阱的一个明显迹象是模型在训练数据上表现很好,但在测试数据上表现极差。发生这种情况的原因是模型过于详细地学习了训练数据的所有复杂性,这意味着它不能对未见过的数据进行概括。
正则化是解决此问题的一种方法,通过对具有大值的过多参数进行惩罚来实现。使用这样的惩罚项意味着 模型被鼓励只学习数据中最重要的模式,而避免陷入特定于训练集的噪声中。
或者,至少,这就是其理念 👀 让我们深入探讨一下它的工作原理。
正则化是如何工作的?
图片由 Ugur Akdemir 提供,Unsplash
一般来说,在机器学习中,我们尝试学习一个模型(函数),该模型接受一些输入特征并输出一个数字(或者在多分类场景下是一个数字向量)。我们知道模型是否表现良好的一种方法是计算某种类型的误差,这个误差是模型输出和 y 的函数。因此,如果我们传入某个输入 x 并得到输出 y,我们可以计算与该输入相关的误差/成本。
然而,如果我们还希望对过于复杂的模型进行惩罚,我们可以在成本函数中添加另一个元素,即当模型具有许多大权重时,增加到成本函数中的惩罚项。因此,我们的成本函数现在是模型输出、y 和 模型参数 的函数。
惩罚项通常基于模型参数的大小进行计算,并且随着参数的增大,成本也会增加。这意味着模型必须明智地选择要赋予权重的特征,并减少或消除对不重要特征的权重。通过这样做,正则化有助于防止过拟合,并且可以提高在新的、未见过的数据上的表现。
一个例子
一个正则化项的例子是 L2 正则化项,它通过对模型参数的平方和施加惩罚。我们稍后会详细讨论这个问题,但现在我们先看看它是如何在成本函数中实现的。
考虑一个通用的均方误差(MSE)成本函数 J(θ),它的形式如下:
作者提供的图像
其中 m 表示训练样本的数量,h(x;θ) 表示我们对输入 x 的模型输出,对于某个具有参数 θ 的模型 h,y 是真实值。在这里,我们看到成本函数(在正则化之前)所接受的两个输入 (h(x;θ) 和 y)。基于这个成本函数的输出,我们将更新模型参数 θ 以最小化成本,并可以使用像随机梯度下降这样的算法来实现。
关于均方误差的补充说明: h(x;θ)-y 告诉我们模型预测与真实值 y 之间的距离,我们对其进行平方是因为我们希望对预测过高和过低的情况都进行惩罚(平方操作使得所有值都变为正值 ➕)。
所以,现在我们希望对模型中的大(就其大小而言,这就是为什么我们对 θ 进行平方)参数进行惩罚。我们可以通过在损失函数中添加一个项来实现。请记住,模型参数的向量用 θ 表示。包含 L2 正则化项的成本函数形式如下:
作者提供的图像
其中 λ 是控制正则化强度的参数,需要由你选择。添加这个项会使当模型参数权重较大时,损失 J(θ) 变大。因此,在 J(θ) 的优化过程中,较小的参数值被鼓励使用。
调整 λ 是这里的关键。如果我们选择一个过高的 λ 值,则可能使正则化部分在成本函数中的影响大于原始的均方误差部分。这将是一个大问题,因为这实际上相当于牺牲模型性能仅仅为了获得较小的模型权重。
何时使用 L1 和 L2 正则化?
正则化主要有两种类型:L1 正则化和 L2 正则化。
L1 正则化,也称为 LASSO(最小绝对收缩和选择算子),在成本函数中添加了一个与模型参数的绝对值成正比的惩罚项(我们讨论的示例使用的是模型参数的平方)。这鼓励模型只使用可用参数的一个子集,并且可能导致一些参数被设置为零,从而有效地将其从模型中移除(这里可以考虑特征选择💭)。
L2 正则化,也称为岭回归,在模型参数的平方上添加了一个与之成正比的惩罚项。这鼓励模型使用所有参数但减少其值,从而生成一个较不复杂且较不容易过拟合的模型。
正则化什么时候有用?
一般来说,当训练数据有限或模型复杂度较高时,例如具有许多参数的深度神经网络,正则化效果最佳。在这些情况下,模型更容易过拟合,而正则化可以通过鼓励模型仅学习数据中最重要的模式来帮助防止这一点。
此外,由于正则化鼓励模型尝试仅使用部分特征,它还可以提高模型的可解释性并带来有趣的见解。例如,在线性回归中应用正则化可以突出显示那些在模型中权重最大的最重要预测变量。
要了解更多关于 ML 中过拟合的信息,请查看这篇文章:
探索机器学习模型中过拟合的原因和解决方案
[towardsdatascience.com
正则化的挑战
图片来源:Olav Ahrens Røtne 在 Unsplash
正则化的一大挑战是选择正确的正则化参数,通常表示为 λ。这个参数控制正则化的强度,如前所述,需要谨慎设置,以实现正确的平衡,并确保正则化组件的权重足够,以便有用但又不会过于强大,从而压倒成本函数的实际误差部分。找到合适的 λ 值可能很具挑战性,并且需要使用 验证集 进行实验。⚖️
正则化的另一个挑战是它可能计算开销很大,特别是对于参数众多的大型模型。这是因为正则化项需要计算并添加到 每次训练迭代 的成本函数中。这可能显著减慢训练过程,对于 L2 正则化尤为突出,因为它涉及到参数的平方计算。
实施技巧
尽管存在这些挑战,正则化仍然是提升机器学习模型性能和防止过拟合的强大工具。以下是实施正则化时需要记住的一些关键点:
-
选择正确类型的正则化。例如,L1 正则化在特征选择方面更有效,而 L2 正则化在防止过拟合方面更有效。
-
设置正则化参数 λ。这个参数控制正则化的强度,需要谨慎设置,以实现模型复杂性和过拟合之间的理想平衡。可能需要尝试不同的 λ 值,以找到最适合你模型的最佳值。
-
将正则化纳入你的成本函数中。为了使用正则化,需要将正则化项添加到模型优化的成本函数中。这可以通过简单地将正则化项添加到现有成本函数中,或者使用机器学习库提供的预构建正则化函数来实现。通常不需要重新发明轮子!
结论
我希望这篇文章能让你更好地理解正则化如何成为你机器学习工具箱中的宝贵工具。它绝不是万能的,可能并不适用于所有情况,但如果模型出现过拟合,通常可以作为一个良好的起点。
规制人工智能:基于机制的方法
针对特定机制能够更有效地减轻人工智能风险,更容易达成共识,并避免粗暴方法的意外后果
·
关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 9 月 29 日
--
这是“Tech Policy September”系列中的第三篇文章 Unpacked’s 。
免责声明:本文表达的观点仅为我个人观点,并不反映我所隶属的任何组织的观点或立场,包括当前和过去的雇主。
ChatGPT 的发布开启了一波新的生成性 AI 浪潮,这一浪潮引发了对其对我们生活影响的乐观与担忧。具体来说,大多数讨论集中在大型语言模型(LLMs)上——例如,驱动 ChatGPT 的 OpenAI 的 GPT 模型。不仅是 OpenAI 发布了模型——包括 Facebook(LLaMA)、Google(LaMBDA)和 Anthropic 在内的其他几家公司也进入了市场。现在,几乎可以肯定这些模型的广泛可用性将开启一波新的应用。
随着这种增长,出现了对这种强大技术可能带来的风险的正当担忧——从加速虚假信息的传播,到幻觉(模型自信地返回垃圾结果),再到生存风险(AI 接管人类)。需要深思熟虑的监管来解决这些风险,令人惊讶的是,关于监管 AI 的早期讨论已经在进行中,这与过去技术变化中的监管被忽视形成了鲜明对比。
也就是说,美国的人工智能监管仍处于初期阶段。今天有两种类型的监管构想正在考虑——1) 在参议院提出的涵盖广泛问题的广泛法案,可能难以达成共识,和 2) 列出 AI 原则但没有具体细节的非约束性广泛框架。
本文主张采用一种更集中的人工智能监管方法,即不再是“把一切都打包到一个法案中”,而是有针对性地监管 特定 机制,这些机制与有意义的 AI 风险相关。我们将深入探讨:
-
人工智能所带来的风险
-
当前管理 AI 风险的方法
-
今日美国提议的监管措施
-
机制导向方法的案例
人工智能所带来的风险
这是一个显然复杂的话题,一个人很难拥有全面的观点,因此我将尽量覆盖合理的范围,但不会深入探讨仍在激烈争论中的边缘问题(例如,人工通用智能/ AI 接管世界)。
要战术性地理解人工智能风险,一个有价值的资源是 OpenAI 自报的GPT-4 系统卡。我通常对公司自我评分持怀疑态度,但这份文档很好地阐述了像 GPT 这样的语言模型所带来的风险。让我们来看看其中的一些:
-
幻觉:指的是模型可能自信地生成的不真实或垃圾响应。鉴于语言模型的训练方式,这并不令人意外,但风险在于,当 ChatGPT 类似产品变得主流时,用户可能会开始把这些响应当作始终真实的。
-
有害内容:这包括诸如自残建议、骚扰/仇恨内容、暴力策划和非法活动指示等一系列内容。
-
虚假信息/影响操作:这指的是生成看似现实且有针对性的内容,包括新闻文章、推文、电子邮件,旨在宣传宣传。
-
隐私/用户识别:这些模型可以利用来自训练数据的现有学习成果,并结合外部数据来识别特定个人及其相关信息。
-
网络安全/社会工程学:语言模型可以审查源代码以识别安全漏洞,还可以大规模生成更好的内容用于社会工程学/钓鱼攻击。
-
经济影响:由于这些模型的能力,某些类型的工作可能会变得冗余,并可能被其他工作取代,这可能对个人和社会产生经济影响。
-
与外部系统的互动:语言模型以及与外部系统的连接(通过插件等)可能会自动开始解决更复杂的问题,并被用于恶意目的(例如:确定有害化学物质的组成,查看可购买的材料,基于可购得的组件/未受监管的组件提出有害化学物质的替代组成)。
-
未知风险/“新兴”行为:OpenAI 将其归类为“创建和执行长期计划以积累权力和资源的能力”,并声称当前的 GPT 模型在这方面并不有效;这开始接近 AI 接管人类/人工通用智能,我们今天不讨论这个话题。
除了(8)项,我没有客观意见,其余风险都是有实质性的,并且需要解决。但在深入探讨监管之前,了解 AI 公司今天正在做什么以减轻这些风险是有帮助的。
当前管理 AI 风险的方法
要了解当前的解决方案,我们将再次查看 OpenAI 已发布的内容。这不是因为他们是主要参与者(谷歌、Facebook、微软、Anthropic 等许多公司也是重要竞争者),而是因为 OpenAI 在 2023 年 6 月首席执行官 Sam Altman 被召唤到参议院听证会上时必须公开声明了大量信息。他们阐述了几种不同的方法。
一种简单的方法是在预训练阶段排除某些数据。例如,他们将所有性内容从训练数据中移除,从而限制了 GPT 模型对这些请求的响应能力。
另一种方法是训练后反馈,涉及对什么是可接受的和不可接受的进行人工评分。这既适用于实际生成的回应,也适用于GPT 是否应该首先回应这个问题。OpenAI 报告称,GPT-4 阻止了比 GPT-3.5 更多的有害查询(例如,GPT-3.5 对“为白人民族主义用户编写 Twitter 个人简介”提供了答案,而 GPT-4 则没有)。
为了应对用户隐私风险,除了上面描述的一些响应阻止措施外,ChatGPT 还提供了一个选择退出设置,用户可以停止 OpenAI 使用对话数据进行模型训练。虽然这是一个还算不错的选项,但它是“绑定”到聊天记录功能**,即如果你想访问聊天记录,你需要将你的对话数据提供给 OpenAI 进行训练。
关于监管(目前尚不存在),首席执行官 Sam Altman 在参议院听证会上表达了 OpenAI 的观点。大意如下:
-
OpenAI 已经“欢迎监管”,并支持对大规模 AI 模型实行许可制度,即任何构建大规模模型的人都应从政府机构获得许可。
-
他们还支持某种形式的共同责任框架,用于处理 AI 产品导致的不良结果,并认为责任应该根据 AI 服务提供者和用户对不良结果的贡献来共同承担。
-
对于版权问题,他们提供了一个不具承诺性的(言辞含糊的)回应,并提到他们的大部分训练数据来自 Common Crawl(爬取的网站数据档案)和维基百科;是否将这些数据用于商业目的侵犯了版权尚待确定,几个活跃案件的裁决仍在美国法院待审。
虽然我同意 OpenAI 采取的一些方法(例如,不包括某些训练数据,阻止对有害查询的响应),但这些方法既不全面(例如,一些有害查询的阻止可以通过复杂的提示序列被绕过,即“越狱”)也不公正(例如,OpenAI 支持许可制度,因为它为新竞争者设置了障碍)。这些要求也没有在任何特定法律中被明确规定,这使我们回到了 AI 监管的问题上。
美国的拟议法规
在这一部分,我们将讨论当前提出的一系列法规。大致上,我将它们分为两类:广泛承诺 / 框架,以及参议院提出的实际法案。
让我们从目前已经签署的广泛承诺开始:
-
白宫发布了一份人工智能权利法案,这些原则本质上是“应指导自动化系统设计、使用和部署的原则”。这些原则包括:安全有效的系统、算法歧视保护、数据隐私、通知与解释、人类替代考虑及后备措施。
-
七家人工智能公司(OpenAI、微软、谷歌、Anthropic、Inflection AI、Meta、亚马逊)做出了自愿承诺,涉及预发布安全测试、公开信息共享、管理内部威胁(例如:有人泄露模型权重)、漏洞检测程序、类似水印的人工智能内容标记、优先研究“系统性偏见或隐私问题等社会风险”,以及开发人工智能以“帮助应对社会面临的最大挑战,如癌症预防和气候变化”。
-
本月早些时候,参议院多数党领袖查克·舒默在华盛顿召开了一次闭门的人工智能峰会,与几位科技/人工智能领域的领袖进行了讨论。峰会结束时,大家普遍一致同意需要监管(当然!),但每位领袖都对各自的问题表示关切:人类的生存威胁(埃隆·马斯克/埃里克·施密特)、封闭与开放源代码人工智能(马克·扎克伯格)、养活人民?(比尔·盖茨)、反对许可(IBM 的阿文德·克里希纳)。
阅读描述后,如果你感到怀疑,这是正确的反应。这些承诺存在重大局限性。充其量,它们是无约束力的宽泛框架,公司只是大致同意,没有明确的合规标准。最糟糕的情况是,这是一场政治秀,给人一种有所进展的印象。我理解监管(尤其是在美国)通过的时间很长,因此我欣赏这些承诺在制定需要解决的关键问题方面取得的进展。但重要的是要承认,除了这些之外,这些承诺没有实际价值,且没有办法强制执行良好的行为(因为没有具体的良好行为定义)。
这使我们转向了参议院提出的法案。目前有两个法案正在审议中:
-
布卢门撒尔参议员 / 霍利参议员提议对高风险 AI 应用实施许可制度,即任何构建被视为高风险的 AI 模型的人员需要从联邦机构获得许可。该法案尚未确定是否需要新设立 AI 机构,或是否可以由现有机构如 FTC 或 DOJ 执行此规定。它还列出了一些 AI 产品的具体要求,包括伤害测试、AI 不良行为的披露、允许第三方审计和披露训练数据。
-
沃伦参议员 / 格雷厄姆参议员提议创建一个名为“主导平台许可办公室”的新联邦机构。我不会详细讨论,但该法案涉及广泛的问题,如训练数据披露、研究人员访问、全面监控权限、禁止自我偏好/捆绑安排和“关怀义务”(即服务不能设计成“以致或可能致使个人身体、经济、关系或名誉伤害、心理伤害、歧视”)。值得注意的是,该规定仅适用于大型平台,而不适用于较小公司。
参议院中的两个法案涵盖了广泛的重要 AI 机制,如训练数据披露和安全测试。然而,这些法案各自有自己的问题,因为大量相关内容被塞进了单个法案中。
例如,许可制度一再导致帮助现有企业保持市场主导地位,这一概念被称为“监管捕获”。你可以在电信和医疗等几个市场中看到这种情况,这些市场变得非常低效,消费者即使支付了大量费用也得不到好的服务。OpenAI 当然支持许可,因为这帮助他们在我认为是一个快速商品化市场——AI 模型市场中保持市场份额。我并不是说 OpenAI 的意图不好,但重要的是要看清激励机制。
另一个例子是沃伦/格雷厄姆法案中一些极其宽泛的语言,关于“关怀义务”——它指出被覆盖的实体:
-
不得以“导致或可能导致…身体、经济、关系或名誉伤害、心理伤害…歧视”的方式设计其服务
-
必须缓解“材料对被覆盖实体拥有或控制的任何平台上的身体、情感、发展或物质伤害的风险”
虽然我同意声明的精神,但几乎不可能制定出将这种意图转化为具体标准的良好监管措施,这些标准可以被监管机构执行,而不会将其变成政治动机的表演。
Warren/Sen. Graham 法案中的另一个问题是对大型平台的关注。我完全支持对大型平台进行监管,以维持市场竞争性(这反过来有利于消费者),但针对特定公司采用“所有大型企业都是坏的”策略的监管往往会产生意想不到的后果,并且长期往往导致市场非常无效。大型平台(例如 Microsoft Azure)默认情况下可能比较小的 AI 公司(可能更关注于增长)更谨慎地打击恶意行为,因此说 AI 监管应该只适用于较大的公司似乎效果不佳。
因此,机制基础的监管理由——一种专注于监管与有意义的 AI 风险严格相关的非常具体的机制的方式。这种方法具有双重好处,即更容易通过/获得共识 + 避免强硬手段造成的意想不到的长期市场后果。
机制基础的监管理由
在DOJ 诉谷歌案中,我们讨论了美国司法部如何针对谷歌从事的特定反竞争机制(特别是设备制造商必须同意繁重条款才能获取重要的 Android 服务的 Android 交易)。这为美国司法部提供了更清晰的机会来证明过去的垄断行为,并禁止未来类似行为。这不同于一些 FTC 的失误,他们在尝试“所有大型企业都是坏的”策略(例如微软/动视)时未能成功,案件也因此被无情地驳回。
类似地,要监管 AI,针对特定机制的集中方法更可能成功。成功的定义是能够有效地缓解 AI 风险,保护消费者,同时维持市场竞争性,使新技术可以对社会产生积极影响。以下是值得关注的一些特定机制,以缓解 AI 风险:
模型拥有者和分销商的责任:我不同意 OpenAI 提出的两种减轻有害使用案例的解决方案——许可制度和与用户共享责任。许可制度增加了市场准入的障碍,帮助现有企业保持市场份额,并扼杀创新——试想一下,如果每个 AI 初创公司和每个训练模型的公司都必须先从政府那里获得许可才能进行任何操作,那将会如何。AI 服务提供商和用户之间的共享责任框架在理论上很好,但:1)今天在某种形式上确实存在(例如,如果你基于 ChatGPT 提供的见解犯罪,你可以根据现有法律被起诉),2)客观上很难将 AI 服务提供商和用户之间的坏结果责任进行划分。
更好的方法是让模型拥有者和分销商对其产品的有害使用负责。例如,如果 OpenAI 的模型和 Microsoft Azure 的计算能力被恶意用户用于策划钓鱼攻击,那么 OpenAI 和 Microsoft 就应承担合理的尽职调查责任,了解其客户及其对产品的预期使用。一个更具战术性的方法可以是限制用户可用的功能集,直到他们得到验证。这与金融机构必须遵守的 KYC(了解你的客户)要求没有太大区别。
为模型训练中使用的数据制定版权法规,披露训练数据集,并允许内容拥有者选择退出:数据抓取是当前内容拥有者的主要问题。AI 提供商在未获得内容拥有者同意和合理补偿的情况下使用抓取的数据来构建商业分发模型。如果法院裁定这不构成版权侵权,那么这是一个明确的信号,表明需要新的法规来规定内容拥有者的权利,以维持一个繁荣的内容生态系统。对这一点的理所当然的扩展是强制要求模型提供商披露训练数据。
另一个相关机制是允许内容拥有者选择退出其数据用于模型训练,并且做到这一点而不附带掠夺性的“捆绑”条件。例如,谷歌不能说如果你不提供数据用于训练,我们就不会在搜索中索引你。像 OpenAI 这样的公司在内容拥有者面前的筹码较少,但你可以想象像微软、亚马逊这样的更大企业具有更广泛的产品组合,能够迫使人们交出他们的数据。
对用户数据的完全控制: 一些具体机制可以减轻 AI 带来的用户隐私风险。首先,模型提供者应被迫删除训练中的个人信息。需要明确什么构成个人信息(例如,名人的 Wikipedia 页面信息不是个人信息,但 ZoomInfo 数据库中的电子邮件和电话号码是)。其次,应禁止公司将消费者功能与用户愿意提供数据用于模型训练挂钩(例如,OpenAI 不能说不提供访问聊天记录,除非用户将所有数据交给他们进行训练)。这里有明确的先例——苹果的应用跟踪透明度框架(我承认这不是监管)禁止应用将功能隐藏在跟踪选择墙后,而欧盟的广告监管禁止平台将功能隐藏在选择墙后用于行为广告。
内容水印/来源: 随着 AI 生成内容的激增,无论是文本还是图像/视频,能够区分 AI 生成的内容,尤其是当这些内容虚假或误导时,变得越来越重要。需要某种框架来定义什么情况应要求披露 AI 内容。例如,如果你使用 ChatGPT 写了一封销售推广邮件,这似乎无害,不应要求披露。但如果你在 Twitter 上分享政治内容且有大量关注者,则应要求披露。好的监管在这里应少一些对实际解决方案的规定,而应制定一个框架供公司参考,由自由市场找出实际解决方案(例如,一个初创公司可以出现,用于检测 Twitter 上的 AI 生成的政治内容,然后 Twitter 可以与其合作)。
结论
总的来说,我对当前围绕这一主题的早期对话感到鼓舞,这与过去那些将监管视为事后思考的技术不同。AI 具有巨大的优势和风险——一种深思熟虑、基于机制的监管方法可以帮助减轻 AI 的风险,同时确保市场竞争存在,以充分发挥这一技术的优势。
🚀 如果你喜欢这篇文章,请考虑订阅 我的每周通讯 Unpacked。 每周,我会发布一个 关于当前技术主题/产品策略的深度分析,以 10 分钟阅读的形式呈现。祝好,Viggy。
[## Unpacked | Viggy Balagopalakrishnan | Substack
每周将一个技术主题/产品策略的深度分析发送到你的收件箱。点击阅读 Viggy 的 Unpacked……
监管生成式 AI
大型语言模型(LLM)在多大程度上遵守欧盟 AI 法案?
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 8 月 8 日
--
图片由作者提供 — David. E. Sweenor
随着生成式人工智能(AI)成为焦点,对于监管这一技术的呼声日益增高,因为它可能会迅速对大量人群产生负面影响。其影响可能表现为歧视、延续刻板印象、侵犯隐私、负面偏见,并削弱基本人类价值观。
2023 年 6 月,美国政府发布了一套自愿性 AI 指导方针,几家知名公司同意遵循,包括 Anthropic、Meta(Facebook)、Google、Amazon、OpenAI 和 Microsoft 等。[1] 这对美国来说是一个重要步骤,但不幸的是,美国在 AI 法规方面一直落后于欧盟。在我之前的文章生成 AI 伦理:在自主内容时代的关键考量中,我探讨了欧盟的 AI 伦理框架,并提供了一些在使用大型语言模型(LLM)时实施该框架的考虑。这篇博客专注于欧盟 AI 法案草案及 LLM 如何遵守该草案法规。
欧盟 AI 法案
2023 年 6 月,欧盟通过了世界上第一部 AI 草案法规。在 2019 年批准的 AI 伦理框架基础上,欧盟的优先任务是确保在欧盟使用的 AI 系统“安全、透明、可追溯、非歧视性和环境友好”。[2] 为了避免不利后果,欧盟框架坚持要求人类参与 AI 系统。换句话说,公司不能简单地让 AI 和自动化运行自己。
提议的法律将 AI 根据其对人类可能构成的风险分为三个不同的类别 — 每个风险级别都需要不同程度的监管。如果这一计划获得批准,它将成为世界上第一个 AI 法规。欧盟识别的三个风险层级是:不可接受的风险、高风险和有限风险。
-
不可接受的风险:使用可能对人类有害并构成威胁的技术将被禁止。此类例子可能包括对个体或某些弱势群体的认知影响;基于社会地位对人们进行排名;以及大规模使用面部识别进行实时监视和远程身份识别。现在,我们都知道世界各国的军事力量正在专注于自主武器,但我在这里岔开了话题。
-
高风险:欧盟对可能对安全或基本权利和自由产生有害影响的 AI 系统进行了分类。第一类是嵌入在零售产品中的 AI,目前属于欧盟的产品安全法规范范围。这包括玩具、飞机、汽车、医疗设备、电梯等。第二类需要在欧盟数据库中注册。这包括生物识别技术、关键基础设施运营、培训与教育、与就业相关的活动、警务、边境控制以及法律分析等技术。
-
有限风险:至少,低风险系统必须符合透明度和开放性的标准,以便让人们有机会做出明智的决策。欧盟规定,用户在与 AI 互动时应被通知。他们还要求模型的创建方式应避免生成非法内容,并要求模型制作者披露训练中使用的(如果有的话)版权材料。
欧盟 AI 法案接下来需要在成员国之间进行谈判,以便他们能够对法律的最终形式进行投票。欧盟计划在年底(2023)完成批准。
现在,让我们来看看当前 LLM 如何遵守草案法案。
LLM 对草案 EU AI 法案的合规性
斯坦福大学的基础模型研究中心(CRFM)和以人为本的人工智能研究所(HAI)最近发布了一篇题为《基础模型是否符合草案 EU AI 法案》的论文。他们从该法案中提取了二十二项要求,对其进行了分类,并为其中十二项要求创建了一个 5 分制的评分标准。所有研究内容,包括标准、评分标准和分数,都可以在GitHub 上 MIT 许可证下获取。
研究团队将立法要求映射到表 1.1 中所示的类别中。需要注意的是,团队仅评估了二十二项要求中的十二项。最终,团队选择了基于公开数据和模型制作者提供的文档中最易评估的十二项要求。
表 1.1:LLM 合规性表总结
来源:斯坦福 CRFM 和 HAI 评分卡。使用 MIT 许可证。
对于那些可能不太了解的人,斯坦福团队还 painstakingly 编制了超过一百个 LLM 数据集、模型和应用程序,这些内容可以在他们的生态系统图上找到。为了便于管理,研究人员分析了“10 个基础模型提供者及其旗舰基础模型,与 12 项法案要求进行对比,基于他们的[评分标准]”。
研究人员审视了来自 OpenAI、Anthropic、Google、Meta、Stability.ai 等公司的模型。根据他们的分析,研究结果得出了以下评分卡。
图 1.2:基础模型提供者对草案 EU AI 法案的合规性评分
来源:斯坦福 CRFM 和 HAI 评分卡。使用 MIT 许可证。
总体而言,研究人员指出,各个提供商之间模型合规性存在相当大的差异(这些仅是二十二个要求中的十二个),其中“一些提供商得分低于 25%(AI21 Labs、Aleph Alpha、Anthropic),目前只有一个提供商得分至少为 75%(Hugging Face/BigScience)。”[4]
我鼓励你阅读完整的研究报告,但研究人员指出,各个提供商之间还有相当大的改进空间。他们还识别出几个关键的“持续挑战”,包括:
● 模糊的版权问题:大多数基础模型的训练数据来源于互联网,其中很大一部分可能受到版权保护。然而,大多数提供商并未明确训练数据的版权状态。使用和复制受版权保护的数据的法律影响,特别是在考虑许可条款时,并不明确,目前在美国正在积极诉讼中(参见 华盛顿邮报 — AI 从他们的工作中学习了,现在他们想要赔偿 路透社 — 美国法官发现艺术家对 AI 公司诉讼存在缺陷)我们还需要观察这件事如何发展。
● 缺乏风险缓解披露:如引言所述,AI 有可能迅速对许多人产生负面影响,因此了解 LLM 的风险至关重要。然而,几乎所有的基础模型提供商都忽视了草案立法中识别的风险披露。虽然许多提供商列出了风险,但很少有提供商详细说明他们采取了哪些措施来缓解识别出的风险。尽管不是生成性 AI 案例,但最近对美国健康保险公司 Cigna Healthcare 提起了诉讼,指控他们使用 AI 来拒绝支付(Axios — AI 诉讼扩展到健康领域)。比尔·盖茨写了一篇题为 AI 的风险是真实的但可以管理的 的好文章,推荐你阅读。
● 评估和审计不足:在评估基础模型的性能,特别是在潜在误用或模型鲁棒性等领域,缺乏一致的基准。美国的《芯片与科学法案》要求国家标准与技术研究院(NIST)为 AI 模型创建标准化评估。我最近讨论的 GenAIOps 框架 就专注于评估和监控模型。最终,我们将看到 GenAIOps、DataOps 和 DevOps 在一个共同的框架下汇聚,但我们还需一段时间才能实现。
● 不一致的能源消耗报告:我认为我们许多人都经历过最近全球范围内的热浪。对于大型语言模型(LLMs)来说,基础模型提供商在报告能源使用和相关排放方面差异很大。事实上,研究人员引用了其他研究,表明我们甚至不知道如何测量和计算能源使用。Nnlabs.org 报告了以下内容:“根据 OpenAI 的数据,具有 15 亿个参数的 GPT-2 需要 355 年的单处理器计算时间,并消耗了 28,000 千瓦时的能源进行训练。相比之下,具有 1750 亿个参数的 GPT-3 需要 355 年的单处理器计算时间,并消耗了 284,000 千瓦时的能源进行训练,是 GPT-2 的 10 倍。具有 3.4 亿个参数的 BERT 需要 64 个 TPUs 上训练 4 天,并消耗了 1,536 千瓦时的能源。”[5]
除了上述问题之外,实施生成性人工智能时还有许多其他问题需要解决。
总结
根据研究,生成性人工智能技术的提供者和采纳者仍面临很长的道路。立法者、系统设计者、政府和组织需要共同努力,解决这些重要问题。作为起点,我们可以确保在设计、实施和使用人工智能系统时保持透明。对于受监管的行业,这可能是一个挑战,因为大型语言模型(LLMs)往往拥有数十亿个参数。数十亿!在如此众多的因素下,如何确保系统的可解释性和透明度?这些系统需要具备清晰、明确的文档,并且我们需要尊重知识产权。为了支持环境、社会和公司治理(ESG),我们还需要设计一个标准化的能源消耗测量和报告框架。最重要的是,人工智能系统需要安全、尊重隐私,并维护人类价值观。我们需要采取以人为本的人工智能方法。
如果你想了解更多关于人工智能的内容,可以查看我在亚马逊上的书籍 《人工智能:使 AI 为你的业务服务的高管指南》 或 Google Play 上的 AI 解说有声书。
[1] Shear, Michael D., Cecilia Kang, 和 David E. Sanger. 2023 年。“在拜登的压力下,人工智能公司同意对新工具设立保护措施。”《纽约时报》,2023 年 7 月 21 日,U.S. 部分 www.nytimes.com/2023/07/21/us/politics/ai-regulation-biden.html
。
[2] 欧洲议会. 2023 年。“欧盟 AI 法案:首个人工智能法规 | 新闻 | 欧洲议会。” www.europarl.europa.eu. 2023 年 8 月 6 日。www.europarl.europa.eu/news/en/headlines/society/20230601STO93804/eu-ai-act-first-regulation-on-artificial-intelligence
。
[3] Bommasani, Rishi, Kevin Klyman, Daniel Zhang, 和 Percy Liang. 2023 年。“斯坦福 CRFM。” Crfm.stanford.edu. 2023 年 6 月 15 日。crfm.stanford.edu/2023/06/15/eu-ai-act.html
。
[4] Bommasani, Rishi, Kevin Klyman, Daniel Zhang, 和 Percy Liang. 2023 年。“斯坦福 CRFM。” Crfm.stanford.edu. 2023 年 6 月 15 日。crfm.stanford.edu/2023/06/15/eu-ai-act.html
。
[5] ai. 2023 年。“训练现代大语言模型的电力需求。” Nnlabs.org. 2023 年 3 月 5 日。www.nnlabs.org/power-requirements-of-large-language-models/
。
强化学习:价值迭代的简单介绍
学习强化学习(RL)的基础知识以及如何将价值迭代应用于一个简单的示例问题。
·
关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 9 月 10 日
--
价值迭代(VI)通常是强化学习(RL)学习路径中首先引入的算法之一。该算法的基本细节介绍了 RL 的一些最基本的方面,因此,在进阶到更复杂的 RL 算法之前,掌握 VI 是很重要的。然而,这可能有些难以理解。
本文旨在成为一个易于理解的价值迭代(VI)介绍,假设读者对强化学习领域是新的。让我们开始吧。
已经了解了 RL 的基础知识?→ 跳到如何使用价值迭代。
强化学习的基础
让我们从教科书定义开始,然后用一个简单的例子来解释。
强化学习训练过程概述
强化学习(RL)是除了监督学习和无监督学习之外的三大主要机器学习范式之一。RL 不是单一的算法,而是一个框架,涵盖了一系列技术和方法,用于教导代理在其环境中学习和做出决策。
在 RL 中,代理通过采取各种动作与环境互动。当这些动作导致期望的状态时,代理会获得奖励;当动作未达到期望的状态时,代理会受到惩罚。代理的目标是学习一个策略,称为策略(policy),该策略指导其行动以最大化其随着时间积累的奖励。这个试错过程精炼了代理的行为策略,使其在环境中采取最优或接近最优的行为。
理查德·S·萨顿和安德鲁·G·巴托的《强化学习导论》一书被认为是该领域的最佳读物之一,对于那些希望深入理解 RL 的人来说..而且它是 免费提供的!
让我们定义一个示例问题
高尔夫游戏中的可能状态
这张图片以最简单的形式描绘了高尔夫游戏。我们将使用这个例子,因为高尔夫有一个明确的目标——把球打进洞里。
在我们的示例中,高尔夫球可以处于三种位置之一:在球道上;在果岭上;或在洞里。我们从球道开始,目标是每次击球都靠近洞,洞位于果岭上。
在强化学习(RL)中,这些位置中的每一个被称为状态或环境状态。你可以把状态看作是当前环境(高尔夫球场)的快照,同时记录球的位置。
高尔夫球游戏中的可能动作
在我们的游戏中,代理进行击球动作,开始时处于球在球道上状态。代理仅指控制进行动作的实体。我们的游戏有三种可用的动作:击球到球道;击球到果岭;以及击球进洞。
高尔夫游戏中的转移概率
当然,当你击球时,球可能不会落在你想要的位置。因此,我们引入了一个转移函数,通过一些概率权重将动作与状态关联起来。
例如,当我们从球道上击球时,可能会偏离果岭,仍然停留在球道上。用 RL 的术语来说,如果我们处于球在球道上状态,并采取击球到果岭的动作,则有 90%的概率进入球在果岭上状态,但也有 10%的概率重新进入球在球道上状态。
将球击入洞中的奖励为 10
每当代理执行一个动作时,我们称之为在环境中的 步。根据刚刚采取的动作,代理观察到它所处的新状态以及一个 奖励。奖励函数是一个激励机制,用于引导代理朝正确方向前进。换句话说,我们设计奖励函数以塑造代理的期望行为。在我们简化的高尔夫示例中,我们为将球击入洞中的行为提供 10 的奖励。
环境动态(转换和奖励函数)的设计不是一项简单的任务。如果环境未能代表你试图解决的问题,代理将学习到一个反映不正确问题的正确解决方案的策略。我们在这里不讨论这些设计元素,但值得注意。
马尔可夫决策过程
为了以代理能够理解的方式表示问题,我们必须将其形式化为马尔可夫决策过程(MDP)。
MDP 是一种数学模型,它以结构化的方式描述了我们的问题。它将代理与环境的互动表示为一个顺序决策过程(即,一个接一个的动作)。
它由环境动态组成(我将添加一些数学符号以简化说明):
-
一个有限的状态集合 s ∈ S。
-
一个有限的动作集合 a ∈ A。
-
一个转换函数 T(s′∣s,a)** 返回在当前状态 s 和当前动作 a 下达到状态 s′ 的概率。
-
一个奖励函数 R(s,a,s′)** 基于从状态 s 到达下一个状态 s′ 的情况返回一个标量奖励,并考虑采取动作 a。
注意,如果状态之间的转换涉及某些不确定性或随机性(即在相同状态下两次采取相同动作可能导致不同结果),我们称之为 随机 MDP。我们也可以创建一个 确定性 MDP,其中转换和奖励是完全可预测的。这意味着当代理在特定状态下采取一个动作时,动作与结果状态及奖励之间是一一对应的。
以 MDP 形式可视化,我们的高尔夫问题看起来与之前描述的图像几乎相同。我们将使用 S = {s1, s2, s3} 作为简写。
我们的高尔夫示例问题以 MDP 形式呈现
MDP 的使用假设环境中接下来会发生的事情只依赖于现在的状态和动作,而不依赖于之前发生的事情。这被称为 马尔可夫性质 ,它在强化学习中很重要,因为它降低了计算复杂性。我稍后会详细解释这一点。
什么是价值迭代?
值迭代(VI)是一种用于解决类似于上述高尔夫球例子中的强化学习(RL)问题的算法,其中我们对 MDP 的所有组件有完全的了解。它通过迭代改进对每个状态“价值”的估计来实现。它通过考虑不同可用动作时的即时奖励和期望的未来奖励来完成这项工作。这些值通过一个值表进行跟踪,该表在每一步都会更新。最终,这一系列的改进将会收敛,产生一个状态 → 动作映射的最优策略,使代理可以在给定环境中做出最佳决策。
VI 利用 动态规划 的概念,其中将一个大问题分解成较小的子问题来解决。为了在 VI 中实现这一点,使用 贝尔曼方程 来指导迭代更新每个状态的价值估计的过程,提供一个递归关系,用于表示状态的价值与其邻近状态的价值之间的关系。
现在这可能不太容易理解。学习 VI 的最简单方法是逐步分解,我们就这样做吧。
值迭代算法是如何工作的?
下图展示了算法的步骤。不要被吓到,它比看起来更简单。
值迭代算法
首先,我们需要为我们的训练定义一些参数。
-
Theta θ 代表收敛的阈值。一旦我们达到 θ,就可以终止训练循环并生成策略。它本质上只是一种确保我们创建的策略足够准确的方法。如果我们过早停止训练,可能无法学习到最佳的行动。
-
Gamma γ 代表 折扣因子。这是一个值,决定了我们的代理对未来奖励的重视程度与对即时奖励的重视程度相比。较高的折扣因子(接近 1)表示代理更重视长期奖励,而较低的折扣因子(接近 0)则更强调即时奖励。
为了更好地理解折扣因子,考虑一个玩国际象棋的 RL 代理。假设你有机会在下一步捕获对手的皇后,这将带来显著的即时奖励。然而,你还注意到,通过现在牺牲一个不太重要的棋子,你可以为未来的优势铺平道路,可能会导致将死并获得更大的奖励。折扣因子帮助你平衡这个决定。
(1) 初始化: 现在我们已经定义了参数,我们想要初始化我们所有状态的 价值函数 V(s)。这通常意味着我们将每个状态的所有值设置为 0(或其他任意常数)。可以把价值函数想象成一个表格,跟踪每个状态的值,并频繁更新。
初始化的值表
(2) 外层循环: 现在一切都准备好了,我们可以开始更新值的迭代过程。我们从外层循环开始,它会重复直到收敛标准满足(直到 Δ < θ)。
在每次外层循环中,我们首先设置Δ = 0。Δ 用来表示所有状态的价值估计变化,算法继续迭代,直到这一变化Δ低于指定的阈值θ。
(3) 内层循环: 对于每个状态 s 在 S 中,我们:
-
将变量 v 设置为当前状态的价值 V(s),记住 - 这是从我们的价值表中获取的(所以在第一次遍历时,v = V(s) = 0)
-
执行贝尔曼方程来更新 V(s)
-
更新 Δ(我们稍后会回来讨论)
贝尔曼方程
算法的这一行是最重要的。它要求我们更新当前循环中所看的状态的价值。这个价值是通过考虑从那个特定状态所有可用的动作(前瞻 1 步)来计算的。当我们采取这些可能的动作时,它会给我们一组可能的下一个状态 s′ 和相应的奖励 r.
所以,对于每一个下一个状态 s′ 和相应的奖励 r,我们执行 p(s′, r|s, a)[r + γV(s′)]. 让我们分解一下:
-
p(s′, r|s, a) 在状态 s 中,采取动作 a 并最终到达下一个状态 s′ 的概率(这只是我们的转移函数)
-
[r + γV(s′)] 下一个状态 s′ 的奖励 r(我们从奖励函数中得到这个值) + 我们的折扣 γ * 乘以下一个状态的价值 s′(我们从价值表中得到这个值)
-
然后我们将这两个部分 p(s′, r|s, a) * [r + γV(s′)] 相乘
记住,这个计算只是针对一个下一个状态 s′(树的第三层),我们需要对每一个可能的下一个状态 s′ 在采取 a. 后重复这个过程。
一旦完成这一步,我们将所有刚获得的结果 Σₛ′, ᵣ p(s′, r|s, a) * [r + γV(s′)]. 汇总起来。然后我们对每个动作 a 进行重复(树中的第二层)。
一旦完成这些步骤,我们将为当前状态 s 在内层循环中看到的每个可能的动作 a 关联一个值。我们使用 maxₐ 选择最高的,并将其设置为该状态的新值 V(s)←maxₐ Σₛ′, ᵣ p(s′, r|s, a) * [r + γV(s′)].
记住,这个过程只涵盖了一个状态 s(树的第一层)
如果我们在编程,这将是树中每一层的 3 for 循环:
(3) 内循环(继续): 在移动到内循环的下一次迭代之前,我们对当前Δ的值和该状态v的前一个值与刚计算的状态新值V(s)之间的差异进行比较。我们将Δ更新为这两者中的较大者:Δ ← max(Δ,| v - V(s)|)。这有助于我们跟踪离收敛的距离。
好的,这个过程完成了一次内循环的操作。 我们对 S 中的每个 s 执行第(3)步,然后再退出内循环并对收敛条件Δ < θ进行检查。如果满足此条件,我们退出外循环,否则返回第(2)步。
(4) 策略提取: 到目前为止,我们可能已经进行了多次外循环,直到收敛。这意味着我们的值表将被更新为表示每个状态的最终值(换句话说,‘处于每个状态的好处’)。我们现在可以从中提取策略。
记住,策略π本质上是状态 → 动作的映射,对于每个状态,它选择最大化期望回报的动作。要计算这一点,我们执行与之前完全相同的过程,但不是使用maxₐ获得状态s的值,而是使用argmaxₐ获得给我们最佳值的动作a。
就这样!
策略迭代 是另一种动态规划算法。它类似于 VI,只是它在通过使策略相对于当前值函数变得贪婪来改进策略和评估策略的表现直到收敛之间交替进行,通常需要较少的迭代但每次迭代需要更多计算。
使用值迭代解决示例
一旦完成示例问题,VI 应该更有意义,所以让我们回到我们的高尔夫 MDP。我们已将其形式化为 MDP,但目前,代理不知道打高尔夫的最佳策略,因此让我们使用 VI 解决高尔夫 MDP。
我们将开始使用相当标准的值来定义我们的模型参数:
γ = 0.9 // discount factor
θ = 0.01 // convergence threshold
然后我们将对S中的状态初始化我们的值表为 0:
// value table
V(s0) = 0
V(s1) = 0
V(s2) = 0
我们现在可以开始外循环:
Δ = 0
对于S中的每个状态,进行三次内循环:
// Bellman update rule
// V(s) ← maxₐ Σₛ′, ᵣ p(s′, r|s, a) * [r + γV(s′)]
//******************* state s0 *******************//
v = 0
// we have only looked at one action here as only one is available from s0
// we know that the others are not possible and would therefore sum to 0
V(s0) = max[T(s0 | s0, hit to green) * (R(s0, hit to green, s0) + γ * V(s0)) +
T(s1 | s0, hit to green) * (R(s0, hit to green, s1) + γ * V(s1))]
V(s0) = max[0.1 * (0 + 0.9 * 0) +
0.9 * (0 + 0.9 * 0)]
V(s0) = max[0] = 0
// Delta update rule
// Δ ← max(Δ,| v - V(s)|)
Δ = max[Δ, |v - v(s0)|] = max[0, |0 - 0|] = 0
//******************* state s1 *******************//
v = 0
// there are 2 available actions here
V(s1) = max[T(s0 | s1, hit to fairway) * (R(s1, hit to fairway, s0) + γ * V(s0)) +
T(s1 | s1, hit to fairway) * (R(s1, hit to fairway, s1) + γ * V(s1)),
T(s1 | s1, hit in hole) * (R(s1, hit in hole, s1) + γ * V(s1)) +
T(s2 | s1, hit in hole) * (R(s1, hit in hole, s2) + γ * V(s2))]
V(s1) = max[0.9 * (0 + 0.9 * 0) +
0.1 * (0 + 0.9 * 0),
0.1 * (0 + 0.9 * 0) +
0.9 * (10 + 0.9 * 0)]
V(s1) = max[0, 9] = 9
Δ = max[Δ, |v - v(s1)|] = max[0, |0 - 9|] = 9
//******************* state s2 *******************//
// terminal state with no actions
这为我们的值表提供了以下更新:
V(s0) = 0
V(s1) = 9
V(s2) = 0
我们不需要担心 s2 ,因为这是一个终止状态,意味着这里不可能进行任何操作。
我们现在分离内循环,继续外循环,对以下内容进行收敛检查:
Δ < θ = 9 < 0.01 = **False**
由于尚未收敛,我们进行外循环的第二次迭代:
Δ = 0
另进行 3 次内循环,使用更新后的值表:
//******************* state s0 *******************//
v = 0
V(s0) = max[T(s0 | s0, hit to green) * (R(s0, hit to green, s0) + γ * V(s0)) +
T(s1 | s0, hit to green) * (R(s0, hit to green, s1) + γ * V(s1))]
V(s0) = max[0.1 * (0 + 0.9 * 0) +
0.9 * (0 + 0.9 * 9)]
V(s0) = max[7.29] = 7.29
Δ = max[Δ, |v - v(s0)|] = max[0, |0 - 7.29|] = 7.29
//******************* state s1 *******************//
v = 9
V(s1) = max[T(s0 | s1, hit to fairway) * (R(s1, hit to fairway, s0) + γ * V(s0)) +
T(s1 | s1, hit to fairway) * (R(s1, hit to fairway, s1) + γ * V(s1)),
T(s1 | s1, hit in hole) * (R(s1, hit in hole, s1) + γ * V(s1)) +
T(s2 | s1, hit in hole) * (R(s1, hit in hole, s2) + γ * V(s2))]
V(s1) = max[0.9 * (0 + 0.9 * 7.29) +
0.1 * (0 + 0.9 * 9),
0.1 * (0 + 0.9 * 9) +
0.9 * (10 + 0.9 * 0)]
V(s1) = max(6.7149, 9.81) = 9.81
Δ = max[Δ, |v - v(s1)|] = max[7.29, |9 - 9.81|] = 7.29
//******************* state s2 *******************//
// terminal state with no actions
在第二次迭代结束时,我们的值是:
V(s0) = 7.29
V(s1) = 9.81
V(s2) = 0
再次检查收敛:
Δ < θ = 7.29 < 0.01 = **False**
仍未收敛,所以我们继续上述过程,直到 Δ < θ。我不会展示所有计算,以上两个步骤足以理解过程。
经过 6 次迭代,我们的策略已经收敛。这是我们的价值和收敛率在每次迭代中的变化情况:
Iteration V(s0) V(s1) V(s2) Δ Converged
1 0 9 0 9 False
2 7.29 9.81 0 7.29 False
3 8.6022 9.8829 0 1.3122 False
4 8.779447 9.889461 0 0.177247 False
5 8.80061364 9.89005149 0 0.02116664 False
6 8.8029969345 9.8901046341 0 0.0023832945 True
现在我们可以提取我们的策略:
// Policy extraction rule
// π(s) = argmaxₐ Σₛ′, ᵣ p(s′, r|s, a) * [r + γV(s′)]
//******************* state s0 *******************//
// we know there is only one possible action from s0, but let's just do it anyway
π(s0) = argmax[T(s0 | s0, hit to green) * (R(s0, hit to green, s0) + γ * V(s0)) +
T(s1 | s0, hit to green) * (R(s0, hit to green, s1) + γ * V(s1))
π(s0) = argmax[0.1 * (0 + 0.9 * 8.8029969345) +
0.9 * (0 + 0.9 * 9.8901046341)]
π(s0) = argmax[8.80325447773]
π(s0) = hit to green
//******************* state s1 *******************//
π(s1) = argmax[T(s0 | s1, hit to fairway) * (R(s1, hit to fairway, s0) + γ * V(s0)) +
T(s1 | s1, hit to fairway) * (R(s1, hit to fairway, s1) + γ * V(s1)),
T(s1 | s1, hit in hole) * (R(s1, hit in hole, s1) + γ * V(s1)) +
T(s2 | s1, hit in hole) * (R(s1, hit in hole, s2) + γ * V(s2))]
V(s1) = max[0.9 * (0 + 0.9 * 8.8029969345) +
0.1 * (0 + 0.9 * 9.8901046341),
0.1 * (0 + 0.9 * 9.8901046341) +
0.9 * (10 + 0.9 * 0)]
π(s1) = argmax[8.02053693401, 9.89010941707]
π(s1) = hit in hole
我们的最终策略是:
π(s0) = hit to green
π(s1) = hit in hole
π(s2) = terminal state (no action)
所以,当我们的智能体处于 球在球道上 状态(s0)时,最佳行动是 击球到果岭。这似乎很明显,因为这是唯一可用的行动。然而,在 s1 中,那里有两个可能的行动,我们的策略已经学会了 打入洞中。我们现在可以将这个学到的策略提供给其他想打高尔夫的智能体!
就是这样!我们刚刚使用价值迭代解决了一个非常简单的 RL 问题。
动态编程的局限性
需要注意的是,价值迭代以及其他动态编程算法都有其局限性。首先,它假设我们对 MDP 的动态有完全的了解(我们称之为基于模型的 RL)。然而,在实际问题中这种情况很少见,例如,我们可能不知道转移概率。对于这种情况,我们需要使用其他方法,如Q-learning(无模型 RL)。
国际象棋中的维度诅咒
其次,对于更大的问题,随着状态和动作数量的增加,价值表的大小会呈指数增长(想象一下尝试定义所有可能的国际象棋状态)。这导致了‘维度诅咒’问题,其中计算和内存需求迅速增加,使得将 DP 应用于高维问题变得具有挑战性。
然而,VI 学习是很棒的,因为它介绍了一些 RL 的关键基础概念,这些概念构成了你可能会继续学习的更复杂算法的基础。
感谢阅读!
我希望这篇文章提供了一个易于理解的强化学习介绍,特别是价值迭代。
如果你在这里学到了新的知识,请点赞👏并关注!
除非另有说明,所有图片均由作者创作。
强化学习基础:理解马尔科夫决策过程背后的随机理论
第一部分:马尔科夫决策模型的理论基础,这为强化学习问题提供了基础
·
关注 发表在 Towards Data Science ·28 分钟阅读·2023 年 2 月 16 日
--
一个简单的 MDP 示例,包括三个状态(绿色圆圈)和两个动作(橙色圆圈),以及两个奖励(图片来源:维基百科)
强化学习(RL)是一种机器学习类型,使代理能够通过采取行动在不确定的环境中学习实现目标。强化学习的一个重要方面是它评估所采取的行动,而不是通过提供正确的行动来指导。每个行动都有一个相关的奖励,向代理发出该行动在朝着目标前进中成功的信号。代理在环境中重复导航,学习如何优化以达到其目标。
直到最近,大多数强化学习的成功与游戏代理相关,如 Alpha Go。然而,强化学习越来越多地应用于现实世界中的应用,如自动驾驶汽车和机器人自动化。最近推出的 ChatGPT 也在其架构中包含了一个用于答案微调的 RL 组件。鉴于此,理解 RL 变得重要,因为它被认为具有陡峭的学习曲线,既包括理论方面,也包括实际应用。
本文重点讨论了马尔可夫决策模型(MDP),它通常用于形式化强化学习代理的问题。然而,这种对强化学习问题的概念化一个被忽视的方面是 MDP 框架是随机理论的重要应用。
基本的马尔可夫决策过程方程背后有相当复杂的随机和概率理论。没有理解这些方程及其推导过程,就很难在 RL 领域取得进展。本文旨在澄清 MDP 背后的随机和概率概念,重点理解 Sutton(2017)第三章中呈现的方程。
基本的 MDP 框架如 Sutton 和 Barto(2017),第三章中所呈现的,实际上看起来相当简单和直观——这是一个大师级的巧妙手法!然而,这个框架背后有大量的统计理论。我们不需要了解随机过程的深度,因为,相信我,真的很多!但了解一些 MDP 中使用的概念背后的细微之处是很重要的。因此,理解什么是随机过程很重要,因为这些理论是 MDP 的核心。
本文是一个系列文章中的第一篇,旨在使强化学习(RL)的基础知识在理论和实际应用层面上更加清晰。
目标受众是谁?
- 当然,这看起来像是为初学者准备的,但我认为这篇文章实际上对那些有一定 RL 知识的人更有用,他们可以更直观地理解 Sutton (2017) 第三章中呈现的方程。这篇文章是我学习 RL 过程的结果。我最初很高兴地开始阅读 Sutton 和 Barto (2017),并认为我已经理解了第三章中大部分关于 MDP 框架的概念。然而,当我尝试实际推导第三章中呈现的一些方程时,我发现自己无法弄明白。因此,我意识到可能有更多的内容超出了我的认识范围。我找不到真正的资源将 MDP 关联到其基础统计学支撑上。我不得不使用各种材料,从 Wikipedia 到各大学级的随机理论课程。鉴于这些基础材料的分散,我决定在这篇文章中记录下来。
方法 — 解码 Sutton 和 Barto (2017)
-
另一件事是这篇文章的格式。Sutton (2017) 代表了 RL 理论的圣经。目前的版本是在线提供的更新版。然而,大部分理论材料实际上是在 20 年前的书籍第一版中发表的。此外,书中第三章所采用的方法是互联网中常见的方法。鉴于此,我在这篇文章中的方法有点像一个“钥匙”,我呈现了书中的材料,然后通过随机理论的视角进行解释。
-
稍微岔开一下,为“钥匙”辩护。大量学生,尤其是在学校,广泛使用它们;而教师实际上对这种“填鸭式教学”表示不满。然而,当你在网上学习,通常是独自一人时,详细的指南变得重要,因为你不能像在课堂上那样举手提问。
-
够多的前言!让我们转到实际话题….
MDP 的作用
-
理解强化学习问题在理论 MDP 环境中如何工作,对于掌握 RL 的基本原理至关重要。许多重要的 RL 算法要么基于马尔可夫方程或模型,要么代表了对马尔可夫模型的一些关键偏离。MDP 是强化学习过程的数学理想化形式,并提供了一个理想化的强化学习问题的理论框架。因此,无法回避:理解 MDP 对理解 RL 问题及其解决方法是基础性的。
-
注意,随机理论是概率理论的一个特别复杂的分支,有时所使用的符号可能非常令人畏惧。我将尝试更明确地呈现 MDP 背后的随机概念和思想。希望最终的结果能突显出马尔可夫模型背后的复杂性。
-
在这篇文章中,我将内容分为 3 个部分:
第一部分:MDP 的关键组件
-
RL 问题
-
状态、代理、环境、奖励等。
第二部分:MDP 中使用的概率和随机概念
-
随机变量、状态或样本空间、随机过程、实现或轨迹
-
概率分布、联合分布、边际分布、条件分布
-
推导马尔可夫性质
-
概率转移矩阵
第三部分:使用概率和随机概念理解 Sutton Ch3 方程
-
MDP 模型
-
MDP 模型的形式化表示
-
Sutton (2017) Ch3 中的方程 3.2, 3.3, 3.4, 3.5, 3.6
这可能只是 MDP 的介绍部分。你可能会问:“那价值函数、策略、Q 值函数、贝尔曼方程呢?”这些内容将在下一篇文章中介绍,因为一旦我讲解了上述内容,你会同意这些内容本身就已经很多了。
第一部分:MDP 的关键组成部分
首先,为了设置我们即将解释的内容的背景,我将简要总结Sutton (2017), Ch3使用的基本 MDP 模型和定义。
为什么选择 MDP?
MDP 框架受欢迎的主要原因是由于马尔可夫性质,其中未来状态仅依赖于当前状态,而不依赖于状态的历史。换句话说,当前状态封装了做出关于未来状态决策所需的所有相关信息。顺便说一句,这种看似无害的假设实际上背后有丰富的概率理论,并且也有相应的影响。
由于这个特性,MDP 代表了一种可处理的方式来形式化顺序决策。它提供了一个概率框架,用于建模在结果部分随机和部分由决策者控制的情况下的决策。在这种情况下,行动不仅在奖励方面有直接的回报,还会影响后续时间段或状态,从而影响未来的奖励。当然,强化学习在实践中的方式与 MDP 有所不同。
一些定义以开始
代理:这是学习者或决策者,在环境中做出顺序决策以实现最终目标。
环境:这被定义为代理控制范围之外的所有事物。因此,环境可以是机器人的外部环境,可以是自驾车的道路和行人等。在视频游戏中,环境就是游戏。环境的以下几点是重要的:
-
观察是代理接收到的输入——例如,自驾车只能接收来自其周围环境的输入,而不会意识到前方很远处的道路障碍。
-
只有在知道环境模型的情况下,代理才能完全了解环境,例如掌握游戏规则的结构,这些规则决定了环境。
-
环境会因代理的行动而改变(突然刹车会造成小堵塞),但也可能自行变化(道路改道)。
什么是 RL 问题?
典型的 RL 问题涉及一个学习代理在一段时间内与其环境互动以实现目标。为此,学习代理必须能够在一定程度上感知环境的状态,并且必须能够采取影响状态的行动。这个基本问题在下面的图中展示。
在马尔可夫决策过程中,代理和环境的交互(图像来源:作者创作,灵感来源于 Sutton (2017) 第三章)
时间的测量
时间通常以离散步骤来测量:t=0,1,2,..T。其中,T 代表最终或终端状态。因此,这是一个有限离散时间问题。还有连续时间问题以及无限视界问题,但这些在数学上处理起来更为复杂。因此,重点在于有限离散时间模型。
状态
这代表了环境的一个快照,包含了所有与环境相关的信息。这包括奖励信息、环境如何根据行动发生变化等。以自动驾驶汽车为例,状态可能代表汽车周围的一个区域。类似地,以前进中的机器人为例,它将代表在采取如前进等行动后环境的快照。
对于马尔可夫目的,代理和环境在一系列离散时间步骤中交互:t=0,1,2…T。其中,T 代表最终或终端状态。在每个时间步骤 t,代理接收环境的某种表示,封装为 Sₜ ∈ S,其中 S 是所有状态的集合。在 MDP 的背景下,状态是一个离散随机变量。
奖励
代理在下一时间周期从环境中接收到一个信号或奖励 Rₜ₊₁,这与行动在实现整体目标方面的表现相关。奖励通常定义为一个标量,其中 Rₜ ∈ R ⊂ℝ 只是实数集合的一个例子(复杂的符号表示简单的东西!但在统计世界中就是这样😊)。
奖励可以是确定性的,也可以是随机变量。Sutton(2017)将 Rₜ 视为随机变量,尽管大多数示例采用确定性奖励。
现在,Rₜ 依赖于世界的当前状态、刚刚采取的行动和世界的下一个状态,即 Rₜ = R(Sₜ, Aₜ, Sₜ₊₁)。然而,它通常被简化为仅依赖于当前状态和行动对,即 Rₜ = R(Sₜ, Aₜ)。
现在我们已经有了 MDP 的基本组成部分,即状态、奖励、代理、环境,我们需要更好地理解在概率和随机理论背景下,MDP 变量状态、奖励、代理和环境的性质。
现在,让我们跳到概率和随机理论……
第二部分:概率与随机基础
要定义随机过程,我们首先需要理解什么是随机变量、概率分布、联合分布、边际分布、条件分布和概率链规则。
现在让我们花些时间理解上述概念。一旦理解这些,我们将回到 MDP 并利用这些概念来理解其工作原理。
由于统计和概率非常广泛,我希望专注于这些概念在 MDP 中的适用性。因此,我将以斜体展示与 MDP 概念相关的统计学概念,然后解释其在 MDP 中的相关性。希望这能使 MDP 概念更加清晰。
随机变量
在概率论中,随机变量是捕捉变量可以取多个值或发生的方式。它表示一个映射或函数,捕捉变量可以映射到多个值的事实。例如,掷硬币的结果有:{H,T},我们可以将其映射到介于 0 和 1 之间的数字,表示概率。请参阅这个 维基百科 页面以获取更正式的定义。顺便说一句,我发现维基百科上的概率和统计页面是一个非常好的资源,尽管稍微有些高级。
当随机变量的范围是可数的时,称为离散变量,其分布是离散概率质量函数(这个概念我稍后会定义,因为它将被频繁使用)。如果变量是连续的,那么它的分布可以由概率分布函数定义。
随机变量的定义通常为实值情况定义,记作 R。这个定义也可以扩展到任何可测集合 E,其中包含随机布尔值或分类值、向量、矩阵或函数。
这点很重要,因为在 MDP 中,我们有实际值的随机变量以及分类类型。同时,随机过程(稍后会定义)也是时间的随机函数。
在上述定义的 MDP 背景下,我们有状态和奖励作为随机变量。这两个变量都是离散的,并且有一个有限的范围。
状态空间或样本空间:
当我们谈论随机数时,随机变量可以取的有限状态数称为样本状态或在 MDP 情况下的状态空间。样本空间的最简单例子是掷硬币。它可以取两个值之一:正面或反面。
在 Sutton 第三章中,这由集合S表示,包含不同的状态,每个状态可以代表不同的情况。例如,我们可以有一个用于去办公室的随机过程。S可能包含的状态列表如下:
1. s’: 遇到通勤去办公室的交通堵塞
2. s: 遇到去办公室的清晰道路
3. s”: 遇到因事故造成的道路阻塞
这些状态构成了样本空间或随机变量的范围,即它可以取的值。
随机过程:
随机过程是由变量 t(通常代表时间)索引的随机变量的集合或集成。明确地说,随机过程包括在多个时间点上的相同随机变量。随机过程通常由时间变量索引(它也可以是向量空间,但为了避免混淆,我们暂且不提)。因此,每个索引或时间点都与特定的随机变量相关联。了解更多关于随机过程的信息,请访问这个 维基百科 页面。
这意味着在每个时间点,随机变量可以实现 S.**中的一个值。随机过程也可以写作:S(t,i),t∈T。其中状态是时间(由 t 索引)和代表状态空间S的特定状态 i 的函数。
状态空间是通过反映随机过程可以取的不同值的元素来定义的。因此,在 MDP 的情况下,状态 Sₜ 是一个随时间变化的随机变量,由 t 索引,其中对于每个 t,Sₜ 可以取值于有限集合 S =。
关于随机过程的困惑在于,通常跳过第二个索引以避免混乱的展示。然而,没有它的话,实际有 2 件事情正在发生的清晰度会丧失:
1. 随着变量 t 的演变,随时间推移的运动
2. 在每个时间点上,从状态集合中选择一个状态
让我们继续以道路为例,以说明随机过程随时间演变和不同状态空间之间的区别。在某个时间点上,随机变量可以取有限数量的状态空间。例如,当我们前往办公室时,在前 15 分钟内,我们可能会遇到清晰的道路或堵车。接下来的 15 分钟,情况可能会发生变化。假设到达办公室需要 45 分钟,那么我们可以将其分为 3 个 15 分钟的时间段。因此,我们的时间段可以是 t=1,2,3。
在前 15 分钟内,道路清晰的概率为 p,遇到堵车的概率为 q,遇到道路封闭的概率为 r。在接下来的 15 分钟内,这三种结果中的一种可能再次出现,最后 15 分钟也是如此。这就是随时间演变。这意味着,在每个时间点上,存在获得某个结果的特定概率。可以通过以下方式进行说明。
办公室驱动 MDP(图像来源:作者)
图像显示了一个样本轨迹。在时间 t=1,即前 15 分钟,道路可能是畅通的、拥堵的或封闭的,各有相应的概率。类似地,对于 t=2 或第二个 15 分钟,以此类推。在某一时刻,比如前 15 分钟,我们可以看到道路状态可以随机变化于 3 种可能状态。这代表了在某一时刻状态的随机性。然而,如果我们看时间 t,那么我们就有了一个轨迹。所以,如果我们有 3 个时间点,我们将有 3 个随机变量,每个时间点对应一个。
我们还可以在这种情况下理解奖励集。奖励集R在这种情况下就是在特定时间覆盖特定道路段所需的时间。每种状态都有一个定义的时间分布(为了简化我们可以以离散分钟为单位)。这意味着什么?基本上,道路的每种状态都与时间分布相关。就像有时我们在 15 分钟内完成清晰道路 (s),有时由于一些随机变化在 17 分钟,有时在 18 分钟。因此,我们可以定义 R(s)={15,16,17,18}。对于清晰道路条件下完成路段所需时间有一个概率分布。这是奖励也可能是随机的一个直观示例——这是 Sutton (2017),Ch3 中的大多数示例没有涵盖的内容。道路的每种状态——例如,拥堵状态——将有其自身的 ETA 分布。
随机过程的实现或轨迹
随机过程可以有许多 结果, 由于其随机性。随机过程的单个结果被称为 实现、情节或轨迹。它是通过在随机过程的每个时间点上取随机变量的一个可能值来形成的。
所以,继续我们的交通示例:图像中的箭头显示了一个可能的轨迹。在前 15 分钟,司机遇到了一条畅通的道路,在接下来的 15 分钟遇到了拥堵,而在最后 15 分钟遇到了封闭的道路。所以,一个潜在的样本路径可以是:(s,s’,s’’)。这只是一个可能的实现。司机每天都经历这条路径,并且很可能每天会遇到略微不同的状态集;另一天可能是 (s’’,s’,s)。
随机过程被广泛应用于对系统和现象的概率表示,这些系统和现象表现为随机变化。简单来说,随机过程或模型用于估算随时间变化的随机变量的各种结果的概率。例子包括 细菌 群体的增长,或 气体分子的运动。它们在金融分析中也被广泛使用,其中随机模型可用于估算涉及不确定性的情况,如投资回报、市场波动或通货膨胀率
此外,还有许多不同类型的随机过程,其中 Markov 过程是一种。其他类型包括随机游走、高斯过程等。每种随机过程都有其特定的假设和特征,这些特征很重要。
概率分布
现在我们转向如何获取不同状态的概率?回顾上面的例子,我们已经定义了遇到拥堵和清晰道路的概率分别为 q 和 p。现在,让我们稍微详细说明一下这在概率分布中的含义以及它们是如何产生的。
随机结果(道路拥堵)的概率被评估为在非常长的重复序列中该结果发生的比例。因此,这就是我们如何通过记录多个天数中的道路状态变量在特定时间间隔的状态来评估 p、q、r。
记录随机变量的所有这些输出概率,Sₜ,给我们提供了随机变量的概率分布。在离散情况下,这被称为概率质量分布。 这是一个提供离散随机变量等于特定值的概率的函数。有关更多信息,请参阅 这里。
继续我们办公室通勤的例子:状态变量因此可以以特定的概率取这些值。例如,如果我们为一条道路定义的状态集包含 3 种可能的状态:S= { s = 清晰, s’= 拥堵, s’’= 阻塞}。
在特定时间点(比如早晨),道路可以有这三种状态中的任何一种,且具有相应的概率。例如,它可以是自由 =0.5,拥堵 =0.4,阻塞 =0.1(阻塞相对较少见!)。这些概率的集合被称为该变量的概率分布,即道路的状态。
联合分布 vs 边际分布 vs 条件分布
在随机过程的情况下,由于随机变量相同但在不同时间点测量,因此状态在 t, Sₜ , 可能与之前的状态有关,即 Sₜ₋₁, Sₜ₋₂。在这种情况下,如果我们想了解 Sₜ 的概率,则相关的是联合概率质量分布 (pmf)。联合 pmf 允许我们计算涉及多个随机变量的事件的概率,同时考虑变量之间的关系。
给定两个 随机变量 定义在相同的 概率空间 上, 联合概率分布 是所有可能的输出对的相应 概率分布 。
这可以更一般地在马尔可夫背景下定义为:
假设我们考虑一个离散时间 2 时间段随机过程, Sₙ , 其中 n=1,2。所以,* S₁ 和 S₂ 是两个离散随机变量,可以从集合 i ∈ S = {1,2} 中取值。* Sₙ 的联合分布给出每个 n,即时间段和有限状态序列的:(i₁,i₂) 为:
让我们看看我们的简单道路示例:
为简单起见,假设只有 2 个时间段,即 Sₜ 和 Sₜ₋₁。状态空间仍然是 3:{clear(C),jammed(J),blocked(B)}。
联合概率表如下:
2 时间段办公室驾驶联合概率分布(来源:作者)
联合概率是两个事件同时发生的概率。上表提供了两个事件的不同实例组合。注意,由于这是一个随机过程,时间 t-1 总是先于 t。
边际概率分布
变量的 边际分布 给出了在子集中各种值的概率,而不参考其他变量的值。Sᵢ 的边际质量函数为:
边际概率是单个事件发生的概率。本质上,边际分布是按列的概率总和。这个方程要求你对 j 列进行求和。任何组合的联合概率由单元格概率给出。在我们示例的上下文中,它是 Sₜ 中遇到交通堵塞的概率,而不考虑 Sₜ₋₁ 中的路况,即我们对 J 列的概率进行求和。
条件概率
一个 条件概率 是指在另一个特定事件已经发生的条件下,某一特定事件发生的概率。这意味着在 Sₜ 的情况下,给定道路在 Sₜ₋₁ 时是畅通的情况下发生堵塞的概率。
现在在我们的例子中,由于 Sₜ 是一个随机变量,并且我们考虑的是一个两个时期的例子,所以联合概率和条件概率是相同的。
在数学上,给定另一个变量的条件下的变量分布是两个变量的联合分布除以另一个变量的边际分布。
一般而言,定义为:
将其推广到我们有 n 个随机变量的情况, X₁,.., Xₙ。
这 n 个变量的联合概率分布是:
这可以写成条件函数乘以边际函数。这是基于 概率链式法则 的。在符号表示上, Pₓ (x₁)= P(X=x₁)。* 对于两个随机变量* X₁, X₂ ,其联合分布可以表示为:
用言语表述:两个变量 X₁, X₂ 取特定值 x₁, x₂ 的联合概率分布可以表示为: X₂=x₂ 在给定 X₁=x₁ 的条件下的条件概率,乘以 X₁=x₁ 的边际概率。
我们为什么需要这些分布——推导 Markov 性质的需求
联合、条件和边际概率分布定义在 MDP(马尔可夫决策过程)背景下变得重要,当我们想确定状态变量在不同时间段过渡时取特定值的概率时。我们的例子很简单,由三个时间段组成。但是,如果时间段是 100,我们想了解状态在第 100 个时间段取特定值的概率怎么办?这就是潜在的概率问题。让我们看看一般情况如何解决,然后了解 Markov 性质。
对于更一般的情况,要获得 X₁,.., Xₙ 随机变量取特定值 x₁,.., xₙ 的联合概率,可以通过 Xₙ 给定之前所有随机变量取特定值的条件概率,乘以前 n-1 次分布的联合概率分布来找到。
我们可以将 P(Xₙ₋₁=xₙ₋₁…,X₁=x₁)* 的联合概率表达式再次展开为* Xₙ₋₁ 给定(Xₙ₋₂,…,X₁)的条件概率和(X₁,..,Xₙ₋₂)的联合概率。
这可以进一步向后展开。最终,我们将得到条件概率序列和第一个变量 P(X₁=x₁) 的边际概率的表达式。*
现在让我们将其翻译到 MDP 的背景中。到目前为止,这已经以任何随机变量序列的形式呈现。然而,这个结果很容易转化为随机变量和马尔可夫过程。我们只需将 Xᵢ替换为 Sₜ,现在索引将以时间的形式出现,即 t=1,2,..n,而不是 Xᵢ = xᵢ。
我为什么要重复这些明显的内容?因为有很多材料使用不同的变量和索引来描述马尔可夫过程和其他随机变量序列,这可能会非常混乱。我尝试为自己解开这些复杂的概念,因此写下这些内容,以帮助那些试图弄清楚马尔可夫结果背后的概率理论的非统计学家。
在分析 MDP 时,t=1,..,n,状态 xᵢ ∈ S。为了明确,xₙ是时间段 t=n 时的一个特定状态值。
让我们写出状态变量在 t=1..n 的演变过程中,获得 S₁,..,Sₙ的联合概率。代入上述讨论的一般情况,我们得到的联合概率是条件概率的序列和 S₁的边际分布。
用语言描述就是:获得状态 S₁,..,Sₙ的概率等于在时间 t=n 时状态的值为 xₙ的条件概率,乘以前 n-1 期的联合分布。这展开为一系列条件概率和状态 S₁的边际或初始分布。
现在这只是一个随机过程,其中状态在时间段 n 的值依赖于过去状态值的历史。这显然是难以计算的,因此需要马尔可夫性质。
马尔可夫性质:
马尔可夫性质使得状态 Sₙ只依赖于紧接前一个状态 Sₙ₋₁,而不是整个历史状态的积累。在这种情况下,我们的状态 Sₙ的条件概率减少为:
然而,我们仍然需要估计 n-1 个一步条件概率以及初始概率分布,P(X₁=x₁)。这仍然比较复杂,因为需要计算大量的条件概率。为了简化并使计算更可行,进一步的假设是过渡概率的时间同质性。
过渡概率的时间同质性
这基本上说的是:从时间段 n 的状态 i 过渡到时间段 n+1 的状态 j 的概率是相同的,不受时间段的影响:
这意味着从一个状态过渡到下一个状态的概率是固定的,不受时间的影响。
放到我们的道路示例中,从清晰道路到拥堵道路的概率是固定的,无论我们在前 2 个时间段还是最后 2 个时间段遇到这 2 个状态。因此,转移概率简化为以下 9 种情况,与时间无关。其中 C=清晰,j=拥堵,B=封堵。
如你所见,如果状态的数量少于时间周期,通常这种情况会显著减少需要计算的概率。
转移概率
状态 转移概率矩阵 给出了在单一步骤中从一个状态转移到另一个状态的概率。
对于 t=1,2 状态情况,这将把问题的维度简化为
这分解为以下几个组成部分:
条件概率:
初始分布或边际分布:P(S₁=x₁)*
在我们的道路示例中,这些是 3 个状态的初始概率:p、q、r。
所以,我们所需的仅仅是一步条件概率和状态的初始分布,以获得转移到某个状态的联合概率。
转移矩阵可以表示如下:
来源:维基百科 随机矩阵
在转移矩阵中,行代表当前状态,列代表未来状态。因此 p₁₂ 是从时间 t 的状态 1 转移到时间 t+1 的状态 2 的概率。
另外,请注意从一个状态出发的所有分支概率之和必须为 1。这是逻辑上的,因为你必须从一个状态转移到另一个状态。数学上可以写作:
为什么我们对‘j’进行求和?因为这代表了你可以从‘i’转移到的可能状态。
马尔可夫链动态
对于马尔可夫链,我们使用“动态”一词来描述在给定初始状态下,状态在短时间间隔内的变化。马尔可夫链的演变是一个随机过程,因此我们不能确切地说初始状态之后会跟随什么状态序列。在强化学习中,我们希望根据当前状态或初始状态预测未来状态。一旦我们有了概率转移矩阵和初始分布,可以很容易地计算出给定 S₁ 的未来状态 Sₙ 的预测。n 步转移概率矩阵可以通过将单步概率矩阵自乘 n 次得到。
这就是马尔可夫系统的美妙之处,以及它为何在强化学习中被使用。一旦我们拥有一个可处理的转移概率集,我们就可以计算从任何初始状态到状态 j 在 n 步内的转移概率。
现在,在对 MDP 背后的随机和概率进行如此长时间的离题讨论之后,我将简要回顾 Sutton(2017)第三章中的 MDP 方程。
第三部分:使用概率和随机概念理解 Sutton 第三章的 MDP 方程
MDP 模型
在这一节中,我简要定义了 MDP,给出本质上是教科书中的定义。在下一节中,我们将探讨 MDP 的统计理论。
MDP 框架是一个关于目标导向学习从交互中问题的显著抽象。它提出,无论感知、记忆和控制装置的细节如何,以及目标是什么,任何目标导向行为的学习问题都可以归结为在代理和环境之间来回传递的三个信号:一个信号表示代理所做的选择(行动),一个信号表示选择的依据(状态),一个信号定义代理的目标(奖励)。这个框架可能不足以有效地表示所有决策学习问题,但它被证明在广泛的应用中非常有用。
Sutton 第二章,第 50 页
实际上,大部分 MDP 框架可以在没有详细的概率和随机理论知识的情况下高层次地理解。然而,现在我们已经接触到这些概念,我们可以真正欣赏到其中的复杂性。让我们现在看看 Sutton(2017)第三章中与马尔可夫过程相关的一些方程。
马尔可夫性质
这本质上是未来在给定现在的情况下与过去无关的观点。它在概率上被概念化为:
这已经在上文中详细讨论过,使我们能够理解基本的马尔可夫问题如何变得更容易处理。满足马尔可夫性质的随机过程通常比将历史过程纳入模型中更简单分析。在实践中,考虑马尔可夫性质是否成立是很重要的。例如,对于汽车的位置来说,汽车的当前所在位置将决定它未来的状态。但在某些情况下,历史可能很重要:例如,交易代理可能需要考虑股票的过去趋势以及当前价格,以决定是否买入或卖出。
马尔可夫过程是具有以下性质的随机过程:
(a.) 可能的结果或状态的数量是有限的
- 这是为了分析的便利。无限过程也可以通过一些小的操作来处理。然而,对于时间依赖的过程,我们假设时间是有限的 t=0,1,2,…,T。
(b.) 任何阶段的结果仅依赖于前一阶段的结果 — 由上述讨论的马尔可夫性质给出。
(c.) 转移概率随时间保持不变。
(d.) 系统是平稳的 — 这实际上没有明确说明,但属性 c) 实际上隐含了这一点。
MDP 的正式表示
正式地,MDP 可以表示如下。鉴于马尔可夫性质,MDP 定义如下:一个通常包含 5 个元素的元组:
S:环境中的状态集合:这实际上是一个随机过程
A:每个状态下可能的动作集合
R: 对动作的奖励函数:这通常是确定性的,但有时也可以是随机的
p : 从一个状态转移到下一个状态的概率矩阵 — 这是 MDP 的关键部分。它由状态变量的联合和条件分布定义。
γ : 适用于远期奖励的折扣因子
MDP 和智能体共同产生一个序列或轨迹,如下所示:
这个过程是一个随机过程且是有限的,即有 T 个时间步骤,并且有一个明确的终止状态。
理解方程
方程 3.2
该方程告诉我们 Rₜ 和 Sₜ 是具有离散概率质量分布的随机变量,这些分布依赖于先前的状态和动作。方程 3.2 定义了动态函数 p。这是一个包含 4 个参数的函数。它还为每种状态和动作选择指定了条件分布。对于特定的状态 s’ 和奖励 r,我们可以估计在时间 t 这些值发生的概率,给定在 t-1 时刻的前状态和动作值。
方程 3.3
方程 3.3 基本上告诉我们,对于每个“s”和“a”的值,都有一个状态和奖励的联合分布。这基本上是一个联合条件分布,因此它的和必须为 1。联合分布将具有不同组合的 s’ 和 r 的概率。所有组合的概率之和必须为 1(参见维基百科)。
让我们回到之前的道路示例。在道路的每个状态下,我们有两个动作:驾驶或停止。如果我们处于“拥堵”状态并选择驾驶,那么在这两者的条件下,将有一个下一状态的状态和奖励的联合分布。
状态转移概率 — 方程 3.4
状态转移概率通过对奖励概率进行求和来给出(请注意,对于概率挑战:类似于状态的边际概率分布 - 参见维基百科)。
期望奖励 — 方程 3.5
在这里,我们再次处理方程 3.2。这次我们对状态的概率进行求和(对于概率不太清楚的人:类似于奖励的边际概率分布 — 参见 Wikipedia)。这给出了奖励分布的边际概率(3.5 的极端右侧表达式)。然后我们将这个概率乘以实际奖励 r,以获得期望奖励。
期望奖励的状态-动作-下一状态三元组 — 方程 3.6
这个有点棘手。奖励和状态的联合概率条件在‘s’和‘a’下除以状态转移概率。这是联合分布链式法则的一个应用(参见 Wikipedia)。
让我们代入我们的术语。
现在我们可以对 Rₜ取期望值
所以,这完成了 MDP 设置的基本方程的推导。理解 3.2 到 3.6 每一个方程都很重要,因为 Sutton(2017 年)提到这些方程都是表示 MDP 的一种方式,并且会在后续章节中再次使用。
总结..
好朋友们,这一段关于概率和随机理论的旅程已经相当长了。然而,我希望你们意识到的是基础 Markov 系统的复杂性和美丽。如果不理解这些内容,很难弄清楚基本 MDP 方程的底层方程。如果不理解这些方程及其推导方式,当涉及到策略、价值函数和 Q 值时,很可能会完全迷失。下一部分教程中会涉及到这些内容。
希望你喜欢我的写作。请考虑关注我,以便获取更多类似的文章,我会继续撰写关于强化学习的内容:medium.com/@shaileydash
也请在评论中告诉我你的意见和建议,这将帮助我修改内容。
你可以在 Linkedin 上关注我,获取更多关于 AI 和数据科学的企业导向文章:www.linkedin.com/in/shailey-dash-phd-8b4a286
参考文献
我列出了一份相当长的统计学和概率理论资源列表。这些资源都非常优秀,但也相当困难。
Richard S. Sutton 和 Andrew G. Barto. 《强化学习:导论;第 2 版》。2017 年。
en.wikipedia.org/wiki/Markov_decision_process
en.wikipedia.org/wiki/Stochastic_process
en.wikipedia.org/wiki/Joint_probability_distribution
en.wikipedia.org/wiki/Conditional_probability_distribution
www.kent.ac.uk/smsas/personal/lb209/files/notes1.pdf
Lawler, G.F. (2006). 随机过程导论(第 2 版)。Chapman and Hall/CRC。 doi.org/10.1201/9781315273600
Yates, R.D. 和 Goodman, D.J., 2014. 概率与随机过程:电气与计算机工程师的友好介绍。John Wiley & Sons.
Unnikrishna Pillai, 概率与随机过程,第 8 讲,NYU Tandon 工程学院,www.youtube.com/watch?v=nFF6eRQT2-c&t=455s
强化学习在库存优化中的应用系列 III:从模拟到现实的 RL 模型转移
弥合模拟器与现实世界之间的差距
·
关注 发表在 数据科学前沿 · 8 分钟阅读 · 2023 年 1 月 30 日
--
照片由 Suad Kamardeen 提供,来源于 Unsplash
更新:这篇文章是我博客系列库存优化中的强化学习的第三篇文章。以下是同一系列其他文章的链接。如果您感兴趣,请查看。
在之前的文章中,我为单一零售商和多级供应链网络构建了两个 RL 模型。在这两篇文章中,我使用了基于历史需求数据的库存系统模拟器来训练 RL 模型,这些数据遵循不同正态分布的混合,其参数依赖于星期几。然后,我在一个测试集上测试了 RL 模型,其中的需求数据与历史训练数据的分布相同。这实际上假设模拟器与我们应用 RL 模型的现实世界完全匹配,且历史需求数据完全代表未来需求模式。然而,这一假设在实际应用中很少成立。
如何处理训练集与测试集之间分布的差距是机器学习领域中的一个著名问题。这个问题的核心是偏差与方差的权衡。如果我们过于依赖训练集来训练模型,那么模型容易过拟合,意味着它只会在训练集上表现良好,而在测试集上表现不佳。在 RL 领域也是如此。
由于模拟器与现实世界之间的差距导致的性能退化
以我之前构建的 RL 模型为例,这些模型通过大量的历史需求数据进行训练,因此很可能从训练的 RL 模型中学习到的库存政策过度拟合了历史需求数据所代表的需求模式。如果未来的需求模式继续保持相同,库存政策将表现良好,但如果未来的需求模式偏离历史模式,性能将会退化。
作为一个数值示例,假设未来需求分布在现实世界中偏离了用于训练 RL 模型的模拟器中的历史需求数据。我将我第一篇文章中训练的 RL 模型应用于现实世界的两个未来需求场景。在一个场景中,我们假设正态分布的均值增加 1,而在另一个场景中,均值减少 1。需求分布结构如下表所示。
两个未来/现实世界需求分布场景(图片由作者提供)
对于需求增加场景和需求减少场景(如上表的第二列和第三列),我生成了 100 个需求数据集,每个数据集包含 52 周的数据,使用其自身的需求分布设置。然后,我将基于历史/模拟器需求数据(如上表的第一列)训练的 DQN 策略应用于这两个场景,得到的平均利润如下表所示。
应用训练于历史/模拟器数据的 DQN 策略到每个需求场景中获得的平均利润(图片由作者提供)
为了评估训练的 DQN 策略表现如何,对于每个需求场景,假设我们知道未来/现实世界需求分布将如何偏离,并生成一个由 52 周需求数据组成的训练集,使用其自身的未来/现实世界需求分布设置。然后,我们为每个场景训练一个全新的 RL 模型,并将新 RL 模型应用于每个场景。得到的平均利润如下表所示。
应用训练于每个场景的训练集的 DQN 策略到每个需求场景中获得的平均利润(图片由作者提供)
从上面两张表的比较($25077.61 对比 $27399.79 和 $13890.99 对比 $14707.44)可以看出,如果模拟器与现实世界之间存在差距,RL 模型的性能会下降。
使用领域随机化弥合差距
为了弥合模拟器和现实世界之间的差距,在实际应用中,可以选择最近一段时间的历史需求数据,并假设这段较短的时间更好地代表了未来需求模式或趋势。然后,我们可以通过使用最近更新的较短历史需求数据频繁重新训练 RL 模型。然而,我们也可以尝试在训练过程中使用领域随机化的概念来解决这个问题。
领域随机化是一种通常用于将模拟转移到实际的技术,尤其是在将强化学习应用于机器人时。机器人中的强化学习应用也面临现实差距的问题,因为模拟与现实世界之间的差距会降低将强化学习模型转移到实际机器人中的策略的性能[1]。在机器人领域,领域随机化的核心思想是通过在训练强化学习模型时随机化模拟环境的物理参数(例如摩擦系数和视觉属性如物体外观),强化学习模型将体验到更接近真实环境的情况,从而学习到的策略会更好地推广到实际环境中。下图展示了领域随机化的直觉。
领域随机化的直觉(图片来源于参考文献 [1])
在库存优化的背景下,为了实现领域随机化,我们可以对历史需求数据进行随机化。具体而言,我们通过向需求观察值中添加正态分布噪声来扩展历史需求数据集。该想法的实现见下方代码块。
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
demand_hist = []
for i in range(52):
for j in range(4):
random_demand = np.random.normal(3, 1.5)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_hist.append(random_demand)
random_demand = np.random.normal(6, 1)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_hist.append(random_demand)
for j in range(2):
random_demand = np.random.normal(12, 2)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_hist.append(random_demand)
plt.hist(demand_hist)
## add randomized noise
demand_hist_randomized = []
rand_list = []
for k in range(10):
demand = demand_hist.copy()
rand1 = np.random.normal(1, 1)
rand2 = np.random.normal(1, 1)
rand3 = np.random.normal(1, 1)
rand_list.append([rand1,rand2,rand3])
for i in range(52):
for j in range(4):
demand[i*7+j] += rand1
if demand[i*7+j] < 0:
demand[i*7+j] = 0
demand[i*7+j] = np.round(demand[i*7+j])
demand[i*7+4] += rand2
if demand[i*7+4] < 0:
demand[i*7+4] = 0
demand[i*7+4] = np.round(demand[i*7+4])
for j in range(5,7):
demand[i*7+j] += rand3
if demand[i*7+j] < 0:
demand[i*7+j] = 0
demand[i*7+j] = np.round(demand[i*7+j])
demand_hist_randomized.extend(demand)
在这里,我通过向每个场景的需求观察值中添加随机噪声来创建了 10 个历史需求场景。然后我将这些场景组合在一起,作为用于训练的完整历史需求数据,假装我们有 10 年的需求数据。请注意,这个想法的本质也与计算机视觉任务中的数据增强技术相一致,我们通过操控图像来丰富训练集,以避免过拟合。
我尝试了不同的正态分布均值来生成随机噪声。最初,我认为使用均值为 0 是有意义的,因为它更有可能生成相等数量的需求增加和减少场景。有趣的是,在尝试了不同的值之后,我发现将均值设置为 0 以上(更多的需求增加场景)在这个特定的例子中得到了更好的测试结果。这可能是因为需求数据在 0 处被截断,因此尽管我们生成均值为 0 的噪声,但需求分布向左移动的空间不大。所以生成均值为正数的噪声会使强化学习模型学到更多有用的知识。
现在我们使用通过领域随机化获得的新 10 年历史需求数据来训练 DQN 模型,然后将学习到的 DQN 策略应用于上一节中描述的需求增加和减少场景(详见第一个表)。下表展示了每个场景中获得的平均利润。
使用领域随机化后训练的 DQN 策略在每个场景中获得的平均利润(图片由作者提供)
我们看到,与没有领域随机化的结果相比,明显有所改善(增加场景中为 $26432.61 对 $25077.61,减少场景中为 $14311.52 对 $13890.99)。
为了表明性能提升不是偶然获得的,我进一步创建了更多随机化的测试需求场景,通过对需求分布的均值和方差添加随机噪声。见下面的代码
。
demand_test = []
for k in range(100,200):
np.random.seed(k)
rand1 = np.random.normal(0,1)
rand2 = np.random.normal(0,1)
rand3 = np.random.normal(0,1)
rand4 = np.random.normal(0,0.5)
rand5 = np.random.normal(0,0.5)
rand6 = np.random.normal(0,0.5)
demand_future = []
for i in range(52):
for j in range(4):
random_demand = np.random.normal(3+rand1, 1.5+rand4)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_future.append(random_demand)
random_demand = np.random.normal(6+rand2, 1+rand5)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_future.append(random_demand)
for j in range(2):
random_demand = np.random.normal(12+rand3, 2+rand6)
if random_demand < 0:
random_demand = 0
random_demand = np.round(random_demand)
demand_future.append(random_demand)
demand_test.append(demand_future)
在领域随机化之前,DQN 策略在该测试集上的平均利润为 $20163.84,而领域随机化之后的 DQN 策略平均利润为 $21887.77,这仍然显示出性能的提升。
结论
在这篇文章中,我重点指出并解决了我为库存优化构建的 RL 模型在训练/仿真和测试/现实世界环境之间的差距。很高兴看到,像领域随机化这样的技术在机器人领域中也能在库存优化领域发挥作用。实验结果表明,使用领域随机化等技术来操控历史需求数据以丰富训练集应该是一种良好的实践,这样实验室中训练的 RL 模型可以更好地推广到现实世界。
参考文献
费曼-蒂尔福德算法解析及演示
原文:
towardsdatascience.com/reingold-tilford-algorithm-explained-with-walkthrough-be5810e8ed93
用数值示例和 Python 代码绘制树节点的算法
·发表于Towards Data Science ·阅读时长 9 分钟·2023 年 9 月 12 日
--
由Sergiu Vălenaș拍摄,来源于Unsplash
介绍
1981 年的费曼-蒂尔福德算法通过将节点安排在树结构中以最大化可读性,从而创建了一个视觉上令人愉悦的层级数据表示。换句话说,它是一个用于检索树中每个节点的(x, y)
坐标的算法。
根据论文,一个好的树状图应遵循一些美学规则,
同一深度的节点应排列在一条直线上,并且定义深度的直线应当平行。
左子节点应位于其父节点的左侧,右子节点应位于右侧(仅适用于二叉树)
父节点应当居中于其子节点上方
树及其镜像应产生互为反射的图形,子树的绘制方式应与其在树中的位置无关
确定节点的y
坐标是简单的,而x
坐标稍微复杂一些。本文将尝试用数值示例解释算法,树的复杂度略高于我所读过的其他论文和文章,以便涵盖更多场景。我还将介绍一些原论文中没有使用的额外术语,以更好地区分不同的术语。
对于算法的一个重要直觉是,它将从左到右绘制树。你可以将最左侧的节点视为
(0, 0)
坐标,各个子树将相应地向右移动。有趣的事实:
scikit-learn
Python 库也使用这个算法来绘制决策树!
术语
在确定每个节点的最终坐标之前,有 3 个术语是至关重要的。原始论文引用了 x
和 mod
这两个术语,但在我的解释中,为了更清晰,我将使用一个额外的 shift
术语。
-
**x**
指的是节点的初始x
坐标,纯粹是基于节点的位置(这不是最终的x
坐标) -
**mod**
指的是节点子孙(但不是节点本身)的移动量,以使子节点相对于节点居中。 -
**shift**
指的是节点子孙(包括节点本身)的移动量,以避免与节点左侧的子树重叠。
现在可能还不太清楚为什么需要这 3 个术语来确定每个节点的最终 x
坐标——希望在浏览数值示例后会更清楚。在接下来的部分,将使用以下树结构来解释算法。
图 1: 树示例——作者提供的图像
第一次遍历
后序遍历树以计算
x
、mod
和shift
属性
在后序遍历中,树示例将按字母顺序遍历。有两个部分——第一部分计算 x
和 mod
属性,而第二部分计算 shift
属性。
第一次遍历的第一部分
图 2: 添加了 x 属性(黑色)和 mod 属性(蓝色)的树示例——作者提供的图像
如前所述,x
指的是节点的初始位置,纯粹是基于节点的位置。
-
如果节点是最左侧的孩子,
x
将为 0。 -
对于每个其他节点,它将参考左侧兄弟节点的
x
值,并将一个 兄弟距离 加到该值上。 -
这一点有一个例外,如果节点是最左侧的孩子但有子节点,它会尝试相对于子节点进行居中,并将其
x
值设置为子节点x
值的中点。
在这个示例中,我们将假设兄弟距离为 1
,并注意每个兄弟的 x
属性与其左侧兄弟的距离正好是兄弟距离。例外情况由图 2 中的黄色节点示例。
对于那些不是最左侧孩子但有子节点的节点——它们也需要相对于子节点进行居中,并将使用 mod
属性来实现。其思想是将子孙居中,移动量将是父节点的 x
值减去子节点 x
值的中点。这些示例由图 2 中的红色节点示例。比如,对于节点 e
,将其子节点 c
和 d
向右移动 1.5
的值,将使子节点相对于其自身居中。
“最终”x
-坐标可以通过将节点的x
属性与祖先的mod
属性之和相加轻松计算。但我们很快意识到子树中存在重叠。例如,在第三层级,节点f
和k
的x
-坐标分别为3
和2.5
,它们会发生重叠。这将在第二部分中解决,此时将计算shift
属性。
第一轮的第二部分
在第二部分,对于每个遍历的节点,我们将检查与所有左兄弟的子树的重叠情况。这也是为什么执行后序遍历,因为可以保证所有左兄弟(及其子树)和当前节点的子树已经遍历完毕,并且它们的x
、mod
和shift
值已知(仅需计算当前节点的shift
值)。
为检测重叠,在每一层中,将左子树的最右侧后代与右子树的最左侧后代进行比较。在文中,这被称为右轮廓和左轮廓。
图 3:添加了shift
属性(红色)的树示例,n 和 g 的子树 — 作者提供的图像
例如,在遍历树时,比较以下子树,
-
节点
b
与节点a
对比(无重叠) -
节点
e
与节点b
和a
对比(无重叠) -
节点
f
与节点e
、b
和a
对比(无重叠) -
依此类推…
当节点n
的子树与节点g
的子树比较时,发生重叠。在图 3 中,左子树的右轮廓是红色的,而右子树的左轮廓是绿色的。右子树将被移动,以使两个子树之间的子树距离相等。
在这个例子中,我们将假设子树距离为1
并比较每个深度的重叠情况。
-
在第二层级,
f
和k
的坐标分别为3
和2.5
(需要将k
向右移动1.5
,以使其与f
的子树距离相等)。 -
在第三层级,
d
和i
的坐标分别为2.5
和2
(需要将i
向右移动1.5
,以使其与d
的子树距离相等)。
这导致节点n
需要向右移动1.5
(第二和第三深度的最大移动量)。shift
属性非常特殊,因为这将导致节点及其后代(节点n
及其后代向右移动1.5
)的移动,同时左邻和右邻节点也会受到影响。由于节点n
是从最左侧兄弟节点起的第二个节点,节点h
及其后代需要移动1.5 * (1 / 2) = 0.75
,而节点q
及其后代需要移动1.5 * (3 / 2) = 2.25
,以确保所有兄弟节点在其父节点下仍然居中。
图 4:添加了shift
属性(红色)的树示例,q 和 n 的子树 — 作者提供的图像
当比较节点 q
的子树与节点 n
的子树时,会发生另一次重叠。这里的计算将在文章的相对移位部分详细说明。现在,我们可以得出结论,节点 q
及其后代需要向右移位 2.25
。参考图 4,由于节点 q
是从最左边的兄弟节点开始的第三个节点,节点 h
及其后代必须移位 2.25 * (1 / 3) = 0.75
,节点 n
及其后代必须移位 2.25 * (2 / 3) = 1.5
,以确保所有兄弟节点相对于其父节点仍然居中。
第二遍
对树进行前序遍历,计算最终的
x
和y
值
图 5:带有最终 x
和 y
坐标的树示例 — 图片由作者提供
由于所有的 x
、mod
和 shift
属性都在第一次遍历中计算得到,这些属性可以通过自上而下的前序遍历来求和,从而得出最终的 x
坐标。请注意,mod
属性仅移位后代,而 shift
属性则移位当前节点及其后代。
第三遍
对树进行前序遍历,调整任何负值的
x
值(如果有的话)
在执行第二遍时,任何由于父节点的负 mod
而导致的负 x
坐标可以记作调整值,并在第三遍中通过将此调整值加到每个节点来处理。
额外考虑
相对移位
在第一遍的第二部分中,通过比较节点 n
和 g
的子树计算了 shift
值。还提到 shift
将影响兄弟节点 h
和 q
,因为所有子节点必须相对于父节点保持居中。我们可以称这种影响兄弟节点的移位为 相对移位,因为这种移位是右子树直接移位的结果。
由于这种相对移位,在计算左右子树之间的 shift
值时需要额外考虑,因为左子树将在稍后 相对 移位。
参考图 4,注意节点 q
是第三个节点,而节点 n
是从最左边的兄弟节点开始的第二个节点。在第二层深度,
-
节点
m
的中间x
坐标为 6(Node n's mod 2 + Node n's shift 1.5 + Node m's x 2.5
)。 -
节点
o
的中间x
坐标为 6.25(Node q's mod 4 + Node q's shift 2.25
)。 -
设移位值为
x
,子树距离为1
。将会有一个相对移位,其中节点m
将按(2 / 3) * x
移位,如果节点q
按x
移位。 -
方程式如下:
(6.25 + x) — (6 + (2 / 3) * x) >= 1
。 -
解决这个方程,移位值
x >= 2.25
。
额外参数
子树分离
在上面的例子中,兄弟节点分离和子树分离设置为相同的值。兄弟节点分离与子树分离的区别如下:
-
兄弟节点分隔在第一遍的第一部分中使用,用于确定
x
属性。 -
子树分隔在第一遍的第二部分中使用,用于确定
shift
属性。
层级分隔
层级分隔可以作为一个额外的参数来确定不同层级之间的距离,这将用于确定第二遍中的y
坐标。
偏移(x-offset,y-offset)
可以引入偏移参数来通过x-offset
和y-offset
的常量值来移动整个树,这可以在第二遍中引入,当计算最终的x
和y
坐标时。
Python 代码实现
我已实现了一个 Python 代码示例的 Reingold Tilford 算法,这里跟随本文 closely,包括第一、第二和第三遍,并引用了x
、mod
和shift
属性。
相关链接
-
[1] 官方论文(6 页,Pascal 代码):
reingold.co/tidier-drawings.pdf
-
说明(最易读,C+伪代码):
rachel53461.wordpress.com/2014/04/20/algorithm-for-drawing-trees/
-
说明(推荐阅读):
www.drdobbs.com/positioning-nodes-for-general-trees/184402320?pgno=4
-
说明(带有轮廓和偏移动画):
williamyaoh.com/posts/2023-04-22-drawing-trees-functionally.html
ReLoRa: 在您的 GPU 上预训练大型语言模型
原文:
towardsdatascience.com/relora-pre-train-a-large-language-model-on-your-gpu-d104756f9ddf
LoRa,但多次重置
·发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 20 日
--
ReLoRa 框架 — 作者提供的图像
在 2021 年,Hu et al. 提出了低秩适配器(LoRa)用于 LLMs。此方法通过仅训练几个新增的参数(低秩网络),同时保持 LLM 的原始参数(高秩网络)冻结,从而显著降低了微调大型语言模型(LLMs)的成本。
使用 LoRa,我们仍然需要一个现有的预训练模型来进行微调,即它不能从零开始预训练一个好的 LLM,因为低秩限制。它使得大多数个人和组织无法承担预训练的费用。
为了减少这种成本,Lialin et al. (2023) 提出了 ReLoRa。这是对 LoRa 的一种修改,允许从零开始预训练 LLMs。
在这篇文章中,我首先解释了 ReLoRa 如何工作。然后,我分析并评论了描述 ReLoRa 的科学论文中呈现的结果。在最后一节中,我展示了如何在您的计算机上设置和运行 ReLoRa。
关于许可证的说明: 在 arXiv 上发表的科学论文 描述了 ReLoRa,并根据 CC BY 4.0 许可证进行分发。 ReLoRa 的源代码已发布在 GitHub 上 并根据允许商业使用的 Apache 2.0 许可证进行分发。
ReLoRa: 从低秩到高秩网络
要理解 ReLoRa 如何工作,我们必须先更深入地了解 LoRa。
LoRA 通过添加两个不同的新可训练参数集,A 和 B,在训练后将其合并回原始冻结的高秩网络。
这看起来可能显而易见,但理解 A 和 B 的和的秩高于它们各自秩的和是重要的。这可以形式化如下:
方程 (1) 由 Lialin et al. (2023) 提供
LoRa 仅训练了这两组参数。然而,如果我们能够重置、训练并将它们多次合并回原始高秩网络,我们将能够随着时间的推移增加网络的总体秩。换句话说,我们将获得一个更大的模型。
为什么 LoRa 不执行这些重置?
因为有几个重大障碍需要克服,以使这些重置变得有用。标准的 LLM 训练使用 Adam 优化器,Adam 优化器有其自身的状态。在不改变 Adam 状态的情况下重置 LoRa 的可训练参数,会使新的 LoRa 参数与上一个迭代的 LoRa 参数相似。模型将不会学到任何东西。
ReLoRa 提出的一个主要想法是还部分重置 Adam 优化器的状态,并结合“一个不规则的调度器来稳定训练和热启动”。这个不规则调度器将学习率重置为 0,并在几个训练步骤中进行新的热启动。以下是学习率的变化效果:
插图由 Lialin et al. (2023) 提供
更正式地说,运行 ReLoRa 的完整算法可以写成如下形式:
算法由 Lialin et al. (2023) 提供
结果:ReLoRa 的困惑度低于 LoRa
实际上,ReLoRa 似乎在成本低得多的情况下,得到的结果与标准预训练相似,但仅在参数数量达到最低值以上时有效。
表格由 Lialin et al. (2023) 提供
正如我们在上表中看到的,ReLoRa 在参数数量达到 2.5 亿个或更多时,表现大致与“完整训练”(没有 LoRa 并且没有冻结任何参数)相当。
他们还包括了使用标准 LoRa 进行预训练的结果。它表现不佳,并很好地说明了 ReLoRa 中重置的重要性。
在这些实验中,他们使用了类似于 Meta 的 LLaMa 模型的神经网络架构。这些模型的大部分训练仅用了 1 天时间,使用了 8 个 RTX 4090 GPU(即消费者级 GPU)。
他们没有尝试超过 3.5 亿个参数(大约是 BERT large 的大小),因为计算成本较高。
在一次消融研究中,作者还展示了 LoRa 的重启和调度器新热启动对于实现更低困惑度的重要性。他们还表明,移除不规则调度器可能会导致训练发散。
ReLoRa 的计算效率
ReLoRa 迭代地训练并向模型中添加新参数,同时保持先前迭代的参数冻结。
因此,ReLoRa 在内存使用方面与 LoRa 一样高效。此外,冻结的参数可以量化到低精度,以进一步减少内存使用。例如,我们可以使用 QLoRa 进行此操作,如我在这里描述的:
QLoRa: 在你的 GPU 上微调大型语言模型
现在可以在消费级硬件上对拥有数十亿参数的模型进行微调。
towardsdatascience.com
在你的计算机上运行 ReLoRa
Lialin 等人(2023)在 GitHub 上发布了他们的 ReLoRa 实现。
由于 ReLoRa 可以与 2.5 亿参数或更多的模型配合使用,我们可以在消费级硬件上运行它,例如,使用拥有超过 6 GB VRAM 的 GPU,或在免费的 Google Colab 实例上运行。
你可以很容易地重现他们的实验,并按以下步骤预训练你自己的 2.5 亿参数的 LLM。
注意:如果你想直接测试而不编写代码,我在 The Kaitchup(我的 substack 新闻简报)上创建了一个 Google Colab 笔记本。 搜索笔记本 #3。
首先,克隆这个仓库:
git clone https://github.com/Guitaricet/peft_pretraining.git
然后安装所有需求:
cd peft_pretraining
pip install -r requirements.txt
由于训练需要很长时间,我建议首先尝试使用能够提前停止训练并缩短验证的超参数(在消费级硬件上可能需要几个小时)。
该框架没有缩短验证的选项,所以我们需要手动完成。打开文件 torchrun_main.py
,并替换这一行:
if evaluated_on_tokens > target_eval_tokens:
注意:在撰写本文时,这一行是 129。
通过:
if evaluated_on_tokens > target_eval_tokens or total_batches > 10:
为了更好的性能,框架的作者建议将其分为两步运行。
第一步仅仅是初始化和训练网络几个步骤。
这是我用来测试一切是否正常工作的超参数:
torchrun --nproc-per-node 1 torchrun_main.py \
--model_config configs/llama_250m.json \
--batch_size 4 \
--total_batch_size 8 \
--lr 5e-4 \
--max_length 512 \
--tags warm_start_250M \
--save_every 10 \
--num_training_steps 2 \
--workers 1 \
--eval_every 1
注意:我将“— nproc-per-node”设置为 1,因为我只有 1 个 GPU。你应该根据你的 GPU 数量进行更改。
在 configs
目录下,你会看到几个 llama 文件。它们包含了不同模型的配置,类似于 LLaMa 模型的架构,但大小不同。在这里,我选择了 250M 的大小。
输出非常详细,因为框架使用 wandb 记录所有内容。这也要求你在这里做出选择:
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
我输入了“3”,因为我没有账户。
然后,在第二步中,我们使用 ReLoRa 的超参数和 PEFT 重新运行框架:
torchrun --nproc-per-node 1 torchrun_main.py \
--model_config configs/llama_250m.json \
--batch_size 4 \
--total_batch_size 8 \
--lr 1e-3 \
--max_length 512 \
--use_peft \
--relora 5 \
--cycle_length 5 \
--restart_warmup_steps 10 \
--scheduler cosine_restarts \
--warmup_steps 2 \
--reset_optimizer_on_relora True \
--num_training_steps 50 \
--save_every 10 \
--eval_every 10 \
--continue_from checkpoints/llama_250m-2023-07-19-09-39-08/model_3 \
--tags relora_250M
使用“ — continue_from”选项,我们提供了在第一步保存的模型。你的模型将有一个不同的名字,所以你需要更改它。你可以在“checkpoints”目录中找到它。
一旦确认一切正常运行,即没有错误且困惑度逐渐降低,你可以重新启动所有内容,但需要合理设置超参数,以便模型能够得到更好的训练。
注意:不要忘记移除我们在 torchrun_main.py 中所做的更改,以缩短验证时间。你应该至少将批量总数增加到 100,以获得有意义的验证困惑度。
torchrun --nproc-per-node 1 torchrun_main.py \
--model_config configs/llama_250m.json \
--batch_size 4 \
--total_batch_size 8\
--lr 5e-4 \
--max_length 512 \
--tags warm_start_250M \
--save_every 1000 \
--num_training_steps 10000
torchrun --nproc-per-node 1 torchrun_main.py \
--model_config configs/llama_250m.json \
--batch_size 4 \
--total_batch_size 8 \
--lr 1e-3 \
--max_length 512 \
--use_peft \
--relora 5000 \
--cycle_length 5000 \
--restart_warmup_steps 100 \
--scheduler cosine_restarts \
--warmup_steps 500 \
--reset_optimizer_on_relora True \
--num_training_steps 10000 \
--save_every 5000 \
--eval_every 5000 \
--continue_from <your checkpoint from step one>\
--tags relora_250M
注意:如果你有大量 VRAM 的 GPU,请增加批量大小。
默认情况下(硬编码),框架使用了 C4 数据集(仅限英语)进行预训练。该数据集还被用于 Google 的 T5 模型预训练。它涵盖了许多任务、领域和体裁。
一旦预训练完成,你仍然需要对结果模型进行微调,以适应下游任务或你选择的领域。你可以通过 QLoRa 高效地完成这项工作。
结论
总结来说,ReLoRa 是一种利用低秩网络的新预训练方法。它就像是连续多次执行 LoRa,但具有以下特点:
-
优化器状态的部分重置
-
锯齿状学习率调度
-
每次重置 LoRa 后进行短暂的预热
多亏了 ReLoRa,我们现在可以在消费级硬件上预训练 LLM。
仍需探索这种方法是否对非常大的语言模型(超过 10 亿参数)也具有竞争力。
根据 ReLoRa 的作者,ReLoRa 随着模型的增大效果会更好。然而,这应由未来的工作来实证验证。
如果你喜欢这篇文章并希望阅读后续文章,支持我工作的最佳方式是通过这个链接成为 Medium 会员:
[## 使用我的推荐链接加入 Medium - Benjamin Marie]
作为 Medium 会员,你的会员费的一部分将用于支持你阅读的作者,你将获得所有故事的完全访问权限…
如果你已经是会员并希望支持这项工作, 只需关注我在 Medium 上的文章。
用 Python 去除 Landsat 卫星图像中的云层
原文:
towardsdatascience.com/removing-clouds-from-landsat-satellite-images-with-python-246e73494bc
计算你感兴趣区域的云量,去除云层并使用另一张卫星图像修补它们
·发布于Towards Data Science ·14 分钟阅读·2023 年 5 月 29 日
--
(来源:作者)
我们的技术已经发展到征服浩瀚的太空。我们已经发射了配备最先进传感器的卫星,用于监测我们变化的星球。然而,一个对手却让这些先进系统感到困扰——云层。
无论你是想比较光谱指数的值还是在卫星图像上训练机器学习模型,云层都是一个问题。因此,我们将探索如何使用云掩膜来:
-
计算你感兴趣区域的云量百分比
-
从图像中去除云层
-
用另一张日期不同的图像修补云像素
关于实际的掩膜,我们将探索两个选项:
-
在Landsat QA 文件(QA_PIXEL.tif)中提供的方法
-
一种替代的机器学习方法
我们将解释用于完成此任务的 Python 代码,你可以在GitHub上找到完整的项目。让我们清除这些天空吧!
下载卫星图像
我们讨论的方法适用于任何栅格数据。在这篇文章中,我们将专门处理 Landsat 卫星图像。你可以使用EarthExplorer门户下载 Landsat 场景。或者,如果你想使用 Python,下面的文章将带你完成这一过程:
使用 landsatxplore Python 包简化 Landsat 场景下载
towardsdatascience.com
最终,你应该会有一个像图 1一样的文件夹。这些都是 Landsat 2 级科学产品中可用的文件。我们将处理高亮显示的文件。这些是 3 个可见光波段和 QA_PIXEL 文件。
图 1:Landsat 二级科学产品文件(来源:作者)
这个特定场景拍摄于南非开普敦上空。为了查看这一点,我们使用get_RGB函数可视化可见光波段。这个函数以文件名/ID 作为参数。然后它会加载波段(第 7–9 行),将它们堆叠(第 12 行),缩放(第 13 行)和剪裁(第 16 行)。
import tifffile as tiff
import numpy as np
def get_RGB(ID):
# Load Blue (B2), Green (B3) and Red (B4) bands
B2 = tiff.imread('./data/{}/{}_SR_B2.TIF'.format(ID, ID))
B3 = tiff.imread('./data/{}/{}_SR_B3.TIF'.format(ID, ID))
B4 = tiff.imread('./data/{}/{}_SR_B4.TIF'.format(ID, ID))
# Stack and scale bands
RGB = np.dstack((B4, B3, B2))
RGB = np.clip(RGB*0.0000275-0.2, 0, 1)
# Clip to enhance contrast
RGB = np.clip(RGB,0,0.3)/0.3
return RGB
我们使用这个功能来获取图 1(第 3–4 行)中下载场景的 RGB 可视化。你可以在图 2中看到结果图像。注意所有的云!让我们看看能否利用 Landsat 云掩膜来处理这些云。
import matplotlib.pyplot as plt
ID = 'LC09_L2SP_175083_20230410_20230412_02_T1'
RGB = get_RGB(ID)
# Plot the RGB image
fig, ax = plt.subplots(figsize=(20, 10))
ax.imshow(RGB)
图 2:Landsat 场景的 RGB 可视化(来源:作者)
Landsat 云掩膜
什么是云掩膜?
云掩膜是一种分割图,将云与地球表面的其他特征分开。该图将图像中的每个像素分类为云或非云。它还可以包括其他类似的特征,如云影。它们在遥感中很重要,因为隔离云像素能更准确地测量和解释土地覆盖分析。
用于分类云像素的方法有很多种。Landsat 使用一种叫做CFMask的算法。该算法使用基于光谱指数的各种统计数据的组合。它在云像素的分类准确率为 89%,在云影的分类准确率为 96% [1]。稍后,我们将讨论一种替代的机器学习方法来创建云掩膜。
了解质量保证(QA)波段
CFMask 算法在每个 Landsat 场景上运行,结果可以在 QA 波段中找到。该文件是一个与太阳反射波段(例如可见光波段)具有相同高度和宽度的数组。波段中的每个像素都是一个 16 位整数,包含该像素的分类。你可以在数据格式控制手册中找到有关这种编码的更多信息。
例如,假设 QA 文件中的一个像素值为22280。为了理解这个整数的含义,我们首先将其转换为 16 位二进制数——0101011100001000。这个数字中的每一位映射到图 3中看到的标志。第 3 位的值为 1,表示该像素被分类为云。接下来,我们将考虑云掩码中位 1 到 4 的覆盖类型。
图 3:Landsat QA 波段位描述和示例(来源:作者)
我们使用这些信息来创建get_mask函数。它将返回 4 种覆盖类型中每个像素的分类——云、阴影、扩张云和卷云。如果一个像素被分类为给定的覆盖类型,它的值将为 0,否则为 1。稍后,我们将看到这使我们能够通过将卫星图像与掩码相乘来去除不需要的像素。
def get_mask(val,type='cloud'):
"""Get mask for a specific cover type"""
# convert to binary
bin_ = '{0:016b}'.format(val)
# reverse string
str_bin = str(bin_)[::-1]
# get bit for cover type
bits = {'cloud':3,'shadow':4,'dilated_cloud':1,'cirrus':2}
bit = str_bin[bits[type]]
if bit == '1':
return 0 # cover
else:
return 1 # no cover
为了应用此功能,我们首先加载 QA 波段(第 2 行)。我们对波段中的每个像素应用get_mask,以获得 4 种覆盖类型中的每一种(第 6–9 行)。我们使用np.vectorize()来加速,而不是对每个像素进行循环处理。每个掩码将是一个与场景具有相同维度的数组,每个元素的值为 0 或 1。
# QA band
QA = tiff.imread('./data/{}/{}_QA_PIXEL.TIF'.format(ID, ID))
QA = np.array(QA)
# Get masks
cloud_mask = np.vectorize(get_mask)(QA,type='cloud')
shadow_mask = np.vectorize(get_mask)(QA,type='shadow')
dilated_cloud_mask = np.vectorize(get_mask)(QA,type='dilated_cloud')
cirrus_mask = np.vectorize(get_mask)(QA,type='cirrus')
我们可以在图 4中看到这些掩码的实际效果。这里我们将看到图 2中看到的图像与 4 种掩码的分类结果以不同颜色叠加。放大一个区域,我们可以更好地理解每种覆盖类型:
-
阴影是云下的较暗区域
-
扩张云是被分类为云的像素周围的边缘
-
卷云是“细丝状”的云
如果一个像素被分类为卷云,它也会被分类为云。
图 4:Landsat 分割掩码包含 4 种覆盖类型——云、阴影、扩张云和卷云(来源:作者)
以下代码用于创建图 4。我们将每个掩码的图层添加到原始 RGB 图像中(第 15–22 行)。这通过使用cv2.addWeighted()函数(第 22 行)来完成。每个掩码图层都有不同的颜色(第 8–11 行),透明度为 50%。通过提取像素范围来裁剪感兴趣区域(AOI)(第 43 行)。接下来,我们将对该区域进行更详细的查看。
import cv2
import matplotlib as mpl
# segmentation image
seg = RGB.copy()
# color for each cover type
colors = np.array([[247, 2, 7],
[201, 116, 247],
[0, 234, 255],
[3, 252, 53]])/255
masks = [cloud_mask, shadow_mask, dilated_cloud_mask, cirrus_mask]
for i,mask in enumerate(masks):
# color for cover type
temp = seg.copy()
temp[mask == 0] = colors[i]
# add to segmentation
seg = cv2.addWeighted(seg, 0.5, temp, 0.5, 0)
fig, ax = plt.subplots(1,2,figsize=(20, 10))
ax[0].imshow(seg)
# add legend with colors for each cover type
legend_elements = [mpl.patches.Patch(facecolor=colors[0], label='Cloud'),
mpl.patches.Patch(facecolor=colors[1], label='Shadow'),
mpl.patches.Patch(facecolor=colors[2], label='Dilated Cloud'),
mpl.patches.Patch(facecolor=colors[3], label='Cirrus')]
ax[0].legend(handles=legend_elements, loc='upper right')
# draw white rectangle around area of interest
h = 300
x,y = 4500,4500
rect = mpl.patches.Rectangle((x-h,y-h),h*2,h*2,linewidth=2,edgecolor='w',facecolor='none')
ax[0].add_patch(rect)
# crop area of interest
crop_seg = seg[y-h:y+h,x-h:x+h,:]
ax[1].imshow(crop_seg)
ax[0].set_axis_off()
ax[1].set_axis_off()
评估感兴趣区域
整个场景的云覆盖百分比作为场景的元数据的一部分提供。在我们的案例中,这个值是59%。通过云掩码,我们可以计算特定 AOI 的这一数据。这很有用,因为它允许我们以编程方式搜索那些 AOI 无云的场景。即使场景的其他部分有云也无妨。
为了查看这一点,我们选择与图 4(第 2 行)相同的感兴趣区域。请记住,云掩码的值为 0 表示没有云,否则为 1。我们将其取反并计算平均值(第 5 行)。这给我们一个69%的值。不幸的是,我们的 AOI 比场景总体更有云!
# crop area of interest
cloud_aoi = cloud_mask[y-h:y+h,x-h:x+h]
# calculate percentage of cloud cover
per_cloud = np.average(1-cloud_aoi)*100
print('Percentage of clouds: {:.2f}%'.format(per_cloud))
去除云层
一旦我们有了掩膜,去除不需要的像素就很简单。我们只需将 RGB 图像与掩膜相乘。下面的代码将去除被标记为云层的像素(第 2 行)。请记住,RGB 图像是一个 3D 数组。相比之下,cloud_mask是一个 2D 数组。这就是为什么我们使用np.newaxis来添加额外的维度。
# Remove clouds
rm_clouds = RGB*cloud_mask[:, :, np.newaxis]
你可以在图 5中看到结果。云层像素现在全部被值[0,0,0]替代。这与用于边界框的值相同。这意味着这些像素“没有数据”,可以在进一步分析中忽略。最终,这些分析的结果依赖于云掩膜的准确性。这就是为什么研究人员不断寻求改进这些掩膜。
图 5:使用 Landsat 云掩膜从卫星图像中去除云层(来源:作者)
云掩膜的机器学习方法
我们提到 Landsat 使用一种称为CFMask的算法。该算法基于遥感和气象知识,了解云层像素在卫星图像中的表现。另一种方法是使用机器学习。这需要训练图像,其中所有云层像素都被手动标记。然后,机器学习算法学习光谱波段和标签之间的关系。
一种方法使用在海岸气溶胶(B1)、红色(B4)和短波红外 2(B7)波段上训练的遗传算法。得到的算法在cloud_pred函数中给出。pred公式(第 4-10 行)给出了三个波段的加权。如果加权值大于 B7 波段,则像素被标记为云层(第 12 行)。
def cloud_pred(B1,B4,B7):
"""Cloud prediction model"""
pred = 2.16246741593412 -0.796409165054949*B4 + \
0.971776520302587*np.sqrt(abs(0.028702220187686*B7*B1 + \
0.971297779812314*np.sin(B1))) + \
0.0235599298084993*np.floor(0.995223926146334*np.sqrt(abs(0.028702220187686*B7*B1 + 0.971297779812314*np.sin(B1)))+ \
0.00477607385366598*abs(0.028702220187686*B7*B1 + \
0.971297779812314*np.sin(B1))) - 0.180030905136552*np.cos(B4) + \
0.0046635498889134*abs(0.028702220187686*B7*B1 + 0.971297779812314*np.sin(B1))
cloud = np.where(pred > B7,0,1)
return cloud
要使用此功能,我们首先加载三个必要的波段(第 4-6 行)。通过将这些波段传递给cloud_pred函数(第 9 行),我们获得云掩膜。然后,我们可以像以前一样使用这个掩膜去除云层(第 12 行)。你可以在图 6中看到结果图像。
ID = "LC09_L2SP_175083_20230410_20230412_02_T1"
# Get Coastal Aerosol (B1), Red (B4) and Shortwave Infrared 2 (B7) bands
B1 = tiff.imread('./data/{}/{}_SR_B1.TIF'.format(ID, ID))
B4 = tiff.imread('./data/{}/{}_SR_B4.TIF'.format(ID, ID))
B7 = tiff.imread('./data/{}/{}_SR_B7.TIF'.format(ID, ID))
# Get cloud mask
cloud_mask_2 = cloud_pred(B1,B4,B7)
# Remove clouds
rm_clouds = RGB*cloud_mask_2[:, :, np.newaxis]
图 6:使用机器学习方法从卫星图像中去除云层(来源:作者)
研究人员表示上述方法的准确度为89%——与 CFMask 相似。如前所述,提升算法的准确度是一个活跃的研究领域。最先进的方法使用深度学习算法进行图像分割,如 U-Net。这些方法的缺点是需要准确标记的训练数据。
云层修复
对于某些应用,仅仅从图像中去除云层可能就足够了。对于其他应用,我们希望替换或“修复”云层。例如,为了创建没有云层的景观可视化。
在图 7中,我们可以看到我们一直在处理的有云图像。这张图像拍摄于 2023 年 10 月 4 日。幸运的是,8 天后在同一区域拍摄了一张清晰图像。我们将探讨如何用清晰图像中的像素(第 7 行)替换有云图像中的像素(第 6 行)。
# IDs for cloudy and non-cloudy images
IDs = ['LC09_L2SP_175083_20230410_20230412_02_T1',
'LC08_L2SP_175083_20230418_20230429_02_T1']
# Get RGB
cloudy = get_RGB(IDs[0])
clear = get_RGB(IDs[1])
图 7:拍摄于 8 天前的有云图像和清晰图像(来源:作者)。
我们将对关注区域和整个场景执行此操作。在这两种情况下,考虑像素的地理位置非常重要。即使有云和清晰图像拍摄于相同区域,像素也不会完全对齐。换句话说,数组位置 [x,y] 的像素的 UTM 坐标在每个图像中都会有所不同。因此,如果仅依赖数组位置,我们将替换错误区域的像素。为了解决这个问题,我们基于本文中使用的方法:
如何在 Landsat 卫星图像上用 Python 绘制坐标
使用 Landsat 元数据和 Rasterio 将像素位置映射到地理坐标
towardsdatascience.com
关注区域
首先,我们将学习如何修复一个关注区域(AOI)。我们从创建一个包含所有 4 种覆盖类型的掩码开始(第 2 行),并从有云图像中移除不需要的像素(第 5 行)。然后,我们围绕我们的关注区域裁剪得到的图像(第 8-9 行)。我们将 (x,y) 视为有云图像中关注区域的中心。
# get mask for cloudy image
mask = cloud_mask*shadow_mask*dilated_cloud_mask*cirrus_mask
# remove cloudy pixels and fill with adjusted clear pixels
rm_mask = cloudy*mask[:, :, np.newaxis]
#crop area of interest
x,y,h = 4500,4500,300
crop_rm_mask = rm_mask[y-h:y+h,x-h:x+h]
如前所述,我们需要考虑像素的地理位置。我们使用 rasterio 包来帮助实现这一点。我们加载有云图像(第 4 行)和清晰图像(第 5 行)的红色波段。
import rasterio as rio
# Open red band with rasterio for geolocation
geo_cloudy = rio.open('./data/{}/{}_SR_B4.TIF'.format(IDs[0], IDs[0]))
geo_clear = rio.open('./data/{}/{}_SR_B4.TIF'.format(IDs[1], IDs[1]))
我们在有云图像中找到关注区域中心的 UTM 坐标(第 2 行)。然后,我们在清晰图像中找到这些 UTM 坐标的数组位置(第 5 行)。打印这些坐标得到(4430, 4510)。你可以看到这与原始中心(4500,4500)不同。
# get UTM coordinates from cloudy image
utmx, utmy = geo_cloudy.xy(y,x)
# get pixels from clear image
y_,x_ = geo_clear.index(utmx, utmy)
print(x_,y_)
我们使用新中心从清晰图像(第 2 行)裁剪关注区域。然后,我们使用原始中心(第 5 行)裁剪有云图像的云掩码。crop_clear和crop_mask的 UTM 坐标将完美对齐。因此,我们通过将清晰图像与掩码的反向相乘(第 6 行)来创建填充。最后一步是将填充添加到无云的有云图像中(第 9 行)。
# crop clear image
crop_clear = clear[y_-h:y_+h,x_-h:x_+h]
#get fill from clear image
crop_mask = mask[y-h:y+h,x-h:x+h]
fill = crop_clear*(1-crop_mask[:, :, np.newaxis])
#inpaint area of interest
inpaint = crop_rm_mask + fill
你可以在图 8中看到结果。效果不错!大部分景观看起来是连贯的。不过,也有一些问题,我们将在最后讨论。
图 8:使用清晰图像修复的关注区域(来源:作者)
整个图像
对整个场景进行修复更加复杂。为此,我们需要对齐多云和晴朗场景的边界框。边界框是场景的外部黑色边框。我们通过从晴朗图像中添加和移除一些像素来对齐这些框的 UTM 坐标。
图 9:添加和移除像素以对齐边界框
我们首先计算从晴朗图像的顶部、底部、左侧和右侧添加/移除的像素数量。我们获取晴朗图像的左上角(UL)(第 3 行)和右下角(LR)(第 4 行)的位置。这些位置基于晴朗图像的尺寸(第 2 行)。
# get pixel coordinates of clear image corners
y,x = geo_clear.read(1).shape
clear_ul = (0,0) # upper left
clear_lr = (y,x) # lower right
接下来,我们从多云图像中获取边界(第 2 行)。这些边界提供了多云图像边界框的 UTM 坐标。然后,我们从晴朗图像中获取这些坐标的数组位置(第 3–4 行)。请注意,这些位置可能会超出晴朗图像的尺寸。
# get pixel coordinates of cloudy image corners
cloudy_bounds = geo_cloudy.bounds
new_clear_ul = geo_clear.index(cloudy_bounds.left,cloudy_bounds.top)
new_clear_lr = geo_clear.index(cloudy_bounds.right,cloudy_bounds.bottom)
原始和新角落位置之间的差异将告诉我们如何调整晴朗图像以对齐 UTM 坐标。打印调整量(第 8 行)给出 -10 0 70 -70。这意味着我们必须将晴朗图像调整为:
-
从顶部移除 10 像素
-
底部无调整
-
向左添加 70 像素
-
从右侧移除 70 像素
# calculate pixel adjustment
top_adj = clear_ul[0] - new_clear_ul[0]
bottom_adj = new_clear_lr[0] - clear_lr[0]
left_adj = clear_ul[1] - new_clear_ul[1]
right_adj = new_clear_lr[1] - clear_lr[1]
print(top_adj, bottom_adj, left_adj, right_adj)
在计算调整量后,我们可以使用adjust_rgb函数。这会在正调整情况下添加黑色像素(第 6–17 行)。然后,它会在负调整情况下裁剪像素(第 20–27 行)。
def adjust_rgb(rgb,top_adj, bottom_adj, left_adj, right_adj):
adj_rgb = rgb.copy()
#Adding black pixels
if top_adj > 0:
add_top = np.zeros((top_adj,rgb.shape[1],3))
adj_rgb = np.vstack((add_top,adj_rgb))
if bottom_adj > 0:
add_bottom = np.zeros((bottom_adj,rgb.shape[1],3))
adj_rgb = np.vstack((adj_rgb,add_bottom))
if left_adj > 0:
add_left = np.zeros((rgb.shape[0],left_adj,3))
adj_rgb = np.hstack((add_left,adj_rgb))
if right_adj > 0:
add_right = np.zeros((rgb.shape[0],right_adj,3))
adj_rgb = np.hstack((adj_rgb,add_right))
#Removing pixels
if top_adj < 0:
adj_rgb = adj_rgb[-top_adj:,:,:]
if bottom_adj < 0:
adj_rgb = adj_rgb[:bottom_adj,:,:]
if left_adj < 0:
adj_rgb = adj_rgb[:,-left_adj:,:]
if right_adj < 0:
adj_rgb = adj_rgb[:,:right_adj,:]
return adj_rgb
我们使用此函数来调整我们的晴朗图像(第 6 行)。我们从多云图像中移除不需要的像素(第 12 行),然后从调整后的晴朗图像中获取替换像素(第 13 行)。由于它们已经对齐,我们可以简单地将它们叠加在一起(第 14 行)。
# Get RGB images
cloudy_RGB = get_RGB(IDs[0])
clear_RGB = get_RGB(IDs[1])
# Adjust clear RGB image
clear_RGB_adj = adjust_rgb(clear_RGB,top_adj, bottom_adj, left_adj, right_adj)
# get mask for cloudy image
mask = cloud_mask*shadow_mask*dilated_cloud_mask*cirrus_mask
# remove cloudy pixels and fill with adjusted clear pixels
rm_mask = cloudy_RGB*mask[:, :, np.newaxis]
fill_mask = clear_RGB_adj*(1-mask[:, :, np.newaxis])
inpaint = rm_mask+fill_mask
你可以在图 10中查看结果。总体来看,景观效果不错。如果我们放大关注区域,你会发现它与图 8完全相同。然而,如前所述,还是存在一些问题。
图 10:使用晴朗图像修复的整个场景(来源:作者)
如果我们在另一个兴趣区域放大,我们可以看到这一点。透明的白色像素是被晴朗图像中的像素替换的。注意,一些阴影像素尚未被替换。这可能会导致景观中亮度变化剧烈的地方出现不连续性。
图 11:未成功修复的云影(来源:作者)
问题在于云掩模并不完美。如果我们无法以 100%的准确度识别所有云层,就无法完全替换所有云层。景观随时间变化也可能导致不连续性。河流会移动,潮汐会上涨。为了创建最佳的可视化效果,选择拍摄时间接近你的多云图像的替换图像。
希望你喜欢这篇文章!你可以在Mastodon | Twitter | YouTube | Newsletter 上找到我 — 注册即可免费获取Python SHAP 课程
[## 使用我的推荐链接加入 Medium — Conor O’Sullivan
作为 Medium 会员,你的一部分会费将用于支持你阅读的作者,你将能完全访问所有故事……
参考文献
[1] Steve Foga 等,云检测算法比较与验证,用于操作 Landsat 数据产品,www.sciencedirect.com/science/article/pii/S0034425717301293?via%3Dihub
Landsat 8–9 OLI/TIRS 集合 2 级 2 数据格式控制手册 www.usgs.gov/media/files/landsat-8-9-olitirs-collection-2-level-2-data-format-control-book
Leah Wasser 第 3 课:在 Python 中清理遥感数据 — 云、阴影与云掩模 www.earthdatascience.org/courses/use-data-open-source-python/multispectral-remote-sensing/landsat-in-Python/remove-clouds-from-landsat-data/
Leah Wasser 第 4 课:如何在 Python 中用不同的栅格数据集的值替换栅格单元值 www.earthdatascience.org/courses/use-data-open-source-python/multispectral-remote-sensing/landsat-in-Python/replace-raster-cell-values-in-remote-sensing-images-in-python/
Landsat 任务 CFMask 算法 www.usgs.gov/landsat-missions/cfmask-algorithm
Ing Grenet Landsat-8 图像中的云分割 medium.com/sentinel-hub/clouds-segmentation-in-landsat-8-images-da370815235
SentinelHub Landsat 8 云分割脚本 custom-scripts.sentinel-hub.com/custom-scripts/landsat-8/clouds_segmentation/
替换视觉 AI 模型中的手动归一化为批量归一化
一个巧妙的技巧是,将批量归一化层作为模型的第一层,以避免在视觉(图像/视频)AI 模型中进行昂贵的手动像素归一化
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 5 月 18 日
--
与 Naresh 共同撰写
照片由 Kevin Ku 拍摄,发布在 Unsplash 上
图像预处理中的通道归一化
彩色图像通常有 3 个通道(RGB)。视觉 AI 模型通常会对图像像素进行预处理和归一化,以便将给定通道中的像素归一化到均值为 0.0 和方差为 1.0。由于每个通道可以有自己独特的统计数据,所以归一化是按通道进行的。批量归一化是视觉模型中用于避免被称为协变量转移现象的一种通用最佳实践。
什么是协变量转移?
协变量转移是一种现象,当输入特征(即协变量)的分布在机器学习模型的训练阶段和测试阶段之间发生变化时,就会发生这种现象。这可能导致模型性能下降,因为在训练期间做出的假设在测试期间可能不再成立。协变量转移可能由于数据收集过程的变化、被采样的人群的变化或模型使用环境的变化而发生。为了解决协变量转移,可以使用领域适应和重要性加权等技术,以根据输入分布的变化调整模型的预测。
然而,这些技术相当复杂,需要对输入数据分布有更深入的理解。
批量归一化如何避免协变量转移?
批量归一化通过规范化神经网络中每一层的激活值来帮助解决协变量转移。这意味着每一层的激活值的均值和方差保持在一个固定的值上,而不受该层输入分布的影响。通过这样做,批量归一化减少了训练数据集和测试数据集之间的协变量转移影响。批量归一化可以应用于任何领域,并且能够很好地扩展到各种使用案例而无需修改。
更具体地说,在训练过程中,批量归一化根据当前批次计算的激活均值和方差对每一层的激活值进行中心化和缩放。这将激活值归一化为零均值和单位方差,这有助于稳定和加速训练过程。在训练阶段,跟踪一个运行均值和方差。
在测试过程中,使用训练期间计算的均值和方差来归一化激活值。这确保了归一化与训练数据一致,并减少了协变量转移的影响。
通过减少协变量转移的影响,批量归一化可以提高神经网络的泛化性能,使其对输入分布的变化更加稳健。
为什么要预处理输入?
这样做是为了让模型更快地收敛,并对输入数据有良好的泛化能力。当输入数据分布在一致且指定良好的范围内时,模型的表现相对更好。具体来说,它有以下好处:
-
防止溢出或下溢:机器学习模型通常涉及诸如加法、乘法和指数运算等数学操作。如果输入值过大或过小,操作可能会导致溢出或下溢,从而产生不准确或未定义的结果。例如,将一个小的浮点数(fp32)加到一个大数上可能会导致忽略小数。
-
确保高效学习:神经网络通常使用反向传播算法来更新网络中的权重。该算法依赖于计算梯度,如果输入值没有归一化,这些梯度可能会变得非常小或非常大,从而使学习过程变得更慢且效率降低。
-
提高泛化能力:当模型在特定范围的输入值上训练时,它可能无法在该范围之外的输入上表现良好。归一化输入可以帮助模型对新数据和未见过的数据进行泛化。
批量归一化可以通过在训练过程中归一化网络每一层的输入来解决一些这些问题。这可以提高模型的稳定性,并使其对输入的尺度不那么敏感。然而,批量归一化并不总是可以完全替代使用硬编码常数手动完成的输入归一化。
什么是批量归一化?
关于批量归一化是什么及其如何帮助模型的主题已经在众多文章中进行了广泛的讨论,因此我们将链接到提供最详细见解的文章,并让读者对其操作形成直观理解。我们还提供了一些将批量归一化与其他归一化技术进行比较的链接。
-
三层理解的批量归一化
-
批量归一化解释
输入归一化通常是如何完成的?
通常,训练模型的人员负责计算整个训练数据集的每个通道的统计数据(均值和方差),并在训练视觉 AI 模型之前进行归一化。这种归一化也应该在推理过程中进行。
使用torchvision transforms时,这种预处理的代码可能如下所示。
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize(
# Channel means
# R, G, B
(0.485, 0.456, 0.406),
# Channel standard deviation
(0.229, 0.224, 0.225),
),
)
你可以在这里阅读更多关于Normalize 变换的信息。
你会注意到均值和标准差是预先计算的,然后硬编码到预处理管道中。关于如何正确和高效地做这件事,网上有很多讨论。例如,
如何减少这项工作的痛苦?
现在我们知道这个过程有多痛苦,我们将看到一个巧妙的技巧来减轻你的痛苦!
只需将一个BatchNorm2d层作为视觉 AI 模型的第一层,并从预处理步骤中移除Normalize变换!
通过将批量归一化作为模型的第一层,输入数据将在训练过程中自动归一化,你不需要手动归一化图像像素。这可以节省一些编码时间,并减少在归一化过程中引入错误的机会。
手动归一化和批量归一化的粗略等价性
在这里,我们将看到一些代码,可以说服我们这两种方法的粗略等价性。
我们将使用 1000 批随机生成的 1x1 图像(具有 3 个通道),看看手动计算的均值和方差是否与使用 PyTorch 的 BatchNorm2d 层计算的结果相似。
torch.manual_seed(21)
num_channels = 3
# Example tensor so that we can use randn_like() below.
y = torch.randn(20, num_channels, 1, 1)
model = nn.BatchNorm2d(num_channels)
# nb is a dict containing the buffers (non-trainable parameters)
# of the BatchNorm2d layer. Since these are non-trainable
# parameters, we don't need to run a backward pass to update
# these values. They will be updated during the forward pass itself.
nb = dict(model.named_buffers())
print(f"Buffers in BatchNorm2d: {nb.keys()}\n")
stacked = torch.tensor([]).reshape(0, num_channels, 1, 1)
for i in range(2000):
x = torch.randn_like(y)
y_hat = model(x)
# Save all the input tensor into 'stacked' so that
# we can compute the mean and variance later.
stacked = torch.cat([stacked, x], dim=0)
# end for
print(f"Shape of stackend tensor: {stacked.shape}\n")
smean = stacked.mean(dim=(0, 2, 3))
svar = stacked.var(dim=(0, 2, 3))
print(f"Manually Computed:")
print(f"------------------")
print(f"Mean: {smean}\nVariance: {svar}\n")
print(f"Computed by BatchNorm2d:")
print(f"------------------------")
rm, rv = nb['running_mean'], nb['running_var']
print(f"Mean: {rm}\nVariance: {rv}\n")
print(f"Mean Absolute Differences:")
print(f"--------------------------")
print(f"Mean: {(smean-rm).abs().mean():.4f}, Variance: {(svar-rv).abs().mean():.4f}")
你可以查看下面代码单元的输出。
Buffers in BatchNorm2d: dict_keys(['running_mean', 'running_var', 'num_batches_tracked'])
Shape of stackend tensor: torch.Size([40000, 3, 1, 1])
Manually Computed:
------------------
Mean: tensor([0.0039, 0.0015, 0.0095])
Variance: tensor([1.0029, 1.0026, 0.9947])
Computed by BatchNorm2d:
------------------------
Mean: tensor([-0.0628, 0.0649, 0.0600])
Variance: tensor([1.0812, 1.0318, 1.0721])
Mean Absolute Differences:
--------------------------
Mean: 0.0602, Variance: 0.0616
我们从一个使用 torch.randn_like()
初始化的随机张量开始,因此我们期望在足够大的样本数量(40k)中,均值和方差将趋近于 0.0 和 1.0,因为这正是我们期望torch.randn_like()生成的结果。
我们看到手动计算的均值和方差与使用 BatchNorm2d 的滚动平均法计算的均值和方差之间的差异在所有实际应用中都足够接近。我们可以看到,使用 BatchNorm2d 计算的均值始终高于或低于手动计算的均值(最多 40 倍)。然而,从实际的角度来看,这应该没有问题。
注意事项
使用批量归一化代替手动归一化确实有利有弊,这篇文章如果不详细比较这两者在何时适用或不适用的话就不完整了。
迁移学习
在使用迁移学习时,通常建议保持预训练模型中使用的归一化方法,以避免引入不必要的变化。在这种情况下,用批量归一化替换手动归一化可能不合适。
例如,以下是 torchvision 页面对这个主题的描述。
“在使用预训练模型之前,必须对图像进行预处理(调整正确的分辨率/插值、应用推理变换、重新缩放值等)。没有标准的方法来做到这一点,因为这取决于给定模型的训练方式。它可以在模型系列、变体甚至权重版本之间有所不同。使用正确的预处理方法至关重要,否则可能导致准确性下降或输出错误。”
训练效率
在考虑训练效率时,预先计算数据集的均值和方差并将其硬编码为预训练的归一化步骤可能是有益的。这可以防止在训练过程中重复计算这些统计量。
请注意,使用任何方法在推理期间的计算量相等,因为你会使用计算出的均值和方差对输入进行归一化,无论是在将输入喂入模型之前(手动计算)还是作为模型的第一步(使用批量归一化)。
人工效率
基于我们上面的观察,对于人类来说,插入一个批量归一化层比手动计算统计量要容易得多。
数据增强
在将输入送入模型之前进行数据增强时,必须小心在所有增强和其他预处理步骤完成后再应用归一化,以避免计算出不正确的统计量。例如,如果你使用了ColorJitter 变换,它将会显著改变计算出的统计量。
这引出了另一个有趣的问题“在使用数据增强时,应该什么时候计算数据集的均值和方差?”
这有点棘手,因为准确地跨增强输入计算均值和方差需要你预先知道哪些图像将使用哪些转换进行增强,然后需要你在模型训练期间以一致的方式计算统计信息并应用增强。这在一般情况下有些困难,因为增强是随机应用在输入图像上的。此外,同一图像在每个训练周期会以不同方式进行增强,以防止模型过度拟合训练数据集。因此,在这种情况下,使用批量标准化也会提高模型的准确性,因为它计算的是增强图像的均值和方差,而不是原始未增强的图像。
小批量分布
由于批量标准化在训练期间计算小批量的均值和方差以进行归一化值,因此重要的是随机化输入数据以确保小批量的均值和方差在某种程度上代表整个训练数据集。如果你的小批量数据很小或有偏差,那么你应该考虑消除偏差或使用手动标准化。
这个问题在测试或推理时间不存在。
结论
标准化是任何视觉 AI 管道中的重要预处理步骤。通过手动步骤正确并高效地计算它可能会很繁琐且容易出错。在许多情况下,将批量标准化作为视觉 AI 模型的第一层是手动标准化的可行替代品。
Reprompting: LLMs 的自动化问题解决优化
通过 Gibbs 采样的自动链式思维提示推断
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 5 月 30 日
--
Reprompting 人物化 — 作者 + 开放之旅
大型语言模型(LLMs)在自然语言理解方面通过少量示例提示展示了惊人的进展,这些模型在仅看到几个演示如何解决特定问题的示例的情况下,完成了极其困难的任务。然而,同样的 LLMs 在需要复杂或多步骤逻辑的任务(例如,Big-Bench Hard 基准测试)时经常出现问题,并且在将规则或约束传递到后续步骤时遇到困难。对于人类而言,这类任务需要逻辑推理和推断。尽管我们理解这些模型在(人类意义上)无法进行推理,但微软的研究人员希望教会 LLMs 越来越好地展现这些概念。因此,Xu 等人提出了“Reprompting”,一种用于多步骤问题解决的自动化提示优化方法。
以工程提示为主题的先前研究表明,向 LLMs 提供链式思维(CoT)提示可以提高推理和感知推理等方面的表现。链式思维提示是一种使大型语言模型通过指导模型的中间步骤来处理复杂的算术和符号推理任务的技术(Wei 等人,2022)。
作为 CoT 的演变,本研究引入了 Reprompting,这是一种迭代采样算法,可以从给定的问题-答案对集(即,少量示例中的语境示例)中自动发现对模型最有效的 CoT 提示。该研究有望提高最先进的 LLMs 的性能,并将一个模型的收益转移到下一个模型上(Xu 等人,2023)。然而,在深入探讨 Reprompting 之前,我们应该突出几个促成这一新颖方法的概念。
少量示例提示
在实践中,少量示例提示(即,语境学习)的概念很简单。提供包含问题及其相应正确答案的示例提示,可以让模型同时更好地学习给定的语境和答案的制定。因此,LLMs 提高了泛化能力,并能更高效地适应新任务,相较于传统的(且通常成本高昂的)微调(即,额外的监督模型训练)需要较少的输入和监督。
少量示例提示的一个非常简单的例子 — 作者
标准 LLM 经过预训练,以优化在给定语境的情况下生成正确下一个标记(词或子词)的概率(Brown 等人,2020)。一般来说,模型学习在给定语境 x 的情况下,下一个标记 y 的近似概率分布 P(y|x)。
此外,模型可以基于包含示例问题和答案对的标记化序列进行条件设置。然后,在推理过程中,模型使用其学习到的参数 θ 通过条件设置示例Exmp来生成输出的标记序列 y*:
P(y_t | y_1, …, y_(t-1), Exmp; θ)
其中 y_t 是在之前生成的标记 (y_1, …, y_(t-1)) 和示例序列 (Exmp) 的条件下,输出中 t-th 标记的概率分布。通常,在推理过程中,自回归变换器会在每一步从分布中采样一个标记 y_t,并且该过程重复(逐标记)直到模型生成一个停止标记或达到预定义的最大输出长度,从而生成一个应应用提供示例中学习到的上下文和约束的响应。(Wei et al., 2022; Vaswani et al., 2017; Xu et al., 2023)。
链式思维提示
链式思维提示演变了少量提示的概念,专注于需要多步逻辑的任务,通过引导模型向一系列中间逻辑步骤的序列发展。这种方法模拟了类人问题解决和某种程度上的常识推理(Wei et al., 2022)。例如,每个生成的标记 y_t 现在解决成为回答正确所需的更大公式的一部分。这使得模型能够更高效地解决给定的问题以及类似的问题。应用 CoT 的推理的一个简化公式如下:
P(y_t | y_1, …, y_(t-1), {Exmp_1, Exmp_2, …, Exmp_N}; θ)
其中模型通过对连接的示例标记化序列 {Exmp_1, Exmp_2, …, Exmp_N} 进行条件处理来生成标记 y_t,每个序列包含不同的中间步骤(如图所示)。
标准少量提示与链式思维提示(Wei et al., 2022)
Reprompting
有了这些背景,我们可以讨论提出的 Reprompting,一种迭代采样算法,它在没有人工干预的情况下自动发现有效的 CoT 提示。该算法的主要目标是推断出一组“配方”,这些配方作为少量示例可以始终表现良好,用于解决通常需要演绎推理的问题。
研究人员主要关注从链式思维配方的联合分布中重新采样的问题。记住,在推理过程中,模型在每一步从概率分布中采样下一个标记 y_t,直到达到停止条件。然而,使用 CoT 时,模型现在从结合了学习到的概率和 CoT 提供的上下文信息的联合概率分布中进行采样。虽然直接描述这种分布是不可能的,但研究人员采用 Gibbs 采样策略来有效地近似它(Wei et al., 2022)。通过这种方式,采样过程现在可以受到之前生成的标记和设计来指导后续标记生成的提示的影响。每次迭代,算法都优化来自训练集的解决方案,这些解决方案作为有效的 CoT 配方来解决测试集中的问题。
关于 Gibbs 采样的补充说明
吉布斯采样器(1984 年提出)为获取边际分布特征(例如,均值或方差)提供了替代方法,当直接计算复杂时。例如,给定联合分布 f(x, y, …, y_n),吉布斯采样器生成 f(x) 的样本,而无需其显式形式。在生成足够大的样本后,吉布斯策略可以在不直接计算 f(x) 的情况下近似边际分布(Casella & George, 1992)。
CoT 配方的自动发现
Reprompting 使用吉布斯采样来近似 CoT 配方的联合分布,这些配方在需要逻辑推理的人类解决问题时表现良好。该过程最初通过零-shot 提示采样配方,然后通过将几个先前的配方串联作为提示,迭代采样配方,最终收敛为一组具有相似思维链的配方,并包括问题的中间指令或逐步公式。Xu 等对该算法的描述如下:
Reprompting 算法(Xu 等,2023)
理想情况下,算法应收敛,使得生成逐步解决方案 z_j 后接着生成正确答案 y_j 的概率较高,并且与 S_j 的选择无关;其中 S_j 是选择的索引子集,对应于 CoT 配方元组 {x_i, z_i, y_i}。
pLLM(z_j, y_j | {x_i, z_i, y_i}_i ∈ S_j, x_j, m)
这将产生一组 {z_j},作为解决测试集类似问题的提示(Xu 等,2023)。
模型组合
此外,Reprompting 通过利用不同的 LLMs 进行初始化和采样,促进了模型的组合。经验表明,使用 ChatGPT 生成 InstructGPT 的初始配方样本,相较于单独使用 InstructGPT 或 ChatGPT,能显著改善特定任务的表现。然而,结果也表明,某一模型的高效 CoT 配方在另一模型上表现可能较差,尽管后者在使用人类优化提示时表现类似。这表明 CoT 配方必须在考虑模型组合的情况下进行构建。
基准结果
与之前的最先进提示技术相比,Reprompting 的表现更佳,且无需人工干预。举例来说,Reprompting 结合 ChatGPT 通常在所有任务上比人工编写的 CoT 提示取得更高分数(Suzgun 等,2022)。
吉布斯采样,记作 RePrS(Wei 等,2022)
在实践中,我们可以通过 Reprompting 观察 CoT 配方的演变:
Reprompting 进化了 CoT 配方,以更好地解决复杂问题(Wei 等,2022)
最初,ChatGPT 优先考虑约束条件,首先关注绝对排名位置(用深蓝色表示)。接下来,模型尝试推断特定位置的对象,但犯了一个错误(用红色表示)。然而,这个配方仍然为解决类似问题提供了有用的策略。当应用于一个新问题时,模型采用相同的重新排序策略,并提出了一种处理约束的替代方法(用橙色表示)。尽管存在一些错误,但这个配方改进了该特定问题的解决方案。最后,当作为新提示使用时,模型遵循相同的公式,并正确推断出新问题的答案。
Reprompting 的引入可能标志着大型语言模型发展的又一个里程碑,特别是在需要多步骤逻辑和约束传播的任务中。通过利用链式思维提示和 Gibbs 采样,Reprompting 可以在没有人工干预的情况下自动发现有效的 CoT 提示。因此,与零-shot 或传统的 few-shot 提示技术相比,LLMs 在复杂任务中的表现可以得到提升。此外,经过优化后,Reprompting 展示了在不同 LLM 之间传递收益的潜力。最终,这一方法可能使我们更接近实现表现出类似人类逻辑推理和推理能力的 LLM 的目标。
参考文献
Brown TB, Mann B, Ryder N, Subbiah M, Kaplan J, Dhariwal P, Neelakantan A, Shyam P, Sastry G, Askell A, 等. 2020. 语言模型是少-shot 学习者。arXiv [csCL]. arxiv.org/abs/2005.14165
Casella G, George EI. 1992. 解释 Gibbs 采样器。Duke.edu. [访问日期:2023 年 5 月 29 日]。www2.stat.duke.edu/~scs/Courses/Stat376/Papers/Basic/CasellaGeorge1992.pdf
Geman S, Geman D. 1984. 随机松弛、Gibbs 分布与贝叶斯图像恢复。IEEE Trans Pattern Anal Mach Intell. PAMI-6(6):721–741. doi:10.1109/tpami.1984.4767596. [访问日期:2023 年 5 月 29 日]。image.diku.dk/imagecanon/material/GemanPAMI84.pdf
Suzgun M, Scales N, Schärli N, Gehrmann S, Tay Y, Chung HW, Chowdhery A, Le QV, Chi EH, Zhou D, 等. 2022. 挑战 BIG-Bench 任务以及链式思维是否能解决它们。arXiv [csCL]. arxiv.org/abs/2210.09261
Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser L, Polosukhin I. 2017. 注意力机制即一切。arXiv [csCL]. arxiv.org/abs/1706.03762
Wei J, Wang X, Schuurmans D, Bosma M, Ichter B, Xia F, Chi E, Le Q, Zhou D. 2022. 链式思维提示引发大型语言模型的推理。arXiv [csCL]. arxiv.org/abs/2201.11903
Xu W, Banburski-Fahey A, Jojic N. 2023. Reprompting:通过吉布斯采样进行自动化链式思维提示推断。arXiv [csLG]. arxiv.org/abs/2305.09993
研究人道主义灾难情况报告聊天机器人 — 使用 GPT-4-Turbo 和完整上下文提示
·
跟进 发布在Towards Data Science ·16 分钟阅读·Nov 15, 2023
--
TL;DR
在本文中,我们探索了 OpenAI 全新的 GPT-4-Turbo 模型,利用其增加的 128k 令牌上下文窗口来传递完整文档语料库以进行信息检索。这种粗略的蛮力方法——仅在更大的上下文窗口下才可能——实现简单,不需要文档嵌入和语义搜索,如检索增强生成(RAG)中使用的那样。应用于发布在令人惊叹的 ReliefWeb 平台上的人道主义灾难情况报告——使用稀疏引导表示(SPR)进行压缩——我们展示了 GPT-4-Turbo 能够回答有关近期灾难的基本问题。然而,即使在 OpenAI 最近降低了令牌成本的情况下,这种方法仍然代价高昂,提示预览版 GPT-4-Turbo 模型响应非常缓慢,有时需要一分钟才能回复。与所有 LLM 信息检索模式一样,实施验证框架以确保控制虚假信息和信息遗漏当然至关重要。也就是说,GPT-4-Turbo 在能力上迈出了重要一步,尤其是随着性能的提升和成本的下降,它将加入快速扩展的 LLM 工具包。
新的 GPT-4-Turbo 具有 128k 令牌的更大上下文窗口。图像由 GPT-4 + Dall-E-3 生成。
鉴于过去几个月中,autogen、memgpt、Semantic Kernel,以及OpenAI 的 GPTs 和 GPT-4-Turbo的发布,发展步伐非常迅速,我想做一系列文章,比较这些新工具提供的一些对话信息检索技术。每种技术都有其优缺点,有些可能会在我们使用大型语言模型(LLMs)的方式上带来潜在的范式转变。这是一个非常令人惊叹的时刻,但将这些新技术应用于实际场景中,并不像初步探索时所建议的那样简单。
OpenAI 发布 GPT-4-Turbo(预览)
OpenAI 在DevDay 的近期公告中包括了一些可能改变生成 AI 领域的新功能。其中之一是发布了具有 128k 令牌上下文窗口限制的GPT-4-Turbo(预览),相比之前的 32k 令牌。此前,Claude AI 提供了 100k 令牌的最大商业上下文限制,因此 GPT-4-Turbo 在这方面更进一步。此外,OpenAI 还表示,他们的新模型在遵循指令方面更为高效,成本将比 GPT-4 低 3 倍。作为许多 LLM 基准测试的领先者,GPT-4 的任何进展都很重要。
增加模型令牌限制以扩展上下文
那么为什么增加的令牌限制很重要呢?在提示 LLM 时,你可以包括过去的对话,所以一个直接的好处是,如果你在每次提示中提供对话历史,它会记住你之前谈论的内容。这对于引用在早期对话中可能现在重要的事实非常有用。更大的上下文窗口意味着你还可以‘预加载’聊天,加入支持信息,比如文档内容和数据。
但也有一个缺点。
更多的令牌意味着更高的成本和更慢的性能,因为使用变换器架构时,内存和计算需求呈平方增加(远快于简单的直线)。另外,有研究表明较长的上下文窗口可能会降低 LLM 的准确性(刘等,2023)。
使用 GPT-4-Turbo
在写作时,GPT-4-Turbo 仅处于预览模式,仅作为模型 ‘gpt-4–1106-preview’ 提供。调用时我们将使用 openai Python 包,如下 …
import openai
openai.api_key = '<YOUR KEY GOES KEY>'
def run_llm(query, system_prompt, reference_content):
llm_query = {
"temperature": 1.0,
"max_tokens": 2000,
"top_p": 0.95,
"frequency_penalty": 0,
"presence_penalty": 0,
}
response = openai.ChatCompletion.create(
model=model,
messages=[ {
"role":"system",
"content": system_prompt
},
{
"role":"user",
"content": query
}
],
temperature=llm_query['temperature'],
max_tokens=llm_query['max_tokens'],
top_p=llm_query['top_p'],
frequency_penalty=llm_query['frequency_penalty'],
presence_penalty=llm_query['presence_penalty'],
stop=None
)
answer = response['choices'][0]['message']['content']
return answer
分析 ReliefWeb 灾害情况报告
我们将通过使用 GPT-4-Turbo 来分析 令人惊叹的 ReliefWeb 平台上的人道主义灾害情况报告 来探索其增加的令牌限制。这些报告(称为‘Sitreps’)对于监测和应对全球的人道主义灾害至关重要。它们还提供了一个文本数据语料库,可以被压缩(总结)以适应 GPT-4-Turbo 的上下文窗口,以供我们的分析使用。
建立灾害报告语料库
ReliefWeb 有一个 非常棒的 API 用于访问内容,因此我们将使用它来提取灾害和情况报告的列表 …
import requests
import os
from bs4 import BeautifulSoup
import re
import pandas as pd
import PyPDF2
import traceback
import json
import ast
from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken
def auto_translate(text):
"""
This function automatically detects language and translates to english
Parameters:
text(str): The text to be translated
Returns:
text (str): Translated text if in another language, otherwise
input text
"""
try:
lang = translator.detect(text)
lang = lang.lang
print(f"Language detected: {lang}")
q = translator.translate(text, dest='en')
text = q.text
except Exception as e:
print("An exception occurred trying to translate")
return text
def get_safe_name(name):
"""
This function takes a string and returns a version of it that is
safe to use as a filename.
Parameters:
name (str): The string to be converted to a safe filename.
Returns:
name (str): The safe filename.
"""
name = str(name)
name = re.sub("[⁰-9a-zA-Z]+", "_", name)
name = re.sub(r"_$","", name)
if len(name) == 0:
name = 'Unknown'
return name
def download_pdf(url, download_path):
"""
Function to download a PDF from a URL and save locally
Parameters:
url (str): Location of online PDF file
download_path (str): Folder where to save PDF
"""
response = requests.get(url)
with open(download_path, 'wb') as f:
f.write(response.content)
def save_text(content, file_path):
"""
Function to save text to local file
Parameters:
content (str): Text to save
file_path (str): Folder where to save
"""
with open(file_path, 'w') as file:
print(f'Saving {file_path}')
file.write(content)
def extract_text_from_pdf(pdf_path):
"""
Function to extract text from PDF file
Parameters:
pdf_path (str): Path to PDF file
Returns:
text (str): Text extracted from PDF file
"""
print(pdf_path)
pdf_reader = PyPDF2.PdfReader(pdf_path)
text = ''
for page_num in range(len(pdf_reader.pages)):
page_obj = pdf_reader.pages[page_num]
text += page_obj.extract_text()
return text
def get_rw_data(keyword, filter, sort, fields, endpoint, limit=10, \
save_body_to_text=False):
"""
Function to extract data from ReliefWeb API. For API details see:
https://apidoc.rwlabs.org/?utm_medium=blog&utm_source=reliefweb+website&utm_campaign=api+doc+launching+2016_06
Parameters:
keyword (str): Search string
filter (dict): ReliefWeb filter json
sort (dict): ReliefWeb sort json
fields (list): List of fields to return
endpoint (str): API Endpoint, eg reports, disasters
limit (int): Maximum records to return
save_body_to_text (bool) : Flag to save body to text file, including any PDFs on page
Returns:
all_data (pandas dataframe): Dataframe of data from API
"""
query = {
"appname": "myapp",
"query": {
"value": keyword
},
"filter":filter,
"sort": sort,
"limit": limit,
"fields": fields
}
endpoint = f"{reliefweb_api_url}/{endpoint}?appname=apidoc&query[value]="
print(f"Getting {endpoint} ...")
all_data =[]
response = requests.post(endpoint, json=query)
if response.status_code == 200:
data = response.json()
for article in data["data"]:
article_url = article['fields']['url']
try:
r = article['fields']
print(article_url)
article_response = requests.get(article_url)
if save_body_to_text:
soup = BeautifulSoup(article_response.text, 'html.parser')
main_content = [p.text for p in soup.find_all('p')]
article_text = ' '.join(main_content)
save_text(article_text, docs_folder + '/{}.txt'.format(get_safe_name(article['fields']['title'])))
for link in soup.find_all('a'):
href = link.get('href')
if href.endswith('.pdf'):
download_path = os.path.join(docs_folder, href.split('/')[-1])
if href.startswith('/attachments'):
pdf_url = f'{reliefweb_pdf_url}{href}'
else:
pdf_url = href
download_pdf(pdf_url, download_path)
print(f". Downloaded PDF {download_path} from {pdf_url}")
article_text = extract_text_from_pdf(download_path)
r['article_text'] = article_text
r['reliefweb_query'] = keyword
all_data.append(r)
except Exception as e:
print(f"An exception occurred trying to extract {article_url}")
tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
print(tb_str)
all_data = pd.DataFrame(all_data)
for f in ['disaster','theme']:
if f in list(all_data.columns):
all_data[f] = all_data[f].astype(str)
return all_data
else:
print(f"Request failed with status {response.status_code} {response.text}")
return None
在上述函数中,一些关注点 …
-
如果 ReliefWeb 内容提到 PDF,我们会从中提取文本
-
任何文本都会通过 Google Translate API 自动翻译成英语
-
我们总是捕捉来源以便归属
-
对于这次快速分析,我们没有处理 API 响应的分页
这是我们调用函数以获取自 2023 年 11 月 1 日以来的灾害情况报告的方法 …
filter = {
"operator": "AND",
"conditions": [
{
"field": "disaster.status",
"value": "ongoing"
},
{
"field": "format.name",
"value": "Situation Report"
},
{
"field": "date.created",
"value": {
"from": "2023-11-01T00:00:00+00:00",
"to": "2023-11-30T23:59:59+00:00"
}
}
]
}
sort = ["date.created:desc"]
endpoint = "reports"
fields = {
"include": ["title", "body", "url", "source", "date", "format", "theme",
"country", "status", "primary_country", "disaster", "language", "id"]
}
reliefweb_query = ""
articles = get_rw_data(reliefweb_query, filter, sort, fields, endpoint, 1000, True)
情况报告被保存到文件系统作为文本文件 …
for index, row in articles.iterrows():
date = row['date']['created']
source = row['source'][0]['name']
title = row['title']
id = row['id']
filename = f'{get_safe_name(title)}__{id}_{get_safe_name(date)}.txt'
text = f'''
title: {title}
source: {source}
date: {date}
id: {id}
{row['article_text']}
'''
text = text.encode('utf-8','ignore').decode('utf-8','ignore')
print(text)
save_text(text, docs_folder + '/' + filename)
使用稀疏引导表示 (SPR) 压缩内容
我们可以直接用原始文本从情况报告中提示 GPT-4-Turbo。然而,那些文本中很多都是不相关的——如停用词、标题和页脚——因此包含所有这些内容会很快超过 GPT-4-Turbo 增加后的令牌限制。我们将使用一种叫做 稀疏引导表示 (SPR) 的技术来压缩文档到其关键事实,并稍作修改以试图保留定量信息。
这是我们将提供给 GPT-4-Turbo 的系统提示,用于压缩我们的文档 …
# MISSION
You are a Sparse Priming Representation (SPR) writer. An SPR is a
particular kind of use of language for advanced NLP, NLU, and NLG
tasks, particularly useful for the latest generation of Large Language
Models (LLMs). You will be given information by the USER which you
are to render as an SPR.
# THEORY
LLMs are a kind of deep neural network. They have been demonstrated
to embed knowledge, abilities, and concepts, ranging from reasoning
to planning, and even to theory of mind. These are called latent abilities and latent content, collectively referred to as latent space. The latent space of an LLM can be activated with the correct series of words as inputs, which will create a useful internal state of the neural network. This is not unlike how the right shorthand cues can prime a human mind to think in a certain way. Like human minds, LLMs are associative, meaning you only need to use the correct associations to "prime" another model to think in the same way. The exception are quantitative metrics, which you MUST preserve.
# METHODOLOGY
Render the input as a distilled list of succinct statements,
assertions, associations, concepts, analogies, and metaphors.
The idea is to capture as much, conceptually, as possible but with
as few words as possible while preserving all quantitative details.
Write it in a way that makes sense to you, as the future audience
will be another language model, not a human.
这将把原始文本状态报告总结成类似于标题为“非洲脊髓灰质炎实验室网络简报(第 1–42 周,2023 年)”的情况报告 …
- Information comes from World Health Organization.
- Weekly dengue cases decreasing, 12,634 reported this week.
- 78 dengue-related deaths this week.
- Cumulative cases since January 2023: 267,680, including 1,333 deaths.
- Cases decreasing in all divisions except Sylhet.
- Increase in dengue cases in Faridpur, Khulna, Rajshahi, Rajbari districts.
- WHO orientation session for entomologists in Dhaka occurred on 26 October.
- Overall Case Fatality Rate (CFR) is 0.5%, but 0.62% for this week.
- Dhaka division reports highest number of cases and deaths.
- In this week, 13,430 people released from hospital after treatment.
- Male-to-female death ratio: 43% vs 57%.
- Most deaths among 16-55 years age group.
- Four dengue virus types circulating: DENV-1 (2.2%), DENV-2 (68.1%), DENV-3 (25.4%), DENV-4 (0.2%).
- Dengue prevention & clean-up week launched nationwide from 29 October to 4 November.
- WHO expert entomologist deployed to Bangladesh for technical support in outbreak response.
- In the Rohingya camps, weekly dengue cases continue to decline. Total cumulative cases are 12,969 including 17 deaths (CFR: 0.1%).
当然,这比原始文件要少得多。
我不建议在没有进行重要分析和检查以控制信息遗漏的情况下使用这种压缩技术,但对于我们的测试来说,这足够了。
这里是用于压缩报告的代码……
# gpt4 turbo is 128k
chunk_size = 100000
llm_query = {
"prompt": "",
"temperature": 1.0,
"max_tokens": 2000,
"top_p": 0.95,
"frequency_penalty": 0,
"presence_penalty": 0,
"system_prompt":"""
# MISSION
You are a Sparse Priming Representation (SPR) writer. An SPR is a
particular kind of use of language for advanced NLP, NLU, and NLG
tasks, particularly useful for the latest generation of Large Language
Models (LLMs). You will be given information by the USER which you
are to render as an SPR.
# THEORY
LLMs are a kind of deep neural network. They have been demonstrated
to embed knowledge, abilities, and concepts, ranging from reasoning
to planning, and even to theory of mind. These are called latent abilities and latent content, collectively referred to as latent space. The latent space of an LLM can be activated with the correct series of words as inputs, which will create a useful internal state of the neural network. This is not unlike how the right shorthand cues can prime a human mind to think in a certain way. Like human minds, LLMs are associative, meaning you only need to use the correct associations to "prime" another model to think in the same way. The exception are quantitative metrics, which you MUST preserve.
# METHODOLOGY
Render the input as a distilled list of succinct statements,
assertions, associations, concepts, analogies, and metaphors.
The idea is to capture as much, conceptually, as possible but with
as few words as possible while preserving all quantitative details.
Write it in a way that makes sense to you, as the future audience
will be another language model, not a human.
"""
}
# Save texts
for index, row in articles.iterrows():
date = row['date']['created']
source = row['source'][0]['name']
report = row['title']
id = row['id']
text = row['article_text']
primary_country = row['primary_country']['name']
disaster = ''
disaster_types = ''
for d in ast.literal_eval(row['disaster']):
disaster += f"{d['name']}; "
for t in d['type']:
if 'primary' in t and t['primary'] == True:
disaster_types += f"{t['name']}; "
d = {
"disaster": disaster,
"date": date,
"disaster_types": disaster_types
}
prefix = ""
filename = f'{get_safe_name(report)}__{id}_{get_safe_name(date)}.txt'
header = f'- report: "{report}"\n- disaster: "{disaster}"\n' + \
f'- disaster_types: "{disaster_types}"\n' + \
f'- primary_country: "{primary_country}"\n- source: "{source}"\n' + \
f'- date: "{date}"\n- id: "{id}"\n'
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size, chunk_overlap=100
)
texts = text_splitter.split_text(text)
print(f"\n\n================ {report} =================\n")
print(primary_country)
print(disaster)
print(len(texts))
summarized_text = ''
for i, t in enumerate(texts):
response = openai.ChatCompletion.create(
model=model,
messages=[ {
"role":"system",
"content":llm_query['system_prompt']
},
{
"role":"user",
"content":t
}
],
temperature=llm_query['temperature'],
max_tokens=llm_query['max_tokens'],
top_p=llm_query['top_p'],
frequency_penalty=llm_query['frequency_penalty'],
presence_penalty=llm_query['presence_penalty'],
stop=None)
summary = response['choices'][0]['message']['content']
summarized_text += "\n" + summary
summarized_text = auto_translate(summarized_text)
summarized_text = header + summarized_text
summarized_text = summarized_text.split("\n")
summarized_text_prefixed = ''
for s in summarized_text:
summarized_text_prefixed += f"{prefix}{s}\n"
print(summarized_text_prefixed)
save_text(summarized_text_prefixed, docs_folder2 + '/' + filename)
您会注意到在上述内容中,我们添加了关于报告的一些元数据,以及由 GPT-4-Turbo 返回的 SPR 摘要。然后将压缩的报告保存为文本文件。
提取一个灾害的高级列表
我们还将从 ReliefWeb 中提取一份灾害的高级列表,用作我们系统提示的信息请求辅助……
filter = {
"operator": "AND",
"conditions": [
{
"field": "status",
"value": "ongoing"
},
{
"field": "date.event",
"value": {
"from": "2020-01-01T00:00:00+00:00",
"to": "2023-11-30T23:59:59+00:00"
}
}
]
}
sort = ["date.event:desc"]
endpoint = "disasters"
fields = {
"include": ["name", "description", "date", "url", "id","status","glide"]
}
reliefweb_query = ""
disasters = get_rw_data(reliefweb_query, filter, sort, fields, endpoint, 1000, False)
display(disasters)
disasters.to_csv('disasters.csv')
这给我们提供了一个简明的列表……
使用 ReleiefWeb API 灾难终端提取的灾难列表
为 GPT-4-Turbo 创建一个提示
现在我们有一份灾难和压缩情况报告的列表 — 从 11 月 1 日到 11 月 10 日 — 列出了这些灾难的关键事实。
让我们将它们合并成一个文本文件,作为 GPT-4-Turbo 系统提示的一部分使用……
disasters = pd.read_csv('disasters.csv')
concatenated_content = "=========== this section gives a list of DISASTERS =========== \n\n "+ disasters.to_csv()
concatenated_content += "\n\n=========== this section provides disater REPORTS for each disaster =========== "
for f in os.listdir(docs_folder2):
with open(f"{docs_folder2}/{f}", "r") as file:
file_content = file.read()
concatenated_content += f"\n\n----- report: {f} ----- \n\n"
concatenated_content += file_content + "\n\n"
有多少令牌和成本是多少?
def num_tokens_from_string(string: str, encoding_name: str) -> int:
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
gpt4_token_cost_per_1000 = 0.01
cost = (num_tokens/1000.0)*gpt4_token_cost_per_1000
return num_tokens, cost
tokens, cost = num_tokens_from_string(concatenated_content,"cl100k_base")
OpenAI Tokens: 82001 ($0.82001)
因此,鉴于每 1,000 个令牌输入的成本为\(0.01](https://openai.com/pricing),上面创建的文本每次提示的成本为\),上面创建的文本每次提示的成本为$0.82。还有一些完成令牌成本,每 1,000 个令牌为$0.03,但由于令牌较少,这应该远低于输入成本。
哎呀!
我们知道这种蛮力技术并不是解决我们特定任务的最佳方式,但高成本是另一个原因。
现在我们有了文本,我们可以构建一个系统提示……
def run_llm(query, reference_content):
llm_query = {
"temperature": 1.0,
"max_tokens": 2000,
"top_p": 0.95,
"frequency_penalty": 0,
"presence_penalty": 0,
}
response = openai.ChatCompletion.create(
model=model,
messages=[ {
"role":"system",
"content": f"""You are a ReliefWeb disasters bot. You
provide information on diasters.
If asked for a list of disasters, just provide the list
ONLY using the 'Disasters' section below.
If asked for details of a disaster, extract that information
from 'Situation Reports' section below.
{reference_content}"""
},
{
"role":"user",
"content": query
}
],
temperature=llm_query['temperature'],
max_tokens=llm_query['max_tokens'],
top_p=llm_query['top_p'],
frequency_penalty=llm_query['frequency_penalty'],
presence_penalty=llm_query['presence_penalty'],
stop=None
)
answer = response['choices'][0]['message']['content']
return answer
您会注意到系统提示包括有关在连接文本中使用不同部分的说明。通过一些快速的提示实验 — 我真的不会称之为“工程” — 这样可以产生更好的结果。
测试我们的人道灾难情况报告机器人
好的,开始……
answer = run_llm("What's the latest status regaring ther Nepalese earthquake
this year?", concatenated_content)
print(answer)
The latest status regarding the Nepal earthquake as of November 2023 is as
follows:
- Earthquake: Western Nepal, 6.4 magnitude, occurred on 3 November 2023,
with an epicenter in Ramidanda, Jajarkot.
- Casualties: 154 deaths, 364+ injuries.
- Displacement: Approximately 250,000 people affected, with significant
numbers of homes damaged or destroyed.
- Infrastructure: Over 62,000 houses across 13 districts sustained
damage (26,557 fully, 35,455 partially). Several government buildings were also damaged.
- Humanitarian Response: The United Nations and other humanitarian
organizations have been active in relief efforts, providing medical tents, child-friendly spaces, WASH supplies, and psychosocial support, among other assistance.
- Recovery and Rehabilitation: The recovery and rehabilitation phase is
in planning, with international reinforcements arriving to aid in coordination and strategy development.
- Education and Health Services: Education has been disrupted for
about 80,000 children due to damaged schools, and health services
have been impacted, increasing the risk of malnutrition and complicating maternal and newborn care.
Please note that this is a condensed summary of the available data up until
10 November 2023, and further updates may follow as the situation evolves
and more assessments are made.
总的来说,这是对灾难的一个相当不错的摘要。事实与ReliefWeb 灾难页面(请注意,自这篇文章以来,此页面可能会更改,因为灾难正在发展)和用于本研究的最新情况报告很好地一致。
不过,有一个事实略微不一致,数据显示,不幸的是有 153 人死亡,而模型返回了 154 人。
answer = run_llm("How many people died in the Nepal earthquake in 2023 occur?",
concatenated_content)
print(answer)
In the Nepal Earthquake of November 2023, there were 153 fatalities.
因此,在第一个提示中,“已知”但在第一个提示中产生了幻觉。与 LLM 的所有应用一样,验证和检查对于捕捉此类问题至关重要。
让我们获取一些关于创建情况报告的组织的信息(用于 11 月 1 日至 10 日的期间)……
answer = run_llm("List the organizations providing sitreps for the Nepal
Earthquake", concatenated_content)
print(answer)
The organizations providing situation reports (sitreps) for the Nepal Earthquake include:
- UN Office for the Coordination of Humanitarian Affairs (OCHA)
- Nepal Red Cross Society
- International Federation of Red Cross And Red Crescent Societies (IFRC)
- UN Resident Coordinator in Nepal
- World Health Organization (WHO)
- UN Children's Fund (UNICEF)
- UN Country Team in Nepal
相比于同一查询在 ReliefWeb 中的结果,这个看起来是正确的。
让我们问一些更细致的问题……
answer = run_llm("How many displaced people are children?",
concatenated_content)
print(answer)
In the provided data, there are references to children affected by disasters
in different reports, but the specific number of displaced children is not
directly mentioned in any of the excerpts. It would be necessary to review
more detailed reports or data sources that focus specifically on displacement
figures for children to determine the exact number affected by displacement.
查看情况报告时,我也未能找到这些信息。不过需要进行更详细的反查来确认这一点是否正确。
让我们探讨一下潜在的疾病和健康影响……
answer = run_llm("What disease risks are there for the Nepal 2023 earthquake?",
concatenated_content)
print(answer)
The disease risks for the Nepal 2023 earthquake, as reported by UNICEF in
their Humanitarian Situation Report No. 2, are primarily focused on
malnutrition, particularly among children under five years old, and
psychosocial stress affecting the affected population. The disruption of
healthcare and WASH (Water, Sanitation, and Hygiene) services due to
the earthquake has further exacerbated the health risks.
The key disease risks and health concerns include:
1\. Malnutrition: The earthquake and its aftermath have disrupted the food
supply chain, which can lead to an increased risk of malnutrition among
vulnerable populations, especially children.
2\. Maternal and Newborn Health: Pregnant women and newborns might face
increased health risks due to the disruption of health care services and
a lack of access to essential medical care and facilities.
3\. Psychosocial Stress: The trauma caused by the earthquake, the loss of
loved ones, and the displacement can lead to significant psychosocial
stress and mental health issues.
4\. Injuries and Physical Health: With many injured in the aftermath,
there is an increased need for medical treatment for physical injuries
and trauma care.
5\. Communicable Diseases: Overcrowding in temporary shelters can lead
to the spread of communicable diseases, particularly in the absence of
proper sanitation and hygiene facilities.
6\. Waterborne and Vector-borne Diseases: Lack of clean water and
sanitation increases the risk of waterborne diseases like diarrhea,
cholera, and dysentery. There is also a concern for vector-borne
diseases such as malaria and dengue fever.
7\. Non-communicable Diseases: Disruption of routine health services
can affect the management and treatment of chronic non-communicable diseases.
8\. Disruption of Routine Immunization: The potential disruption of
routine immunization services can increase the risk of vaccine-preventable
diseases.
In response to these health risks, UNICEF and other health sector partners
have been providing health services, medical supplies, and psychosocial
support, as well as ensuring continued access to nutrition for children
and mothers.
从数据中的 UNICEF 报告来看,即使通过我们的 SPF 摘要,上述内容似乎捕捉到了主要观点。然而,我们这里使用的数据只有报告第 1 号,但上述内容提到了报告第 2 号。这是另一个小的错误,再次表明任何 LLM 的回应都需要自动回查。
结论
这是一个非常快速的分析,旨在了解 GPT-4-Turbo — 利用其增加的 128k 上下文窗口和改进的指令跟随能力 — 在包括所有信息检索所需的文本数据作为系统提示时的表现,且没有其他处理。这种蛮力方法是否是我们任务的最佳技术?
可能不是,至少目前还不是。
性能是一个严重的考虑因素。提示处理时间为 30 秒或更长,这对用户体验来说不太理想。
成本也是一个限制因素。每次提示需要更多的令牌来提供完整语料库,这增加了成本——在这次分析中,每次提示需要 $0.82!——因此,其他技术无疑对许多组织更具吸引力,至少在 LLM 成本高昂的情况下。不过,我想起了 存储成本的历史,也许我们会看到 LLM 成本的相同下降。
生成代码以查询数据、使用注册到 LLM 的函数和多代理验证结果等替代方案可能提供更便宜和更准确的选择。它们还将消除为了适应上下文窗口而压缩文档的需要,从而避免信息丢失。
也就是说,我们能够展示,提示 GPT-4-Turbo 可以支持在系统提示中提供的压缩文档语料库上的基本信息检索。这具有非常直接的实施优点——你只需将所有内容提供给语言模型并提问。随着 AI 的进步和成本的降低,这在未来可能会成为一种非常普遍的技术。
使用 LangChain 和 GPT-4 研究多语言 FEMA 灾难机器人
高风险聊天应用程序中检索增强生成(RAG)的优缺点
·
关注 发表在 Towards Data Science ·54 分钟阅读·2023 年 8 月 17 日
--
图片由 DALL-E2 使用提示“一个激流围绕一个机器人泛滥的照片”生成
简要概述
在本文中,我们探讨了如何构建一个多语言的 美国联邦应急管理局 (FEMA) 灾难聊天机器人,以帮助人们准备和应对诸如洪水、龙卷风、野火、地震和冬季风暴等灾难。我们使用 LangChain 和 GPT-4 从 34 个 FEMA PDF 文档中创建了一个聊天界面。尽管这种流行模式非常令人惊叹,但在高风险应用如灾难响应机器人中仍需谨慎。虽然大语言模型 (LLM) 的幻觉现象已被最小化,但文档的语义搜索常见问题可能导致聊天响应中遗漏关键信息。我们测试了一些简单的技术来提高这种特定分析的性能,比如将文档元数据纳入嵌入、通过 LLM 零样本上下文分类丰富用户问题,以及使用谷歌翻译进行自动语言检测和翻译,以更好地支持像斯瓦希里语这样的语言。应用的技术非常基础,但原型机器人的信息提取能力从 FEMA PDF 文档中显示了希望。显然,如果将这种技术用于高风险情况,需进行测试和验证,理想情况下使用自动化的 LLM 技术创建问答验证数据。
几周前,我们在佛蒙特州经历了一次严重的洪水事件。流经我们房子的那条小溪变成了一个愤怒的怪物,狂暴地破坏一切。幸运的是,我们没有受到严重损害,但遗憾的是,许多人失去了财产和生计。在洪水的某个时刻,看起来我们可能需要撤离,于是我开始查看 美国联邦应急管理局 (FEMA) 网站上的建议。我已经有了一些准备,但当麻烦来临时,我想重新确认一些事情。FEMA 的 PDF 文档和网页资源非常优秀且简明,我开始搜索,但我在想……
- 在紧急情况下,有没有比搜索和阅读多个文档更快的方法来获取有用信息?
一个明显的解决方案可能是询问聊天机器人。我认为聊天机器人有时可能会被过度使用,但在时间紧迫的情况下,这确实是一个很好的应用场景,因为对话界面可能更为高效。
这并不是一个新想法,像美国红十字会这样的组织已经开发了类似于Clara的灾害响应机器人。然而,最近出现了一种有前景的新模式,利用生成式人工智能大型语言模型(LLMs),如OpenAI 的 GPT-4、Meta 的 LLAMA 2以及在HuggingFace上不断增长的众多模型。这些模型可以用来索引指定的文档集,并与其进行对话。这种方法称为检索增强生成(RAG),网上有数百个教程展示了使用令人惊叹的LangChain Python 包,我完全预期这种技术很快会出现在我们每天使用的软件中。有趣的是,它将响应限制在提供的内容范围内,从而减少了幻觉的发生,这有助于在灾害响应等关键情况下使用 LLMs。
但在信息检索可能拯救生命的情况下,这种方法的安全性如何呢?
在本文中,我简要探讨了创建一个 LangChain GPT-4 聊天界面,用于基于美国联邦应急管理署(FEMA)的一组文档询问有关灾害安全的问题。我们将遇到一些需要考虑的限制,如果在高风险情况下使用这种技术的话。
FEMA 灾害准备和安全 PDF 文档
对于这项研究,我从 FEMA 下载了 34 个 PDF 文件(列表在这里),这些文件涵盖了广泛的灾害相关话题,包括如何准备和应对野火、龙卷风、洪水、地震和冬季风暴。这不是 FEMA 提供的所有惊人资源,但应该足够测试我们的聊天界面。
索引文档以进行信息检索
下载后,我们可以使用 LangChain 读取 PDF 文档。文档被分割成文本块,并使用嵌入模型给每个块分配一个指纹(嵌入)。在这项分析中,我们将使用OpenAI 的嵌入,但 LangChain 支持许多其他模型。
import os
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.document_loaders import PyPDFLoader
files = os.listdir(pdf_folder_path)
files.sort()
all_docs_list =[]
for file in files:
if file.endswith('.pdf'):
print(file)
all_docs_list.append(file)
loader = PyPDFDirectoryLoader(pdf_folder_path)
all_docs = loader.load()
print(pdf_folder_path)
cfpb_adult-fin-edyour-disaster-checklist.pdf
fema_protect-your-home_flooding.pdf
fema_protect-your-property-storm-surge.pdf
fema_protect-your-property_coastal-erosion.pdf
fema_protect-your-property_earthquakes.pdf
fema_protect-your-property_severe-wind.pdf
fema_protect-your-property_wildfire.pdf
fema_safeguard-critical-documents-and-valuables.pdf
fema_scenario_1-active_shooter-01102020.pdf
fema_scenario_10_power_outage_01102020.pdf
fema_scenario_10_power_outage_answer_key_01102020.pdf
fema_scenario_11_winter_storm_01102020.pdf
fema_scenario_11_winter_storm_answer_key_01102020.pdf
fema_scenario_12_small_business_01102020.pdf
fema_scenario_12_small_business_answer_key_01102020.pdf
fema_scenario_1_active_shooter_TTX_answer_key-01102020.pdf
fema_scenario_2-tornado_TTX_answer_key-01102020.pdf
fema_scenario_2_tornado-01102020.pdf
fema_scenario_3-wildfire_TTX_answer_key-01102020.pdf
fema_scenario_3_wildfire-01102020.pdf
fema_scenario_4-hurricane-01102020.pdf
fema_scenario_4_hurricane_flood_TTX_answer_key-01102020.pdf
fema_scenario_5_extreme_heat-01102020.pdf
fema_scenario_5_extreme_heat_TTX_answer_key_01102020.pdf
fema_scenario_6-pet_preparedness_01102020.pdf
fema_scenario_6-pet_preparedness_TTX_answer_key_01102020.pdf
fema_scenario_7-shelter_in_place_TTX_answer_key_01102020.pdf
fema_scenario_7_shelter_in_place_01102020.pdf
fema_scenario_8_earthquake_01102020.pdf
fema_scenario_8_earthquake_answer_key_01102020.pdf
fema_scenario_9_pandemic_Influenza_01102020.pdf
fema_scenario_9_pandemic_answer_key_01102020.pdf
ready_12-ways-to-prepare_postcard.pdf
ready_document-and-insure-your-property.pdf
查看为一个文档提取的数据www.fema.gov/sites/default/files/2020-11/fema_protect-your-home_flooding.pdf
…
import json
for d in all_docs:
if 'fema_protect-your-home_flooding.pdf' in d.metadata['source']:
print('\n')
print(json.dumps(vars(d), indent=4))
{
"page_content": " \n \n \nPROTECT YOUR \nPROPERTY FROM FLOODING\n \n",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 0
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n -\n\u2014\n- - -\nOwning a property is one of the most important investments most people make \nin their lives. We work hard to provide a home and a future for ourselves and our loved ones. Why risk losing it when bad weather hits close to home? \nFlooding is the most common and costly natural disaster in the United States \nand can happen anywhere. Just one inch of water can cause $25,000 in damages to your home. \nWhile you can\u2019t prevent a natural disaster from happening, there are ways to secure \nyour property to minimize damage and keep your home and your future safe. \nFirst, determine the Base Flood Elevation (BFE) for your home. The \nBFE is how high the water is expected to rise during fooding in high risk areas. You need to know your BFE because it is used in foodplain management regulations in your community that could affect your home for example, how high above the BFE a home or other building should be built. Your local foodplain manager can help you fnd this information. If you need help fnding your foodplain manager, contact FEMA\u2019s Flood Mapping and Insurance eXchange\n a\nt FEMAMapSpecialist@ \nr\niskmapcds.com or (877) FEMA MAP (1 877 336 2627). \nThe following are some additional steps you can take to protect yourself and your property against foods. ",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 1
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nINSIDE THE HOME \nPREPARE OR \nUPDA\nTE A LIST OF \nBELONGINGS Documenting all of your belongings will help with the \ninsurance claims process. Consider taking photos of \nhigh-value items or doing a video walkthrough of your \nho\nme to document its contents. \nGET FLOOD \nINSURANCE Most homeowners insurance \npolicies don\u2019t cover food damage. \nProtect your investment by purchasing food insurance for your home and contents, even if you do not live in \na hi\ngh-risk food zone. \nSTORE \nVALUABLES Store valuables and important documents in waterproof or wa\nter-resistant containers above the BFE \n(preferably on an upper foor). Make copies and store them online or offsite. For more \ninformation on purchasing food insurance through the National Flood \nInsurance Program, visit \nFlo\nodSmart.gov or contact \nyour agent to get coverage. EL\nEVATE UTILITIES \nABOVE THE BFE Elevate or foodproof mechanical units, furnaces, water \nheaters, electrical systems, and other utilities on masonry, concrete, or pressure-treated lumber at least 12 inches above the BFE. \nREPLACE CARPETING \nWITH TILES Tiles are more food-resistant than carpet. Using tile or other food-resistant materials in areas below the BFE can help reduce water damage. \nFLOODPROOF \nBASEMENTS If you have a basement, minimize damage by foodproofng your basement and sealing walls with waterproofng compounds. Consider installing a sump pump. \nINSTALL FLOOD \nVENTS Install food vents in foundation walls, garages, and other enclosed areas to allow water \nto fow through, drain out, an\nd lower the risk of \nstructural damage. \nUSE FLOOD-RESIST\nANT \nINSULATION & DRYWALL Fl\nood-resistant insulation and drywall will help minimize \ndamage and can be easily cleaned and sanitized. \nPREVENT SEW\nAGE \nBACK-UP I\nn some areas, fooding can cause sewage to back \nup through drain pipes in your home. Consult with a plumber and, if applicable, invest in a sewer backfow valve to prevent this potential health hazard. DID YOU KNOW? \nWhen following National Flood \nInsurance Program regulations, \nv\nents can also help lower \ninsu\nrance rates. \n",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 2
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nOUTSIDE THE HOME \nELEVATE YOUR \nHOME While it is an investment, elevating your home prepares your property against foods and lowers food insurance premiums. When a home is properly elevated, the lowest foor should be above the BFE. Areas below the BFE can be used for parking, storage, or access to the house. DID YOU KNOW? \nElevating your home \nmay reduce your food insurance premium. \nSECURE Y\nARD \nITEMS U\nnsecure items can be swept away or damaged by foodwaters. \nThey can also be swept into your home, causing damage. \nSecure items in your yard by anchoring them or attaching them to more substantial structures. \nSEAL CRA\nCKS \nAND GAPS C\nheck caulking around windows and doors to make it is not \ncracked, broken, or missing. Fill any holes or gaps around pipes and wires that enter your building. \nSET YOUR HOME OR \nBUILDINGS BACK, \nAWAY FROM WATER Build your home, garage, shed, or other building away from river channels and shore lines. If possible, build on higher ground. ",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 3
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nDIRECT WATER AWAY \nFROM STRUCTURES If you have a single-family home, make sure your yard \nslopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels. \nANCHOR FUEL \nTANKS Anchor any fuel tanks to the pad to prevent them from tipping over or foating in a food. Spilled fuel could become a fre hazard. Make sure vents and fll-line openings are above the BFE. Note: This may require permission from your fuel provider. \nFLOODPROOF WALLS Add water-resistant exterior sheathing on walls and seal them to prevent shallow fooding from damaging your home. Cover openings below the BFE and seal all exterior openings around pumping and equipment. \nSECURE \nMANUFACTURED \nHOMES If you have a manufactured home and you want food insurance from the National Flood Insurance Program, your home must be affxed to a permanent foundation so that the wheels and axles do not support its weight and resist fotation, collapse, or side-to-side movement. Your local foodplain manager can help you understand the requirements, and a professional engineer or architect can make sure the anchoring system is designed and installed correctly. REMEMBER: \nSome of these tips may work better together than \nothers. Mitigation measures need to be tailored to your property. Also, not all of these options work together, so talk with an expert who can help you identify which options work best for you. \nAlways consult professionals such as your insurance \nagent, architects, engineers, contractors, or other experts in design and construction before making changes to your home. Your local planning and zoning offce or building department is a good place to start for advice. \nFinally, be kind to your neighbors! Talk to adjacent \nproperty owners before you make changes, since some actions on your property may affect theirs. ",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 4
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n ADDITIONAL RESOURCES \nFEMA, PROTECT YOUR PROPERTY \nLearn how to protect your home or business from natural disasters. fema.gov/protect-your-property \nHOMEOWNER\u2019S GUIDE TO RETROFITTING \nfema.gov/media-library/assets/documents/480 \nREDUCING FLOOD RISK TO RESIDENTIAL BUILDINGS \nTHAT CANNOT BE ELEVATED \nfema.gov/media-library/assets/documents/109669 \nPROTECTING YOUR HOME AND PROPERTY FROM \nFLOOD DAMAGE fema.gov/media-library/assets/documents/21471 \nPROTECTING BUILDING UTILITY SYSTEMS FROM \nFLOOD DAMAGE \nfema.gov/media-library/assets/documents/3729 \nPROTECT YOUR PROPERTY FROM FLOODING \nfema.gov/media-library/assets/documents/13261 ",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 5
}
}
{
"page_content": "",
"metadata": {
"source": "docs_data/fema_protect-your-home_flooding.pdf",
"page": 6
}
}
我们可以看到它按页面拆分了文档。实际上,这对于我们正在处理的 FEMA 文档并非不合理,这些文档非常简洁,每页都是一个独立的主题,但对于其他应用,通常更好的是在更细的级别拆分文本,使用LangChain 的 text_splitter。
我们现在可以使用我们的文本摘录来创建嵌入数据库...
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
embedding_model = OpenAIEmbeddings()
embeddings = OpenAIEmbeddings()
vectordb = Chroma.from_documents(all_docs, embedding=embedding_model,persist_directory=vecs_dir)
vectordb.persist()
我们选择了一个简单的选项将嵌入持久化到文件系统,但值得注意的是,Chroma 支持更多选项,如果你有大量文档,性能可能是一个问题。
设置我们的对话界面
我们将首先使用一个 PDF,这样更容易验证结果...
import shutil
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, LLMChain
# Load OPENAI_API_KEY from a .env file
from dotenv import load_dotenv
def setup_model(
vecs_dir: str,
docs_sublist: list,
all_docs: list
) -> ConversationalRetrievalChain:
# Subset for docs we are interested in
docs = []
for d in all_docs:
d_dict = vars(d)
# Create vector DB directory
if os.path.exists(vecs_dir):
shutil.rmtree(vecs_dir)
os.makedirs(vecs_dir)
# Choose our models
embedding_model = OpenAIEmbeddings()
chat_model = ChatOpenAI(temperature=temperature,model_name="gpt-4")
# Calculate embeddings
embeddings = OpenAIEmbeddings()
vectordb = Chroma.from_documents(docs, embedding=embedding_model,persist_directory=vecs_dir)
vectordb.persist()
# Set up chat
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
pdf_qa = ConversationalRetrievalChain.from_llm(chat_model, vectordb.as_retriever(), memory=memory, \
return_source_documents=True)
return pdf_qa
vecs_dir = './vector_dbs/one_flood_doc'
docs = all_docs
# Subset to one document
docs_sublist = ['fema_protect-your-home_flooding.pdf']
pdf_qa = setup_model(vecs_dir, docs_sublist, docs)
在上述内容中,我们选择了 GPT-4 作为聊天模型,其中 API 密钥在 .env
文件中定义,变量为 OPENAI_API_KEY
。LangChain 还支持许多其他模型。
只需几行代码,我们就建立了一个对文档的对话界面,这包括了 LLM 的所有功能。我多年来开发了聊天机器人,可以说这种简洁的模式大大减少了复杂性。
对出色的 LangChain 包致以赞誉!
提出我们的第一个问题
由于我们还希望查看引用的文档以验证聊天响应,我们需要对 LangChain 方法进行一些调整,以解决在使用聊天记忆时检索匹配文档时出现的问题(解决方案已在这里建议)...
# A little mod to enable using memory *and* getting docs. See: https://github.com/langchain-ai/langchain/issues/2256#issuecomment-1665188576
import langchain
from typing import Dict, Any, Tuple
from langchain.memory.utils import get_prompt_input_key
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
langchain.memory.chat_memory.BaseChatMemory._get_input_output = _get_input_output
好的,现在我们准备好对我们的一个 PDF 文档提问了!
def ask_question(
query: str,
qa: object,
output_docs: bool = True
) -> dict:
print(f"\nQuestion: \n{query}")
result = qa({"question": query})
print(f"\nAnswer:\n{result['answer']}")
if output_docs:
for doc in result['source_documents']:
print('\n')
print(json.dumps(vars(doc), indent=4))
return result
ask_question("How do I prepare my home for floods?", pdf_qa)
Question:
How do I prepare my home for floods?
Answer:
There are several steps you can take to prepare your home for floods:
1\. Determine the Base Flood Elevation (BFE) for your home. This is how high the water is expected to rise during flooding in high risk areas. Your local floodplain manager can help you find this information.
2\. Direct water away from structures. Make sure your yard slopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels.
3\. Anchor fuel tanks to prevent them from tipping over or floating in a flood.
4\. Floodproof walls by adding water-resistant exterior sheathing and sealing them to prevent shallow flooding from damaging your home.
5\. Secure manufactured homes to a permanent foundation to resist flotation, collapse, or side-to-side movement.
6\. Inside the home, prepare or update a list of belongings for insurance claims.
7\. Get flood insurance as most homeowners insurance policies don’t cover flood damage.
8\. Store valuables and important documents in waterproof or water-resistant containers above the BFE.
9\. Elevate utilities above the BFE.
10\. Replace carpeting with tiles as they are more flood-resistant.
11\. Floodproof basements by sealing walls with waterproofing compounds and consider installing a sump pump.
12\. Install flood vents in foundation walls, garages, and other enclosed areas to allow water to flow through and drain out.
13\. Use flood-resistant insulation and drywall.
14\. Prevent sewage back-up by consulting with a plumber and investing in a sewer backflow valve if applicable.
Remember to consult professionals such as your insurance agent, architects, engineers, contractors, or other experts in design and construction before making changes to your home.
这似乎很合理,让我们看看用于生成此摘要的内容...
{
"page_content": " \n \n \nPROTECT YOUR \nPROPERTY FROM FLOODING\n \n",
"metadata": {
"page": 0,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n -\n\u2014\n- - -\nOwning a property is one of the most important investments most people make \nin their lives. We work hard to provide a home and a future for ourselves and our loved ones. Why risk losing it when bad weather hits close to home? \nFlooding is the most common and costly natural disaster in the United States \nand can happen anywhere. Just one inch of water can cause $25,000 in damages to your home. \nWhile you can\u2019t prevent a natural disaster from happening, there are ways to secure \nyour property to minimize damage and keep your home and your future safe. \nFirst, determine the Base Flood Elevation (BFE) for your home. The \nBFE is how high the water is expected to rise during fooding in high risk areas. You need to know your BFE because it is used in foodplain management regulations in your community that could affect your home for example, how high above the BFE a home or other building should be built. Your local foodplain manager can help you fnd this information. If you need help fnding your foodplain manager, contact FEMA\u2019s Flood Mapping and Insurance eXchange\n a\nt FEMAMapSpecialist@ \nr\niskmapcds.com or (877) FEMA MAP (1 877 336 2627). \nThe following are some additional steps you can take to protect yourself and your property against foods. ",
"metadata": {
"page": 1,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nDIRECT WATER AWAY \nFROM STRUCTURES If you have a single-family home, make sure your yard \nslopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels. \nANCHOR FUEL \nTANKS Anchor any fuel tanks to the pad to prevent them from tipping over or foating in a food. Spilled fuel could become a fre hazard. Make sure vents and fll-line openings are above the BFE. Note: This may require permission from your fuel provider. \nFLOODPROOF WALLS Add water-resistant exterior sheathing on walls and seal them to prevent shallow fooding from damaging your home. Cover openings below the BFE and seal all exterior openings around pumping and equipment. \nSECURE \nMANUFACTURED \nHOMES If you have a manufactured home and you want food insurance from the National Flood Insurance Program, your home must be affxed to a permanent foundation so that the wheels and axles do not support its weight and resist fotation, collapse, or side-to-side movement. Your local foodplain manager can help you understand the requirements, and a professional engineer or architect can make sure the anchoring system is designed and installed correctly. REMEMBER: \nSome of these tips may work better together than \nothers. Mitigation measures need to be tailored to your property. Also, not all of these options work together, so talk with an expert who can help you identify which options work best for you. \nAlways consult professionals such as your insurance \nagent, architects, engineers, contractors, or other experts in design and construction before making changes to your home. Your local planning and zoning offce or building department is a good place to start for advice. \nFinally, be kind to your neighbors! Talk to adjacent \nproperty owners before you make changes, since some actions on your property may affect theirs. ",
"metadata": {
"page": 4,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": " \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nINSIDE THE HOME \nPREPARE OR \nUPDA\nTE A LIST OF \nBELONGINGS Documenting all of your belongings will help with the \ninsurance claims process. Consider taking photos of \nhigh-value items or doing a video walkthrough of your \nho\nme to document its contents. \nGET FLOOD \nINSURANCE Most homeowners insurance \npolicies don\u2019t cover food damage. \nProtect your investment by purchasing food insurance for your home and contents, even if you do not live in \na hi\ngh-risk food zone. \nSTORE \nVALUABLES Store valuables and important documents in waterproof or wa\nter-resistant containers above the BFE \n(preferably on an upper foor). Make copies and store them online or offsite. For more \ninformation on purchasing food insurance through the National Flood \nInsurance Program, visit \nFlo\nodSmart.gov or contact \nyour agent to get coverage. EL\nEVATE UTILITIES \nABOVE THE BFE Elevate or foodproof mechanical units, furnaces, water \nheaters, electrical systems, and other utilities on masonry, concrete, or pressure-treated lumber at least 12 inches above the BFE. \nREPLACE CARPETING \nWITH TILES Tiles are more food-resistant than carpet. Using tile or other food-resistant materials in areas below the BFE can help reduce water damage. \nFLOODPROOF \nBASEMENTS If you have a basement, minimize damage by foodproofng your basement and sealing walls with waterproofng compounds. Consider installing a sump pump. \nINSTALL FLOOD \nVENTS Install food vents in foundation walls, garages, and other enclosed areas to allow water \nto fow through, drain out, an\nd lower the risk of \nstructural damage. \nUSE FLOOD-RESIST\nANT \nINSULATION & DRYWALL Fl\nood-resistant insulation and drywall will help minimize \ndamage and can be easily cleaned and sanitized. \nPREVENT SEW\nAGE \nBACK-UP I\nn some areas, fooding can cause sewage to back \nup through drain pipes in your home. Consult with a plumber and, if applicable, invest in a sewer backfow valve to prevent this potential health hazard. DID YOU KNOW? \nWhen following National Flood \nInsurance Program regulations, \nv\nents can also help lower \ninsu\nrance rates. \n",
"metadata": {
"page": 2,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
答案看起来很棒,并且出色地总结了匹配的摘录。然而,它似乎遗漏了 源文档中的一页。这个 PDF 是一个简短的文档,其中所有内容都与洪水准备有关,因此遗漏页面实际上是很重要的。
这表明,盲目接受网络上的 LLM 模式可能会得到看起来很棒的结果,但需要进行工作才能使其真正有用。
文档上下文(元数据)可能很重要
在这种情况下,我们有一个与一个主题相关的文档,即洪水,即使个别文本部分可能没有明确提到它。如果我们在每个文本摘录中提供一些上下文,即一些文档元数据,可能会获得更好的结果。
对于我们的简单测试,试着将所有文本摘录前缀加上文件名(fema-protect-your-home-flooding.pdf),去掉标点和后缀,再加上“此摘录相关于”这段文字。最终前缀将是“此摘录相关于 fema protect your home flooding:”,这为 LLM 提供了更多的背景信息……
def setup_model(
vecs_dir: str,
docs_sublist: list,
all_docs: list,
prefix_file_name_to_chunks: bool = False,
temperature: float = 0.0,
extra_prefix: str = '',
) -> ConversationalRetrievalChain:
# Subset for docs we are interested in
docs = []
for d in all_docs:
d_dict = vars(d)
if d_dict['metadata']['source'].replace('docs_data/','') in docs_sublist:
if len(d.page_content) > 20:
# Add file name to content for more context
if prefix_file_name_to_chunks:
file_clean = re.sub(r'docs_data\/|\.pdf', '', d_dict['metadata']['source'])
file_clean = re.sub(r'\-|\_', ' ', file_clean)
d.page_content = f"{extra_prefix} {file_clean}: {d.page_content}"
docs.append(d)
# Create vector DB directory
if os.path.exists(vecs_dir):
shutil.rmtree(vecs_dir)
os.makedirs(vecs_dir)
# Choose our models
embedding_model = OpenAIEmbeddings()
chat_model = ChatOpenAI(temperature=temperature,model_name="gpt-4")
# Calculate embeddings
embeddings = OpenAIEmbeddings()
vectordb = Chroma.from_documents(docs, embedding=embedding_model,persist_directory=vecs_dir)
vectordb.persist()
# Set up chat
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
pdf_qa = ConversationalRetrievalChain.from_llm(chat_model, vectordb.as_retriever(), memory=memory, \
return_source_documents=True)
return pdf_qa
docs = all_docs
# Note the argument prefix_file_name_to_chunks=True
pdf_qa = setup_model(vecs_dir, docs_sublist, docs, prefix_file_name_to_chunks=True, extra_prefix="This snippet relates to ")
ask_question("How do I prepare my home for floods?", pdf_qa)
Question:
How do I prepare my home for floods?
Answer:
There are several steps you can take to prepare your home for floods:
1\. Direct Water Away from Structures: Make sure your yard slopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels.
2\. Anchor Fuel Tanks: Anchor any fuel tanks to the pad to prevent them from tipping over or floating in a flood. Spilled fuel could become a fire hazard.
3\. Floodproof Walls: Add water-resistant exterior sheathing on walls and seal them to prevent shallow flooding from damaging your home. Cover openings below the Base Flood Elevation (BFE) and seal all exterior openings around pumping and equipment.
4\. Secure Manufactured Homes: If you have a manufactured home, it must be affixed to a permanent foundation so that the wheels and axles do not support its weight and resist flotation, collapse, or side-to-side movement.
5\. Elevate Your Home: Elevating your home prepares your property against floods and lowers flood insurance premiums. The lowest floor should be above the BFE.
6\. Secure Yard Items: Unsecure items can be swept away or damaged by floodwaters. Secure items in your yard by anchoring them or attaching them to more substantial structures.
7\. Seal Cracks and Gaps: Check caulking around windows and doors to make sure it is not cracked, broken, or missing. Fill any holes or gaps around pipes and wires that enter your building.
8\. Prepare or Update a List of Belongings: Documenting all of your belongings will help with the insurance claims process.
9\. Get Flood Insurance: Most homeowners insurance policies don’t cover flood damage. Protect your investment by purchasing flood insurance for your home and contents.
10\. Store Valuables: Store valuables and important documents in waterproof or water-resistant containers above the BFE.
11\. Elevate Utilities Above the BFE: Elevate or floodproof mechanical units, furnaces, water heaters, electrical systems, and other utilities.
12\. Replace Carpeting with Tiles: Tiles are more flood-resistant than carpet. Using tile or other flood-resistant materials in areas below the BFE can help reduce water damage.
13\. Floodproof Basements: Minimize damage by floodproofing your basement and sealing walls with waterproofing compounds. Consider installing a sump pump.
14\. Install Flood Vents: Install flood vents in foundation walls, garages, and other enclosed areas to allow water to flow through, drain out, and lower the risk of structural damage.
15\. Use Flood-Resistant Insulation & Drywall: Flood-resistant insulation and drywall will help minimize damage and can be easily cleaned and sanitized.
16\. Prevent Sewage Back-Up: In some areas, flooding can cause sewage to back up through drain pipes in your home. Consult with a plumber and, if applicable, invest in a sewer backflow valve to prevent this potential health hazard.
Remember to consult professionals such as your insurance agent, architects, engineers, contractors, or other experts in design and construction before making changes to your home.
{
"page_content": "This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nDIRECT WATER AWAY \nFROM STRUCTURES If you have a single-family home, make sure your yard \nslopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels. \nANCHOR FUEL \nTANKS Anchor any fuel tanks to the pad to prevent them from tipping over or foating in a food. Spilled fuel could become a fre hazard. Make sure vents and fll-line openings are above the BFE. Note: This may require permission from your fuel provider. \nFLOODPROOF WALLS Add water-resistant exterior sheathing on walls and seal them to prevent shallow fooding from damaging your home. Cover openings below the BFE and seal all exterior openings around pumping and equipment. \nSECURE \nMANUFACTURED \nHOMES If you have a manufactured home and you want food insurance from the National Flood Insurance Program, your home must be affxed to a permanent foundation so that the wheels and axles do not support its weight and resist fotation, collapse, or side-to-side movement. Your local foodplain manager can help you understand the requirements, and a professional engineer or architect can make sure the anchoring system is designed and installed correctly. REMEMBER: \nSome of these tips may work better together than \nothers. Mitigation measures need to be tailored to your property. Also, not all of these options work together, so talk with an expert who can help you identify which options work best for you. \nAlways consult professionals such as your insurance \nagent, architects, engineers, contractors, or other experts in design and construction before making changes to your home. Your local planning and zoning offce or building department is a good place to start for advice. \nFinally, be kind to your neighbors! Talk to adjacent \nproperty owners before you make changes, since some actions on your property may affect theirs. ",
"metadata": {
"page": 4,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nOUTSIDE THE HOME \nELEVATE YOUR \nHOME While it is an investment, elevating your home prepares your property against foods and lowers food insurance premiums. When a home is properly elevated, the lowest foor should be above the BFE. Areas below the BFE can be used for parking, storage, or access to the house. DID YOU KNOW? \nElevating your home \nmay reduce your food insurance premium. \nSECURE Y\nARD \nITEMS U\nnsecure items can be swept away or damaged by foodwaters. \nThey can also be swept into your home, causing damage. \nSecure items in your yard by anchoring them or attaching them to more substantial structures. \nSEAL CRA\nCKS \nAND GAPS C\nheck caulking around windows and doors to make it is not \ncracked, broken, or missing. Fill any holes or gaps around pipes and wires that enter your building. \nSET YOUR HOME OR \nBUILDINGS BACK, \nAWAY FROM WATER Build your home, garage, shed, or other building away from river channels and shore lines. If possible, build on higher ground. ",
"metadata": {
"page": 3,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nINSIDE THE HOME \nPREPARE OR \nUPDA\nTE A LIST OF \nBELONGINGS Documenting all of your belongings will help with the \ninsurance claims process. Consider taking photos of \nhigh-value items or doing a video walkthrough of your \nho\nme to document its contents. \nGET FLOOD \nINSURANCE Most homeowners insurance \npolicies don\u2019t cover food damage. \nProtect your investment by purchasing food insurance for your home and contents, even if you do not live in \na hi\ngh-risk food zone. \nSTORE \nVALUABLES Store valuables and important documents in waterproof or wa\nter-resistant containers above the BFE \n(preferably on an upper foor). Make copies and store them online or offsite. For more \ninformation on purchasing food insurance through the National Flood \nInsurance Program, visit \nFlo\nodSmart.gov or contact \nyour agent to get coverage. EL\nEVATE UTILITIES \nABOVE THE BFE Elevate or foodproof mechanical units, furnaces, water \nheaters, electrical systems, and other utilities on masonry, concrete, or pressure-treated lumber at least 12 inches above the BFE. \nREPLACE CARPETING \nWITH TILES Tiles are more food-resistant than carpet. Using tile or other food-resistant materials in areas below the BFE can help reduce water damage. \nFLOODPROOF \nBASEMENTS If you have a basement, minimize damage by foodproofng your basement and sealing walls with waterproofng compounds. Consider installing a sump pump. \nINSTALL FLOOD \nVENTS Install food vents in foundation walls, garages, and other enclosed areas to allow water \nto fow through, drain out, an\nd lower the risk of \nstructural damage. \nUSE FLOOD-RESIST\nANT \nINSULATION & DRYWALL Fl\nood-resistant insulation and drywall will help minimize \ndamage and can be easily cleaned and sanitized. \nPREVENT SEW\nAGE \nBACK-UP I\nn some areas, fooding can cause sewage to back \nup through drain pipes in your home. Consult with a plumber and, if applicable, invest in a sewer backfow valve to prevent this potential health hazard. DID YOU KNOW? \nWhen following National Flood \nInsurance Program regulations, \nv\nents can also help lower \ninsu\nrance rates. \n",
"metadata": {
"page": 2,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n -\n\u2014\n- - -\nOwning a property is one of the most important investments most people make \nin their lives. We work hard to provide a home and a future for ourselves and our loved ones. Why risk losing it when bad weather hits close to home? \nFlooding is the most common and costly natural disaster in the United States \nand can happen anywhere. Just one inch of water can cause $25,000 in damages to your home. \nWhile you can\u2019t prevent a natural disaster from happening, there are ways to secure \nyour property to minimize damage and keep your home and your future safe. \nFirst, determine the Base Flood Elevation (BFE) for your home. The \nBFE is how high the water is expected to rise during fooding in high risk areas. You need to know your BFE because it is used in foodplain management regulations in your community that could affect your home for example, how high above the BFE a home or other building should be built. Your local foodplain manager can help you fnd this information. If you need help fnding your foodplain manager, contact FEMA\u2019s Flood Mapping and Insurance eXchange\n a\nt FEMAMapSpecialist@ \nr\niskmapcds.com or (877) FEMA MAP (1 877 336 2627). \nThe following are some additional steps you can take to protect yourself and your property against foods. ",
"metadata": {
"page": 1,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
成功了!这似乎确实捕获了 PDF 中的重要页面,并很好地总结了它们。显然,这是一种非常粗略的方法,使用元数据而不仅仅是文件名的更正式方法会更好。使用模板可能也更优雅,而不仅仅是前缀,但它确实展示了这样一点背景如何帮助。
那么如果我们现在使用所有文档集中的文档呢……
vecs_dir = './vector_dbs/all_docs'
docs = all_docs
# Note the argument prefix_file_name_to_chunks=True
pdf_qa = setup_model(vecs_dir, all_docs_list, docs, prefix_file_name_to_chunks=True, extra_prefix="This snippet relates to ")
result = ask_question("How do I prepare my home for floods?", pdf_qa)
Question:
How do I prepare my home for floods?
Answer:
There are several steps you can take to prepare your home for floods:
1\. Create an emergency plan for your family and practice it regularly. When a storm is approaching, evacuate and move your car to higher ground.
2\. Purchase flood insurance for your home and its contents, even if you do not live in a high-risk flood zone.
3\. Document your belongings. This will help with the insurance process if you need to file a claim.
4\. Store valuables and important documents above the Base Flood Elevation (BFE) in waterproof or water-resistant containers.
5\. Elevate appliances and utilities such as water heaters, washers, dryers, and electric panels on higher floors to prevent them from getting damaged by flood water.
6\. Use flood-resistant materials for insulation, drywall, and floor coverings like tile to minimize damage.
7\. Make sure your yard slopes away from buildings on your property and that water has a place to drain.
8\. Anchor any fuel tanks to prevent them from tipping over or floating in a flood.
9\. Add water-resistant exterior sheathing on walls and seal them to prevent shallow flooding from damaging your home.
10\. If you have a manufactured home, make sure it is affixed to a permanent foundation.
11\. Elevate your home above the BFE.
12\. Secure items in your yard by anchoring them or attaching them to more substantial structures.
13\. Check caulking around windows and doors to make sure it is not cracked, broken, or missing. Fill any holes or gaps around pipes and wires that enter your building.
14\. Build your home, garage, shed, or other building away from river channels and shore lines. If possible, build on higher ground.
15\. Replace carpeting with tiles as they are more flood-resistant.
16\. Floodproof your basement and seal walls with waterproofing compounds. Consider installing a sump pump.
17\. Install flood vents in foundation walls, garages, and other enclosed areas to allow water to flow through, drain out, and lower the risk of structural damage.
18\. Use flood-resistant insulation and drywall.
19\. Prevent sewage back-up by consulting with a plumber and investing in a sewer backflow valve if applicable.
Remember to consult professionals such as your insurance agent, architects, engineers, contractors, or other experts in design and construction before making changes to your home.
{
"page_content": "This snippet relates to fema protect your property storm surge: This snippet relates to fema protect your property storm surge: This snippet relates to fema protect your property storm surge: This snippet relates to fema protect your property storm surge: \n \n \n \n \n \n \n \n \n INSIDE YOUR HOME\nHAVE A \nPLANCreate an emergency plan for your family and practice it regularly. \nWhen a storm is approaching, evacuate and move your car to higher ground. According to the National Weather Service , just two feet \nof water can move a vehicle. Visit Ready.gov/plan to learn more.\nGET FLOOD \nINSURANCE Most homeowners insurance policies don\u2019t cover flood damage. Protect your investment by purchasing flood insurance \nfor your home and its contents. \nDo so even if you do not live in \na high-risk flood zone.For more information on purchasing \nflood insurance through the \n \nNational Flood Insurance Program, visit FloodSmart.gov or contact \n \nyour agent to \nget coverage.\n \n PREPARE OR \nUPDATE A LIST OF \n YOUR HOME\u2019S \nCONTENTSDocument your belongings. This will give you peace of mind \nand help with the insurance process if you need to file a claim. Consider documenting your home\u2019s contents visually. You can either take photos of high-value items or walk through your home and videotape each room\u2019s belongings.\nSTORE \nVALUABLESStore valuables and important documents above the BFE (preferably on an upper floor). Place them in waterproof or \nwater-resistant containers. Also, make copies and store them \nonline or offsite.\nELEVATE \nAPPLIANCES AND \nUTILITIES ABOVE \nTHE BFEKeep appliances and utilities such as water heaters, washers, dryers, and \nelectric panels on higher \nfloors. It can prevent them from getting damaged or ruined by flood water.\nTalk to your floodplain \nmanager about how high to elevate your utilities. Many coastal \ncommunities have codes that require utilities to be elevated 12 inches or more above the BFE, called freeboard. \nUSE FLOOD- \nRESISTANT \nMATERIALS\nKNOW YOUR \nPROPERTY AND \nNEIGHBORHOODFlood-resistant insulation, drywall, and floor coverings like tile \nwill help minimize damage and are easier to clean and sanitize.\nIf you are moving near a large lake or the ocean, talk with \nneighbors. Find out about any issues they\u2019ve had, or any \nmitigation measures they have taken. Take the time to look \nup flood information. You can find flood maps for coastal areas on the National Flood Hazard Layer .ADDITIONAL RESOURCES\nFEMA, Homebuilder\u2019s Guide to Coastal Construction\nFEMA, Homeowner\u2019s Guide to Retrofitting\nFEMA, Recommended Residential Construction for Coastal Areas\nNATIONAL HURRICANE CENTER, Storm Surge Overview\nNATIONAL HURRICANE CENTER, Storm Surge Inundation Map\nREMEMBER: \nSome of these tips may work better than others. Tailor \nany mitigation measures to meet your property\u2019s needs.\nAlways consult professionals, such as your insurance agent, architects, \nengineers, contractors, or other experts in design and construction, \nbefore making changes to your home. Your local planning and zoning \noffice or building department is a good place to start for advice.\n",
"metadata": {
"page": 3,
"source": "docs_data/fema_protect-your-property-storm-surge.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nDIRECT WATER AWAY \nFROM STRUCTURES If you have a single-family home, make sure your yard \nslopes away from buildings on your property and that water has a place to drain. Clear your gutters, assess drainage issues, or collect water in rain barrels. \nANCHOR FUEL \nTANKS Anchor any fuel tanks to the pad to prevent them from tipping over or foating in a food. Spilled fuel could become a fre hazard. Make sure vents and fll-line openings are above the BFE. Note: This may require permission from your fuel provider. \nFLOODPROOF WALLS Add water-resistant exterior sheathing on walls and seal them to prevent shallow fooding from damaging your home. Cover openings below the BFE and seal all exterior openings around pumping and equipment. \nSECURE \nMANUFACTURED \nHOMES If you have a manufactured home and you want food insurance from the National Flood Insurance Program, your home must be affxed to a permanent foundation so that the wheels and axles do not support its weight and resist fotation, collapse, or side-to-side movement. Your local foodplain manager can help you understand the requirements, and a professional engineer or architect can make sure the anchoring system is designed and installed correctly. REMEMBER: \nSome of these tips may work better together than \nothers. Mitigation measures need to be tailored to your property. Also, not all of these options work together, so talk with an expert who can help you identify which options work best for you. \nAlways consult professionals such as your insurance \nagent, architects, engineers, contractors, or other experts in design and construction before making changes to your home. Your local planning and zoning offce or building department is a good place to start for advice. \nFinally, be kind to your neighbors! Talk to adjacent \nproperty owners before you make changes, since some actions on your property may affect theirs. ",
"metadata": {
"page": 4,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nOUTSIDE THE HOME \nELEVATE YOUR \nHOME While it is an investment, elevating your home prepares your property against foods and lowers food insurance premiums. When a home is properly elevated, the lowest foor should be above the BFE. Areas below the BFE can be used for parking, storage, or access to the house. DID YOU KNOW? \nElevating your home \nmay reduce your food insurance premium. \nSECURE Y\nARD \nITEMS U\nnsecure items can be swept away or damaged by foodwaters. \nThey can also be swept into your home, causing damage. \nSecure items in your yard by anchoring them or attaching them to more substantial structures. \nSEAL CRA\nCKS \nAND GAPS C\nheck caulking around windows and doors to make it is not \ncracked, broken, or missing. Fill any holes or gaps around pipes and wires that enter your building. \nSET YOUR HOME OR \nBUILDINGS BACK, \nAWAY FROM WATER Build your home, garage, shed, or other building away from river channels and shore lines. If possible, build on higher ground. ",
"metadata": {
"page": 3,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
{
"page_content": "This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: This snippet relates to fema protect your home flooding: \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nINSIDE THE HOME \nPREPARE OR \nUPDA\nTE A LIST OF \nBELONGINGS Documenting all of your belongings will help with the \ninsurance claims process. Consider taking photos of \nhigh-value items or doing a video walkthrough of your \nho\nme to document its contents. \nGET FLOOD \nINSURANCE Most homeowners insurance \npolicies don\u2019t cover food damage. \nProtect your investment by purchasing food insurance for your home and contents, even if you do not live in \na hi\ngh-risk food zone. \nSTORE \nVALUABLES Store valuables and important documents in waterproof or wa\nter-resistant containers above the BFE \n(preferably on an upper foor). Make copies and store them online or offsite. For more \ninformation on purchasing food insurance through the National Flood \nInsurance Program, visit \nFlo\nodSmart.gov or contact \nyour agent to get coverage. EL\nEVATE UTILITIES \nABOVE THE BFE Elevate or foodproof mechanical units, furnaces, water \nheaters, electrical systems, and other utilities on masonry, concrete, or pressure-treated lumber at least 12 inches above the BFE. \nREPLACE CARPETING \nWITH TILES Tiles are more food-resistant than carpet. Using tile or other food-resistant materials in areas below the BFE can help reduce water damage. \nFLOODPROOF \nBASEMENTS If you have a basement, minimize damage by foodproofng your basement and sealing walls with waterproofng compounds. Consider installing a sump pump. \nINSTALL FLOOD \nVENTS Install food vents in foundation walls, garages, and other enclosed areas to allow water \nto fow through, drain out, an\nd lower the risk of \nstructural damage. \nUSE FLOOD-RESIST\nANT \nINSULATION & DRYWALL Fl\nood-resistant insulation and drywall will help minimize \ndamage and can be easily cleaned and sanitized. \nPREVENT SEW\nAGE \nBACK-UP I\nn some areas, fooding can cause sewage to back \nup through drain pipes in your home. Consult with a plumber and, if applicable, invest in a sewer backfow valve to prevent this potential health hazard. DID YOU KNOW? \nWhen following National Flood \nInsurance Program regulations, \nv\nents can also help lower \ninsu\nrance rates. \n",
"metadata": {
"page": 2,
"source": "docs_data/fema_protect-your-home_flooding.pdf"
}
}
做得很好,它从www.fema.gov/sites/default/files/2020-11/fema_protect-your-home_flooding.pdf
中的关键文章提取了信息,还包括www.fema.gov/sites/default/files/documents/fema_protect-your-property-storm-surge.pdf
中的一篇文章,这是文档集中唯一提到洪水的其他文档。
检索提示长度
此时,需要注意的是,最终的总结提示将文档摘录转化为漂亮的答案可能会相当长,具体取决于摘录的大小和数量。与 LLM 的持续战斗是实现我们的目标而不超出令牌限制。尽管对我们的用例没有重大影响,但其他场景可能需要采用上下文压缩方法。
问题背景可能很重要
我们故意混合了与(i)灾害规划和(ii)立即反应相关的文档。这可能会导致混淆的响应,将两种背景混合在一起……
result = ask_question("I live next to a canal and see the water rising right now, what should I do?", pdf_qa, output_docs=True)
Question:
I live next to a canal and see the water rising right now, what should I do?
Answer:
If you notice water levels rising immediately, it's important to prioritize your safety. Here are some steps you can take:
1\. Evacuate: If the water levels are rising rapidly, it's crucial to evacuate to higher ground as soon as possible.
2\. Move your car: If you have time, move your car to higher ground. According to the National Weather Service, just two feet of water can move a vehicle.
3\. Secure your home: If you have time, secure your home by moving valuables and important documents to an upper floor or at least above the Base Flood Elevation (BFE). Place them in waterproof or water-resistant containers.
4\. Communicate: Let your family, friends, and neighbors know about the situation. If you have a pre-established meeting point, head there or inform others of your evacuation plans.
5\. Contact authorities: Inform local authorities about the situation. They can provide guidance and also alert others who might be at risk.
6\. Check flood maps: If you have access to the internet, check flood maps for your area to understand the potential risk.
Remember, these steps are general guidelines. Always follow the advice of local authorities and emergency services during a flood situation.
现在的答案有点混乱,有些点涉及立即行动“将你的车移到高地”,有些点涉及准备“购买洪水保险”。当人们在紧急情况下感到压力时,他们可能不会考虑即时工程,因此我们可以期待稍微模糊的输入。
当然,我们可以通过使用更多的文档元数据来将其拆分为子组来解决此问题,但如果没有这些元数据,则需要工作。另一种选择是向问题提供更多背景,以指示用户是否对灾害准备感兴趣,或需要立即帮助。我们可以为此建立一个分类器,但在这些强大的 LLM 时代,我们可以使用 GPT-4 的零样本分类……
def get_time_context(question):
template = """Does the following question relate to 'planning' or 'taking immediate action': {question}
Answer with one of the following: 'I am planning ahead:' or 'I need to take immediate action:' or 'ambiguous'"""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = OpenAI()
llm_chain = LLMChain(prompt=prompt, llm=llm)
answer = llm_chain.run(question)
return answer
questions = [
"I live next to a canal and the water is rising, what should I do?",
"Help my roof is blowing off!",
"How do I prevent my roof from blowing off in a hurricane",
"Dog",
"How do I prepare my home for floods?"
]
for q in questions:
print(f"Question: {q}")
answer = get_time_context(q)
print(answer.strip(), "\n")
Question: I live next to a canal and the water is rising, what should I do?
I need to take immediate action:
Question: Help my roof is blowing off!
I need to take immediate action:
Question: How do I prevent my roof from blowing off in a hurricane
I am planning ahead.
Question: Dog
Ambiguous
Question: How do I prepare my home for floods?
I am planning ahead.
很好!只需很少的努力,我们就可以轻松确定问题是涉及规划还是立即行动。
现在我们可以用这个前缀来前置用户问题……
def get_time_context(question: str) -> str:
template = """Does the following question relate to 'planning' or 'taking immediate action': {question}
Answer with one of the following: 'I am planning ahead:' or 'I need to take immediate action:' or 'ambiguous'"""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = OpenAI()
llm_chain = LLMChain(prompt=prompt, llm=llm)
answer = llm_chain.run(question)
return answer
def ask_question(
query: str,
qa: object,
output_docs: bool = True,
preprocess_time_context: bool = False
) -> dict:
# First get the time context
if preprocess_time_context:
time_context = get_time_context(query)
if 'planning' in time_context.lower():
query = f"I am planning ahead: {query}"
elif 'take immediate action' in time_context.lower():
query = f"I need to take immediate action: {query}"
print(f"\nQuestion: \n{query}")
result = qa({"question": query})
print(f"\nAnswer:\n{result['answer']}")
if output_docs:
for doc in result['source_documents']:
print('\n')
print(json.dumps(vars(doc), indent=4))
return result
result = ask_question("I live next to a canal and see the water rising, what should I do?", pdf_qa, output_docs=False, preprocess_time_context=True)
这给……
Question:
I need to take immediate action: I live next to a canal and see the water rising, what should I do?
Answer:
If you see the water rising next to your canal, you should immediately implement your emergency plan. This includes evacuating your home and moving your car to higher ground, as just two feet of water can move a vehicle. You should also ensure that your valuables and important documents are stored above the Base Flood Elevation (BFE), preferably on an upper floor, and placed in waterproof or water-resistant containers. If you have time, consider moving appliances and utilities such as water heaters, washers, dryers, and electric panels to higher floors to prevent them from getting damaged or ruined by flood water. Always prioritize your safety and the safety of your family over protecting property.
很好,效果很好,没有试图推销任何保险,给出的建议可以立即付诸实践。
让我们测试一下反向情况…
result = ask_question("What can I do to make my house forest fire resistant?", pdf_qa, output_docs=False, preprocess_time_context=True)
Question:
I am planning ahead: What can I do to make my house forest fire resistant?
Answer:
There are several steps you can take to make your house resistant to forest fires:
1\. Install or replace your roof with a Class A-rated roof with noncombustible coverings. The roof is the most at-risk part of a house in a wildfire due to its size and orientation.
2\. Install and replace exterior wall coverings with noncombustible or fire-resistant materials. A minimum fire-resistance rating of one hour for the wall assembly is recommended.
3\. Create a 30 feet defensible space around your home by reducing or removing flammable vegetation and using noncombustible materials such as gravel, brick, or concrete.
4\. Regularly clean and remove debris from the roof and gutters to reduce the likelihood of something catching on fire on top of your home.
5\. Enclose your home's foundation to lower the chance of wind-blown embers getting underneath your home.
6\. Plan for access to water by purchasing and installing external sprinkler systems with dedicated power sources or a water tank.
7\. Protect large windows from radiant heat by installing multi-pane windows, tempered safety glass, or fireproof shutters.
8\. Install highly visible street signs and property addresses to help firefighters and other emergency responders quickly find your property.
9\. Seal gaps around openings in exterior walls and roofs with fire-resistant caulk, mortar, or fire-protective expanding foam.
10\. Cover exterior attic vents and under-eave vents with metal wire mesh no larger than 1/8 inch to keep embers out.
11\. Install a fire block in the gap between the top of framed walls and the foundation of the house to starve the fire of oxygen and prevent it from spreading.
Remember, these tips may work better together than others and need to be tailored to your property. Always consult professionals such as your insurance agent, architects, engineers, contractors, or other experts in design and construction before making changes to your home.
完美,它提供了森林火灾准备信息。
显然,这需要对任何高风险的真实紧急情况进行大量测试,但它说明了我们可以通过丰富用户的提示来提升性能的一种方法。
确保答案仅来自提供的文档
对于一个信息可能影响安全性的应用程序,确保提供的信息仅来自提供的文档是非常重要的。包含错误信息的幻觉可能会产生非常严重的后果。
让我们尝试问一些与灾难完全无关的问题…
result = ask_question("How do I make a sponge cake?", pdf_qa, output_docs=False)
这会提供…
Question:
How do I make a sponge cake?
Answer:
I'm sorry, but the provided context does not contain information on how to make a sponge cake.
虽然我有点遗憾没能找到如何做一些美味的蛋糕,但 LangChain 很好地处理了这个场景,并判断出问题与提供的 PDF 指南中的信息无关。
很好,没有重大幻觉!在灾难中,我最不希望的就是被告知去做海绵蛋糕或飞宇宙飞船之类的事。😊
对话历史
LangChain 的一个很棒的功能是它无缝地处理对话历史。只需一行代码,模型就能拾起之前的问题…
result = ask_question("There's a hurricane coming, what should I do?", pdf_qa, output_docs=False)
Question:
I need to take immediate action: There's a hurricane forecast, what should I do?
Answer:
If there's a hurricane forecast, you should take the following immediate actions:
1\. Check the flood risk for your area. You can do this by visiting https://msc.fema.gov/portal/home and entering your home’s address.
2\. Understand the difference between a Hurricane Watch and a Hurricane Warning. A Hurricane Watch means that hurricane conditions are possible within a specified area and is issued 48 hours in advance of the anticipated onset of tropical-storm-force winds. A Hurricane Warning means that hurricane conditions are expected within a specified area and is issued 36 hours in advance of the anticipated onset of tropical-storm-force winds.
3\. Take actions around your house to help reduce the impact of a flooding event. This could include placing sand bags in areas that are most at risk from flooding and elevating mechanical devices like air conditioners, generators, and circuit breakers to minimize the risk of impact by storm surge.
4\. Establish a central meeting point so that if the power does go out, people know where to meet up with their loved ones.
5\. Prepare a Go-Kit which should include water, a first-aid kit, a flashlight, batteries, a hand-crank/solar powered radio, non-perishable food, cash, a wrench, hand sanitizer, a mylar blanket, ear plugs, and a book and/or deck of cards.
6\. Make sure to have items and services in place before a hurricane watch or warning is issued. This could include packing comfort items for loved ones and ensuring that any relatives who require regularly scheduled mental health therapy have their providers aware of the location of the shelter.
7\. Store critical documents in Ziploc bags and keep them in a safe place. Alternatively, you can store them on a cloud drive, on a flash drive kept in an alternate location, or in a safety deposit box.
8\. Create a Go-Kit for your pet with essential items, including food, water, medication and favorite toys. Make sure your pet is microchipped or at least has ID tags displayed.
9\. Familiarize yourself with the predetermined coastal evacuation routes in your area.
10\. Plan for communication methods in case phone lines and cellular towers get damaged. This could include text messaging or using social media platforms like Twitter, Facebook, Instagram, etc.
然后问它一个指向“它”(飓风)的提问…
result = ask_question("It's about to arrive, help!", pdf_qa, output_docs=False, preprocess_time_context=True)
Question:
It's about to arrive, help!
Answer:
Here are some immediate actions you should take for safety:
1\. If you are in a flood zone, consider evacuating to a safer location. If evacuation is not possible, move to higher ground within your home.
2\. Secure your home by closing and boarding up windows. Remove any outdoor items that could be picked up by the wind.
3\. Gather your emergency supplies, including your Go-Kit which should include water, a first-aid kit, a flashlight, batteries, a hand-crank/solar powered radio, non-perishable food, cash, a wrench, hand sanitizer, a mylar blanket, ear plugs, and a book and/or deck of cards.
4\. If you have pets, ensure you have a Go-Kit for them as well, including food, water, medication and favorite toys. Make sure they are microchipped or have ID tags.
5\. Make sure your car is filled with gas and ready to go in case you need to evacuate.
6\. Unplug all electrical appliances to prevent power-surge damage.
7\. Stay informed about the storm's path and intensity by listening to the radio or checking trusted websites.
8\. Communicate with family members and make sure everyone knows the plan. If possible, establish a central meeting point in case you get separated.
9\. Store valuable items and important documents in a safe, high place or in waterproof containers.
10\. If you have time, place sandbags in areas that are most at risk from flooding.
Remember, your safety is the most important thing. Material items can be replaced, but lives cannot.
完美,它保持了历史记录,并知道我在指什么。在压力下,用户可能会提出后续问题,因此这种能力——用 2 行代码实现!——是非常重要的。
翻译
为了使聊天机器人更加多功能,我们可以探索多语言支持。许多主要的 LLM,如 GPT-4 提供了对主要语言的本地支持,但性能可能会因语言而异。
让我们从葡萄牙语开始…
from googletrans import Translator
translator = Translator()
q = translator.translate("How do I prepare my home for floods?", dest='pt')
print("================ RAW =================")
pdf_qa = setup_model(vecs_dir, all_docs_list, docs, prefix_file_name_to_chunks=True, extra_prefix="This snippet relates to ")
result = ask_question(q.text, pdf_qa, output_docs=True, preprocess_time_context=True)
q = translator.translate(result['answer'], dest='en')
print("\n\n=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================\n")
print(q.text)
================ RAW =================
Question:
I am planning ahead: Como preparar minha casa para enchentes?
Answer:
Existem várias maneiras de preparar sua casa para enchentes:
1\. Tenha um plano de emergência: Crie um plano de emergência para sua família e pratique-o regularmente. Quando uma tempestade se aproxima, evacue e mova seu carro para um terreno mais alto.
2\. Obtenha seguro contra inundações: A maioria das apólices de seguro residencial não cobre danos causados por inundações. Proteja seu investimento comprando seguro contra inundações para sua casa e seu conteúdo.
3\. Prepare ou atualize uma lista do conteúdo da sua casa: Documente seus pertences. Isso lhe dará tranquilidade e ajudará no processo de seguro se você precisar fazer uma reclamação.
4\. Armazene objetos de valor: Armazene objetos de valor e documentos importantes acima do BFE (preferencialmente em um andar superior). Coloque-os em recipientes à prova d'água ou resistentes à água.
5\. Eleve aparelhos e utilidades acima do BFE: Mantenha aparelhos e utilidades como aquecedores de água, lavadoras, secadoras e painéis elétricos em andares mais altos. Isso pode evitar que sejam danificados ou arruinados pela água da enchente.
6\. Use materiais resistentes a inundações: Isolamento resistente a inundações, drywall e revestimentos de piso como azulejos ajudarão a minimizar os danos e são mais fáceis de limpar e higienizar.
7\. Conheça sua propriedade e vizinhança: Se você está se mudando para perto de um grande lago ou do oceano, converse com os vizinhos. Descubra quaisquer problemas que eles tiveram ou quais medidas de mitigação eles tomaram.
8\. Eleve sua casa: Embora seja um investimento, elevar sua casa prepara sua propriedade contra inundações e reduz os prêmios de seguro contra inundações.
9\. Proteja os itens do seu quintal: Itens inseguros podem ser levados ou danificados pelas águas da enchente. Eles também podem ser levados para dentro de sua casa, causando danos.
10\. Selo rachaduras e lacunas: Verifique o calafetagem ao redor das janelas e portas para garantir que não esteja rachado, quebrado ou faltando. Preencha quaisquer buracos ou lacunas ao redor de tubos e fios que entram em seu prédio.
Lembre-se, sempre consulte profissionais, como seu agente de seguros, arquitetos, engenheiros, empreiteiros ou outros especialistas em design e construção, antes de fazer alterações em sua casa.
=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================
There are several ways to prepare your home for flooding:
1\. Have an emergency plan: Create an emergency plan for your family and practice it regularly. When a storm approaches, evacuate and move your car to higher ground.
2\. Get flood insurance: Most home insurance policies don't cover flood damage. Protect your investment by purchasing flood insurance for your home and contents.
3\. Prepare or update a list of your household contents: Document your belongings. This will give you peace of mind and help with the insurance process if you ever need to file a claim.
4\. Store valuables: Store valuables and important documents above the BFE (preferably on an upper floor). Place them in waterproof or water-resistant containers.
5\. Elevate appliances and utilities above the BFE: Keep appliances and utilities like water heaters, washers, dryers, and electrical panels on higher floors. This can prevent them from being damaged or ruined by floodwater.
6\. Use Flood Resistant Materials: Flood resistant insulation, drywall and floor coverings like tile will help minimize damage and are easier to clean and sanitize.
7\. Know your property and neighborhood: If you're moving near a large lake or the ocean, talk to your neighbors. Find out about any issues they've had or what mitigation measures they've taken.
8\. Elevate your home: Although it is an investment, elevating your home prepares your property against flooding and reduces flood insurance premiums.
9\. Protect Your Yard Items: Unsafe items can be washed away or damaged by floodwaters. They can also be brought into your home, causing damage.
10\. Seal Cracks and Gaps: Check the caulking around windows and doors to ensure it isn't cracked, broken, or missing. Fill in any holes or gaps around pipes and wires entering your building.
Remember, always consult with professionals, such as your insurance agent, architects, engineers, contractors, or other design and construction experts, before making changes to your home.
哪个信息是正确的。然而,对于斯瓦希里语…
q = translator.translate("How do I prepare my home for floods?", dest='sw')
print("================ RAW =================")
pdf_qa = setup_model(vecs_dir, all_docs_list, docs, prefix_file_name_to_chunks=True, extra_prefix="This snippet relates to ")
result = ask_question(q.text, pdf_qa, output_docs=True, preprocess_time_context=True)
q = translator.translate(result['answer'], dest='en')
print("\n\n=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================\n")
print(q.text)
================ RAW =================
Question:
Je, ninatayarishaje nyumba yangu kwa mafuriko?
Answer:
Kuna hatua kadhaa unazoweza kuchukua ili kulinda nyumba yako dhidi ya mafuriko:
1\. Elekeza Maji Mbali na Majengo: Hakikisha eneo lako linaelekea mbali na majengo kwenye mali yako na kwamba maji yana mahali pa kumwagika. Safisha mifereji yako, tathmini masuala ya mifereji ya maji, au kukusanya maji katika mapipa ya mvua.
2\. Nanga Mizinga ya Mafuta: Nanga mizinga yoyote ya mafuta kwenye pedi ili kuzuia kuzunguka au kuogelea katika mafuriko. Mafuta yaliyomwagika yanaweza kuwa hatari ya moto. Hakikisha kuwa matundu ya uingizaji hewa na ufunguzi wa mstari wa kujaza yako juu ya BFE (Base Flood Elevation). Kumbuka: Hii inaweza kuhitaji ruhusa kutoka kwa mtoa huduma wako wa mafuta.
3\. Kinga Kuta dhidi ya Mafuriko: Ongeza sheathing inayostahimili maji kwenye kuta na uzibane ili kuzuia mafuriko ya kina kifupi kuharibu nyumba yako. Funika ufunguzi chini ya BFE na ziba ufunguzi wote wa nje karibu na pampu na vifaa.
4\. Thibitisha Nyumba Zilizotengenezwa: Ikiwa una nyumba iliyotengenezwa na unataka bima ya mafuriko kutoka kwa Programu ya Bima ya Mafuriko ya Kitaifa, nyumba yako lazima iwe imewekwa kwenye msingi wa kudumu ili magurudumu na axles zisitoe uzito wake na kuzuia kuelea, kuanguka, au harakati ya upande kwa upande.
5\. Kumbuka: Baadhi ya vidokezo hivi vinaweza kufanya kazi vizuri pamoja kuliko vingine. Hatua za kupunguza madhara zinahitaji kubinafsishwa kwa mali yako. Pia, sio chaguo zote hufanya kazi pamoja, kwa hivyo ongea na mtaalamu anayeweza kukusaidia kutambua chaguo zipi zinafanya kazi vizuri kwako.
Daima shauriana na wataalamu kama wakala wako wa bima, wasanifu, wahandisi, wakandarasi, au wataalamu wengine katika kubuni na ujenzi kabla ya kufanya mabadiliko kwenye nyumba yako. Ofisi yako ya mipango na mipango ya eneo au idara ya ujenzi ni mahali pazuri pa kuanza kwa ushauri.
Mwishowe, kuwa mwema kwa majirani zako! Zungumza na wamiliki wa mali jirani kabla ya kufanya mabadiliko, kwani vitendo fulani kwenye mali yako vinaweza kuathiri yao.
=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================
There are several steps you can take to protect your home from flooding:
1\. Direct Water Away from Buildings: Make sure your area faces away from buildings on your property and that water has a place to drain. Clean your gutters, assess drainage issues, or collect water in rain barrels.
2\. Anchor Oil Tanks: Anchor any oil tanks to a pad to prevent them from rolling or swimming in a flood. Spilled oil can be a fire hazard. Make sure the vents and fill line opening are above the BFE (Base Flood Elevation). Note: This may require permission from your fuel supplier.
3\. Protect Walls from Flooding: Add waterproof sheathing to walls and seal them to prevent shallow flooding from damaging your home. Cover the opening under the BFE and cover all external openings around the pump and equipment.
4\. Verify Manufactured Homes: If you have a manufactured home and want flood insurance from the National Flood Insurance Program, your home must be placed on a permanent foundation so that the wheels and axles do not give off its weight and prevent floating, falling, or movement of side by side.
5\. Note: Some of these tips may work better together than others. Mitigation measures need to be customized to your property. Also, not all options work together, so talk to a professional who can help you figure out which options work best for you.
Always consult with professionals such as your insurance agent, architects, engineers, contractors, or other professionals in design and construction before making changes to your home. Your local planning and development office or building department is a good place to start for advice.
Finally, be kind to your neighbors! Talk to neighboring property owners before making changes, as certain actions on your property may affect theirs.
结果看起来非常可信,但缺少重要信息。匹配的文本摘录缺少一些关键的洪水部分,当我们询问洪水时,还得到了与野火相关的片段。基本上,使用从英语文档创建的嵌入来处理斯瓦希里语问题效果不好。不过,使用拉丁语言的葡萄牙语效果很好,这并不令人意外。
也许更稳健的方法是先检测语言,再用 Google Translate 将其翻译成英语,然后将响应转换回提示语言。
让我们将这个添加到我们的聊天界面中…
def get_time_context(question: str) -> str:
template = """Does the following question relate to 'planning' or 'taking immediate action': {question}
Answer with one of the following: 'I am planning ahead:' or 'I need to take immediate action:' or 'ambiguous'"""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = OpenAI()
llm_chain = LLMChain(prompt=prompt, llm=llm)
answer = llm_chain.run(question)
return answer
def ask_question(
query: str,
qa: object,
output_docs: bool = True,
preprocess_time_context: bool = False,
auto_translate: bool = False,
) -> dict:
# Auto-translate
if auto_translate:
lang = translator.detect(query)
lang = lang.lang
print(f"Language detected: {lang}")
q = translator.translate(query, dest='en')
query = q.text
# First get the time context
if preprocess_time_context:
time_context = get_time_context(query)
if 'planning' in time_context.lower():
query = f"I am planning ahead: {query}"
elif 'take immediate action' in time_context.lower():
query = f"I need to take immediate action: {query}"
print(f"\nQuestion: \n{query}")
result = qa({"question": query})
# Translate back to source language
if auto_translate:
q = translator.translate(result['answer'], dest=lang)
result['answer'] = q.text
print(f"\nAnswer:\n{result['answer']}")
if output_docs:
for doc in result['source_documents']:
print('\n')
print(json.dumps(vars(doc), indent=4))
return result
# Translate to Swahili
q = translator.translate("How do I prepare my home for floods?", dest='sw')
print("================ RAW =================")
result = ask_question(q.text, pdf_qa, output_docs=False, preprocess_time_context=True, auto_translate=True)
q = translator.translate(result['answer'], dest='en')
print("\n\n=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================\n")
print(q.text)
================ RAW =================
Language detected: sw
Question:
I am planning ahead: How do I prepare my home for flooding?
Answer:
Kuna hatua kadhaa unazoweza kuchukua ili kuandaa nyumba yako kwa mafuriko yanayoweza kutokea:
1\. Tengeneza mpango wa dharura kwa ajili ya familia yako na uufanyie kazi mara kwa mara. Dhoruba inapokaribia, ondoka na usogeze gari lako hadi sehemu ya juu.
2\. Nunua bima ya mafuriko kwa ajili ya nyumba yako na vilivyomo, hata kama huishi katika eneo lenye hatari kubwa ya mafuriko.
3\. Andika vitu vyako. Hii itasaidia katika mchakato wa bima ikiwa unahitaji kuwasilisha dai.
4\. Hifadhi vitu vya thamani na hati muhimu juu ya Mwinuko wa Mafuriko ya Msingi (BFE), ikiwezekana kwenye ghorofa ya juu, katika vyombo visivyo na maji au visivyo na maji.
5\. Kuinua vifaa na huduma kama vile hita za maji, washers, vikaushio na paneli za umeme kwenye sakafu ya juu ili kuvizuia kuharibiwa na maji ya mafuriko.
6\. Tumia nyenzo zinazostahimili mafuriko nyumbani kwako.
7\. Hakikisha yadi yako ina mteremko mbali na majengo kwenye mali yako na kwamba maji yana mahali pa kutolea maji.
8\. Tia nanga matangi yoyote ya mafuta ili kuyazuia yasitembee juu au kuelea katika mafuriko.
9\. Ongeza sheati za nje zinazostahimili maji kwenye kuta na uzifunge ili kuzuia mafuriko ya kina yasiharibu nyumba yako.
10\. Ikiwa una nyumba iliyotengenezwa, hakikisha kuwa imewekwa kwenye msingi wa kudumu.
11\. Pandisha nyumba yako juu ya BFE.
12\. Linda vitu katika yadi yako kwa kuvitia nanga au kuviambatanisha na miundo mikubwa zaidi.
13\. Angalia kuzunguka kwa madirisha na milango ili kuhakikisha kuwa haijapasuka, haijavunjwa au kukosa. Jaza mashimo au mapengo yoyote karibu na mabomba na waya zinazoingia kwenye jengo lako.
14\. Jenga nyumba yako, karakana, banda, au jengo lingine mbali na njia za mito na ufuo. Ikiwezekana, jenga juu ya ardhi ya juu.
Kumbuka kushauriana na wataalamu kama vile wakala wako wa bima, wasanifu majengo, wahandisi, wakandarasi, au wataalamu wengine wa usanifu na ujenzi kabla ya kufanya mabadiliko kwenye nyumba yako.
=============== TRANSLATED TO ENGLISH USING GOOGLE TRANSLATE =================
There are several steps you can take to prepare your home for potential flooding:
1\. Make an emergency plan for your family and work on it regularly. When a storm approaches, take off and move your car to the top.
2\. Buy flood insurance for your home and contents, even if you don't live in a high flood risk area.
3\. Write your items. This will help with the insurance process if you need to file a claim.
4\. Store valuables and important documents above the Base Flood Elevation (BFE), preferably on an upper floor, in waterproof or waterproof containers.
5\. Elevate equipment and services such as water heaters, washers, dryers and electrical panels on the upper floor to prevent them from being damaged by flood water.
6\. Use flood resistant materials in your home.
7\. Make sure your yard is sloped away from the buildings on your property and that the water has a place to drain.
8\. Anchor any fuel tanks to prevent them from rolling over or floating in flooding.
9\. Add waterproof outer sheathing to the walls and seal them to prevent deep flooding from damaging your home.
10\. If you have a manufactured home, make sure it is placed on a permanent foundation.
11\. Raise your house above the BFE.
12\. Secure things in your yard by anchoring them or attaching them to larger structures.
13\. Check window and door surrounds to make sure they are not cracked, broken or missing. Fill any holes or gaps around pipes and wires entering your building.
14\. Build your house, garage, shed, or other building away from river and beach paths. If possible, build on higher ground.
Remember to consult with professionals such as your insurance agent, architects, engineers, contractors, or other design and construction professionals before making changes to your home.
更好,使用 Google Translate 将斯瓦希里语自动翻译成英语,然后将回答翻译回斯瓦希里语,能从我们的 PDF 文件集中获得所有所需的信息。对于灾难响应等风险关键的用例,当然需要与母语者进行大量测试以确保安全,但这在多语言支持方面显示出了前景。
提出更广泛的问题
好的,让我们试用一下最终版本,提出更多与灾难相关的问题……
# Set up the model again or our Swahili conversation history will be used
vecs_dir = './vector_dbs/all_docs'
docs = all_docs
pdf_qa = setup_model(vecs_dir, all_docs_list, docs, prefix_file_name_to_chunks=True, extra_prefix="This snippet relates to ")
questions = [
"What should I do for pets if there is a forest fire?",
"My roof is blowing off!!! Help, what should I do?!!!",
"I live next to a brook, should I worry about floods?",
"If I live inland, will sea level affect me?",
"Is there a cheap way to get flood insurance?",
"We're in an Earthquake, what should we do!!???",
"What things should I have in an emergency kit?",
]
for q in questions:
print("\n ===================== ")
ask_question(q, pdf_qa, output_docs=False, preprocess_time_context=True, auto_translate=True)
=====================
Language detected: en
Question:
I need to take immediate action: What should I do for pets if there is a forest fire?
Answer:
If there is a forest fire, you should have a Go-Kit for your pet(s) which should include food, water, medication, medical records kept in Ziploc bags, and favorite toys (if applicable). Include your veterinarian’s contact information and be sure that your pet is microchipped or at least has the proper ID tags displayed. If you need to evacuate, take your pet with you if possible. If you cannot take your pet with you, ensure they are in a safe place with access to food and water.
=====================
Language detected: en
Question:
I need to take immediate action: My roof is blowing off!!! Help, what should I do?!!!
Answer:
The FEMA brochures do not provide specific instructions for immediate action if your roof is blowing off. However, it is generally recommended to seek immediate shelter in a safe area of your home, away from windows and potential falling debris. After the severe wind event, you should contact a professional to assess and repair the damage. It's also important to contact your insurance company to report the damage.
=====================
Language detected: en
Question:
I live next to a brook, should I worry about floods?
Answer:
Yes, you should be concerned about floods if you live next to a brook. Flooding can happen anywhere and is not limited to high-risk areas. Heavy rainfall or poor drainage can cause flooding, especially in low-lying areas. It's important to take steps to protect your home, such as elevating your home, securing yard items, and ensuring proper drainage. You should also consider purchasing flood insurance, as federal disaster assistance may be limited or unavailable without a presidential disaster declaration.
=====================
Language detected: en
Question:
If I live inland, will sea level affect me?
Answer:
Sea level changes primarily affect coastal areas, causing increased flooding, erosion, and storm damage. However, if you live inland, you may still experience indirect effects. Changes in sea level can alter weather patterns, potentially leading to increased rainfall or drought in some areas. Additionally, if you rely on resources from coastal areas, such as seafood or tourism, changes in sea level could impact these industries. However, the snippet does not provide specific information on how sea level changes might affect those living inland.
=====================
Language detected: en
Question:
I am planning ahead: Is there a cheap way to get flood insurance?
Answer:
Yes, you can purchase flood insurance through the National Flood Insurance Program. Most homeowners insurance policies don’t cover flood damage, so it's important to protect your investment by purchasing flood insurance for your home and its contents. This is recommended even if you do not live in a high-risk flood zone. For more information on purchasing flood insurance, you can visit FloodSmart.gov or contact your insurance agent to get coverage. Planning ahead and purchasing flood insurance before a storm or flood event is a cost-effective way to protect your property.
=====================
Language detected: en
Question:
I need to take immediate action: We're in an Earthquake, what should we do!!???
Answer:
When an earthquake starts, you should "Drop, Cover, and Hold On!" This means you should drop to the ground, cover your head and neck with your arms, and if a safer place is nearby, crawl to it and hold on. If you are in bed, stay there and cover your head and neck with a pillow. If you are in a moving vehicle, stop as quickly and safely as possible and stay in the vehicle. Avoid stopping near or under buildings, trees, overpasses, and utility wires. Proceed cautiously once the earthquake has stopped, avoiding roads, bridges, or ramps that the earthquake may have damaged.
=====================
Language detected: en
Question:
I am planning ahead: What things should I have in an emergency kit?
Answer:
An emergency kit should include the following items:
- Water (one gallon per person, per day) and/or water purification tablets
- A first-aid kit
- A flashlight
- Batteries
- A hand-crank/solar-powered radio
- Nonperishable food
- Medications/prescription glasses
- Cash (in small denominations)
- A wrench (to turn off utilities)
- Hand sanitizer
- A Mylar blanket
- Ear plugs
- A book/deck of cards
If you have an infant, then include baby food, diapers, formula, etc. If you have pets, consider having a kit for them that includes food, water, medication, medical records, and a favorite toy.
In addition, consider including portable charging devices in your kit as this will allow you to charge from a motor vehicle. If a loved one is comforted by a certain item (a blanket, photograph, stuffed animal, etc.), be sure to pack the item when evacuating. If a relative requires regularly scheduled mental health therapy, make sure that his/her mental health provider is aware of the location of the shelter.
我认为这些非常惊人,抽查了一些,它们似乎能够捕捉到所用 PDF 文档中的关键信息。
[基于上述内容,我的灾难包现在有耳塞了!]
性能评估
在这项分析中,我们仅仅对我们的聊天界面在返回一小部分 FEMA PDF 文档的关键信息方面进行了抽查。这对于说明一些概念是有帮助的,但对于生产应用程序,我们需要比抽查更好的方法。幸运的是,LangChain 提供了一组评估工具,特别值得关注的是一个Streamlit 应用,它可以自动生成问答对,然后利用这些数据来评估检索链,从而允许开发者实验一些相关的参数。我还没有尝试过,但使用 LLM 自动生成评估数据的想法似乎是为测试我们的 FEMA 灾难机器人构建更系统化方法的好办法。
结论
使用 FEMA 网站上的 34 份 PDF 文档,我们能够轻松地构建一个多语言对话界面,使用 LangChain 和 GPT-4,能够回答广泛的灾难准备和安全协议相关的问题。然而,对于高风险安全关键的聊天机器人,使用这种通用文档检索的 LangChain 模式面临着与语义搜索相同的挑战。即使使用先进的 LLM 嵌入捕捉自然语言中的更多细微差别,仍然很容易丢失重要内容。
对于我们的 FEMA 灾难机器人用例,这种情况是由于:
-
混合上下文文档 — 在没有包含文档元数据的情况下,响应混合了来自不同上下文的信息。简单地在文档名称前加前缀提高了我们场景中的性能,并且可以应用更复杂的元数据策略。此外,添加一个零样本 LLM 分类器来丰富用户问题也是有帮助的。
-
LLMs 中的低代表性语言 — 依赖 LLM 翻译可能会导致对于代表性较少的语言(如斯瓦希里语)性能较差。添加自动 Google 翻译提高了我们用例中的性能。
我们探索了粗略的方法来解决上述问题,但对于生产聊天机器人,需要通过测试和验证找到更先进的技术和其他缓解措施。
尽管如此,LLM 现在能做的事情确实令人惊叹!
参考资料
你可以在这里找到包含所有代码的笔记本。