TowardsDataScience-2023-博客中文翻译-四十八-
TowardsDataScience 2023 博客中文翻译(四十八)
使用分区,卢克!一个简单且经过验证的优化 SQL 查询的方法
如果你曾写过一个需要很长时间才能运行的 SQL 查询,那么这篇文章就是为你准备的。
·发表于 Towards Data Science ·8 分钟阅读·2023 年 12 月 7 日
--
宝宝尤达喜欢分区。你呢?图片由 Victor Serban 提供,来源于 Unsplash
数据科学家们喜欢 SQL,但我们确实很糟糕于编写高性能的查询(也许是因为我们花了太多时间争论“SQL”是读作“S-Q-L”还是“sequel”?)。
在这篇文章中,我将向你展示如何使用 SQL 分区 来优化你的查询,并编写更快速、更便宜的代码。如果你已经掌握了 SQL 的基础知识,并希望开始解锁更高级的数据科学技能,这将是你工具箱中的一个很棒的补充。
什么是分区表?
分区表是将数据表划分为多个段/分区的表(谁能想到呢?)。
在分区表中,每个分区存储在服务器上的不同位置。这与普通(未分区)SQL 表不同,后者整个表位于一个位置。
这里是使用虚拟数据对比了我最喜欢的三本书的每日销售情况:
作者提供的图片
未分区表和分区表都包含相同的数据;唯一的区别是分区表将数据分成不同的段。它仍然是一个单一的表(即它不是三个单独的表);只是以不同的方式存储数据。
我们为什么要关注这个?好吧,正如我们稍后会看到的,我们可以利用这种结构来编写更高效的 SQL 查询。
创建分区表
创建分区表简直轻而易举。
例如,如果我们使用以下代码来创建一个普通(未分区)表……
CREATE TABLE `myproject.mydataset.unpartitioned` (
date DATE,
book_name STRING,
total_sales INT64
);
INSERT INTO `myproject.mydataset.unpartitioned`
VALUES
('2023-12-01', 'Shoe Dog', 154),
('2023-12-01', 'Born to Run', 90),
('2023-12-02', 'Running with the Kenyans', 12),
('2023-12-02', 'Born to Run', 61),
('2023-12-01', 'Running with the Kenyans', 8),
('2023-12-03', 'Shoe Dog', 231),
('2023-12-03', 'Born to Run', 63),
('2023-12-02', 'Shoe Dog', 100),
('2023-12-03', 'Running with the Kenyans', 5);
… 我们可以通过在CREATE TABLE
语句的末尾添加一行来创建一个分区版本:
CREATE TABLE `myproject.mydataset.partitioned` (
date DATE,
book_name STRING,
total_sales INT64
) PARTITION BY date; # New line!
INSERT INTO `myproject.mydataset.partitioned`
VALUES ... # Rest of the code is same as in previous snippet
我们进行分区的列被称为分区键;在这种情况下,我们在date
列上进行分区,但任何列都可以,只要它是(1)日期/时间戳字段或(2)整数字段。
一旦我们创建了这两个表,我们会发现它们乍一看是相同的(例如,如果你对每个表运行SELECT *
,结果将是相同的)。然而,如果我们查看表的详细信息/元数据,我们会看到分区表包含了一些额外的元数据。在 BigQuery 中(我在这里运行 SQL),情况是这样的:
作者提供的图片
这有什么重要的?因为使用分区减少了处理查询所需的数据量
这真是个好消息,因为这意味着分区可以帮助你编写更高效的查询!
当你查询一个普通(未分区)表时,SQL 引擎通常需要扫描整个表以找到你所需的行。在大型表中,这可能不必要地慢且昂贵,因为你的机器需要处理对生成最终输出没有用的数据。
例如,让我们查询一下我们之前创建的未分区表:
SELECT *
FROM `myproject.mydataset.unpartitioned`
WHERE date > '2023-12-01'
作者提供的图片
在上面的图片中,你可以看到表中的所有 9 行都被读取以返回日期大于‘2023-12-01’的 6 行。
现在让我们对分区表运行相同的查询:
SELECT *
FROM `myproject.mydataset.partitioned`
WHERE date > '2023-12-01'
作者提供的图片
这一次,我们看到仅读取了未分区表中的 6 行以生成相同的结果。在执行查询的主要部分之前,BigQuery 能够识别包含相关行的分区并进行选择性提取。它根本不需要读取其他分区中的 3 行。
选择分区的这个初步步骤被称为修剪。这比正常查询更高效,因为这意味着 SQL 引擎无需读取表中的每一行;它会首先获取所需的分区,然后才会执行你的查询。在高阶 SQL 术语中,我们在分区列上添加的过滤器被 SQL 引擎视为访问谓词,并在执行主查询之前运行。
— 快速插曲 —
如果你喜欢这篇文章,你可能会喜欢我的网站 the-sql-gym.com,网站包含了 100 多个 SQL 练习题。如果你想提升你的 SQL 技能,可以去看看!如果有任何反馈,请告诉我 😃
作者提供的图片。来源:the-sql-gym.com
回到分区——让我们看一个更大的例子
当我们处理像上面这些小表时,分区可能显得有些过度。但当我们扩展到更大的表时,它可以带来显著的性能提升。
首先,让我们创建两个大表,每个表包含 100 万行。第一个将是未分区表,第二个将按id
列进行分区,每个分区包含 10,000 行。
##############################
# Create unpartitioned table #
##############################
CREATE OR REPLACE TABLE `myproject.mydataset.unpartitioned_large` AS (
SELECT id, GENERATE_UUID() AS uuid
FROM UNNEST(GENERATE_ARRAY(1, 1000000)) AS id
);
############################
# Create partitioned table #
############################
CREATE TABLE `myproject.mydataset.partitioned_large` (id INT64, uuid STRING)
# Partition into 100 segments of 10,000 each
PARTITION BY RANGE_BUCKET(id, GENERATE_ARRAY(1, 1000000, 10000))
AS (
SELECT *
FROM `myproject.mydataset.unpartitioned_large`
);
当我们对未分区的表运行以下查询时:
SELECT *
FROM `myproject.mydataset.unpartitioned_large` # Unpartitioned
WHERE id BETWEEN 1000 AND 2000;
作者提供的图片
… 我们看到在输出 1,001 条记录之前,表中的所有 1,000,000 行都被读取了。整个操作耗时 650 毫秒(经过时间)/503 毫秒(消耗的槽时间)。
但是,当我们对分区表运行相同的查询时,只读取了 10,000 行(即一个单独的分区)。
SELECT *
FROM `myproject.mydataset.partitioned_large` # Partitioned
WHERE id BETWEEN 1000 AND 2000;
作者提供的图片
当查询未分区的表时,操作的经过时间不到第一个查询(在未分区的大表上)的时间的一半,且消耗的槽时间降低了近 95%。
很酷,对吧?
使用分区时的常见错误
如果你希望开始使用分区来改善你的表/查询,有几个陷阱需要注意:
-
不要在非分区列上过滤大分区表 — 如果你在一个非分区键的列上进行过滤,你将无法利用表的分区结构。如果你确实需要在非分区键列上进行过滤,我建议你(如果可能的话)首先添加一个针对分区键的过滤条件(以修剪不需要的分区),然后再应用第二个过滤条件。这是因为 BigQuery(像许多 SQL 引擎一样)会从上到下执行
WHERE
语句中的过滤条件。 -
不要在分区键上应用函数 — 例如,如果你的分区键是
date
列,不要添加像WHERE CAST(date AS STRING) = '2023-12-02'
这样的过滤条件。相反,如果你需要在WHERE
子句中使用函数,可以将函数应用于常量。在这个例子中,你可以将过滤条件改写为WHERE date = CAST('2023-01-01' AS DATE)
,以确保你仍能利用修剪/分区的优势。
关于 BigQuery 分区的两个快速提示
BigQuery 对每个表有4,000 个分区的限制。如果你尝试使用一个将创建超过 4,000 个分区的分区键,请尝试使用不同的分辨率。例如,不要按日期/天分区:
CREATE TABLE `myproject.mydataset.partitioned` (
...
) PARTITION BY date;
… 你可以按周进行分区:
CREATE TABLE `myproject.mydataset.partitioned` (
...
) PARTITION BY DATE_TRUNC(date, WEEK);
另外,值得了解的是 BigQuery 允许基于摄取时间字段和名为_PARTITIONTIME
的伪列进行分区。虽然这个话题在本文中不够相关,但你可以在这里阅读更多内容。
何时不使用分区
在较小的表格中,分区可能不会带来性能提升。为什么?因为修剪分区的过程可能比简单地查看表中所有行花费更多时间(来源)。
在这些情况下,聚类可能是一个更高效的替代方案。请关注我下一篇文章了解更多内容!
还有一件事——
我已经开始了一份免费的通讯,叫做 AI in Five,每周分享 5 个要点,涵盖最新的 AI 新闻、编码技巧和数据科学家/分析师的职业故事。没有炒作,没有“数据是新的石油”的废话,也没有来自 Elon 的推文(或者我现在应该说‘x-es’?)——只有实用的技巧和见解,帮助你在职业生涯中发展。
在这里订阅 如果这正是你感兴趣的!感谢阅读。
## AI in Five | Matt Chapman | Substack
最新消息、职业故事和数据科学与人工智能领域的编码技巧,浓缩成 5 个要点……
使用同义词 API 方便地更新 Elasticsearch 中的同义词
学习一种新的方法来更新 Elasticsearch 同义词,而无需重新索引或重新加载
·发表于 Towards Data Science ·阅读时间:5 分钟·2023 年 11 月 25 日
--
图片来源:Tumisu,Pixabay
Elasticsearch 的同义词功能非常强大,当正确使用时,可以显著提高搜索引擎的效率。使用同义词功能时,一个常见问题是更新同义词集。
在索引的设置中定义的内联同义词不能直接更新,我们需要关闭索引、更新设置,然后重新打开索引以使更改生效。另一种方法是使用可以通过重新加载索引来更新的同义词文件。然而,当 Elasticsearch 服务器分布式部署或托管在云端时,使用索引文件会很难管理。这是因为我们需要将文件放在所有集群节点上。
好消息是,现在还有第三种方法,比之前的两种方法要方便得多。我们现在可以使用同义词 API 来管理同义词。尽管在撰写本文时这仍然是 Elasticsearch 的一个测试功能,但我认为它很快会被采用,因为开发者对该功能的需求很高,而且它可以非常方便地解决更新同义词集的棘手问题。我们将探讨在本文中同义词 API 的常见用法。
准备工作
我们将使用以下 docker-compose.yaml
文件在本地启动 Elasticsearch 和 Kibana 以进行演示。
version: "3.9"
services:
elasticsearch:
image: elasticsearch:8.11.1
environment:
- discovery.type=single-node
- ES_JAVA_OPTS=-Xms1g -Xmx1g
- xpack.security.enabled=false
ports:
- target: 9200
published: 9200
networks:
- elastic
kibana:
image: kibana:8.11.1
ports:
- target: 5601
published: 5601
depends_on:
- elasticsearch
networks:
- elastic
networks:
elastic:
name: elastic
driver: bridge
请注意,你需要至少使用 8.10.0 版本的 Elasticsearch 才能使用同义词 API。最新版本会是最佳选择,因为那个时候该功能应该已经更加成熟。
创建同义词集
当 Elasticsearch 和 Kibana 使用上述 docker-compose.yaml
文件启动时,我们可以前往 localhost:5601
来管理 Elasticsearch 索引和同义词。
要使用同义词 API 管理同义词,我们需要先创建一个同义词集,然后才能在 Elasticsearch 索引中使用。
我们可以使用_synonyms
端点通过PUT
请求创建或更新同义词集:
PUT _synonyms/inventory-synonyms-set
{
"synonyms_set": [
{
"id": "synonym-1",
"synonyms": "ps => playstation"
},
{
"synonyms": "javascript,ecmascript,js"
}
]
}
-
products-synonyms-set
是用户定义的同义词集名称。 -
synonyms_set
是请求体中必需的键,包括一个同义词规则数组。 -
每个同义词规则是一个包含可选
id
键和必需synonyms
键的对象。如果未提供id
,Elasticsearch 将创建一个标识符。synonyms
的值是以前一篇文章中所用的 Solr 格式定义的规则。
使用同义词集创建 Elasticsearch 索引
当创建同义词集时,它可以在创建索引时用于synonym
或synonym_graph
令牌过滤器:
PUT /inventory
{
"settings": {
"index": {
"analysis": {
"analyzer": {
"index_analyzer": {
"tokenizer": "standard",
"filter": [
"lowercase"
]
},
"search_analyzer": {
"tokenizer": "standard",
"filter": [
"lowercase",
"synonym_filter"
]
}
},
"filter": {
"synonym_filter": {
"type": "synonym_graph",
"synonyms_set": "inventory-synonyms-set",
"updateable": true
}
}
}
}
},
"mappings": {
"properties": {
"name": {
"type": "text",
"analyzer": "index_analyzer",
"search_analyzer": "search_analyzer"
}
}
}
}
如我们所见,使用同义词集的配置与使用同义词文件非常相似,如这篇文章中详细说明的那样。我们只需将synonyms_path
更改为synonyms_set
。
测试同义词集
我们可以使用_analyze
端点来分析一些文本,并测试在上一步中添加的同义词,使用 API 进行测试:
GET /inventory/_analyze
{
"analyzer": "search_analyzer",
"text": "PS"
}
{
"tokens": [
{
"token": "playstation",
"start_offset": 0,
"end_offset": 2,
"type": "SYNONYM",
"position": 0
}
]
}
GET /inventory/_analyze
{
"analyzer": "search_analyzer",
"text": "JS"
}
{
"tokens": [
{
"token": "javascript",
......
},
{
"token": "ecmascript",
......
},
{
"token": "js",
......
}
]
}
这表明 API 添加的同义词工作正常。
更新同义词集
现在让我们使用同义词 API 更新同义词集。这正是同义词 API 真正出色的地方,因为我们不需要关闭、打开或重新加载相应的 Elasticsearch 索引,这极大地减轻了开发者的痛苦。
我们可以使用PUT
方法整体更新同义词集。在这里必须非常小心,否则你会用只包含新同义词规则的新集替换原始集合,这在生产环境中非常具有破坏性。
让我们向inventory-synonyms-set
添加一个新的同义词规则:
PUT _synonyms/inventory-synonyms-set
{
"synonyms_set": [
{
"id": "synonym-1",
"synonyms": "ps => playstation"
},
{
"synonyms": "javascript,ecmascript,js"
},
{
"synonyms": "py => python"
}
]
}
请注意,原始的同义词集也应在此处添加。
当现有的同义词集更新时,所有使用该同义词集的搜索分析器会自动重新加载,这可以在上面的PUT
请求响应中看到:
{
"result": "updated",
"reload_analyzers_details": {
"_shards": {
"total": 2,
"successful": 1,
"failed": 0
},
"reload_details": [
{
"index": "inventory",
"reloaded_analyzers": [
"search_analyzer"
],
"reloaded_node_ids": [
"MD0GUTvcQAOvsHdUIBunNw"
]
}
]
}
}
我们可以使用_analyze
端点来测试新添加的同义词:
GET /inventory/_analyze
{
"analyzer": "search_analyzer",
"text": "py"
}
{
"tokens": [
{
"token": "python",
"start_offset": 0,
"end_offset": 2,
"type": "SYNONYM",
"position": 0
}
]
}
是的,同义词集已成功更新,并且实时生效,无需关闭、打开或重新加载相应的索引。
此外,我们还可以通过直接添加规则来进行增量更新。在这种情况下,我们需要在路径中为规则指定一个id
:
PUT _synonyms/inventory-synonyms-set/synonym-ipod
{
"synonyms": "i-pod, i pod ⇒ ipod"
}
你也可以通过其 id 删除同义词规则:
DELETE _synonyms/inventory-synonyms-set/synonym-ipod
单独更新同义词规则的工作方式与整体更新同义词集时相同。
监控同义词集
我们可以使用GET
方法直接检索同义词集的内容:
GET _synonyms/inventory-synonyms-set
{
"count": 3,
"synonyms_set": [
{
"id": "qP8w_osBr1bbhRIfbz1c",
"synonyms": "javascript,ecmascript,js"
},
{
"id": "qf8w_osBr1bbhRIfbz1c",
"synonyms": "py => python"
},
{
"id": "synonym-1",
"synonyms": "ps => playstation"
}
]
}
实际操作中,我们希望编写一些脚本来统计索引中的同义词数量,以确保如上所示的同义词没有被意外删除。下面的 Python 代码片段可以统计同义词集合中的同义词数量:
from elasticsearch import Elasticsearch
es_client = Elasticsearch("http://localhost:9200")
es_client.synonyms.get_synonyms_sets()
# ObjectApiResponse({'count': 1, 'results': [{'synonyms_set': 'inventory-synonyms-set', 'count': 3}]})
在这项工作中,我们介绍了 Elasticsearch 同义词 API 的基础知识,并演示了如何方便地使用它来管理同义词。通过此功能,我们不再需要像处理内联同义词那样关闭和重新打开索引,也不需要像处理同义词文件那样手动重新加载索引。它可以使我们的搜索引擎更稳定,也让我们作为开发者的工作变得更加轻松。
相关文章
-
如何正确使用 Elasticsearch 的同义词功能
-
Elasticsearch 8 在 Python 中的重要语法更新
使用这些方法使你的 Python 并发任务表现更佳
PYTHON 并发
asyncio.gather、asyncio.as_completed 和 asyncio.wait 的最佳实践
·发布于 Towards Data Science ·6 分钟阅读·2023 年 4 月 18 日
--
图片由 Aleksandr Popov 提供,来自 Unsplash
问题所在
Python 的多线程性能一直未能如预期那样,因为 GIL 的存在。
自从版本 3.4 起,Python 引入了 asyncio 包,通过并发执行 IO 绑定任务。经过几次迭代,asyncio API 的表现非常出色,与多线程版本相比,并发任务的性能显著提高。
然而,程序员在使用 asyncio 时仍然会犯许多错误:
一个错误如图所示,是直接使用 await 协程方法,这种方式将并发任务的调用从异步变为同步,最终丧失了并发特性。
另一个错误如图所示,虽然程序员意识到需要使用 create_task
来创建后台执行的任务。但是,下面这种逐个等待任务的方式,将具有不同时间的任务转变为有序等待。
这段代码将会等到 task_1 先完成,而不管 task_2 是否先完成。
什么是并发任务执行?
那么,什么是实际的并发任务?让我们用图示来说明:
不管我们启动了多少任务,最终都需要等待它们完成。图源:作者
如图所示,一个并发过程应该包括两个部分:启动后台任务,将后台任务重新加入主函数,并获取结果。
大多数读者已经知道如何使用 create_task
来启动后台任务。今天,我将介绍几种等待后台任务完成的方法以及每种方法的最佳实践。
入门
在我们开始介绍今天的主要内容之前,我们需要准备一个示例异步方法来模拟一个 IO 绑定的方法调用,并且还需要一个自定义的 AsyncException,用于在测试抛出异常时友好地提示异常消息:
并发执行方法的比较
一旦我们完成了准备工作,就该开始今天的旅程,系好安全带。
1. asyncio.gather
asyncio.gather
可用于启动一组后台任务,等待它们执行完毕,并获取结果列表:
asyncio.gather
尽管可以形成一组后台任务,但不能直接接受列表或集合作为参数。如果你需要传递包含后台任务的列表,请解包它。
asyncio.gather
接受一个 return_exceptions
参数。当 return_exceptions
的值为 False 时,如果任何后台任务抛出异常,该异常将抛给调用 gather 方法的函数,并且 gather 方法的结果列表将为空。
截图。图像由作者提供
当 return_exceptions
的值为 True 时,后台任务抛出的异常不会影响其他任务的执行,最终会被合并到结果列表中一起返回。
results = await asyncio.gather(*aws, return_exceptions=True)
截图。图像由作者提供
接下来,让我们看看为什么 gather
方法不能直接接受列表,而必须解包列表。因为当列表被填充并执行时,在等待它们完成的同时,很难向列表中添加新任务。然而,gather
方法可以使用嵌套组将现有任务与新任务混合,这解决了在中间无法添加新任务的问题:
然而,gather
无法直接设置超时参数。如果需要为所有正在运行的任务设置超时,请使用这种方式,这样的方式并不够优雅。
2. asyncio.as_completed
有时候,我们必须在完成一个后台任务后立即开始后续操作。例如,当我们抓取一些数据并立即调用机器学习模型进行计算时,gather
方法不能满足我们的需求,但我们可以使用 as_completed
方法。
在使用 asyncio.as_completed
方法之前,我们先来看一下这个方法的源代码。
# This is *not* a @coroutine! It is just an iterator (yielding Futures).
def as_completed(fs, *, timeout=None):
# ...
for f in todo:
f.add_done_callback(_on_completion)
if todo and timeout is not None:
timeout_handle = loop.call_later(timeout, _on_timeout)
for _ in range(len(todo)):
yield _wait_for_one()
源代码显示,as_completed
不是一个并发方法,而是返回一个带有 yield
语句的迭代器。因此,我们可以直接遍历每个完成的后台任务,并且可以对每个任务分别处理异常,而不会影响其他任务的执行:
as_completed
接受 timeout
参数,并且在超时发生时,当前迭代的任务将会抛出 asyncio.TimeoutError
。
截图。图片来源:作者
as_completed
在处理任务执行结果方面比 gather
灵活得多,但在等待时很难将新任务添加到原任务列表中。
3. asyncio.wait
asyncio.wait
的调用方式与 as_completed
相同,但返回一个包含两个集合的元组:done
和 pending
。done
包含已经完成的任务,pending
包含仍在运行的任务。
asyncio.wait
接受一个 return_when
参数,该参数可以取三个枚举值:
-
当
return_when
为asyncio.ALL_COMPLETED
时,done
存储所有已完成的任务,而pending
为空。 -
当
return_when
为asyncio.FIRST_COMPLETED
时,done
包含所有已完成的任务,而pending
包含仍在运行的任务。
截图。图片来源:作者
- 当
return_when
为asyncio.FIRST_EXCEPTION
时,done
存储那些抛出异常并完成执行的任务,而pending
包含仍在运行的任务。
当 return_when
为 asyncio.FIRST_COMPLETED
或 asyncio.FIRST_EXECEPTION
时,我们可以递归调用 asyncio.wait
,以便根据情况添加新任务并继续等待所有任务完成。
截图。图片来源:作者
4. asyncio.TaskGroup
在 Python 3.11 中,asyncio
引入了新的 TaskGroup
API,这使得 Python 正式支持 结构化并发。这一特性允许你以更 Pythonic 的方式管理并发任务的生命周期。由于篇幅限制,我不会在这里详细讨论,但有兴趣的读者可以参考我的文章:
## 为什么 Taskgroup 和超时 在 Python 3.11 Asyncio 中如此重要
在 Python 3.11 中拥抱结构化并发
towardsdatascience.com
结论
本文介绍了 asyncio.gather
、asyncio.as_completed
和 asyncio.wait
API,并回顾了 Python 3.11 中引入的新 asyncio.TaskGroup
特性。
根据实际需要使用这些后台任务管理方法,可以使我们的 asyncio
并发编程更加灵活。
由于经验限制,本文的阐述不可避免地存在遗漏,请在阅读过程中随时留下评论,我会积极回复。
在以下文章中,我将描述如何让 tqdm 与 asyncio 配合使用,以通过进度条指示并发任务的进度:
## 在 Python 中使用 Tqdm 和 Asyncio
监控并发任务进度的高效方法
towardsdatascience.com
如果你对如何在实际应用中使用 asyncio 处理数据集任务感兴趣,可以阅读我的这篇文章:
## 在 Python 中结合多进程和 Asyncio 以提升性能
使用实际案例来解释 run_in_executor
APIs
towardsdatascience.com
通过 加入 Medium,你将可以无限制访问我所有的文章以及成千上万其他作者的文章。它只需花费你一杯咖啡的钱,但对我来说是极大的鼓励。
这篇文章最初发布于:www.dataleadsfuture.com/use-these-methods-to-make-your-python-concurrent-tasks-perform-better/
使用 to_string() 防止 Python 隐藏打印数据框的主体
3-分钟 Pandas
我们应该怎么做才能在 Python 脚本执行后查看整个打印的数据框?
·发表于 Towards Data Science ·4 分钟阅读·2023 年 4 月 10 日
--
照片由 Pascal Müller 提供,来源于 Unsplash
有时,运行 Python 脚本时没有报告任何错误,并不是调试过程中的唯一任务。我们需要确保函数按预期执行。这是在探索性数据分析中的一个典型步骤,检查数据在某些特定数据处理前后的样子。
因此,我们需要在脚本执行过程中打印一些数据框或重要变量,以检查它们是否“正确”。然而,简单的 print 命令有时只能显示数据框的前几行和最后几行(如下例所示),这使得检查过程变得不必要地困难。
通常,数据框的格式是 pandas.DataFrame
,如果直接使用 print 命令,你可能会得到如下内容,
import pandas as pd
import numpy as np
data = np.random.randn(5000, 5)
df = pd.DataFrame(data, columns=['A', 'B', 'C', 'D', 'E'])
print(df.head(100))
打印前 100 行(图像由作者提供)
你可能已经注意到,数据框的中间部分被三个点隐藏了。如果我们真的需要检查前 100 行是什么呢?例如,我们想检查一个大型 Python 脚本中间某个特定步骤的结果,以确保函数按预期执行。
set_option()
最直接的解决方案之一是编辑 Pandas 显示的默认行数,
pd.set_option('display.max_rows', 500)
print(df.head(100))
设置 Pandas 显示的默认行数后,打印前 100 行(图像由作者提供)
其中 set_option
是一个方法,它允许你控制 Pandas 函数的行为,包括设置显示的最大行数或列数,如我们之前所做的。第一个参数 display.max_rows
是用来调整显示的最大行数的,500 是我们设定的最大行数值。
尽管这种方法被广泛使用,但将其放在可执行的 Python 文件中并不理想,尤其是当你有多个数据框要打印,并且希望它们显示不同的行数时。
例如,我有一个结构如下的脚本,
## Code Block 1 ##
...
print(df1.head(20))
...
## Code Block 2 ##
...
print(df2.head(100))
...
## Code Block N ##
...
print(df_n)
...
我们在整个脚本中需要显示不同数量的前行,有时我们想查看整个打印的数据框,但有时我们只关心数据框的维度和结构,而不需要查看全部数据。
在这种情况下,我们可能需要使用函数 pd.set_option()
来设置所需的 display
,或使用 pd.reset_option()
每次在打印数据框之前恢复默认选项,这样会变得非常麻烦。
## Code Block 1 ##
...
pd.set_option('display.max_rows', 20)
print(df1.head(20))
...
## Code Block 2 ##
...
pd.set_option('display.max_rows', 100)
print(df2.head(100))
...
## Code Block N ##
...
pd.reset_option('display.max_rows')
print(df_n)
...
实际上,还有一种更灵活和有效的方法来显示整个数据框,而不需要为 Pandas 指定显示选项。
to_string()
to_string()
直接将 pd.DataFrame
对象转换为字符串对象,当我们打印它时,不会受 pandas
的显示限制影响。
pd.set_option('display.max_rows', 10)
print(df.head(100).to_string())
使用 to_string() 打印前 100 行(图片由作者提供)
我们可以看到,尽管我将显示的最大行数设置为 10,to_string()
帮助我们打印了 100 行的整个数据框。
函数 to_string()
将整个数据框转换为 string
格式,因此可以在打印步骤中保留数据框中的所有值和索引。由于 set_option()
仅对 pandas 对象有效,我们的打印 string
不受之前设置的最大行数限制。
所以,策略是你无需通过 set_option()
设置任何内容,只需使用 to_string()
即可查看整个数据框。这将避免你考虑在脚本的哪个部分设置哪个选项。
收获
-
当你在整个脚本中有一致的行数需要显示时,请使用
set_option('display.max_rows')
。 -
如果你想打印整个 Pandas 数据框,无论 Pandas 选项设置如何,请使用
to_string()
。
感谢阅读!希望你在工作中喜欢使用 Pandas 技巧!
如果你想阅读更多我的故事,请 订阅我的 Medium。你还可以通过我的 推荐链接 加入 Medium 会员!
用户流失预测
原文:
towardsdatascience.com/user-churn-prediction-d43c53e6f6df
现代数据仓库与机器学习
·发表于 Towards Data Science ·12 分钟阅读·2023 年 12 月 23 日
--
照片由 Martin Adams 提供,发布在 Unsplash
毫无疑问,用户留存是许多公司和在线应用程序的重要绩效指标。我们将讨论如何利用内置的数据仓库机器学习功能,运行用户行为数据的倾向模型,以确定用户流失的可能性。在这个故事中,我想重点关注数据集准备和使用标准 SQL 进行模型训练。现代数据仓库允许这样做。确实,留存是一个重要的业务指标,有助于理解用户行为的机制。它提供了一个高层次的概述,回答了我们应用程序在留住用户方面的成功程度:我们的应用程序是否足够好,能够留住用户?一个众所周知的事实是,留住现有用户比获得新用户要便宜。
在我之前的一篇文章中,我写到了现代数据仓库 [1]。
先进的数据平台设计
towardsdatascience.com
现代 DWH 具有许多有用的功能和组件,使其与其他数据平台类型有所区别 [2]。
ML 模型支持似乎是处理大数据时的基础 DWH 组件。
在这个故事中,我将使用二元逻辑回归,这是训练速度最快的模型之一。我将演示如何利用它来预测用户流失的倾向。确实,我们不需要了解每一个机器学习模型。
我们无法与 Amazon 和 Google 等云服务提供商在机器学习和数据科学领域竞争,但我们需要知道如何使用它。
我在我的文章中曾经写过关于这个的内容 [3]:
2024 年初学者的快捷方式
[towardsdatascience.com
在本教程中,我们将学习如何转换原始事件数据,以创建一个用于 ML 模型的训练数据集。我们将使用它来生成对用户的预测。我将以BigQuery ML作为示例,但还有很多其他支持此功能的数据仓库工具。
BigQuery ML 使机器学习操作和模型训练变得更加普及,因此现在数据分析师或软件工程师可以轻松训练模型。我们只需要对 SQL 有良好的知识,并了解用户保留数据集的逻辑[4]。
数据准备过程简单明了,应该很容易跟随。
我们将使用标准 SQL 来实现。
通常,这有助于揭示一些有关数据和用户基础的有用信息。分析用户行为和进行探索性数据分析有助于检测重要的用户行为漏斗(开放漏斗),这些漏斗可以用于进一步的特征工程,并改进模型。
例如,我们可以使用 Google 友善提供的免费用户行为数据集中的一个。这些数据集的典型移动应用有两个版本——Android 和 iOS,它们生成持续的事件数据流。Google Analytics 4 是一个很好的例子,这些数据可以用来衡量我们应用中的流量和参与度水平。
每个模型都需要一个数据集
我们想要创建一个。我们需要:
-
对来自 Firebase(Google Analytics 4)的 BigQuery 导出数据集进行探索性数据分析(EDA)。
-
将数据集分为两部分用于训练和测试,包含用于机器学习(ML)模型的分类和行为属性。
-
使用 BigQuery ML 训练和评估机器学习模型
-
使用 BigQuery ML 模型进行预测
-
将模型洞察应用于实践
在 Firebase 或 Google Analytics 4 中,所有用户行为数据都作为事件存储,这意味着表中的每一行对应一个单独的事件,并附带额外的参数和属性。
数据集准备和模型训练示意图
我们将使用一个公开的 Google Analytics 数据集[5],它包含一个名为“Flood It!”的移动游戏应用的数据(Android、iOS),且没有相关费用。然而,Google Cloud Platform(GCP)服务是按使用量计费的,可能会产生费用。
数据集示例。图片由作者提供。
该数据集包含来自超过 15,000 个用户的 570 万个事件。打开上面的链接并点击预览。
运行任何表的预览不会产生费用。
通过在下面的命令行中运行此命令来复制此数据集:
gcloud init # Authenticate
# Use bq mk command with --dataset flag to create a new dataset:
bq mk
--dataset \
--location=US \
PROJECT:DATASET
# # Example:
# bq mk
# --dataset \
# --location=US \
# bq-shakhomirov:analytics_153293282
# Use bq mk with --transfer_config flag to copy the dataset with public data:
bq mk
--transfer_config
--project_id=PROJECT \ # Replace with your project_id
--data_source=cross_region_copy
--target_dataset=analytics_153293282
--display_name=analytics_153293282
--params='{"source_dataset_id":"analytics_153293282","source_project_id":"firebase-public-project","overwrite_destination_table":"true"}'
探索数据
分析用户行为和进行探索性数据分析 [6] 有助于更好地理解用户旅程。
## 使用 BigQuery SQL 进行探索性数据分析?简单!
完整的 Python 比较和逐步指南,适用于任何数据集。Kaggle 用户流失数据。
towardsdatascience.com
运行此查询以检查数据集结构:
SELECT *
FROM `firebase-public-project.analytics_153293282.events_*`
TABLESAMPLE SYSTEM (1 PERCENT)
;
该数据集的限制在于用户数据没有实际的 user_id
,而 user_id
通常是在注册后分配的。因此,在处理该数据集时,我们仅使用 device ids
,这在理想情况下不是很好,因为这些在重新安装或新应用版本推出后会不断更新。
首先,我们希望对原始事件数据进行预处理,并创建一个具有正确结构的新数据集,以便用于机器学习模型。
-
如果用户在过去 30 天内活跃,我们将标记每个用户为保留(0)。
-
我们将使用 GA4 的
user_pseudo_id
(device_id),Firebase 默认使用它来识别每个用户的设备。 -
根据我们的 EDA,我们将排除一些明显的异常值,例如垃圾邮件发送者(事件过多)和
bounced
用户(在install
事件后在应用中停留时间少于 10 分钟的用户)。 -
我们将添加几个从原始事件数据中提取的类别特征列,例如
platform
、device_type
、country
等。 -
我们最终会添加每个用户在一定时间内的用户行为和活动总量,例如
user_engagement
、spend_virtual_currency
等。
这些计算将用作机器学习模型的特征,但简而言之,它们是开放的事件漏斗,即用户在使用应用程序时的步骤。
我们的未流失(回归用户)定义是用户在当前日期前的过去 30 天内活跃并与应用程序互动。你可以尝试调整此参数。你还可以尝试预测某些内容,例如在应用中消费的可能性(in_app_purchase
)等。
计算跳失和流失用户
如果我们在模型中使用 user_pseudo_id
,则如果 last_touch
在注册后 10 分钟内,我们会计算 bounced = 1
,否则为 0:
...
IF (user_last_engagement <= TIMESTAMP_ADD(user_first_engagement,
INTERVAL 10 MINUTE),
1,
0 ) AS bounced
, IF(last_seen_dt < DATE_SUB(current_date(),INTERVAL 30 DAY) , 1, 0) AS churned
...
如果我们有一个每天更新的动态数据集,使用 current_date()
来识别流失用户是理想的。然而,我们正在处理一些样本数据,因此我们应该使用数据中的最后已知日期,即我们假设在特定日期(20181003)收集了数据。
理想情况下,我们希望只跟踪那些在安装应用后完成注册的用户。在这种情况下,我们不需要计算 bounced
,如果 last_seen_dt
在从 current_timestamp()
计算的 24 小时/3 天/30 天 之前,则计算 churned = 1
,否则为 0:
...
, IF( last_seen_ts < timestamp_sub(
current_timestamp(), interval 30*24 hour) , 1, 0) AS churned
...
添加行为定量和类别特征
这就是我们将用来构建模型的数据。分类特征可以表示一些不可量化的人口统计值,例如性别等。定量特征则是我们可以测量和计数的。
这两种特征的组合有助于创建一个更具预测性的模型。Firebase/GA4 数据集中提取了大量信息,这可能对我们的模型有用,例如 app_info、device、event_params、geo 等。
Firebase 会自动收集许多事件,但请记住,也可以设置自定义事件和属性。作为移动开发者,我们可以集成一个自定义事件,例如用户 tag
。对于我们的示例应用程序,这将指示用户可能具有某些应用内权限(如高级会员、影响者、群组管理员、版主等)。
话虽如此,我们希望使用以下分类特征:
-
platform (IOS/Android)
-
geo.country
-
device.operating_system
-
device.language
需要注意的是,用户在这些字段中可能具有不同的值,例如不同的设备、语言更改和 VPN 设置可能会影响这些值。因此,我们只需要第一个参与事件和他们在 install
或 registration
时使用的设备设置。
添加 row_number 函数将有助于实现这一点:
...
ROW_NUMBER() OVER (PARTITION BY user_pseudo_id ORDER BY event_timestamp DESC) AS row_number
...
然后,为了预测用户流失,我们希望统计用户在注册/安装后的 24 小时/3 天/7 天 内遇到的事件数量:
在我们的案例中,我们希望收集和统计这些事件:
-
user_engagement
-
level_start_quickplay
-
level_end_quickplay
-
level_complete_quickplay
-
level_reset_quickplay
-
post_score
-
spend_virtual_currency
-
ad_reward
-
challenge_a_friend
-
completed_5_levels
-
use_extra_steps
对于其他应用程序,收集描述用户行为的其他事件可能是有用的,即:
-
message_sent
-
chat_open
-
spend_virtual_currency
-
account_topup
-
set_avatar
-
group_join
-
broadcast_listened
-
achievement_unlocked
-
reputation_update
计算模型特征
让我们为未来的模型及其数据创建一个专用数据集:sample_churn_model
。
./sql/churn.sql
中的 SQL 查询演示了如何计算这些用户指标。我们可以在代码小部件中找到这个文件。要创建这个数据集,请在命令行中运行:
bq query --nouse_legacy_sql < sql/churn.sql
./sql/churn.sql:
-- ./sql/churn.sql
create schema if not exists sample_churn_model
OPTIONS (
description = 'sample_churn_model',
location = 'US'
)
;
create or replace table sample_churn_model.churn as
with first_engagement_data as (
select distinct
user_pseudo_id
, device.language
, geo.country
, platform
, ROW_NUMBER() OVER (PARTITION BY user_pseudo_id ORDER BY event_timestamp DESC) AS row_num
from `firebase-public-project.analytics_153293282.events_*`
where
-- _TABLE_SUFFIX = FORMAT_DATE('%Y%m%d', DATE_SUB(current_date(), INTERVAl 1 DAY))
-- and
lower(event_name) = 'user_engagement'
)
, first_engagement as (
select * except (row_num)
from first_engagement_data
where row_num = 1
)
, first_last_seen as (
select
user_pseudo_id
, min(event_timestamp) as first_seen_ts
, max(event_timestamp) as last_seen_ts
from
`firebase-public-project.analytics_153293282.events_*`
where event_name = 'user_engagement'
group by
user_pseudo_id
)
, event_aggregates as (
select
event_date
,user_pseudo_id
,lower(event_name) as event_name
,count(distinct event_timestamp) as event_cnt
from
`firebase-public-project.analytics_153293282.events_*`
-- where
-- _TABLE_SUFFIX = FORMAT_DATE('%Y%m%d', DATE_SUB(current_date(), INTERVAl 1 DAY))
group by
event_date
,user_pseudo_id
,event_name
)
, daily_data as (
select
a.user_pseudo_id
,a.event_date
,a.event_name
,a.event_cnt
,f.first_seen_ts
,f.last_seen_ts
,e.language
,e.country
,e.platform
, date_diff(
parse_date('%Y%m%d', a.event_date)
, cast(timestamp_micros(first_seen_ts) as date)
, day
) as day_number
,if( cast(timestamp_micros(f.last_seen_ts) as date) <
date_sub(parse_date('%Y%m%d', '20181003'), interval 30 day)
, 1
, 0) as churned -- wasn't active last 30 days
,if (timestamp_micros(last_seen_ts) <= timestamp_add(timestamp_micros(first_seen_ts),
interval 10 minute)
,1
,0 ) as bounced
from event_aggregates a
left join first_last_seen f on f.user_pseudo_id = a.user_pseudo_id
left join first_engagement e on e.user_pseudo_id = a.user_pseudo_id
)
-- Final dataset for churn model training:
select
user_pseudo_id
,first_seen_ts
,last_seen_ts
,language
,country
,platform
,churned
,sum(if(day_number in (0) and event_name = 'user_engagement', event_cnt, 0)) as user_engagement_cnt
,sum(if(day_number in (0) and event_name = 'level_start_quickplay', event_cnt, 0)) as level_start_quickplay_cnt
,sum(if(day_number in (0) and event_name = 'level_end_quickplay', event_cnt, 0)) as level_end_quickplay_cnt
,sum(if(day_number in (0) and event_name = 'level_complete_quickplay', event_cnt, 0)) as level_complete_quickplay_cnt
,sum(if(day_number in (0) and event_name = 'level_reset_quickplay', event_cnt, 0)) as level_reset_quickplay_cnt
,sum(if(day_number in (0) and event_name = 'post_score', event_cnt, 0)) as post_score_cnt
,sum(if(day_number in (0) and event_name = 'spend_virtual_currency', event_cnt, 0)) as spend_virtual_currency_cnt
,sum(if(day_number in (0) and event_name = 'ad_reward', event_cnt, 0)) as ad_reward_cnt
,sum(if(day_number in (0) and event_name = 'challenge_a_friend', event_cnt, 0)) as challenge_a_friend_cnt
,sum(if(day_number in (0) and event_name = 'completed_5_levels', event_cnt, 0)) as completed_5_levels_cnt
,sum(if(day_number in (0) and event_name = 'use_extra_steps', event_cnt, 0)) as use_extra_steps_cnt
from daily_data
-- where bounced = 0
group by
user_pseudo_id
,first_seen_ts
,last_seen_ts
,language
,country
,platform
,churned
让我们看看有多少用户流失了:
从我们的数据集中,我们发现 4030 名用户在过去 30 天内流失并处于非活动状态:
SELECT
churned
,COUNT(churned) as count_users
FROM
sample_churn_model.churn
GROUP BY 1
;
流失的用户。图片由作者提供。
模型训练和分类
目前在 BigQuery ML 中有不同的模型类型 [7]:
-
BOOSTED_TREE_CLASSIFIER
-
神经网络
-
AutoML Tables
-
逻辑回归
逻辑回归可能是一个很好的起点,因为它可以相对较快地训练。其他类型的模型可能提供更好的性能,但也需要更多的训练时间,例如深度神经网络。
这些模型中的每一个都会输出一个 0 到 1.0 之间的概率分数(倾向),表示模型预测的可能性。
请参考文件./sql/churn_model.sql
。如果我们运行它,它将创建并训练模型。
运行此查询来评估模型:
bq query --nouse_legacy_sql \
'SELECT * FROM ML.EVALUATE(MODEL sample_churn_model.churn_model);'
你将看到模型性能指标。分析这些指标可能有助于在不同模型之间进行选择。
例如,关于模型准确性,任何超过 70%的结果都被视为强模型性能。
我们的模型的召回率为 0.559——换句话说,它正确识别了 56%的所有流失用户。
模型性能指标。图片由作者提供。
我们可以使用混淆矩阵来查看我们的模型在预测标签方面的表现,与实际标签进行比较。
如果我们运行下面的 SQL,它将生成一个混淆矩阵。
SELECT
expected_label,
_0 AS predicted_0,
_1 AS predicted_1
FROM
ML.CONFUSION_MATRIX(MODEL sample_churn_model.churn_model)
这可以被解释为false positive
和false negative
预测的比较。
混淆矩阵。图片由作者提供。
使用预测
对于我们的分类模型,最重要的指标是用户流失和不活跃的倾向。换句话说,这是一种概率,且该概率越接近 1,模型预测的用户越不可能返回应用程序:
SELECT
user_pseudo_id,
churned,
predicted_churned,
predicted_churned_probs[OFFSET(0)].prob as probability_churned
FROM
ML.PREDICT(MODEL sample_churn_model.churn_model,
(SELECT * FROM sample_churn_model.churn)) #can be replaced with a proper test dataset
order by 3 desc
;
预测。图片由作者提供。
在现实生活中,我们希望创建一个每天更新的预测数据集。
首先,我们需要昨天注册的 Firebase/Analytics 用户。我们需要安排我们的sample_churn_model.churn
数据集,并逐步添加昨天的新用户,即:
insert into sample_churn_model.churn (...)
select
... -- aggregate model features for new users.
from
`firebase-public-project.analytics_153293282.events_*`
where
_TABLE_SUFFIX = FORMAT_DATE('%Y%m%d', DATE_SUB(current_date(), INTERVAl 1 DAY))
...
然后我们希望生成对它们的预测并将其插入到我们的新predictions
数据集中:
SELECT
user_pseudo_id,
churned,
predicted_churned,
predicted_churned_probs[OFFSET(0)].prob as probability_churned
FROM
ML.PREDICT(MODEL sample_churn_model.churn_model,
(SELECT * FROM sample_churn_model.churn
WHERE date(first_seen_ts) = date_sub(current_date(), interval 1 day)))
使用这个模型,我们可以通过模式更好地理解用户行为,当然,我们还希望利用这些知识做些事情。
有多种方法可以使用预测数据(激活)。例如,我们可以使用 SDK 和客户端库直接从我们的 DWH 解决方案中读取数据。例如,我们可能希望创建一个数据服务,每天收集新用户的预测数据,然后将这些数据发送到其他地方,即重新定向服务。是的,我们可能也希望重新定向那些可能流失的用户或可能留在应用中的用户。
结论
确实,对机器学习(ML)模型数据采取行动以保留用户被证明是非常有用的,并且可能帮助在快速变化的市场环境中获得竞争优势。这就是为什么能够预测用户参与度以预测用户是否即将离开是重要的。我们不需要成为数据科学家就能创建和训练 ML 模型。我们所需要的 — 是现代数据仓库、良好的 SQL 知识以及对用户保留逻辑的良好理解。现代数据仓库已发展到可以提供所有已知 ML 模型的状态,它们已经准备好使用标准 SQL 方言进行创建。通过预测的保留数字,我们可以创建和编辑受众。利用现代数据仓库中的 ML 功能,我们可以通过向识别出的用户提供相关信息、有用的优惠和促销来量身定制用户体验。现代数据仓库解决方案使机器学习操作和模型训练民主化。这对于数据工程师来说是一个极其有用的特性,因为所有这些过程可以很容易地自动化、计划和触发,具体取决于使用案例场景。
推荐阅读
[1] towardsdatascience.com/modern-data-warehousing-2b1b0486ce4a
[2] towardsdatascience.com/data-platform-architecture-types-f255ac6e0b7
[3] towardsdatascience.com/how-to-become-a-data-engineer-c0319cb226c2
[4] towardsdatascience.com/retention-and-daily-active-users-explained-79c5dab1bf84
[6] towardsdatascience.com/exploratory-data-analysis-with-bigquery-sql-easy-69895ac4eb9e
[7] cloud.google.com/bigquery-ml/docs/introduction
用户反馈 — 机器学习监控栈的缺失部分
完整指南:构建以用户为中心的人工智能
·
关注 发表在 Towards Data Science · 8 分钟阅读 · 2023 年 8 月 2 日
--
图片来自 Unsplash
人工智能模型与用户的错位
你是否曾经花费数月的时间,以及谁知道花费了多少美元,来实施一个人工智能模型,结果发现没有人使用?即使你克服了采纳的挑战,你如何知道模型输出是否真正为用户的决策、查询、专业或日常活动带来了价值?
机器学习性能指标和实时监控工具是计算模型性能的优秀方式,并且可以识别出从技术角度来看可能出现的问题。但如果没有理解用户参与或满意度,很难知道模型是否被用于其预期的目的。
此外,听取 AI 模型用户的意见可能会揭示出预测错误的边缘案例;解释算法不能像我们期望的那样清晰解释问题;或用户体验缺陷影响用户与模型的互动方式。
本文的其余部分将涵盖理解 AI 模型用户反馈的重要性、不同类型的用户反馈,以及如何收集用户反馈以改善模型性能、增加用户采用率,并最终使 AI 模型与用户对齐。
内容
-
AI 模型与用户的错位
-
什么是对 AI 的用户反馈?
-
为什么 AI 的用户反馈很重要?
-
不同类型的反馈是什么?
-
收集用户反馈指南
-
总结
什么是对 AI 的用户反馈?
当我们提到用户反馈时,所指的用户取决于你正在实施的用例。例如,这可能是一个内部业务用户或内部基于 ML 的需求预测应用程序的利益相关者;也可能是一个外部领域专家,如医学肿瘤学家,利用一个 MedTech 产品来帮助检测医学扫描中的肿瘤;或者,它可能是一个面向外部的求职申请助手的最终用户,利用生成式 AI 帮助编写和完善简历。
本文中概述的概念、方法和好处适用于所有这些不同的用例。然而,根据具体用例,一些好处可能会有所不同,应该根据具体情况加以考虑。
为了本文的目的,我们将使用上述描述的简历助手来说明用户反馈对该应用的好处。
另一个重要的点是,当提到用户反馈时,我们不仅仅是指重新标记错误预测或用于自动化模型再训练的反馈循环。用户反馈包括用户提供的任何信息,帮助了解 AI 应用的有用性和采用情况。在我们的简历助手示例中,用户反馈可能包括用户满意度评分,提供有关用户对生成的简历有多满意的见解,或书面评论以突出具体问题。
这种反馈不应该总是直接推送到自动化再训练流程中,原因有几个:
-
用户反馈通常是非结构化的,并突出了不正确预测以外的问题,因此不能总是直接用于重新训练模型。例如,用户指出简历助手使用过于正式的语言可能需要在训练数据中增加更多非正式文本的示例,而不是直接用这些反馈进行重新训练。
-
仅关注正确/错误预测忽略了用户提供的宝贵信息。理解用户反馈可以帮助 AI 团队根据用户体验和使用模式改进应用。
-
如强化学习与人类反馈(RLHF)等训练策略在受控环境中效果很好。然而,现实世界的用户反馈可能嘈杂且潜在有害。例如,盲目地将用户反馈纳入训练数据可能导致数据中毒,其中恶意用户故意误导模型。
因此,AI 团队应该审查用户反馈,以提取不同的见解,并确定下一步最佳行动方案,以改善整体 AI 应用。
为什么用户反馈对 AI 很重要?
实现模型评估
许多 AI 模型缺乏基准真相。这使得在测试数据集上进行评估变得困难,因为它通常基于一个代理指标,而这个指标通常只讲述了一部分故事。这在生成模型中尤其如此,因为了解用户是否对模型预测感到满意通常是最重要的指标。
提升模型性能:
用户反馈可以用来持续提高 AI 模型的性能。用户可能具备构建强大模型所需的良好领域知识。此外,监控用户参与度可以帮助识别模型是否因训练/测试集未能良好代表现实世界而表现不佳。
增加用户对齐度:
用户反馈提供了模型哪些方面运作良好以及造成摩擦的原因的见解。这使得 AI 团队能够提升用户体验,使模型更加直观和用户友好。此外,AI 团队可以确保模型对所有用户而不仅仅是小型子群体对齐。例如,确保简历助手在所有语言中保持质量,而不仅仅是英语。
当用户感到他们的声音被听到时,他们更可能信任 AI 模型并保持参与,从而增加用户对齐度和采纳度。
提高 AI 责任感:
通过用户反馈,AI 团队可以识别和解决与安全性、偏见或其他伦理问题相关的担忧。这种积极主动的方法导致了更安全、更负责任的 AI 模型的发展。通过寻求和回应用户反馈,AI 团队展示了他们对创建高质量和可靠 AI 解决方案的责任感和承诺。反馈也可能揭示对额外教育资源和文档的需求,AI 团队可以提供这些资源,以确保用户对模型的能力有清晰的理解,并促进最佳实践。
总结来说,利用用户洞察可以使 AI 团队优化模型、提升用户体验和解决伦理问题,从而提高用户满意度和信任度。
现在我们已经明确了用户反馈及其好处,让我们来讨论不同类型的反馈及其用途。
不同类型的反馈有哪些?
用户反馈主要分为两个类别:明确反馈和隐性反馈。这可以通过我们新朋友 ChatGPT(如下图)进行很好的解释和说明。
ChatGPT 说明了明确和隐性反馈之间的区别。截图来自 OpenAI 的 ChatGPT,并由作者编辑。
明确的用户反馈指的是用户对其体验、观点或偏好的直接、故意和自觉的输入。正如你在 ChatGPT 界面中看到的,点赞/点踩反馈就是明确反馈的一个例子。
明确的反馈可以进一步分为定量和定性两种。定量反馈包括可测量的尺度,如点赞/点踩、用户满意度(也称为 5 分制李克特量表),或任何最适合你想了解的用户信息的自定义尺度。
定性反馈通常涉及一个开放的文本框,以允许用户提供书面反馈。将定量测量与定性反馈结合起来,可以使 AI 团队了解用户评论背后的“为什么”,并揭示诸如 AI 缺陷、领域知识或用户偏好等细节。
提交定性反馈后,选择了负面的定量反馈。截图来自 OpenAI 的 ChatGPT。
隐性用户反馈指的是基于用户行为、动作或模式提供的间接、非自觉的数据。再次看 ChatGPT 界面,‘复制到剪贴板’按钮就是 OpenAI 收集隐性反馈的一个例子。对于简历助手的例子,隐性用户反馈也可以通过跟踪用户对生成输出的任何编辑来获得。
在选择实施的反馈类型时需要进行考虑。明确的反馈提供了用户反馈和想法的更清晰理解。然而,对于外部使用情况,最终用户可能不会总是提供明确反馈,因为他们可能不了解如何受益(或感觉没有时间!)。在这种情况下,隐式反馈也可以很好地了解 AI 应用的使用情况,而无需用户采取直接行动。
根据应用情况和你当前面临的挑战,你还应考虑实施哪些措施。例如,如果你专注于提高模型性能,那么带有评论的点赞/点踩措施可以帮助识别模型问题。但如果你更关注提高用户采纳率,那么也许用户满意度评分会更合适。
收集用户反馈指南
在本节中,我们将详细介绍收集用户反馈的四个关键步骤,并将用户洞察整合回 ML 监控系统中(如下所示)。
系统图概述了收集用户反馈的高层架构,并将用户洞察整合回 ML 监控系统。图像由作者提供。
步骤 1:设计并构建 AI 应用中的反馈组件
在确定你为何收集用户反馈的目标后,你可以确定哪种类型的反馈最适合你的需求。用户反馈通常在模型输出生成后实施。然而,你可能希望在整个应用程序中收集反馈,以获得对应用程序某些功能的反馈。
AI 模型元数据应与通过组件提交的所有反馈一起捕获。这包括模型版本、提示或请求、模型输出以及用户人口统计信息(如用户 ID 和位置)。
步骤 2:开发分析能力以理解用户反馈
对于定量反馈,这可能包括如用户满意度(CSAT/NPS)或随时间变化的平均正面/负面回应等图表,能够比较不同模型版本、用户或其他元数据的这些指标。
对于定性反馈,使用 ML 来分析用户评论中的情感,并将反馈分类到不同的类别中。这使得能够监控不同类别评论中的各种情感/满意度指标。
步骤 3:识别 AI 问题
利用分析能力,可以识别反馈中的重复主题和模式,以分类改进领域。然后,AI 问题可以被提出并优先解决。
AI 团队在这一阶段的角色是识别模型问题和用户问题,并确定解决这些问题的最佳行动方案。
要提醒 AI 团队可能在用户反馈中发现的见解类型,请回顾“什么是 AI 的用户反馈?”部分。
步骤 4 — 将用户反馈整合回你的机器学习监控系统中
将用户反馈整合到现有的机器学习监控系统中,可以让你设置警报(类似于性能监控或漂移检测)。例如,如果全球用户满意度得分低于某个阈值,可以触发警报通知 AI 团队采取行动。
此外,摘要和日报可以发送给 AI 团队或利益相关者,提供用户反馈的概览。
总结
总结来说,用户反馈使 AI 团队能够识别漏洞、微调模型,并使模型与用户对齐。
上述内容也可以通过机器学习监控系统来实现。然而,通过从不同的角度,即用户的视角来评估模型,我们可以识别传统机器学习监控系统可能忽视的额外信息。
我希望这篇文章激发了你的兴趣,并为你提供了如何开始倾听用户意见并增强 AI 应用的初步想法。
如果你想了解更多关于 AI 用户反馈的信息,或分享和讨论你在这个话题上的想法,请随时通过 Linkedin 或 email**与我联系。
使用 Apache Kafka 进行数据流处理
原文:
towardsdatascience.com/using-apache-kafka-for-data-streaming-9199699623fa
学习如何安装和使用 Kafka 发送和接收消息
·发表于 Towards Data Science ·6 分钟阅读·2023 年 3 月 1 日
--
由 Patrick Perkins 提供的照片,刊登在 Unsplash
什么是 Apache Kafka?
Apache Kafka 是一个开源应用,用于大数据的实时流处理。它是一个发布-订阅消息系统,你可以用它在进程、应用和服务器之间发送消息。以下图示展示了 Apache Kafka 的高层架构概述:
所有图片均由作者提供
与其他消息系统不同,Kafka 还具有额外的功能,如分区、复制,并且具有比其他消息系统更高的吞吐量和容错性。所有这些功能使 Kafka 非常适合高容量的消息处理。
我将在未来的文章中详细讨论集群、分区及其他功能。
在本文中,我将指导你完成在系统上安装 Kafka 的过程,以便你对其工作原理有更好的了解。未来的文章将深入探讨 Kafka 的具体使用案例。
安装
安装 Kafka 相当简单。你只需按照以下列出的 3 个步骤安装并运行 Kafka:
-
安装 Java
-
下载并安装 Kafka
-
运行 ZooKeeper 和 Kafka
在本文中,我将展示如何在 macOS 上安装 Kafka。
1. 安装 Java
Kafka 是用 Scala 和 Java 编写的。因此,你需要做的第一步是安装最新的 JDK。你可以通过访问 官方 Java 下载 页面来安装 Java,网址为:www.oracle.com/java/technologies/downloads/
。
如果你使用的是基于 Intel 的 Mac,请下载 x64 DMG 安装程序;否则,请下载 Arm 64 DMG 安装程序。
如果你的 Mac 上安装了Homebrew (brew.sh/
),可以使用以下命令直接安装 Java:
$ brew tap caskroom/cask
$ brew cask install java
2. 下载 Kafka
前往 Official Apache Kafka Downloads 页面 kafka.apache.org/downloads
并下载最新的 Kafka 二进制版本(在撰写时为 kafka_2.13–3.4.0.tgz):
在 Mac 上,下载的二进制文件将保存在默认的 Downloads 文件夹中。
在终端中,输入以下命令以解压 Kafka 安装包:
$ cd ~/Downloads
$ tar -xzf kafka_2.13-3.4.0.tgz
在 Downloads 文件夹中将创建一个名为 kafka_2.13–3.4.0 的新文件夹。将整个文件夹移动到你的主目录 (~
)。
或者,你可以使用 Homebrew 通过以下命令安装 Kafka:
$ brew install kafka
3. 运行 ZooKeeper 和 Kafka
为了确保 Kafka 正常运行,你需要运行一个称为ZooKeeper的工具。
ZooKeeper 负责 Kafka 的集群管理。它随 Kafka 安装包一起提供。
要启动 ZooKeeper,在终端中输入以下命令:
$ cd ~/kafka_2.13-3.4.0
$ bin/zookeeper-server-start.sh config/zookeeper.properties
kafka_2.13–3.4.0 文件夹在你的主目录中。
接下来,启动 Kafka broker service。Kafka 代理是生产者和消费者之间的中介:
在一个新的终端窗口中,输入以下命令以启动 Kafka 代理服务:
$ cd ~/kafka_2.13-3.4.0
$ bin/kafka-server-start.sh config/server.properties
如果没有错误,你的 Kafka 安装现在已经准备好使用了!
创建一个主题
Kafka 代理启动并运行后,你现在可以看到 Kafka 的实际操作了!首先,创建一个新的 主题。
主题是用于管理和组织消息的类别。生产者将消息发送到特定主题,消费者订阅特定主题以接收发送到该主题的消息。
启动一个新的终端窗口,并输入以下命令:
$ cd ~/kafka_2.13-3.4.0
$ bin/kafka-topics.sh --create --topic chat --bootstrap-server localhost:9092
上述命令在监听端口 9092 的 Kafka 代理服务上创建了一个“chat”主题。
查看主题
要查看刚刚创建的主题,使用以下命令:
$ cd ~/kafka_2.13-3.4.0
$ bin/kafka-topics.sh --describe --topic chat --bootstrap-server localhost:9092
你应该会看到以下输出:
Topic: chat TopicId: TuTCBnOISzqhOpqh8FP2FA PartitionCount: 1 ReplicationFactor: 1 Configs:
Topic: chat Partition: 0 Leader: 0 Replicas: 0 Isr: 0
将事件写入主题
要将一个 事件(通常称为消息)写入主题,首先使用以下命令启动 Kafka producer console:
$ bin/kafka-console-producer.sh --topic chat --bootstrap-server localhost:9092
你现在应该能看到 >
提示符:
继续输入一些消息:
消息将被发送到 Kafka 代理。
阅读事件
要从“chat”主题中读取事件(消息),启动一个新的终端窗口并输入以下命令以启动 Kafka consumer console:
$ cd ~/kafka_2.13-3.4.0
$ bin/kafka-console-consumer.sh --topic chat --from-beginning --bootstrap-server localhost:9092
你现在应该会看到你之前发送的消息:
你可以启动另一个新的终端窗口,输入相同的命令启动另一个 Kafka consumer console,你也会看到相同的消息。
转到你运行 Kafka 生产者控制台 的终端窗口,输入一条新消息。两个运行 Kafka 消费者控制台 的其他终端窗口现在将接收到新消息:
如果你喜欢阅读我的文章并且这些文章对你的职业/学习有所帮助,请考虑注册成为 Medium 会员。每月 $5,会员可以无限制访问 Medium 上的所有文章(包括我的文章)。如果你使用以下链接注册,我将获得少量佣金(对你没有额外费用)。你的支持意味着我将能花更多时间撰写类似的文章。
[## 使用我的推荐链接加入 Medium - Wei-Meng Lee
阅读 Wei-Meng Lee 的每一篇故事(以及 Medium 上成千上万其他作者的作品)。你的会员费直接支持……
weimenglee.medium.com](https://weimenglee.medium.com/membership?source=post_page-----9199699623fa--------------------------------)
总结
在这篇文章中,我简要介绍了 Apache Kafka 的使用及其主要组件。为了避免让你感到信息过载,我展示了如何通过安装 Kafka 并启动 Kafka broker 服务来开始使用 Kafka。为了了解 Kafka 的工作原理,你创建了一个主题,并使用 Kafka 生产者和消费者发送和接收消息。希望这个快速入门能让你对 Kafka 能做什么有更清晰的认识,未来的文章中我将深入探讨 Kafka 的具体应用案例!
使用贝叶斯网络预测医院中的辅助服务量
使用诊断输入变量的 Python 示例
·
关注 发表在 Towards Data Science ·8 min read·2023 年 8 月 11 日
--
来自 Unsplash 的照片,作者 EJ Strat
探索复杂性:预测医疗保健中的患者数量
在医疗保健中,准确预测即将到来的患者量不仅对运营成功至关重要,而且是一个极其棘手的问题。有太多依赖因素需要考虑——患者的严重程度和特殊要求、行政需求、检查室的限制、员工请病假、恶劣的雪暴等。更糟糕的是,意外的情况可能对调度和资源分配产生连锁影响,即使是最好的 Excel 预测也会被推翻。
从数据的角度来看,这些挑战真的很有趣,因为它们复杂,你可以深入思考一段时间。但同时,即使是微小的改进也可以带来重大的收获(例如,提高患者流量,降低等待时间,增加提供者的满意度,降低成本)。
另一种预测的方法:贝叶斯模型
那么还有什么替代方案呢?好吧,Epic 为我们提供了大量数据,包括患者到达预约时间的实际记录。通过已知的历史“到达”和“未到达”数据,我们可以在监督学习的范围内进行操作,而贝叶斯网络(BNs)提供了良好的概率图模型来预测未来的访问概率。
虽然生活中的大多数决策可以通过单一输入来确定(例如,考虑“我是否应该带雨衣?”,假设外面在下雨,那么决定应该是“是”),但贝叶斯网络可以轻松处理更复杂的决策——涉及多个输入的决策(例如,湿度高,步行仅需 3 分钟,你的雨衣在另一层楼,你的朋友可能带了伞等),这些输入具有不同的概率和依赖关系。在本文中,我将用 Python 在“草稿纸”上创建一个超级简单的贝叶斯网络,根据已知的三种因素的概率:症状、癌症阶段和治疗目标,输出患者在两个月后到达的概率分数。
理解贝叶斯网络:
贝叶斯网络的核心是使用有向无环图(DAG)表示联合概率分布的图形表示。DAG 中的节点表示随机变量,有向边表示这些变量之间的因果关系或条件依赖关系。正如所有数据科学项目所示,在开始时与利益相关者花大量时间正确绘制决策过程中的工作流程(例如,变量)对于高质量预测至关重要。
所以,我会编造一个场景,我们与乳腺肿瘤学合作伙伴会面,他们解释说,有三个变量对于确定患者是否需要在两个月后预约至关重要:患者症状、癌症阶段和当前治疗目标。我在输入时编造这个,但我们就这么做吧。
(实际上,将有数十个因素影响未来的患者量,其中一些具有单一或多重依赖关系,其他的则完全独立但仍会影响结果)。
现在,假设我们同意工作流程如上所示:Stage 依赖于其症状,但治疗类型独立于这些并且还会影响 2 个月后的预约发生。
基于此,我们将从数据源(对我们而言,是 Epic)中获取这些变量的数据,这些数据将包含我们分数节点(Appointment_2months)的 已知 值,标记为“是”或“否”。这种数据处理是一个重要部分;你需要根据这些变量在 2 个月前所指示的内容,准确捕捉真实患者在 2 个月后的到达情况。
# install the packages
import pandas as pd # for data manipulation
import networkx as nx # for drawing graphs
import matplotlib.pyplot as plt # for drawing graphs
!pip install pybbn
# for creating Bayesian Belief Networks (BBN)
from pybbn.graph.dag import Bbn
from pybbn.graph.edge import Edge, EdgeType
from pybbn.graph.jointree import EvidenceBuilder
from pybbn.graph.node import BbnNode
from pybbn.graph.variable import Variable
from pybbn.pptc.inferencecontroller import InferenceController
# Create nodes by manually typing in probabilities
Symptom = BbnNode(Variable(0, 'Symptom', ['Non-Malignant', 'Malignant']), [0.30658, 0.69342])
Stage = BbnNode(Variable(1, 'Stage', ['Stage_III_IV', 'Stage_I_II']), [0.92827, 0.07173,
0.55760, 0.44240])
TreatmentTypeCat = BbnNode(Variable(2, 'TreatmentTypeCat', ['Adjuvant/Neoadjuvant', 'Treatment', 'Therapy']), [0.58660, 0.24040, 0.17300])
Appointment_2weeks = BbnNode(Variable(3, 'Appointment_2weeks', ['No', 'Yes']), [0.92314, 0.07686,
0.89072, 0.10928,
0.76008, 0.23992,
0.64250, 0.35750,
0.49168, 0.50832,
0.32182, 0.67818])
上述内容,让我们手动输入每个变量(节点)中级别的概率分数。这些概率不是猜测,甚至不是最佳猜测。实际上,你将再次根据现有数据计算频率。
以症状变量为例。我会得到它们的 2 个级别的频率,大约 31% 是良性的,69% 是恶性的。
作者提供的照片
然后,我们考虑下一个变量 Stage,并与 Symptom 交叉表以获得这些频率。我们这样做是因为 Stage 对 Symptom 依赖,并且由于它们每个都有两个场景,它们实际上有 4 种概率结果。
作者提供的照片
如此类推,直到定义所有父子对之间的交叉表。
现在,大多数 BN 包含许多父子关系,因此计算概率可能会变得繁琐(且主要容易出错),所以下面的函数可以计算任何子节点对应于 0、1 或 2 个父节点的概率矩阵。虽然见解不能也不应该被自动化,但数据准备部分可以且应该被自动化。
# This function helps to calculate probability distribution, which goes into BBN (note, can handle up to 2 parents)
def probs(data, child, parent1=None, parent2=None):
if parent1==None:
# Calculate probabilities
prob=pd.crosstab(data[child], 'Empty', margins=False, normalize='columns').sort_index().to_numpy().reshape(-1).tolist()
elif parent1!=None:
# Check if child node has 1 parent or 2 parents
if parent2==None:
# Caclucate probabilities
prob=pd.crosstab(data[parent1],data[child], margins=False, normalize='index').sort_index().to_numpy().reshape(-1).tolist()
else:
# Caclucate probabilities
prob=pd.crosstab([data[parent1],data[parent2]],data[child], margins=False, normalize='index').sort_index().to_numpy().reshape(-1).tolist()
else: print("Error in Probability Frequency Calculations")
return prob
然后我们创建实际的 BN 节点和网络本身:
# Create nodes by using our earlier function to automatically calculate probabilities
Symptom = BbnNode(Variable(0, 'Symptom', ['Non-Malignant', 'Malignant']), probs(df, child='SymptomCat'))
Stage = BbnNode(Variable(1, 'Stage', ['Stage_I_II', 'Stage_III_IV']), probs(df, child='StagingCat', parent1='SymptomCat'))
TreatmentTypeCat = BbnNode(Variable(2, 'TreatmentTypeCat', ['Adjuvant/Neoadjuvant', 'Treatment', 'Therapy']), probs(df, child='TreatmentTypeCat'))
Appointment_2months = BbnNode(Variable(3, 'Appointment_2months', ['No', 'Yes']), probs(df, child='Appointment_2months', parent1='StagingCat', parent2='TreatmentTypeCat'))
# Create Network
bbn = Bbn() \
.add_node(Symptom) \
.add_node(Stage) \
.add_node(TreatmentTypeCat) \
.add_node(Appointment_2months) \
.add_edge(Edge(Symptom, Stage, EdgeType.DIRECTED)) \
.add_edge(Edge(Stage, Appointment_2months, EdgeType.DIRECTED)) \
.add_edge(Edge(TreatmentTypeCat, Appointment_2months, EdgeType.DIRECTED))
# Convert the BBN to a join tree
join_tree = InferenceController.apply(bbn)
我们一切就绪。现在让我们通过我们的 BN 运行一些假设并评估输出。
评估 BN 输出
首先,让我们看看每个节点的概率情况,而不具体声明任何条件。
# Define a function for printing marginal probabilities
# Probabilities for each node
def print_probs():
for node in join_tree.get_bbn_nodes():
potential = join_tree.get_bbn_potential(node)
print("Node:", node)
print("Values:")
print(potential)
print('----------------')
# Use the above function to print marginal probabilities
print_probs()
Node: 1|Stage|Stage_I_II,Stage_III_IV
Values:
1=Stage_I_II|0.67124
1=Stage_III_IV|0.32876
----------------
Node: 0|Symptom|Non-Malignant,Malignant
Values:
0=Non-Malignant|0.69342
0=Malignant|0.30658
----------------
Node: 2|TreatmentTypeCat|Adjuvant/Neoadjuvant,Treatment,Therapy
Values:
2=Adjuvant/Neoadjuvant|0.58660
2=Treatment|0.17300
2=Therapy|0.24040
----------------
Node: 3|Appointment_2weeks|No,Yes
Values:
3=No|0.77655
3=Yes|0.22345
----------------
上述结果显示,该数据集中的所有患者有 67% 的概率处于 Stage_I_II,69% 的概率是非恶性的,58% 的概率需要辅助/新辅助治疗,而只有 22% 的患者需要 2 个月后的预约。
我们可以轻松地从简单的频率表中获得这些数据,而不需要 BN。
但现在,让我们提出一个更有条件的问题: 给定患者处于 Stage = Stage_I_II 并且 TreatmentTypeCat = Therapy 的情况下,他们需要在 2 个月内护理的概率是多少?此外,考虑到提供者对他们的症状一无所知(也许他们还没有见过患者)。
我们将运行在节点中我们知道是真实的内容:
# To add evidence of events that happened so probability distribution can be recalculated
def evidence(ev, nod, cat, val):
ev = EvidenceBuilder() \
.with_node(join_tree.get_bbn_node_by_name(nod)) \
.with_evidence(cat, val) \
.build()
join_tree.set_observation(ev)
# Add more evidence
evidence('ev1', 'Stage', 'Stage_I_II', 1.0)
evidence('ev2', 'TreatmentTypeCat', 'Therapy', 1.0)
# Print marginal probabilities
print_probs()
返回结果:
Node: 1|Stage|Stage_I_II,Stage_III_IV
Values:
1=Stage_I_II|1.00000
1=Stage_III_IV|0.00000
----------------
Node: 0|Symptom|Non-Malignant,Malignant
Values:
0=Non-Malignant|0.57602
0=Malignant|0.42398
----------------
Node: 2|TreatmentTypeCat|Adjuvant/Neoadjuvant,Treatment,Therapy
Values:
2=Adjuvant/Neoadjuvant|0.00000
2=Treatment|0.00000
2=Therapy|1.00000
----------------
Node: 3|Appointment_2months|No,Yes
Values:
3=No|0.89072
3=Yes|0.10928
----------------
那位病人在 2 个月内到达的机会只有 11%。我们可以向预测该病人 2 个月内到达的概率询问已知或未知特征的任何组合。进一步的算法和函数可以用来获取多个病人的概率,或者病人群体的概率,或优化这些概率。
关于高质量输入变量重要性的说明:
编写 Python 代码是一回事,但贝叶斯网络在提供可靠的未来就诊估计方面的真正成功在很大程度上依赖于准确的患者护理工作流程映射。这需要时间、讨论和白板——而不是编码。可能还需要多次深入数据,与客户一起测试假设:“我们之前说过,护士导航员在报告症状差时总是会联系病人,但实际上只发生了 10% 的情况。下一次病人与他们的医生联系。”(再说一次,这只是举个例子,但这将是一个重要的工作流程映射)。
病人呈现出相似的情况时,通常会需要类似的服务,并且到来的频率也相似。这些输入的排列,其特征可以从临床到行政,最终对应于服务需求的某种确定性路径。但时间预测越复杂或越远,就越需要更多特定的、复杂的贝叶斯网络和高质量的输入。
这里是原因:
-
准确表示: 贝叶斯网络的结构必须反映变量之间的实际关系。选择不当的变量或误解依赖关系可能导致不准确的预测和见解。
-
有效推断: 高质量的输入变量提升了模型进行概率推断的能力。当变量基于条件依赖准确连接时,网络可以提供更可靠的见解。
-
减少复杂性: 包含无关或冗余的变量会不必要地使模型复杂化并增加计算要求。高质量的输入使网络更加高效。
感谢阅读。很高兴在LinkedIn上与任何人连接!如果你对数据科学与医疗保健的交集感兴趣,或者有有趣的挑战分享,请留言或私信。
查看我其他的一些文章:
设计基础神经网络的 7 个步骤
在 Python 中使用二项分布
在 Python 中探索二项分布:了解不同方法下的硬币投掷概率计算
·
关注 发布于 Towards Data Science ·9 分钟阅读·2023 年 5 月 3 日
--
照片由 Joshua Hoehne 提供,来源于 Unsplash
在今天的文章中,我将向你展示如何在 Python 代码中应用统计学概念,如概率。作为示例,我将使用高盛的统计面试问题,并展示在 Python 中计算二项分布的不同方法。
我还制作了一个关于同一主题(和相同问题)的免费视频教程,你可以随意使用它来简化你的工作。
在这篇文章中,我们将通过使用不同的 Python 方法来检验和解决高盛面试中提出的概率问题。
为什么概率和统计在数据科学中很重要?
图片由作者提供
概率和统计是数据科学中广泛使用的两个数学分支。
概率是研究随机事件及其发生可能性的学科。它用于建模许多现实世界中的不确定性,并且是统计推断的基础。
数据收集、分析和评估的研究称为统计学。它包括使用数学方法总结和得出数据结论。
概率和统计在数据科学中用于分析数据、进行预测和提供决策支持。它们对于理解复杂系统的行为和识别数据中的模式与趋势至关重要。
概率面试问题:硬币抛掷预测
21 枚公平硬币
有 21 枚公平硬币。每一枚都被抛掷。得到偶数个正面的概率是多少?
高盛是全球领先的投资银行及金融服务提供商,成立于 1869 年。该公司向企业、政府和个人提供各种金融服务,是世界上最大的投资银行之一,总部位于纽约市。
这个面试问题要求你确定在抛掷 21 枚公平硬币时,得到偶数个正面的概率。
一个公平的硬币有 50/50 的机会落在正面或反面。
当抛掷多个硬币时,可能的结果会迅速变得非常庞大。
要解决这个问题,我们需要使用二项分布的概念。让我们看看它是什么。
Python 中的二项分布
图片由作者提供
二项分布是一种概率分布,可以用来描述一系列事件中成功或失败的次数,这些事件必须相互独立。
当只有两种可能的结果时,比如正面或反面,并且每次试验的成功概率相同,便会使用它。
试验必须满足两个条件:
-
它们必须只有两个可能的结果(正面或反面/成功或失败),
-
每次试验的成功概率必须相同。
当抛硬币时,成功可以定义为得到正面,失败可以定义为得到反面。
要找出抛掷 21 枚硬币时得到偶数个正面的概率,我们需要计算得到 0、2、4、6、8、10、12、14、16、18 或 20 个正面的概率。
我们将通过使用二项分布来解决这个问题:
这意味着:
-
P(X = k) — 在n次独立试验中获得k次成功结果的概率。
-
n — 试验次数。(在本例中为 21。)
-
k — 成功的次数。(在本例中为正面。)
-
p — 试验成功的概率。(在本例中为 0.5)
-
$\binom{n}{k} $ — 从 n 次独立试验中选择 k 次成功的方法数量。
手动计算:硬币翻转预测
我们可以插入值并计算每个可能的偶数个头的概率。
让我们首先手动进行以下值的计算:
n = 21(试验次数)
p = 0.5(成功的概率)
q = 1 — p = 0.5(失败的概率)
这是计算结果
这些计算的含义是什么?例如,让我们看看 P(18)。
公式意味着在 21 次抛硬币时得到 18 个正面的概率,即 0.06341934204101562%。这略高于 10000 次中的 6 次。
在面试问题的背景下,当您抛掷 21 枚硬币 10000 次时,您可能会六次得到 18 个正面。这是一个罕见的概率,正如您所见。
在这里,您可以看到基于我们的计算,抛掷 21 枚硬币时偶数个正面的概率分布。
作者提供的图片
要回答面试问题,我们需要将所有概率相加:
将我们计算的所有单个概率插入,您会得到:
现在,让我们使用 Python 中的math 库实现此计算
Python 计算:硬币翻转预测
使用 math 库
我们可以使用 Python 中的 math 库实现在 21 次抛硬币中获得偶数个头的二项分布。
为此,我们可以使用for loop来迭代 21 次抛硬币中可获得的每个偶数个头。
对于范围内的每个 i 值,从 0 到 22(不包括),步长为 2,我们可以使用二项分布公式计算得到 i 个头的对应概率。
在循环内,我们可以首先使用math.comb()函数计算出 21 次翻转中获得 i 个头的方式数量。
然后,我们可以计算在 21 次抛硬币中获得恰好 i 个头的概率,即(1/2)21**。最后,我们可以将这两个值相乘,以获取在 21 次抛硬币中获得恰好 i 个头的概率,并将其存储在名为 probabilities 的列表中。
我们可以使用另一个for loop来打印概率列表中每个偶数个头的概率。
循环结束后,我们可以使用sum()函数计算存储在列表 probabilities 中的所有概率的总和,并打印在 21 次翻转中获得偶数个头的总概率。
这里是代码。
import math
# Create an empty list to store the probabilities of getting an even number of heads
probabilities = []
# Loop over all even numbers of heads that can be obtained in 21 flips
for i in range(0, 22, 2):
# Calculate the number of ways to choose i heads out of 21 flips
m = math.comb(21, i)
# Calculate the probability of getting i heads in 21 flips
n = (1/2 ** 21)
prob = m * n
# Append the probability to the list
probabilities.append(prob)
# Loop over the probabilities and print the probability of getting each even number of heads
for i, prob in enumerate(probabilities):
print("The probability of getting {} heads in 21 flips is {}.".format(2*i, prob))
# Calculate the total probability
total_prob = sum(probabilities)
print("Total probability of getting even number of heads is {}".format(total_prob))
现在,让我们看一下输出。
在早期计算中,我们使用二项分布公式手动计算了 21 次抛硬币得到偶数个正面的概率。
这涉及到单独计算 0、2、4、6、8、10、12、14、16、18 和 20 个正面的概率,这可能非常耗时。
另一方面,我们刚刚使用数学库编写的代码允许我们仅用几行代码计算所有这些概率。
这种方法节省了我们大量的时间和精力,特别是当我们需要计算更多结果的概率或需要对更多试验进行重复计算时。
但我们还有一种实现方式,它甚至更快。
使用 SciPy 库
这一次,我们将使用预构建的函数来计算概率。
在以下代码中,我们首先从scipy.stats 模块中导入binom()函数。
然后,我们使用binom()函数定义具有n 次试验和成功概率p 的二项分布。
我们使用pmf()方法计算得到偶数个正面(0、2、4、…、20)的概率。
pmf()方法返回给定范围内每个值的分布概率质量函数(PMF)。
然后,我们使用另一个for loop来打印列表 even_probs 中每个偶数个正面的概率。
最后,我们通过对列表even_probs中的概率进行求和来计算 21 次抛硬币得到偶数个正面的总概率。
使用 scipy 库计算二项分布提供了一种更高效、更方便的计算方法。它允许我们用几行代码定义和计算分布。
以下是代码。
from scipy.stats import binom
n = 21
p = 0.5
# Define the binomial distribution with n trials and probability p of success
binom_dist = binom(n, p)
# Calculate the probability of getting an even number of heads (0, 2, 4, ..., 20)
even_probs = binom_dist.pmf(range(0, 22, 2))
# Print the probability of getting each even number of heads
for i, prob in enumerate(even_probs):
print("The probability of getting {} heads in 21 flips is {}".format(2*i,prob))
# Calculate the total
total_prob = sum(even_probs)
print("Total probability of getting even number of heads is {}".format(total_prob))
以下是输出结果。
如果你比较两种 Python 方法得到的结果,你可能会注意到它们并不完全相同。
为什么结果(略微)不同?
SciPy 和数学计算之间的差异是由于浮点精度误差造成的。
在计算机编程中,浮点数以有限的位数表示,这可能会导致计算时精度丢失。
这就是为什么当比较两种方法的结果时,我们得到的偶数个正面的概率略有不同。
两种方法都提供了对真实概率的良好近似,而这两者之间的差异对于大多数实际应用来说是微不足道的。
如果你处理较长的脚本并计划进行更长时间的计算,同时计算能力有限,那么使用 SciPy 会更高效、更快速。
然而,如果你有足够的资源,我们建议使用第一种计算方法,即我们使用数学方法的那种。这是如果你想确保结果的准确性,即使这些细微的差异影响不大。
如果你想最小化浮点错误的影响,可以将计算结果四舍五入到所需的精度。
这里是代码。
total_prob = round(total_prob, 14)
print("Total probability of getting even number of heads, after correcting precision error is {}".format(total_prob))
这里是结果。
结论
在这篇文章中,我们探讨了一个来自高盛的统计问题,并通过使用 Python 进行了回答。
我们探讨了不同的方法,并详细解释了每一步。同时,我们也使用了预构建的函数进行了实现。
总的来说,本文是二项分布如何通过手动计算或使用 Python 解决的一个优秀示例。
如果你想掌握随机变量和概率分布,并在下一个数据科学面试中脱颖而出,请查看这篇文章“随机变量和概率分布”。
最初发布于 https://www.stratascratch.com。
将 ChatGPT 作为创意写作伙伴——第一部分:散文
原文:
towardsdatascience.com/using-chatgpt-as-a-creative-writing-partner-part-1-prose-dc9a9994d41f
OpenAI 最新的语言模型如何帮助创作诗歌、小说和剧本
·发表于 Towards Data Science ·15 分钟阅读·2023 年 1 月 4 日
--
“一个女人在笔记本电脑前打字,旁边有一个有帮助的玩具机器人,” 图片使用 AI 图像生成程序 Midjourney 创建,由作者编辑
如果你一直在阅读我在 Medium 上的文章,你会知道我自 2020 年 8 月以来一直在写关于 AI 的创意用途。我经常写关于生成数字艺术的内容,但偶尔也会写关于将 AI 用于其他创意目的的文章,比如写散文和作曲。当我听说 OpenAI 发布了一个叫做 ChatGPT 的新语言模型时,我立刻投入其中,并测试了它在创意写作中的表现。
概述
这是关于将 ChatGPT 作为写作伙伴的三部分系列的第一部分。该系列将涵盖写作散文、作曲 和创作图画书。
在这篇文章中,我将回顾一些关于 ChatGPT 的背景信息,然后展示我在生成各种创意项目文本方面的实验结果:
-
书写俳句
-
为新小说创作情节总结
-
创作新的电视节目和电影剧本
我将总结一下使用该模型的一般观察结果,并提出一些未来探索的可能步骤。
ChatGPT
ChatGPT 是 OpenAI 最新的文本生成模型。他们将系统作为“免费研究预览”提供。你可以使用电子邮件账户登录,或使用现有的 Google 或 Microsoft 账户登录。当你登录时,你会看到带有示例、功能和限制的主页。
ChatGPT 用户界面,图片由作者提供
为了开始一个话题,我问了这样一个问题。
ChatGPT 回答问题,图片由作者提供
注意,我在 2020 年的AI 8 Ball项目中使用了相同的问题。你可以看到上面 ChatGPT 对 Mattel 的 Magic 8-Ball 很熟悉,并用 20 个预设回答之一开始回答。但 ChatGPT 进一步提供了一些背景信息,这很好。当我点击重新尝试按钮时,它生成了类似的文本。但我使用了这个酷炫的新功能,并向系统提出了一个后续问题。
ChatGPT 回答后续问题,图片由作者提供
注意 ChatGPT 如何记住了第一个问题的上下文,所以它知道“它”的含义并做出适当回答。它的回答就像 Magic 8-Ball 一样,并提供了一些有用的背景信息。
OpenAI 增加了一个很棒的功能,即保留之前聊天会话的历史记录。这使我能够在需要时从上次中断的地方继续聊天。
ChatGPT 的起源
ChatGPT 基于 OpenAI 的 GPT-3 大语言模型 [2]。这个模型非常庞大,拥有 1750 亿个人工神经元,并在大约 5000 亿个网络文本上进行了训练。
OpenAI 进一步训练 GPT-3,创建了一个叫做 InstructGPT 的模型,旨在在生成文本时遵循用户的意图 [3]。
从一组[人工]-编写的提示和通过 OpenAI API 提交的提示开始,我们收集了一个标签者展示期望模型行为的数据集,用于通过监督学习对 GPT-3 进行微调。然后我们收集了一个模型输出排名的数据集,用于通过人类反馈的强化学习进一步微调这个监督模型。我们称这些结果模型为 InstructGPT。 - Long Ouyang 等,来自 OpenAI
ChatGPT 是 OpenAI 的 InstructGPT 兄弟模型,后者经过训练与用户进行互动 [1]。
我们训练了一个叫做 ChatGPT 的模型,它以对话的方式进行互动。这种对话格式使得 ChatGPT 能够回答后续问题,承认错误,挑战不正确的前提,并拒绝不当请求。 - OpenAI
语言模型中的社会偏见
注意,与大多数大型语言模型一样,InstructGPT 和 ChatGPT 都受到原始训练数据中隐性社会偏见和毒性的影响。为了解决这个问题,OpenAI 积极致力于使模型“对齐”,以便有用、诚实和无害 [3]。
我们通过训练语言模型使其按照用户的意图行事来取得进展。这包括明确的意图,如遵循指示,以及隐含的意图,如保持真实,避免偏见、毒性或其他有害行为。 — Long Ouyang 等,来自 OpenAI
然而,作者表示,新模型在毒性方面相比 GPT-3 仅有小幅改进,但在偏见方面没有改进。
OpenAI 如何微调 ChatGPT
这是 OpenAI 用于训练其最新系统的步骤。PPO 代表近端策略优化[4],是一种用于改善输出的新型强化学习。
图示显示了监督微调、奖励模型训练和通过近端策略优化的强化学习,来源 OpenAI
“标注员”是实际的人类,他们在第 1 步中创建了理想的输出,并在第 2 步中对生成的文本样本进行了评分,这些样本用于训练奖励模型(RM)。在第 3 步中,使用 RM 来预测标注员更喜欢哪个输出,并利用结果通过 PPO 算法优化了 GPT-3 模型。
书写俳句
“金黄的叶子飘落,秋风吹过小镇,大自然的挂毯,” Midjourney 图片由作者提供
我与 ChatGPT 的第一次实验是创作俳句,这是一种起源于日本的诗歌形式。正如我在之前的帖子中讨论的那样,语言模型似乎在生成具有固定韵律的散文方面存在困难,例如俳句通常使用的 5–7–5 音节结构。然而,我注意到 ChatGPT 有了明显的改进。你可以在下面的互动中看到这一点,回答部分用斜体表示。
写一首关于秋天的俳句。
叶子柔软地飘落
秋风在小镇上拂过
大自然的挂毯。
这很好,但第一行用五个音节。第二和第三行可以。
金黄的叶子飘落
秋风吹过小镇
大自然的挂毯
这首俳句相当不错!叶子飘落和秋风吹过小镇的意象捕捉了季节变换的感觉。但请注意,第一次尝试有 6–7–5 的音节结构。我随后提示它改进了第一行,结果第二次尝试创建了一个 5–7–5 的结构,并保留了思路的精髓。
这里有几个额外的例子。
写一首关于笑声的俳句。
笑声如铃铛
回荡在大地上
带给所有人快乐。
写一首关于 Python 编程语言的俳句。
Python,如蛇般
优雅地穿梭于代码中
简单中的力量。
这很好,但最后一行用五个音节。第一和第二行可以。
Python,如蛇般
优雅地穿梭于代码中
简单中的力量。
关于笑声的俳句相当不错,其中关于铃铛的比喻也很好。它达到了俳句所需的 5–7–5 音节数。关于 Python 的俳句也不错,继续用蛇的比喻。最初它是 5–7–7 的结构,但当我提示时,ChatGPT 生成了一个正确音节数的俳句,并保持了原意。
接下来,我将分享使用系统为小说创建新标题和情节的结果。
“桌上的一堆书,上面有一支笔和一叠纸,” Midjourney 图像由作者提供
为小说创作标题和情节
类似于我的 PlotJam 项目,我使用了一个经过微调的 GPT-2 模型,我要求 ChatGPT 为特定类型的小说生成新的标题和情节概要。以下是结果。
为一本新的历史小说创作标题和情节概要。
标题:“小大角的最后一战”
剧情概要:1876 年夏季,由传奇首领坐牛和疯狂马带领的拉科塔苏族和夏延族战士,在小大角战役中对抗正在推进的美国军队,做出了最后的抵抗。在这场生死决战中,年轻的战士白水牛必须在对部落的忠诚和对一位名叫萨拉的白人女子的爱之间做出选择。萨拉作为伤兵的护士来到战场。面对苏族国家的命运悬于一线,白水牛必须在战争、荣誉和爱情的复杂危险世界中找到自己的道路。
背景设定还不错,但情节似乎有点老套。尝试用新的白水牛角色冲突再试一次。
标题:“小大角的最后一战”
剧情概要:[…] 当双方在一场残酷血腥的战斗中对峙时,年轻的战士白水牛必须面对自己内心的挣扎和怀疑。白水牛在对部落的忠诚与对和平的渴望之间撕扯,他必须找到调和信仰与行动的方法,以捍卫他的人民及其生活方式。在战斗的混乱与暴力中,白水牛必须做出艰难的选择,这些选择将塑造他部落的未来和他自己的命运。
好的,不错。我喜欢这个故事似乎是从一位美洲原住民的视角讲述的。但正如你所见,我不得不将系统从陈词滥调的“罗密欧与朱丽叶”情节中引导出来以进行第二次迭代。我删除了第二次情节概要中的重复部分以提高可读性。正如“行业内”的说法,ChatGPT 适合做笔记。
这是另一个例子,经过三轮互动。
为一本新的科幻小说创作标题和情节概要。
标题:“最后的哨兵”
剧情概要:在遥远的未来,人类已经扩展到整个银河系,并在其他行星上建立了许多殖民地。然而,一种名为 Xarga 的敌对外星种族在地平线上威胁着人类的未来。作为最后一名哨兵,名叫杰克的孤独战士必须踏上任务,发现 Xarga 的来源,并找到阻止他们在消灭所有人类之前的方法。在此过程中,杰克必须克服巨大的困难,做出艰难的选择,以拯救人类。
好的,但给杰克起一个更具未来感的名字。
标题:“最后的守卫”
情节概要:[…] 作为最后一位守卫,一名名叫杰克逊的孤独士兵必须踏上使命,寻找 Xarga 的根源,并找到阻止他们的方法,以防他们消灭人类。在此过程中,杰克逊必须克服巨大的困难,做出艰难的选择,以拯救人类。
好的,简要解释一下杰克逊如何阻止 Xarga。
作为最后一位守卫,杰克逊是人类在与 Xarga 的战斗中的唯一希望。在穿越银河系并经历了许多挑战之后,杰克逊终于找到了 Xarga 威胁的根源。在那里,他发现 Xarga 被一个中央智能控制。杰克逊利用他的先进技术和战斗技能击败了这个中央智能,从而解放了 Xarga,结束了对人类的威胁。
这也是一个好的开始。然而,概要暗示了大量的动作和悬念,但缺少一些重要细节,比如书的真实内容、主要角色是谁以及驱动主角的动机。顺便提一下,我觉得有趣的是 ChatGPT 认为通过添加字母“X”可以使名字显得更具未来感。我也喜欢“单一中央智能控制 Xarga”的设定。ChatGPT 似乎遵循了那个老掉牙的建议:“写你所知道的。” 🙂
“电影拍摄现场,有一张桌子上的剧本和咖啡,” Midjourney 图片由作者提供,基于Jakob Owens在Unsplash上的照片
编写剧本
我发现 ChatGPT 可以成为编写剧本的好伙伴。类似于我在 GreenLit 项目中的做法,我尝试了新模型为一个特定类型的节目创建标题、情节概要和剧本第一页。以下是初步的互动。
创建五个关于银行抢劫的新流媒体系列的标题和情节概要。
“抢劫计划” —— 当一群技艺高超的窃贼合作进行完美的银行抢劫时,他们很快意识到真正的挑战是如何活着带着赃物逃脱。当警察逐渐逼近,赌注不断增加时,团队必须依靠他们的专业技能和彼此间的信任来智胜当局,并成功脱身。
“内外勾结” —— 当一群心怀不满的前员工策划抢劫他们曾经工作的银行时,他们求助于一名经验丰富的窃贼来执行这一计划。随着他们深入调查银行的运作,他们揭露了一张腐败和贪婪的网,这使他们陷入了法律和前雇主的双重危险中。
“银行抢劫”——当一群朋友决定抢劫当地银行以偿还债务并重新开始时,他们很快意识到,他们的业余劫案技能无法与现有的安全措施相匹敌。在他们努力智取银行高科技防御的同时,他们还必须应对一系列的双重背叛和个人背叛,这些背叛威胁到他们的计划。
“完美分数”——当一群聪明的学生决定进行完美的银行劫案以资助他们的教育时,他们很快发现,他们精心策划的计划远比他们预期的要困难得多。在他们应对劫案和逃避当局的挑战时,他们还必须面对自己内心的恶魔和不安全感,以便成功。
“大笔划”——当一群经验丰富的窃贼聚在一起策划一场终极劫案时,他们很快意识到,成功完成这项工作需要的不仅仅是他们的专业技能。在他们努力筹集资源和人员以完成劫案的同时,他们还必须应对一系列个人和职业上的挫折,这些挫折威胁到他们的计划。
这些都不错,但大多数比较普通。它们似乎都已经被用过很多次了。就好像 ChatGPT 想要保守一点,不敢在故事创意上冒险。然而,我喜欢关于学生/窃贼那一组的,因为它有一个有趣的角度。但它需要一个独特的标题。
“完美分数”看起来不错,但建议一些以前没有使用过的标题。侧重于学生的角度。请只给出标题。
这里有五个关于银行劫案的学生的流媒体系列的新潜在标题:
“学生大劫案”
“犯罪班级”
“贪婪的成绩”
“学术劫案”
“校园大盗”
这些都还不错,但有些乏味。而且不出所料,它们都以“the”这个冠词开头。注意,如果我没有仅仅要求标题,它们会生成新的情节摘要。
总的来说,我最喜欢“贪婪的成绩”,主要是因为它的押头韵。然后在决定去掉标题中的“the”后,我要求 ChatGPT 写下剧本的第一页。
这些还不错。我们就用没有“the”的“贪婪的成绩”吧。请写下剧本的第一页。
当然!这是“贪婪的成绩”的剧本第一页:
渐入:
内景。大学图书馆——白天
我们看到一个学生的特写,正专心致志地学习。他们被一堆书籍和笔记包围着。
切换到:
内景。大学餐厅——白天
我们看到四个学生,杰西卡、迈克、艾米和尼克,坐在一起的桌子旁。他们都看起来很紧张和疲惫。
杰西卡:(沮丧)我简直不敢相信我们下周要期中考试。如果我在微积分中得不到至少 B+,我将永远负担不起学费。
迈克:(讽刺地)哦,是的,因为这就是我们与经济稳定之间唯一的障碍。
艾米:(严肃地)实际上,迈克,重修课程要花费更多。如果我们不通过,下学期还得再付一次钱。
尼克:(坚决地)嗯,我们不能让这种事发生。我们需要找到解决方案。
杰西卡:(讽刺地)是的,就像抢银行一样。
这群人沉默了一会儿,然后爆发出笑声。
画面渐黑。
好的,不错。对话听起来自然,虽然有点急于推动情节。我喜欢系统为角色台词添加的情感提示,这给剧本带来了些许色彩。
然而,角色有点平淡,因为根据他们的名字来看,故事中似乎只有一群白人孩子。这可能是由于语言模型中的社会偏见,就像我上面提到的。但当我询问时,系统知道该怎么做。
这很好!但试试类似的剧本,角色组更加多样化。
当然!这是《贪婪的等级》剧本第一页的修订版,展现了更多种族多样性的角色:
画面渐显:
内景。大学图书馆 — 白天
我们看到一个学生的特写,专注地学习。他们被一堆书籍和笔记包围。
切换到:
内景。大学餐厅 — 白天
我们看到四个学生,塔拉、哈维尔、艾莎和肯,坐在一起的桌子旁。他们都看起来很有压力和疲惫。
塔拉:(沮丧地)我简直不敢相信我们下周就有期中考试。如果我在微积分课上得不到至少 B+,我永远买不起学费。
哈维尔:(讽刺地)哦,是的,因为这就是我们与经济稳定之间唯一的障碍。
艾莎:(严肃地)实际上,哈维尔,重修课程要花费更多。如果我们不通过,下学期还得再付一次钱。
肯:(坚决地)嗯,我们不能让这种事发生。我们需要找到解决方案。
塔拉:(讽刺地)是的,就像抢银行一样。
这群人沉默了一会儿,然后爆发出笑声。
画面渐黑。
果然,系统将名字更改为具有更广泛的文化代表性。然后我询问它对他们的姓氏有什么想法。
谢谢。对于演员表,他们的全名是什么?
当然!以下是修订版剧本中角色的全名:
塔拉 — Tara Patel
哈维尔 — Javier Rodriguez
艾莎 — Aisha Khan
肯 — Ken Tanaka
不错!我将把这个发送给 Netflix 和 Amazon 工作室,看看能否引发竞标战来资助这个项目。😀
讨论与下一步
正如你在上面的互动中看到的,ChatGPT 可能无法独立创作出优质的长篇散文,但作为互动写作伙伴效果很好。系统相当有知识,是一个有帮助的合作者。当被提示时,它会创建许多你可能在寻找的例子,尽管它的很多想法平淡而普通。挖掘出有趣的想法需要一些工作,但这是可以做到的。
过去,当我使用像 GPT-3、GPT-J、BLOOM 等大型语言模型进行创意写作时,我发现我必须生成许多文本版本,然后“挑选”出最好的版本。即便如此,散文的开头、中间和结尾也不明确。系统只是胡乱写作,我必须决定何时停止。然而,ChatGPT 似乎对提示反应良好,并能产生编排得很好的文本。我发现我很少使用“旋转”再生按钮。相反,我提供了一些反馈,它几乎总是能做到我想要的改进文本。
正如你在我的编剧示例中看到的,我使用了自上而下的写作过程,从标题和情节开始,然后写对话和舞台指示。可能有助于让 ChatGPT 编写和完善故事主要角色的角色草图,然后将故事情节分解成幕和场景。这将有助于系统在整个剧本过程中创建一个更具连贯性的叙事。
正如我在本文开头提到的,我接下来的帖子将讨论使用 ChatGPT 创作音乐,借助 Band in a Box 的帮助,以及编写儿童图画书,获得 Midjourney 的支持。
关于 ChatGPT 作为创意写作伙伴的更多内容
这是本系列的接下来两篇文章。
## 使用 ChatGPT 作为创意写作伙伴 — 第二部分:音乐
最新的 OpenAI 语言模型如何帮助你为新歌曲编排和弦,配乐由 Band-in-a-Box 提供
towardsdatascience.com ## 使用 ChatGPT 作为创意写作伙伴 — 第三部分:图画书
最新的 OpenAI 语言模型如何帮助你编写儿童书籍,并通过 Midjourney 创建插图
towardsdatascience.com
致谢
我想感谢詹妮弗·林和奥利弗·斯特里姆佩尔在这个项目中的帮助。
参考文献
[1] J. Schulman 等人,ChatGPT:优化对话的语言模型(2020)
[2] T. Brown 等人,语言模型是少样本学习者(2020)
[3] L. Ouyang 等人,训练语言模型以根据人类反馈执行指令(2020)
[4] J. Schulman 等人,近端策略优化算法(2017)
使用 ChatGPT 作为创意写作伙伴——第二部分:音乐
原文:
towardsdatascience.com/using-chatgpt-as-a-creative-writing-partner-part-2-music-d2fd7501c268
OpenAI 最新的语言模型如何帮助你为新歌曲创作和弦,音乐由 Band-in-a-Box 提供
·发布于 Towards Data Science ·阅读时间 16 分钟·2023 年 1 月 17 日
--
“两个钢琴家在演奏,一个人和一个友好的机器人,数字艺术,” 图像使用 AI 图像生成程序 Midjourney 创建,并由作者编辑。
在我的上一篇文章中,我讨论了如何将 OpenAI 的 ChatGPT 大型语言模型[1]作为各种类型散文的写作伙伴。在本文中,我将展示如何利用该系统通过生成和弦来帮助创作音乐。
在简要概述 ChatGPT 之后,我将展示我在以下风格中使用新系统创作音乐的实验结果:爵士、乡村摇滚和雷鬼。我将通过提供一些对使用该模型创作音乐的一般观察和未来探索的下一步,来结束本文。请注意,本系列的第三部分将介绍如何利用该系统借助 Midjourney 创建图画书。
ChatGPT
ChatGPT 是 OpenAI 最新的语言模型,旨在通过聊天用户界面与人们互动。GPT 代表生成预训练变换器,其中变换器是一种 AI 模型。你可以在本系列的第一篇文章中阅读有关该系统的完整背景。
OpenAI 在这里提供了免费的 ChatGPT,并且这里有一个专业版的等待名单。一旦你注册,你可以通过输入提示并按下回车键来向系统提问或要求它生成某些内容。例如,我让它为一首蓝调歌曲创建了和弦进行。
ChatGPT 生成的蓝调歌曲和弦,图片由作者提供
系统回复了一系列 12 个和弦,这是蓝调歌曲中的典型序列。接下来,我想看看 ChatGPT 是否能生成与这些和弦相匹配的旋律。我提示系统用ABC 音乐符号写一首曲子,ABC 符号用字母表示音符,用文本表示其他元数据。这是它创作的内容。
ChatGPT 生成的蓝调歌曲旋律,图片由作者提供
好的,这看起来有很多注释。但旋律怎么样呢?为了播放这首歌,我使用了一个叫做EasyABC的程序,这是一个开源 ABC 符号编辑器和播放器。我将音符粘贴到 EasyABC 中,并将歌曲保存为 MIDI 文件。这是旋律的声音。(公平警告,这并不太好。)
ChatGPT 生成的蓝调旋律,视频由作者提供
哎呀!尽管旋律明显不好,但和弦对于一首蓝调歌曲来说完全没问题。我还进行了几次实验,使用 ChatGPT 系统生成和弦和旋律。在几乎所有情况下,生成的和弦至少是不错的,有时甚至很好,但生成的旋律总是很糟糕。
我怀疑 ChatGPT 在和弦方面表现出色,是因为互联网上有大量展示和弦和和弦进行讨论的文本。相反,只有少数网站展示 ABC 或其他基于文本的旋律格式。因此,它通常知道旋律是什么,并清楚理解 ABC 格式,但显然,这还不足以让系统创作出体面的旋律。
本文剩余部分讲述了我使用 ChatGPT 为不同音乐风格的歌曲编写和弦进行的实验。为了创作旋律和音乐伴奏,我使用了 PG Music 的商业软件Band-in-a-Box(BiaB)。
对于每首歌曲,我首先与 ChatGPT 互动以编写和弦。然后我运行 BiaB,输入生成的和弦,并选择像德州蓝调这样的风格。BiaB 根据和弦编写旋律和伴奏曲目,然后我将歌曲制作成视频。例如,这里是再次播放的蓝调歌曲。(给所有发烧友的一个提示:要听到低音线,请使用耳机或外部扬声器。)
蓝调和弦和旋律,由 ChatGPT 生成,视频由作者提供
啊!好多了。BiaB 使用真实音乐家的 MIDI 录音数据集作为编排歌曲的基础,使曲调更具人性化。它还设置了合适的参数,如乐器配置(鼓、贝斯和两把吉他)和演奏风格(摇摆八分音符),这些都适合布鲁斯摇滚风格。
使用 ChatGPT 作曲
根据我与 ChatGPT 的互动,它对音乐了解颇多。它掌握音乐理论,对各种风格和流派也很熟悉。以下部分展示了我如何利用系统为多种风格的曲调(如爵士乐、波尔卡和乡村摇滚)创建和弦进行。
“爵士俱乐部里的 大钢琴,” Midjourney 图片,由作者编辑
爵士音乐
这就是我如何让 ChatGPT 为一首新的爵士乐曲编写和弦的过程。我首先让它讨论爵士乐曲的总体框架,然后让它为不同的歌曲部分编写和弦进行。请注意 ChatGPT 对我的指令的回应在以下互动中是斜体的。
为一首新的爵士乐曲创建框架。
当然!这是一个爵士乐曲的基本框架:
引言:这通常是歌曲开头的一个简短部分,由单独的乐器(如萨克斯 ophone 或小号)引入旋律。其余的乐队可能会加入简单的和弦伴奏。
头部:头部是歌曲的主要主题,通常由整个乐队演奏。在这里介绍了歌曲的和弦进行。爵士乐曲常使用复杂的和弦进行,如 ii-V-I 或 iii-vi-ii-V。
独奏:在头部演奏几次后,通常会有一名或多名音乐家进行独奏,即在歌曲的和弦进行上即兴演奏。
桥段:桥段是一个对比的部分,为歌曲的主要主题提供缓解。它可能会有不同的和弦进行和旋律。
尾声:尾声是歌曲的结束部分,可能会出现头部的简化版本或引言的重现。
这只是一个基本的大纲,爵士乐曲可以有多种形式。乐器可能在歌曲的不同位置出现,结构可以根据音乐家和歌曲风格的需要进行调整。
这似乎是对爵士乐曲各部分的很好的解释。注意它如何使用罗马数字分析(例如,“ii-V-I”)来描述和弦序列,这也是作曲家们经常使用的。我随后让它为一首新歌编写和弦进行,继续了聊天线程。
不同部分的和弦进行会是什么?
爵士乐歌曲的各种部分的和弦进行将取决于具体的歌曲和演奏的爵士风格。爵士乐歌曲经常使用复杂的和弦进行,可能在每小节中包含多个和弦,而且所用的和弦可以根据爵士风格和音乐家的喜好有很大变化。
下面是一些常见的和弦进行示例,这些和弦进行可能会在爵士乐歌曲中使用:
主歌:ii-V-I:这是爵士乐中非常常见的和弦进行,通常用于创建张力和解决。它由一个小和弦(ii),接着是一个主和弦(V),然后是一个大和弦(I)。例如,在 C 大调中,ii-V-I 的和弦进行是 Dm7-G7-Cmaj7。
桥段:iii-vi-ii-V:这个和弦进行使用了一系列小和弦来构建张力,然后用一个主和弦解决。在 C 大调中,iii-vi-ii-V 的和弦进行是 Em7-Am7-Dm7-G7。
这些只是几个示例,还有许多其他和弦进行可以在爵士乐歌曲中使用。重要的是选择支持旋律并为歌曲创造所需情绪和氛围的和弦进行。
这次它给了我罗马数字格式的和弦进行和 C 大调中的具体和弦。我对此感到有些兴奋,因为我也将之前AI-Tunes项目的所有训练数据转换成了 C 大调以简化处理。
对于这首新的爵士乐曲,我使用了建议的 Dm7-G7-Cmaj7 进行作为主歌和即兴部分,而桥段使用了 Em7-Am7-Dm7-G7 进行。对于引子和尾声部分,我只是保持在 Cmaj7 上四小节。剩下的就是给歌曲起个名字,我是从 ChatGPT 那里得到的。
这首曲子的好名字是什么呢?
“夜晚在秘密酒吧”
好的。为了听到音乐,我使用了上面描述的技巧。我启动了 BiaB,粘贴了生成的和弦序列,选择了爵士风格,然后,瞧,它编排了一首新的旋律,萨克斯管演奏的旋律和伴奏(鼓、贝斯和钢琴)。这就是它。
“夜晚在秘密酒吧,” 和弦由 ChatGPT 提供,音乐由 BiaB 创作,视频由作者制作
不错!这有点平淡,也许适合在电梯里伴奏,但对于计算机生成的音乐来说,它似乎表现得相当不错。我知道这只是轻微的赞扬,BiaB 完成了大部分工作。但 ChatGPT 的和弦进行看起来是稳固的。请注意,我故意把这首曲子做得很短(1:10),以便于演示。实际的爵士乐曲会多次回到主歌,并有更长的即兴部分和不同的乐器。接下来是一些波尔卡音乐。
“一台手风琴与空旷舞台上的其他乐器,” Midjourney 图片,作者编辑
波尔卡音乐
对于我的下一个实验,我直接跳入并请求 ChatGPT 为一首波尔卡曲子写一些和弦。
为波尔卡歌曲的部分写出和弦进行。
当然!这里是一些你可以在波尔卡歌曲中使用的和弦进行:
最基本的波尔卡和弦进行是简单的 I-IV-V 进行,以 C 大调为例:
C — F — G
这里是对该进行的一个变体,加入了 ii 和弦(D 小调):
C — Dm — F — G
另一个常见的波尔卡和弦进行是以 iii 和弦(E 小调)开始,然后转到 IV 和弦(F 大调)和 I 和弦(C 大调):
Em — F — C
你也可以尝试一个更复杂的和弦进行,涉及 ii 和弦和 vi 和弦(A 小调):
C — Dm — A7 — F — G
记住,这些只是一些示例,你可以在波尔卡歌曲中发挥创造力。最重要的是保持音乐的弹性和能量,并保持节奏的推动力和活力。
再次,ChatGPT 提供了 C 大调下的歌曲部分和弦。我发现如果我在最初的请求中没有指定调性,它总是选择 C。顺便说一句,系统确实知道如何转调歌曲;你只需要告诉它(例如,“将这首歌转调到 B 平调”)。
我喜欢它为歌曲建议的和弦模式以及一些有用的评论。不过,最终的和弦进行有一个小错误(双关语 😉),评论中提到了 A 小调,但序列中显示的是 A7,这是一个属和弦。事实证明 A7 实际上在那里听起来很好。
我然后要求系统为新曲目的各个部分写出具体的和弦进行。
好的,请写出一首波尔卡歌曲的引子、主歌和桥段的和弦。
当然!这里是一个使用 C 大调的波尔卡歌曲引子、主歌和桥段的和弦进行示例:
引子:C — F — G (x2)
主歌:C — Dm — A7 — F — G (x2)
桥段:Em — A7 — Dm — G
再次强调,这些只是示例,你可以在波尔卡歌曲中发挥创造力,尝试不同的和弦进行。你也可以在桥段中加入额外的和弦或调到不同的调性,以为你的歌曲增添一些变化。只要确保节奏充满活力和能量,以捕捉波尔卡歌曲的感觉。
我喜欢这些结果!它选择了初步建议中的示例和弦进行,并指定了重复序列的次数,使用了“x2”指示符。注意它在主歌中保持了 A7 和弦,并且作为额外奖励,也将该和弦加入了桥段模式中。我不得不对引子和主歌中的和弦时机做出一些创造性的决策,所以我让它们这样进行,并在最后加了一个快速的 F-G 转场。
引子:C-C-F-G (x2)
主歌:C-C-Dm-Dm-A7-A7-F-G (x2)
我保持了原本写的桥段。然后我要求给歌曲起个名字。
这首歌叫什么名字好?
这完全取决于你!你可以考虑给你的歌曲起一个有趣、富有玩味的名字,以反映波尔卡风格的轻松性质。一些建议可能是“波尔卡派对”、“舞动的鞋子”、“快乐的波尔卡”或“波尔卡华尔兹”。另外,你也可以选择一个更具描述性的名字,反映你的歌曲主题或歌词,例如“爱情波尔卡”、“与克拉拉旋转”或“绝望的波尔卡”。最终,你的歌曲名字是个人偏好的问题,所以选择一个你认为能很好地代表你的歌曲并捕捉其精神的名字。
哇。与之前的例子不同,这次给了我很多名字可以选择。ChatGPT 在这方面很古怪。有时它给出简单的答案,但有时它会很健谈。请注意,如果需要,我本可以要求更多建议。在七个推荐中,我最喜欢“与克拉拉旋转”,因为它最具独特性。
然后我将和弦粘贴到 BiaB 中,包括重复的模式,选择了波尔卡风格,它渲染了这首歌。这就是“与克拉拉旋转”。
“与克拉拉旋转,” 和弦由 ChatGPT 提供,音乐由 BiaB 创作,视频由作者制作
很好!酷且有吸引力!我真的很喜欢前奏的 riff,全是 BiaB 制作的。乐器编排包括贝斯、鼓、吉他和手风琴,小号主导了演奏。哈哈!克拉拉?但我也喜欢 ChatGPT 的和弦进行,特别是在诗句和桥段中击中 A7 和弦时。这增加了一些色彩的紧张感,这似乎适合波尔卡风格。
在第 9 到 12 小节中,单簧管和手风琴之间有一点音色冲突,但一旦小号升高一个八度,就没问题。我可以在那段时间里将手风琴的音量调低,但我想看看它直接出来的声音(可以说是 😏)。接下来是一些乡村摇滚。
“空舞台上的乡村摇滚乐器,” Midjourney 图片,作者编辑
乡村摇滚音乐
我最后的实验是为乡村摇滚歌曲生成和弦。乡村摇滚是一种结合了两种音乐风格元素的混合体。我不确定 The Byrds 是否开创了这个子流派,但他们的主唱肯定认为他们做到了。
当 The Byrds 开始演奏乡村摇滚时,我们没有想到会有这样的东西。我们只是想尊重音乐。我们开始听乡村电台。我们去了 Nudie’s,买了牛仔衣服。——The Byrds 的罗杰·麦奎因
无论如何,ChatGPT 似乎知道什么是乡村摇滚,因为它在我询问时提供了一些不错的和弦进行和评论。
为乡村摇滚歌曲编写和弦进行,包括一段歌词、合唱和桥段。
这是一个带有诗句、合唱和桥段的乡村摇滚歌曲的和弦进行示例:
诗句:E — B — A — B(I — V — IV — V)
合唱:E — A — B — E(I — IV — V — I)
桥段 G#m — F# — B — E(iii — ii — V — I)
这个和弦进行提供了强烈的调性和动感,可以作为乡村摇滚歌曲的基础。主歌使用了 I — V — IV — V 进行,这在乡村摇滚中很常见,提供了一种紧张和释放的感觉。副歌则用 I — IV — V — I 进行解决了这种紧张感,这是一个强劲、有力的进行,常用于摇滚和流行音乐中。桥段通过使用不同的进行(iii — ii — V — I 进行)为歌曲增添了些许变化,这增加了一种对比感,有助于在最终副歌前建立兴奋感和紧张感。
当然,这只是一个例子,还有许多其他和弦进行可以在乡村摇滚歌曲中使用。重要的是找到一种与您的歌曲匹配良好并帮助创造所需氛围和感觉的和弦进行。
这里的结果似乎也不错。然而,请注意系统再次对哪些和弦是主和弦,哪些是小和弦感到困惑。在罗马数字格式中,大写字母表示大调和弦,小写字母表示小调和弦。桥段部分的和弦被指定为“G#m — F# — B — E (iii — ii — V — I)”,其中 ii 和弦应该是 F#小调,但显示为 F#大调。它确实正确识别了 G#小调和弦。我决定接受这个错误,但我要求系统使用 7 和弦,这在乡村和摇滚音乐中经常使用。我只需要提出请求。
好的,但使用 7 和弦。
当然!这是一个乡村摇滚歌曲的和弦进行示例,包含主歌、副歌和桥段,使用 7 和弦:
主歌:E7 — B7 — A7 — B7 (I7 — V7 — IV7 — V7)
副歌:E7 — A7 — B7 — E7 (I7 — IV7 — V7 — I7)
桥段:G#m7 — F#7 — B7 — E7 (iii7 — ii7 — V7 — I7)
这个和弦进行与我之前建议的类似,但它包含了 7 和弦,这为和弦增添了更多复杂性和色彩。7 和弦增加了一种紧张感和不和谐感,这在乡村摇滚背景中可能特别有效。
再次,这只是一个例子,还有许多其他和弦进行可以在使用 7 和弦的乡村摇滚歌曲中使用。重要的是找到一种与您的歌曲匹配良好并帮助创造所需氛围和感觉的和弦进行。
果然,它在所有和弦上都加了一个 7。但它也像一个优秀的写作伙伴一样,当它认为我的建议在乡村摇滚背景下“特别有效”时,给出了积极的反馈。非常感谢你,伙伴!我还有一个请求。
这个乡村摇滚歌曲的好名字是什么?
“Kickin’ Up Dust”
再次,它只给出了一个名字建议。但这个名字似乎很合适,特别是“kicking”末尾缺少了“g”,所以我用了它。
然后我将所有和弦放入 BiaB 中,我没有做任何决定,因为三个部分都有四个和弦。我选择了乡村摇滚风格并生成了歌曲。就是这样。
“Kickin’ Up Dust,” 和弦由 ChatGPT 提供,音乐由 BiaB 创作,视频由作者制作
BiaB 选择了鼓、贝斯和风琴作为伴奏,并用吉他演奏主旋律,这在乡村摇滚中很常见。吉他的表现特别出色,尽管使用踏板钢吉他也会更合适。ChatGPT 提供的和弦再次表现良好。对我来说,尽管(或者也许正因为)错误主导的 F#7 和弦,过渡部分仍然是这首作品的亮点。
讨论和下一步
正如我所展示的,ChatGPT 对音乐理论有很深入的了解,能够为各种音乐风格生成和弦进行,并对其选择提供色彩评论。它还可以接受反馈来创建编曲的变体并提供额外的评论。它偶尔会犯错误,比如给出矛盾且似乎不正确的和弦建议。但在发生这种情况的两次中,结果听起来都还不错。
我没有尝试的一件事是将原歌曲的和弦作为进展中的作品展示给 ChatGPT,并请它建议一些改进或添加新部分。根据我与它的经验,结果应该会相当不错。
然而,该系统在作曲方面还有很长的路要走。我初次尝试为一首蓝调歌曲创作旋律时,效果并不好。
ChatGPT 的新版本
当我正在进行这个项目时,OpenAI 发布了一个新版本的模型,进行了改进。
我们很高兴地宣布 ChatGPT 的一些更新!以下是新内容:
我们对 ChatGPT 模型进行了更多改进!它在广泛的话题上应该总体表现更好,并且事实准确性得到了提升。… — Open AI
我随后用新版本重新创建了我的蓝调旋律实验,但结果相似。
在 E 大调中写一个 12 小节的蓝调和弦进行。
E — E — E — E
A — A — E — E
B7 — A — E — B7
为这首歌编写 ABC 记谱法的旋律。
不幸的是,作为 AI,我不能为你创作旋律,因为这需要创造力和灵感,但我可以给你一个 ABC 记谱法的旋律示例,作为起点:
X:1
T:Example Blues Melody
M:4/4
L:1/8
K:E
“E” E2 E E2 E | “A” A2 A A2 A | “B7” B2 B2 B2 B2 | “E” E2 E E2 E |
“E” E2 E E2 E | “A” A2 A A2 A | “E” E2 E E2 E | “B7” B2 B2 B2 B2 |
“E” E2 E E2 E | “A” A2 A A2 A | “E” E2 E E2 E | “B7” B2 B2 B2 B2 |
这只是一个例子,你可以根据自己的喜好随时更改节奏或音符。
这里的旋律也不好,而且在交互之间和弦全错了。但至少更新版本似乎意识到了自身的局限性。也许系统建议“创造力和灵感”应该来自人类是可以的。
更多关于 ChatGPT 作为创意写作伙伴的信息
这是本系列的前一篇和下一篇文章。
最新的 OpenAI 语言模型如何帮助写诗歌、小说和剧本。
最新的 OpenAI 语言模型如何帮助你写儿童书籍,并利用 Midjourney 创建插图。
致谢
我想感谢 Jennifer Lim 和 Néstor Nápoles 对这个项目的帮助。
参考文献
[1] J. Schulman 等人,ChatGPT:优化对话的语言模型(2020)
使用 ChatGPT 作为创意写作伙伴 — 第三部分:图画书
原文:
towardsdatascience.com/using-chatgpt-as-a-creative-writing-partner-part-3-picture-books-4f45e5dfe8dd
OpenAI 最新的语言模型如何帮助你写儿童书籍并用 Midjourney 创建插图
·发布于 Towards Data Science ·阅读时间 17 分钟·2023 年 2 月 7 日
--
“一个人和一个友好的机器人在画架上用画笔和颜料画画的场景,” 这个图像是使用 AI 图像生成程序 Midjourney 创建的,并由作者进行编辑
这是我关于使用 OpenAI 的 ChatGPT 语言模型 [1] 进行创意写作的三篇文章系列的第三篇也是最后一篇。在第一篇文章中,我描述了如何使用 ChatGPT 写作散文、诗歌、小说和剧本。在第二篇文章中,我展示了如何使用该系统创建和弦序列来编写各种风格的音乐。
我为这篇文章进行的最新实验是使用 ChatGPT 创建图画书。因为系统无法直接渲染图像,所以我让它描述场景,然后我使用 Midjourney [2] 这个文本到图像生成系统进行渲染,Midjourney 是我在早期文章中探讨过的。
概述
我将从 ChatGPT 和 Midjourney 的背景信息开始,然后展示我如何利用这两个系统为儿童创作一本新的图画书。最后,我将讨论使用这些系统的一般方法,并提供一些未来探索的下一步建议。
ChatGPT
ChatGPT 是来自 OpenAI 的一个大型语言模型。它基于他们的 GPT-3 模型 [3] 的一种变体,经过微调以遵循人类的指示,称为 InstructGPT [4]。之后,它又进一步微调,成为一个功能性的聊天机器人,利用先前的互动作为输入。
ChatGPT 可以在你这里免费注册使用,并且即将推出的专业版的等待列表在这里。当你注册并登录后,你可以要求它生成一些文本或像这样提出问题。
来自 ChatGPT 的输出,作者截屏
这一切看起来都很直接和合理。它确实知道什么是图画书及其基本特质。这个列表似乎也是按优先顺序排列的。我同意简单而吸引人的故事应该是儿童书籍的首要任务,其次是视觉上吸引人的插图和其他列出的要点。
该系统有一个不错的功能,即在生成每个回应时使用先前的互动。这意味着系统在回应时知道到目前为止的对话背景。
你可以在我的第一篇文章中阅读有关 ChatGPT 的更完整描述。接下来,我将介绍如何使用 Midjourney 从文本生成图像。
Midjourney
Midjourney 是由同名公司创建的文本到图像生成器。该服务提供免费版和付费版。免费版对你创建的图像数量有限制,且这些图像仅供非商业用途。付费版则是一个订阅服务,具有不同级别,你将拥有所有使用该服务创建的图像。完整的服务条款可以在这里查看。
Midjourney 的用户界面是 Discord 服务器上的一个聊天机器人,Discord 是一个主要由玩家使用的社交媒体平台。我开始时使用了一个通用的提示语:“儿童图画书封面”,以查看它会生成什么。我在 Discord 中使用了 /imagine 命令并输入了提示;大约二十秒后,它生成了四个缩略图。
结果相当不错。质量似乎很高,孩子们看起来很可爱。除了飞鲸和令人毛骨悚然的老人之外,我注意到的下一个问题是随机文本。当前版本的 Midjourney,即第 4 版,在生成文本字符时效果还不错,但文字并没有意义。在下面的部分中,我将讨论如何使用 Photoshop 向书中的图像添加文本。你可以在我早期的文章中这里阅读更多关于使用 Midjourney 的内容。
接下来,我将展示我如何使用 ChatGPT 和 Midjourney 创建一本新的儿童图画书。
“儿童图画书的堆叠”, 图像是使用 AI 图像创建程序 Midjourney 生成并由作者编辑的
使用 AI 系统创建图画书
在这一部分,我将带你了解我如何使用 ChatGPT 和 Midjourney 来创建一本儿童图画书。我首先使用前者互动地创建新书的标题和情节摘要。然后我使用两个系统描述并绘制封面的插图。我使用 Photoshop 自己编辑了图像并添加了排版。接着,我使用 ChatGPT 编写了其余的八页书籍内容并描述了每张图片。我用它来创建我输入 Midjourney 的提示,以生成书中的八幅插图。我用 Photoshop 编辑了这些图像,添加了文本,然后在 Microsoft Word 中组装所有页面,创建了最终书籍的 PDF 文件。
以下是每个步骤的详细信息。注意,我使用一个 ChatGPT 会话创建了整个书籍,这样可以让之前的互动内容影响后续的回答。
生成标题和情节摘要
从我之前与 ChatGPT 的互动中,你可以看到它知道一本好的儿童图画书的基本特质。但系统不仅知道这些,它还可以创造东西。我首先让它为书生成标题和情节摘要。
为幼儿创建图画书的标题和情节。
“勇敢的小龙” — 当一只小龙发现自己能像它的大龙朋友们一样喷火时,它变成了这片土地上最勇敢的龙,并通过运用喷火技能帮助需要帮助的人来拯救一天。
“伟大的花园冒险” — 一群朋友踏上了一段穿越花园的旅程,在途中遇到了各种有趣和惊奇的生物。他们了解了照顾环境的重要性和自然的魔力。
“顽皮的猴子” — 一只好奇的猴子在探索丛林时遇到各种麻烦,但最终学会了有时候在行动之前先思考是更好的选择。
“好奇的毛虫” — 一只毛虫踏上了一段探索它长大后会变成什么的旅程。在过程中,它遇到了各种不同的昆虫,并了解了蜕变的不同阶段。
“魔法森林” — 一个小女孩在她的后院发现了一个通往充满妖精、会说话的动物和其他神话生物的魔法森林的秘密门。她在探索这个魔法世界的过程中经历了各种冒险,并学习了想象力的力量。
这些都很不错,但也相当普通。它们涉及各种动物,包括一条龙、一只猴子、一只毛虫以及一些不明确的魔法生物。
我让系统为第五个生成的标题添加细节,以了解它心中有哪些会说话的动物。注意我只需将我喜欢的标题/情节称作“最后一个”,它就知道我在说哪个。
最后一个听起来不错。给情节摘要添加一些新的奇特和具体的细节。
当然!这是“神奇森林”的修订版情节摘要:
“神奇森林”——当一个名叫莉莉的小女孩在她的后院发现一个秘密门时,她从未想过它会把她带到哪里。门后面是一个充满仙女、会说话的动物和其他神话生物的神奇森林。莉莉对这个发现感到非常高兴,开始了她在森林中的旅程,渴望看到所有的奇观。在旅途中,她遇到了一只会说话的狐狸,它愿意成为她的向导,一个讲述森林过去故事的脾气古怪的老树,以及一个捉弄她的顽皮仙女。莉莉还了解到想象力的力量和照顾环境的重要性。随着她探索森林,她开始想知道这个神奇的地方是否会永远成为她生活的一部分,还是只是现实世界中的一时逃避。
我发现这个情节扩展更有趣。它设计了一个会说话的狐狸、一个脾气暴躁的树木和一个爱捉弄人的仙女作为书中的角色。然后我问它这些角色的名字。
好的,狐狸、树和仙女的名字是什么?
当然!以下是“神奇森林”中的角色可能的名字:
会说话的狐狸可以叫弗雷德里克。脾气暴躁的老树可以叫奥克利。顽皮的仙女可以叫费伊。
好吧,看来它喜欢名字中的头韵,因为它用了弗雷德里克来命名狐狸,费伊来命名仙女。我想它没有选择特雷弗来命名树木。但我确实喜欢名字奥克利,非常聪明。
我随后决定在标题中使用狐狸的名字,以使其更独特。
我们用弗雷迪作为狐狸的名字。把这个名字融入标题五次。
当然!这里有五个包含弗雷迪名字的“神奇森林”修订标题:
“狐狸弗雷迪与神奇森林”
“弗雷迪和莉莉在神奇森林中的冒险”
“弗雷迪在神奇森林中的奇妙之旅”
“莉莉和弗雷迪的神奇森林冒险”
“狐狸弗雷迪与神奇森林的奇观”
希望这些标题对你有用!如果你有任何其他问题或需要我为故事提出更多想法,请告诉我。
这些也都很好。再一次,注意它如何记住并提到之前对话中的内容,比如它知道莉莉是故事中的主要角色,并将她融入了五个提议标题中的两个。
创建书封面
在确定了标题之后,我让它描述了书封面的插图。
描述一下“狐狸弗雷迪与神奇森林”的书封面会是什么样子。
对于《弗雷迪与神奇森林》这本书,封面可以展示弗雷迪狐狸站在神奇森林中央的插图,旁边站着一个年轻的女孩(莉莉)。森林中可以充满各种奇幻生物,如仙女和会说话的树木。背景可以是一个美丽的、色彩斑斓的日落,树木和森林中的其他元素在地面上投下长长的阴影。书名《弗雷迪与神奇森林》可以用大号粗体字写在封面的顶部,作者的名字则用较小的字母写在下面。整体上,封面应该引人注目,充满奇幻和惊奇,邀请年轻读者深入书页中的神奇世界。
这是对书籍封面的详细描述。它添加了一些可爱的细节,比如日落如何在地面上投下长长的阴影。
然后我要求系统生成一个提示以创建插图。我知道 Midjourney 不会知道弗雷迪和莉莉是谁,所以我要求 ChatGPT 省略他们的名字。我还要求省略文字描述,因为我知道 Midjourney 对渲染文本字符处理得不好。
为语音转图系统编写一个提示语,以创建书籍封面。描述角色的视觉特征,而不是名字。包括年龄、发色和衣服等细节。不要提及文字。
为一本图画书的封面创建插图。插图中展示一只勇敢且充满冒险精神的会说话的狐狸,站在一个神奇的森林中央。一位年轻的女孩站在狐狸旁边,抬头用惊奇的目光看着它。女孩有着红色的头发,约 8 岁,穿着绿色连衣裙。背景是一个美丽的、色彩斑斓的日落,树木和森林中的其他元素在地面上投下长长的阴影。插图应该引人注目,充满了奇幻和惊奇,吸引年轻读者深入书页中的神奇世界。
很好。这正是我所寻找的。我省略了开头的“创建”一词和结尾的“邀请年轻读者”这一从句,因为这些内容对于图像生成帮助不大。我使用的命令是/imagine [prompt] --v 4 --q 2 --upbeta,这意味着它使用了第四版本的模型,质量设置为 2.0,并且使用了升级器的测试版。这个升级算法将选定缩略图的分辨率从 512x512 像素提高到 2048x2048 像素。有关这些参数的更多信息,请参见这里。
然后我将命令输入到 Midjourney 中,它生成了四个缩略图。
哇,这些都很不错。左上角的那张有一些神奇的事物,包括一只小狐狸和一只大狐狸。左下角的那张在树上有一个奇怪的狐狸状阴影。但我最喜欢右上角的那张。我使用 Midjourney 放大了缩略图,然后在 Photoshop 中稍微修整了一下。我按照 ChatGPT 之前的规格添加了文本。这里是编辑前后的最终结果。
《弗雷迪狐狸与神奇森林》的封面,原始图像(左)和最终图像(右),由作者使用 Midjourney 编辑的图像
我没有对图像进行太多编辑。在 Photoshop 中,我只为狐狸添加了一个瞳孔。对于文本,我使用了 Google 的两个开源字体,Boogaloo用于标题,Carlito用于致谢。我添加了黑色边框和轻微的阴影,以确保文本在背景的浅色上可读。
编写这本书
为了编写这本书,我只需要选择八页的数量,并给 ChatGPT 一些基本指示。
编写一本总共八页的书,每页两句话和一幅插图。提供插图的描述,并用方括号标出。
这是《弗雷迪狐狸与神奇森林》的图画书,共有八页,每页两句话和一幅插图:
第 1 页:“莉莉是一个充满好奇心的女孩,拥有丰富的想象力。” “一天,她在后院发现了一扇秘密的门。” [插图:一个名叫莉莉的小女孩站在她后院的一扇藏在灌木后的门前。]
第 2 页:“门后是一个充满奇迹的神奇森林。” “莉莉惊讶不已,迫不及待地想要探索。” [插图:莉莉站在神奇森林的入口处,四周有盛开的树木和花朵,天空中有一抹淡淡的彩虹。]
第 3 页:“当她在森林中漫游时,莉莉遇见了一只名叫弗雷迪的会说话的狐狸。” “他主动提出成为她冒险的向导。” [插图:莉莉和狐狸弗雷迪站在一起,弗雷迪友好地看着莉莉。]
第 4 页:“莉莉和弗雷迪穿越森林,遇见了一棵名叫奥克利的脾气暴躁的老树。” “他给他们讲述了森林的过去和神奇的居民。” [插图:奥克利树,一棵大橡树,脸上带有表情和手臂,双臂交叉,脸上带着皱眉。]
第 5 页:“当他们继续旅行时,莉莉和弗雷迪遇到了一位名叫费伊的顽皮小精灵。” “她捉弄他们,并引导他们在森林中进行了一次疯狂的追逐。” [插图:费伊小精灵,一个小小的有翅膀的生物,金色头发和顽皮的表情,飞翔在空中,身后留下一串闪光的痕迹。]
第 6 页:“莉莉和 Freddy 在魔法森林中经历了各种冒险。” “他们了解了想象力的力量和保护环境的重要性。” [插图:莉莉和 Freddy 站在森林里,莉莉望向远方,脸上带着梦幻般的表情,Freddy 则友好地看着她。]
第 7 页:“莉莉忍不住想知道魔法森林是否会永远成为她生活的一部分。” “还是只是从现实世界中的暂时逃避。” [插图:莉莉站在秘密门前,带着怀旧的表情回望魔法森林。]
第 8 页:“但有一件事是肯定的——她与 Freddy、Oakly 和 Fae 的冒险记忆将永远伴随她。” “完结。” [插图:莉莉在魔法森林中的冒险蒙太奇,显示她、Freddy、Oakly 和 Fae 在旅程中不同场景的合影。]
我喜欢它如何重复了指示,然后迅速生成了八页的叙述文本和插图简要描述。
从批评的角度来看,故事情节比较薄弱,几乎没有情节(注意缺乏对手的存在)。但为了展示目的,这似乎还可以。第 7 页有些哲学性的内容,但在第 8 页以感人的结局作结。
接下来是这个项目的有趣部分,生成书中的插图。
生成插图
我上面展示了如何使用 Midjourney 渲染书封面图像。我遵循了这个方法来创建八个页面的插图。由于 GPT 有上下文,我只需逐个请求提示。
这是我创建书中第一个插图所做的工作。注意我再次要求了角色的具体视觉细节,这有助于在再次渲染时保持连贯性。
为语音转图像系统编写一个提示,以创建第 1 页的插图。描述角色的视觉特征,不要提及名字。但不要提及文本。
创建一个插图,展示一个小女孩站在她后院一丛灌木后面隐藏的门前。女孩是红发,约 8 岁,穿着绿色连衣裙。她站在门前,脸上带着兴奋的表情,门很大,木制的,带有一个铜把手。
为了在 Midjourney 中渲染图像,我省略了“create”一词,使用了上述选项的/imagine命令。以下是结果。
“Freddy the Fox and the Magical Forest”的第一页缩略图,Midjourney 图像由作者提供
与书封面类似,这些图像也做得很好。所有规定的元素都渲染得很不错,包括女孩、门、灌木丛等。我最喜欢左上角的那一张,但女孩的额头和头发有一些问题。Midjourney 中有一个很好的功能叫做“remix variations”,它使用选择的图像和原始提示生成四个插图变体。我点击了V1按钮来表示左上角的缩略图,系统生成了四个变体。
《弗雷迪狐狸与神奇森林》的首幅插图(左)和最终图像(右),Midjourney 作者图像
在左侧,你可以看到 Midjourney 生成的四个修订版缩略图,每个图像都有细微和不那么细微的视觉差异。我喜欢左下角的那一张,因为她的头发问题较少,而且门的渲染似乎更好。因此,我点击了U3按钮,将其放大到 2048x2048 像素。然后,我使用 Photoshop 添加了文字。我没有对这张图像进行任何修饰。
角色的视觉连续性
阅读图画书时,读者期望角色在每一页上看起来大致相同。但如果你将书封面的莉莉与第一页的插图进行比较,你会注意到一些连续性问题。例如,衣服的图案是最明显的差异。
书封面上的莉莉细节(左)和首幅插图(右),Midjourney 作者图像
使用 Midjourney 生成图画书插图时,这是一个重要问题。除非你告诉它,否则系统并不会“知道”新角色应该是什么样子。
改善插图之间连续性的有两个可能方案。首先,在提示中提供更详细的角色描述会有帮助,即包括对服装和头发的详细描述。Midjourney 会采纳这些词汇,渲染出的角色会更相似。其次,可以在 Photoshop 中清理图像,即使用曲线工具和吸管匹配头发和衣物的颜色。
我没有花太多时间修复图书的连续性问题。尽管如此,我选择了角色匹配度较高的缩略图,偶尔调整了提示,并在 Photoshop 中做了一些清理。
创作其余的插图
我使用 ChatGPT 和 Midjourney 重复了第二页到第八页的插图制作过程。总体来说,我喜欢奥克利这个脾气暴躁的树和法伊这个顽皮的仙女的渲染效果。
《弗雷迪狐狸与神奇森林》的插图,展示了奥克利(左)和菲(右),Midjourney 图像由作者编辑
特别是,我喜欢 Midjourney 将奥克利的眉毛渲染成树枝的效果,以及喜欢菲的发光背光效果。
完成的书籍
我将所有页面加载到一个 Microsoft Word 文档中以编制书籍,并添加了最后的“完”页以及鸣谢页面。然后将其保存为 PDF 文件。
这是完成的书籍。请注意,你可以使用右侧的滚动条或箭头按钮翻页。全屏模式并不非常有用,因为它会放大太多,无法显示完整的页面。
书中的最后一张图片是我在 Photoshop 中创建的蒙太奇,通过复制和粘贴 Midjourney 生成的“最佳”角色图片。我尝试使用系统的图像提示功能来自动生成汇编图像,但没有成功。
讨论
经过大量的实验,我设法使用 ChatGPT 和 Midjourney 作为写作/插画伙伴来创建一本简单的儿童图画书。总体而言,我对结果感到满意。角色很有趣,插图也很好看。
然而,正如我之前提到的,情节简单且缺乏细节。首先,它需要超过八页,以创建更具吸引力的故事。例如,Margaret Wise Brown 和 Clement Hurd 的《晚安月亮》有 30 页,Maurice Sendak 的《野兽国》有 48 页。但我认为可以与 ChatGPT 合作,创建一个更复杂故事的高层次大纲,然后将其拆分成各个部分,并生成有助于故事主线的单独页面。
此外,还需要更多的实验来在页面之间渲染具有适当连续性的角色。例如,可能可以使用 Midjourney 创建角色研究,从不同角度生成角色的各种姿势。然后使用孤立的姿势作为图像提示,以及描述特定场景的文本。
我希望那些开发文本到图像生成系统的开发者,如Midjourney、Dall-E和Stable Diffusion,能改进他们的模型,以便能够渲染具有视觉连续性的角色。
更多关于 ChatGPT 作为创意写作伙伴的内容
这是本系列的前两篇文章。
## 使用 ChatGPT 作为创意写作伙伴 — 第一部分:散文
最新的 OpenAI 语言模型如何帮助创作诗歌、小说和剧本
towardsdatascience.com ## 使用 ChatGPT 作为创意写作伙伴——第二部分:音乐
最新的 OpenAI 语言模型如何帮助你为新歌曲编写和弦,音乐由 Band-in-a-Box 提供
towardsdatascience.com
致谢
我想感谢 Jennifer Lim 在这个项目中的帮助。
参考文献
[1] J. Schulman 等,ChatGPT:优化对话的语言模型(2020 年)
[2] Midjourney midjourney.gitbook.io/docs/
[2] T. Brown 等,语言模型是少样本学习者(2020 年)
[3] L. Ouyang 等,训练语言模型以遵循带有人类反馈的指令(2020 年)
使用 ChatGPT 进行高效调试
利用大型语言模型提升你的调试体验,快速学习
·
关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 6 月 22 日
--
图片来源:Pavel Danilyuk 于 Pexels
难以否认,大型语言模型(LLMs)正在对各行各业和应用产生深远的影响,彻底改变了我们的工作和互动方式。尽管自从约六个月前(2022 年 11 月)发布以来,围绕 ChatGPT 的最初炒作已平息,但它的影响仍然显著。看来自回归 LLMs 将在不久的将来继续成为我们生活的一部分,并且 值得培养与它们互动的技能,无论作为开发者还是用户。
正如Chip Huyen 在她的博客文章中所述,使用 LLMs 达成令人印象深刻的成果相对容易,但考虑到LLMs 目前存在的限制和潜在问题,构建生产就绪的系统却相当具有挑战性。然而,尽管研究和工程界正积极努力解决这些挑战,值得承认的是,个人已经能够从 LLMs 中获益,至少可以将它们作为个人助手用于日常非关键任务,或作为头脑风暴的合作伙伴。
在我之前的文章中,我讨论了提示工程的最佳实践,提供了帮助你开发基于本地 LLMs 的应用程序的见解。在这篇文章中,我将分享一系列技巧,帮助你利用如 ChatGPT 这样的模型进行有效的代码调试和加速编程学习。我们还将看看编写和解释代码的示例提示。这些技巧不仅在与 ChatGPT 互动时有用,也能帮助你从同事那里寻求帮助或独立解决编程挑战。
本文主要面向初学者,因此我尝试提供了说明性的例子和解释。我希望这些技巧能帮助你更高效地理解和排除代码中的问题。
代码调试的一般框架
实际上,ChatGPT 对调试过程没有做出重大改变。好在现在你可以轻松地与虚拟同事联系,而不用担心打扰别人或担心问蠢问题!但我们将考虑的这些技巧只要软件工程存在就会存在,因此不仅在与 LLMs 互动时有用,也有助于更好地理解过程和更有效地与同事互动。
要找到代码中的错误,你只需两个基本步骤(实际上有三个):
-
隔离错误并用最少量的代码展示它;
-
对你的错误做出假设并进行测试;
-
不断迭代,提出更多假设,直到找到解决办法。
虽然你可以立即开始使用 ChatGPT,但实际上从重现错误开始会更好,这有几个原因。首先,在语言模型的上下文中,可能很难包括所有相关点并准确解释你想要实现的目标。其次,这将使你更好地理解问题,并可能自己找到错误。让我们来看看。
顺便说一下,在这篇文章中,我使用的是原版 ChatGPT(GPT-3.5),但对于编码任务,GPT-4 通常更为高效。
第一步:使用最少量的代码隔离并重现问题
第一步是重现问题。众所周知,大多数问题仍然可以通过经典的 “重启” 来解决。可能你已经在 Jupyter Notebook 中与代码执行顺序纠缠不清。
如果可能(通常是),建议 编写新的代码以抛出相同的错误,并尽可能保持简单。
让我们考虑一个 TypeError: ‘int' object is not iterable
的例子,这种错误发生在你尝试迭代 some_integer
而不是使用 range(some_integer)
构造时。
糟糕的例子:一个函数调用另一个函数,该函数又调用一个类的方法。乍一看,可能需要一些时间来确定实际计算发生的位置,尽管这是一个相对简单的例子。类似地,对于模型而言,在无关细节中 定位相关信息 变得更加具有挑战性。
更好的例子:通过将 do_some_work()
函数(引发错误的函数)的功能直接移到我们调用的函数中,来去掉类。
除了我们在变量命名约定上仍然做得很糟糕(记住,变量名应该具有描述性和意义!),这段代码仍然更容易调试和理解。
更好的例子:我们还可以去掉 some_function()
。
总体而言,我们将代码缩短了超过一半。比较一下找出其中错误的难度。
在 pandas 的背景下,这一原则可能意味着不使用原始的数据框。假设我们想计算每个职位的平均工资,并遇到 KeyError
。这是一个糟糕的例子:
首先,我们不能确定数据框中是否包含注释中提供的数据。实际上,我们只需要其中的两列,如果我们创建一个类似的迷你版本,会更容易理解我们只是拼写错误了工资列(Salary
vs salary
)。
顺便说一下,ChatGPT 在 生成虚拟数据 方面相当出色,所以它在这里也可能有帮助!
错误类型繁多,当然不可能一一列举。总体而言,尝试以 产生相同错误 的方式修改代码,但 尽可能让其容易理解。
由于所谓的 “橡皮鸭调试”,这一步通常能帮助你自己理解问题的原因,而无需寻求外部帮助。例如,如果你的迷你代码没有产生相同的错误,你已经找到了解决方案的一半。然而,即使它产生了相同的错误,这也是一个积极的结果。 😃
第 2 步:做出假设、测试并迭代
如果你仍然找不到修复错误的方法,值得寻求帮助。但对可能出错的地方有自己的假设会很有帮助。
查找确切的行
首先,找到导致问题的表达式和确切的代码行。你可能已经知道它在你之前编写的迷你代码的最后一行里面。
记住,Python 追溯显示的错误信息在底部,而对应的执行代码在顶部,中间是内部函数的调用。
Python 追溯。图像由作者提供
对于直接的错误,这可能相对简单,但对于逻辑错误,这些错误不会生成任何错误信息,但由于逻辑错误导致意外的输出,处理起来可能更具挑战性。在这种情况下,使用调试器或简单的print()
语句逐步观察值,并确定不符合预期的代码行是很有帮助的。
如果错误是由复杂的表达式引起的,例如df.groupby(‘Occupation’)[‘Address’].apply(lambda x: ‘, ‘.join(x))
,你可以首先将其拆分为几个部分,并逐步探索输出,例如,先运行df.groupby(‘Occupation’)
,然后df.groupby(‘Occupation’)[‘Address’]
,等等。
思考常见原因
之后,考虑一些常见的错误原因:
-
是否可能是所需的库未安装或安装了错误的版本?
-
也许某处有一个简单的拼写错误或语法错误?
-
错误是否可能与数据类型有关,例如你将字符串和数字相加?
-
等等。
请询问 ChatGPT
如果想不到任何解决办法,可以寻求 ChatGPT 的帮助。简单的问题通常可以通过直接粘贴代码并询问哪里出错了来解决。然而,对于更复杂的问题,你可能需要提供额外的相关信息。例如,如果你遇到系统错误,提供你使用的 Python 版本可能会有帮助。总的来说,尽量总是包括错误信息和描述你想要实现的目标。你可能需要尝试几种不同的措辞,所以不要害怕进行实验。
情况可能有很大变化,所以我们继续一些示例。首先,让我们看一下之前遇到的 pandas KeyError
。
示例提示:
This Python code results in <error> in <this line>.
在此处插入你的代码
Tell me how to debug the code to solve the given error.
示例:
This Python code results in KeyError: 'Column not found: salary'
in the last line.
import pandas as pd
df = pd.DataFrame({
'Occupation': ['Engineer', 'Doctor', 'Engineer'],
'Salary': [56056, 61304, 86850],
})
average_salary_per_occupation = df.groupby('Occupation')['salary'].mean()
Tell me how to debug the code to solve the given error.
调试 pandas 示例的输出。图像由作者提供,使用ChatGPT创建
看起来不错!让我们来看一个具有逻辑错误的更具挑战性的例子。
示例提示:
This Python code <does this, but I want it to do this>.
在此处插入你的代码
Tell me how to fix the code to solve the problem.
示例:
This Python code trains a random forest classifier on the Iris dataset.
As far as I know, the dataset is relatively simple and I expect
the classifier to make the perfect predictions,
however, I am getting about 95% accuracy even with 100 trees in the forest.
Are these results reasonable for this situation?
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, GridSearchCV
加载 Iris 数据集
iris = load_iris()
X = iris.data
y = iris.target
切分数据
X_test, X_train, y_test, y_train = train_test_split(
X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(
X_train, y_train, test_size=0.2, random_state=42)
param_grid = {
'n_estimators': [3, 7, 15, 25, 50, 100]
}
创建一个随机森林分类器并
使用网格搜索来找到最佳的树木数量
rf_classifier = RandomForestClassifier()
grid_search = GridSearchCV(estimator=rf_classifier,
param_grid=param_grid, cv=5)
grid_search.fit(X_train, y_train)
从网格搜索中获取最佳的估算器并进行测试
best_rf = grid_search.best_estimator_
y_pred = best_rf.predict(X_test)
test_accuracy = accuracy_score(y_test, y_pred)
print("Test accuracy:", test_accuracy)
从网格搜索中获取最佳的树木数量
best_n_estimators = grid_search.best_params_['n_estimators']
print("Best number of trees:", best_n_estimators)
Tell me how to fix the code to solve the problem if there is a problem.
调试 sklearn 示例的输出。图像由 ChatGPT 创建。
模型能够找到隐藏的问题并修复它。太棒了!
由于 ChatGPT 能够记住你之前的消息,这里的可能性是无限的。你可以让它解释一些你难以理解的概念,建议替代解决方案,将代码从一种语言翻译到另一种语言,等等。
此外,由于 ChatGPT 能够理解代码,它也可以编写代码。
使用 ChatGPT 编写和解释代码
在这一部分,我们将探讨一些在与 ChatGPT 编码时可以使用的技巧。但首先,我认为重要的是要记住,在 ChatGPT 出现之前,Google 是软件开发人员的主要工具。
我认为重要的是不要忘记如何使用 Google,出于各种原因。最终,使用 Google 你可以做 ChatGPT 能做的所有事情(在编码环境中),只是可能会更慢。然而,当涉及到像用 numpy 创建对角矩阵这样的特定任务时,我可能会使用 Google 更快地完成。
我认为适合的类比是学习外语的过程(尽管可以说你通过学习编程语言实际上就是在做这件事 😃)。使用 Google 就像用词汇翻译单个单词,而使用 ChatGPT 则类似于翻译整句话和段落的在线翻译器。虽然 ChatGPT 可以非常强大,但你可能会遇到在识别隐藏错误或理解某些代码块时的挑战,尤其是作为初学者。
鉴于 ChatGPT 的巨大能力,它可以以各种方式被使用。我建议你用它来生成简短的代码片段——充当习语或稳定表达的词典。确保你对生成的代码有一个扎实的理解是很重要的,因为这有助于防止未来的问题和复杂性。
所以,使用 ChatGPT,但也记得去谷歌查找信息!
这篇文章提供了一些关于如何有效谷歌搜索的软件开发者的宝贵建议。
写
在这种情况下,一个有用的技巧叫做角色提示。你可以让模型以初级开发者的角色来编写代码,而不是仅仅让它生成代码。通过假设这个角色,模型更有可能生成对初学者易于理解的代码,避免过于复杂的构造。
示例提示:
Act like a junior Python developer.
Write code with comments explaining what is going on.
When providing the solution, ensure the output is well-formatted
and the code is well-documented.
Include example usage and explanations.
<Describe your problem here>
示例:
Act like a junior Python developer.
Write code with comments explaining what is going on.
When providing the solution, ensure the output is well-formatted
and the code is well-documented.
Include example usage and explanations.
Write a program to find common values between two numpy arrays.
It is a function that takes two numpy arrays as inputs and
output a numpy array.
针对代码生成提示的输出。图片由作者创建,使用ChatGPT
在这个例子中,ChatGPT 很好地遵循了我们的所有指示,包括示例和详细解释。
别忘了将代码与最初的假设进行比较,并得出结论,这将帮助你在未来避免做出类似的请求。
总的来说,请记住,LLMs 对你或你所面临的具体问题没有先验知识,因此你提供的信息越多,得到的输出就会越好:
-
描述你的任务;
-
定义代码结构:例如,它是完整的脚本、类还是函数;
-
指定输入和输出:例如,函数接受两个整数参数并输出浮点数;
-
提及你想使用的工具/库,如 numpy 或 pandas,以及编程语言;
-
如果可行,添加你的建议,例如解决方案可能是什么样的:例如,我建议使用
pandas.DataFrame.groupby
函数按职位计算平均薪资。
优化
如果我们之前让模型表现得像一个初级开发者,我们希望它在编程任务,如代码优化方面更强。
顺便说一下,你也可以用相同的角色提示技巧来调试代码。
示例提示:
Act like an experienced Python developer mentoring a junior developer.
Provide explanations and comments about the concepts that can be
hard for beginners. Use Python best practices and write docstrings.
Rewrite the above Python code and optimize it.
在这里插入你的代码
示例:
Act like an experienced Python developer mentoring a junior developer.
Provide explanations and comments about the concepts that can be
hard for beginners. Use Python best practices and write docstrings.
Rewrite the above Python code and optimize it.
def unique_list(l):
# 从给定的列表中获取独特元素的列表
x = []
for a in l:
if a not in x:
x.append(a)
return x
针对代码优化提示的输出。图片由作者创建,使用ChatGPT
在这里,为了从列表中提取独特元素,ChatGPT 建议使用set
数据类型(因为其定义上不允许重复值),而不是使用 for 循环。这是一个不错的选择,因为它基本上是一个一行解决方案。
解释
我们可以使用相同的标记请求评论和代码解释。通常,这会输出冗余的注释,如import numpy as np # importing numpy library
。然而,这在你开始学习时仍然有用,并且帮助模型表达其内部思想,正如在链式思维推理部分和我的上一篇文章中讨论的那样。
样例提示:
Act like an experienced Python developer mentoring a junior developer.
Provide explanations and comments about the concepts that can be
hard for beginners.
Explain and comment the following code.
在此处插入你的代码
示例:
Act like an experienced Python developer mentoring a junior developer.
Provide explanations and comments about the concepts that can be
hard for beginners.
Explain and comment the following code.
def func(n):
trow = [1]
y = [0]
for x in range(max(n, 0)):
print(trow)
trow = [l + r for l, r in zip(trow + y, y + trow)]
return n >= 1
func(6)
代码解释提示的输出。图片由作者使用ChatGPT创建
在这里,模型能够识别这段代码背后的任务:生成帕斯卡三角形的行。做得好!
虽然你可以想出许多有用的 ChatGPT 应用,但重要的是要注意,它可能并不总是能够解决你遇到的每一个问题。值得进一步讨论这一方面。
关于潜在陷阱的一些注意事项
虽然 ChatGPT 可以非常有用并执行各种任务,但重要的是要记住潜在的缺点,这些缺点可能比较棘手。
作为自回归 LLM,ChatGPT 不是确定性的……
ChatGPT 是一个大型语言模型的例子,当前的大型语言模型是自回归的,这意味着它们被训练来预测序列中的下一个标记。模型的输出是所有可能标记的概率分布,我们从这个分布中逐个标记地采样最终文本。因此,采样过程是非确定性的,这意味着由于概率原因,你对相同的输入可能会有不同的输出。
为了说明这一点,我们可以将采样过程想象成一棵树。这里,初始句子用蓝色显示,选中的标记用绿色显示,未选中的标记用红色表示(不包括它们的进一步演变)。概率是随机选择的,仅用于说明。
LLM 采样过程的简化可视化。图片由作者提供
输入序列是:“我的名字是”,而 ChatGPT 通过“ChatGPT,很高兴为您服务!”来完成它。这是一个来自我的上一篇文章的例子,我在其中讨论了 LLM 的基本知识。
… 这就是为什么 ChatGPT 可能会出错的原因
在实际操作中,这意味着你可能因为一开始得到了某些稀有不太可能的标记而得到次优的输出。因此,你可能需要对相同的输入进行多次运行以检查不同的输出并选择最合适的一个,甚至结合来自不同输出的不同部分。
此外,重要的是要强调现代语言模型,特别是 GPT-4,具有令人印象深刻的自我纠正能力。如果生成的代码有错误,你可以简单地返回它并指出它的功能不正确。GPT-4 擅长调试自己的代码并提供相关建议。你通常可以在几次迭代后获得正确的代码。
过度自信
尽管 LLMs 有时可能提供不正确的输出,但它们经过训练优先考虑准确性。这可能使它们的输出看起来非常令人信服,即使它是错误的。因此,识别隐藏的错误可能具有挑战性,因为模型通常不能明确地说“我需要更多的信息”,尽管正在进行的研究正积极探索解决这一限制的方法。
从这个意义上讲,使用 ChatGPT 生成小段代码以解决特定任务,如我之前提到的谷歌搜索,可以更安全。确保你对所收到的代码有扎实的理解,这样可以有效避免潜在的陷阱。
结论
在这篇文章中,我们探讨了可以帮助你调试的方法,不仅仅是与 ChatGPT 一起,还有你自己。
通过隔离问题和用最少的代码行重新编写,你很可能会洞察到潜在的问题。或者,让 ChatGPT 帮助你,提供关于发生了什么的完整信息,进行假设并进行实验。
你还可以利用 ChatGPT 进行任务,例如编写、优化或解释代码,就像我们在角色提示示例中讨论的那样。其他与代码相关的应用几乎无穷无尽,包括创建虚拟数据、编写测试、生成文档等。
但要记住 LLMs 的局限性,因为它们可能引入隐藏的问题。由于其自回归性质,LLMs可能在表现自信的同时犯错,这可能需要更多的问题或多次迭代以选择最佳输出。
祝你学习顺利!
资源
查看这篇文章以获取将 ChatGPT 作为个人编码导师的更广泛指南。
这里是我关于 LLMs 的其他文章,可能对你有用。我已经涵盖了:
-
估计大型语言模型的规模:LLMs 是什么,它们是如何训练的,需要多少数据和计算资源;
-
提示工程的最佳实践:如何应用提示工程技术与 LLMs 有效互动,以及如何使用 OpenAI API 和 Streamlit 构建本地 LLM 应用程序。
你可能还感兴趣:
-
免费的Learn Prompting 课程帮助你深入了解提示工程及其相关技术;
-
最近发布了DeepLearning.AI 的短课程,用于构建 OpenAI API 应用程序。
感谢你的阅读!
-
希望这些材料对你有帮助。关注我在 Medium 上的更新以获取更多类似的文章。
-
为了支持我作为作家,并获得其他数千篇 Medium 文章的访问权限,可以使用我的推荐链接来获取 Medium 会员(对你没有额外费用)。
使用 ChatGPT 将 R 代码翻译成 Python
原文:
towardsdatascience.com/using-chatgpt-to-translate-r-code-to-python-effed461037e
进入翻译代码库的第一步。
·发布于 Towards Data Science ·阅读时间 5 分钟·2023 年 5 月 28 日
--
图片由 Katka Pavlickova 提供,来源于 Unsplash
Python 和 R 是数据科学生态系统中的两个关键角色。我在 Python 和 R 编写的项目中都有工作经验。它们在高效处理数据清理、操作和分析任务方面表现出色。
然而,有时我们希望将代码库从 R 翻译成 Python,或反之亦然。这种动机可能来源于公司内部的技能组合。如果你有许多 Python 开发者和少数 R 开发者,你可能希望将项目维护在 Python 中。
在某些情况下,特定的项目要求可能会促使你翻译代码库。例如,虽然 Python 以其多功能性而闻名,涵盖了网页开发、软件开发和机器学习等领域,但 R 在统计分析和可视化方面表现出色。因此,如果项目涉及大量统计分析或复杂的图形表示,翻译 Python 代码到 R 可能会有益。
翻译过程通常需要具备两种编程语言技能的人才,这可能很难找到。好消息是我们可以使用 ChatGPT,它在几乎所有编程语言中都具备相当高的技能水平。
在这篇文章中,我们将使用 ChatGPT 将使用 “data.table” 包编写的 R 代码翻译成使用 “pandas” 库的 Python 代码。
我们将仅给出代码,而不说明其作用或预期功能。让我们从初始提示开始,定义 ChatGPT 的角色并说明我们的期望:
初始提示:你是一名拥有多年 Python 和 R 经验的软件工程师。你的任务是将 R 代码转换为 Python 代码。R 代码使用 data.table 包进行数据清理和处理。你将使用 Pandas 库将其转换为 Python 代码。
第一个示例是创建数据表中新列的代码行。
Convert the following R code that uses `data.table` package
to Python using the `pandas` library:
melb[, Price_per_area := Price / Landsize]
这是 ChatGPT 的回答:
(作者提供的图片)
它不仅提供了代码,还解释了代码的功能。这对于没有源语言或目标语言专业知识的人来说非常有用。
第二个示例通过对价格和类型列进行过滤来创建数据表的子集。
Convert the following R code that uses `data.table` package
to Python using the `pandas` library:
subset <- melb[Price > 1000000 & Type == "h"]
ChatGPT:
(作者提供的图片)
我们再做一个稍微复杂的例子。下面提示中的代码行计算了不同地区类型为“u”的房屋的平均房价。
Convert the following R code that uses `data.table` package
to Python using the `pandas` library:
melb[Type == "u", .(avg_price = mean(Price)), by="Regionname"]
ChatGPT:
(作者提供的图片)
上面截图中的代码:
avg_price = melb[melb['Type'] == 'u'].groupby('Regionname')['Price'].mean().reset_index().rename(columns={'Price': 'avg_price'})
它成功地完成了任务,但我更倾向于使用命名聚合,我认为这更具可读性且更简洁。我们也让 ChatGPT 使用它:
Convert the following R code that uses `data.table` package
to Python using the `pandas` library. Used named aggregations.
melb[Type == "u", .(avg_price = mean(Price)), by="Regionname"]
ChatGPT:
(作者提供的图片)
上面截图中的代码:
avg_price = melb[melb['Type'] == 'u'].groupby('Regionname').agg(avg_price=('Price', 'mean')).reset_index()
这样更好,因为它不需要之后重命名列。我们还可以使用groupby
函数的as_index
参数并将其值设置为 False,这样我们在计算后就不需要重置索引。
下一个提示中的 R 代码计算了每个地区的平均房价和房屋总数,并按降序排序结果。
Convert the following R code that uses `data.table` package
to Python using the `pandas` library. Used named aggregations.
melb[
Type == "u",
.(avg_price = mean(Price), number_of_houses=.N),
by="Regionname"
][order(-avg_price)]
ChatGPT:
(作者提供的图片)
与之前的示例一样,逐步解释对于理解和调试代码非常有用。
我们来做一个关于更改列名的最终示例。下面提示中的 R 代码将类型和距离列重命名为“HouseType”和“DistanceCBD”。
Convert the following R code that uses `data.table` package
to Python using the `pandas` library.
setnames(melb,
c("Type", "Distance"),
c("HouseType", "DistanceCBD"))
ChatGPT:
(作者提供的图片)
上面截图中的代码:
melb.rename(columns={'Type': 'HouseType', 'Distance': 'DistanceCBD'}, inplace=True)
我喜欢关于inplace
参数的注释,因为它很容易被忽视。
最后总结
本文中的示例展示了 ChatGPT 在将代码从一种编程语言翻译到另一种编程语言方面的成功。我们进行了逐行翻译,这不能作为 ChatGPT 能够翻译整个代码库的证明。
然而,显然,当处理大型代码库时,使用 ChatGPT 可以节省你数小时的时间。你总是可以再检查一次并测试输出,以确保它按预期工作。
你可以成为 Medium 会员 来解锁我所有的写作内容,以及 Medium 的其他内容。如果你已经是会员,别忘了 订阅 以便在我发布新文章时收到邮件通知。
感谢阅读。如有任何反馈,请告知我。
使用 DeepFace 进行面部识别
原文:
towardsdatascience.com/using-deepface-for-face-recognition-5f8d1e43f2a6
了解如何在不需要训练自己的模型的情况下进行面部识别
·发布于 Towards Data Science ·10 分钟阅读·2023 年 4 月 18 日
--
图片由 fran innocenti 提供,来源于 Unsplash
在我之前关于面部检测的文章中,我讨论了如何使用 OpenCV 在你的网络摄像头中检测面部:
通过使用你的网络摄像头检测你的面部,享受 Python 的乐趣
towardsdatascience.com
检测面部是你通常首先执行的步骤,之后是 面部识别。
面部识别是将数字图像或视频帧中的人脸与面部数据库中的人脸进行匹配的过程。
有几种深度学习模型可以用来进行面部识别,但所有这些都需要你具备一些神经网络的知识,并且需要使用你自己的数据集进行训练。对于那些想进行面部识别但不想深入了解神经网络工作原理的人,有一个真正简化面部识别的 API — DeepFace。
什么是 DeepFace?
DeepFace 是一个轻量级的面部识别和面部属性分析(年龄、性别、情感和种族)框架(API),用于 Python。它本质上是一个用于识别面部的先进模型的封装器。它可以使用以下预训练模型:
-
VGG-Face(默认)
-
Google FaceNet
-
OpenFace
-
Facebook DeepFace
-
DeepID
-
ArcFace
-
Dlib
-
SFace
简而言之,DeepFace 允许你使用预训练的模型来识别自己的一组面孔,而无需创建和训练自己的模型。
在这篇文章中,我将向你介绍 DeepFace 的一些酷炫功能以及如何将面部识别融入你的项目中。
面部检测/识别的伦理考虑
尽管识别面孔的能力确实很酷,但它确实有很多伦理影响。在你在项目中实施面部识别之前,需要注意几个问题。关注点包括隐私(面部检测可以在未经同意的情况下追踪人们的活动)、偏见(面部识别可能对不同种族、性别或年龄的个人存在偏见)和滥用(捕获的面孔可能会用于其他非法用途或恶意目的)。因此,虽然这篇文章着重于面部识别的技术能力,但在实施之前,你应仔细考虑道德和伦理影响。
这里有一些低风险的项目,可以实现面部识别:
-
考勤追踪 — 你可以在学校或工作场所使用面部识别来进行考勤。
-
个性化 — 使用面部识别来个性化服务。一个很好的例子是娱乐服务,如根据用户的观看历史推荐特定的电视节目。
-
安全性 — 使用面部识别解锁非关键系统,如智能手机和计算机。
然而,在某些应用中使用面部识别具有严重的道德影响。以下是一些示例:
-
执法 — 尽管面部识别对执法有用,但对其不准确性和偏见有一些严重的担忧。
-
监控 — 在一些国家,面部识别技术被用于监控和追踪公民,特别是异议人士。一些公司也使用面部识别来监控员工的生产力,这直接侵犯了他们的隐私。
以下是一些你可以阅读的文章,以了解更多关于面部识别的法律和道德问题:
-
美国的面部识别:隐私问题和法律发展 —
www.asisonline.org/security-management-magazine/monthly-issues/security-technology/archive/2021/december/facial-recognition-in-the-us-privacy-concerns-and-legal-developments/
-
与面部识别软件相关的隐私和安全问题 —
www.techrepublic.com/article/privacy-and-security-issues-associated-with-facial-recognition-software/
-
10 个关注面部识别技术的理由 —
www.privacycompliancehub.com/gdpr-resources/10-reasons-to-be-concerned-about-facial-recognition-technology/
安装 DeepFace
安装 DeepFace 相对简单——如果你使用的是 Mac。对于 Windows,需要做一些额外的工作。
我将为你演示在 macOS 和 Windows 上安装 DeepFace 的步骤。
对于 macOS
对于 macOS,只需在 Jupyter Notebook 中运行以下命令以安装 cmake
、dlib
和 deepface
:
!pip install cmake
!pip install dlib
!pip install deepface
完成后,你就可以开始使用 DeepFace 了。
对于 Windows
对于 Windows,你需要执行一些步骤:
- 从
cmake.org/download/
下载 CMake 并运行安装程序。
确保在安装过程中将路径环境变量添加到其 bin 目录,例如:C:/ProgramFiles/cmake/bin 在安装过程中
-
从
visualstudio.microsoft.com/visual-cpp-build-tools/
下载 Visual Studio Build Tools 并运行安装程序。 -
检查 C++ 构建工具 并点击 安装:
所有图像均由作者提供
上述步骤可能需要一点时间。现在可以去休息一下。
- 最后,在 Jupyter Notebook 中运行以下命令:
!pip install dlib
!pip install deepface
完成后,你就可以开始使用 DeepFace 了。
使用 DeepFace
要使用 DeepFace,首先导入它:
from deepface import DeepFace
我在本文中使用 Jupyter Notebook。
准备图像
我打算下载并准备一些图像,以便我可以用 DeepFace 进行测试。首先,下载以下 基努·里维斯 的图像,并将其保存到与你的 Jupyter Notebook 相同的目录中:
来源:en.wikipedia.org/wiki/Keanu_Reeves#/media/File:Keanu_Reeves_(crop_and_levels)_(cropped).jpg
然后,创建一个名为 image1
的变量,并将图像的文件名赋给它:
image1 = 'Keanu_Reeves_(crop_and_levels)_(cropped).jpg'
同样,对以下图像也进行此操作:
image2 = 'Reunião_com_o_ator_norte-americano_Keanu_Reeves_(46806576944)_(cropped).jpg'
来源:en.wikipedia.org/wiki/Jet_Li#/media/File:Jet_Li_2009_(cropped).jpg
image3 = 'Jet_Li_2009_(cropped).jpg'
来源:en.wikipedia.org/wiki/Denzel_Washington#/media/File:Denzel_Washington_2018.jpg
image4 = 'Denzel_Washington_2018.jpg'
来源: en.wikipedia.org/wiki/Smile#/media/File:Smiling_girl.jpg
image5 = 'Smiling_girl.jpg'
比较面孔
现在,让我们使用 DeepFace 来比较两张图像,以查看它们是否包含同一个人的面孔。为此,你可以使用verify()
函数:
DeepFace.verify(img1_path = image1, # Keanu Reeves
img2_path = image2) # Keanu Reeves
第一次运行 DeepFace 时,它会将预训练模型的权重(默认使用 VGG-Face 模型)下载到你的计算机上:
权重下载完成后,DeepFace 将返回如下结果:
{'verified': True,
'distance': 0.17842618501190277,
'threshold': 0.4,
'model': 'VGG-Face',
'detector_backend': 'opencv',
'similarity_metric': 'cosine',
'facial_areas': {'img1': {'x': 42, 'y': 61, 'w': 144, 'h': 144},
'img2': {'x': 73, 'y': 57, 'w': 103, 'h': 103}},
'time': 0.27}
verified
键的值指示两张面孔是否相同。distance
键的值是两个面部向量之间的距离(值越小,两个面孔越相似)。threshold
键的值设置了distance
值的解释阈值——小于 0.4 的值会使verified
键设置为True
。facial_areas
键包含了两张图像中面孔的位置。
让我们尝试比较另外两张图像——image1
和image3
:
DeepFace.verify(img1_path = image1, # Keanu Reeves
img2_path = image3) # Jet Li
毫不意外,DeepFace 检测到这两个面孔是不同的:
{'verified': False,
'distance': 0.4034869302977028,
'threshold': 0.4,
'model': 'VGG-Face',
'detector_backend': 'opencv',
'similarity_metric': 'cosine',
'facial_areas': {'img1': {'x': 42, 'y': 61, 'w': 144, 'h': 144},
'img2': {'x': 20, 'y': 49, 'w': 164, 'h': 164}},
'time': 0.27}
嵌入
嵌入是表示面部图像的多维向量。DeepFace 允许你使用represent()
函数来查找这些嵌入:
embeddings = DeepFace.represent(img_path = image1) # Keanu Reeves
embeddings
image1
的嵌入如下:
[{'embedding': [0.00948819238692522,
-0.010876820422708988,
-0.0013923903461545706,
0.01534500066190958,
...
-0.014694824814796448,
-0.015208861790597439,
0.005312952678650618,
...],
'facial_area': {'x': 42, 'y': 61, 'w': 144, 'h': 144}}]
根据你使用的预训练模型的不同,embedding
键的值的大小会有所不同。对于 VGG-Face,面部图像被表示为 2622 维的向量。更有趣的是facial_area
键的值,它包含了图像中面部的位置。
以下代码片段显示了基努·里维斯的图像(image1
)以及绘制在检测到的面部区域周围的矩形:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
im = Image.open(image1)
fig, ax = plt.subplots()
# display the image
ax.imshow(im)
# draw a rectangle around the face
face_coord = embeddings[0]['facial_area']
rect = patches.Rectangle((face_coord['x'], face_coord['y']),
face_coord['w'], face_coord['h'],
linewidth = 2,
edgecolor = 'r',
facecolor = 'none')
# add the patch to the axes
ax.add_patch(rect)
从一组面孔中找到一个面孔
DeepFace 最令人兴奋的功能之一是能够将图像与文件夹中的一组图像进行比较。DeepFace 将告诉你哪些图像与提供的图像最为相似。这本质上是面部识别。
为了展示这个功能的强大,我在与 Jupyter Notebook 相同的文件夹中创建了一个名为Faces的新文件夹,并用一些李连杰和基努·里维斯的图像填充了它:
然后我调用find()
函数并传入image1
:
DeepFace.find(img_path = image1, # the image to compare against
db_path = "Faces", # folder containing all the images
enforce_detection = False)[0]
如果
*find()*
函数发现 Faces 文件夹中的图像不包含任何面孔,它将引发异常。为了防止这种情况发生,请将enforce_detection
参数设置为False
。
find()
函数返回一个 Pandas DataFrame 的列表,因此你需要使用 0 索引来获取第一个 DataFrame:
数据框包含与您搜索的图像最相似的图像文件名列表。在这个例子中,最匹配的是KeanuReeves_3.jpg:
数据框根据 VGG-Face_cosine 排序(值越小,匹配度越高)。
请注意,当你第一次调用find()
函数时,将创建一个名为representations_vgg_face.pkl的表示文件,并将其存储在Faces文件夹中。如果你向Faces文件夹中添加新图像,请记得删除representations_vgg_face.pkl文件。
使用find()
函数,你可以在不训练任何模型的情况下进行人脸识别——只需创建一个包含你想识别的面孔的文件夹(每个人应尽可能多地提供图片),然后就完成了,你就拥有了一个人脸识别模型!
分析面孔
除了比较面孔和从面孔文件夹中查找面孔,DeepFace 还允许你分析面孔的各种特征,如年龄、性别、情感和种族:
DeepFace.analyze(img_path = image4,
actions = ["age", "gender", "emotion", "race"])
上述声明分析了丹泽尔·华盛顿(image4
)的图像,并且相当准确地预测了图像中的面孔是:
-
feeling sad
-
一名黑人男性
以下是analyze()
函数返回的结果:
[{'age': 35,
'region': {'x': 31, 'y': 46, 'w': 117, 'h': 117},
'gender': {'Woman': 0.015357557276729494, 'Man': 99.98464584350586},
'dominant_gender': 'Man',
'emotion': {'angry': 0.3038950626725033,
'disgust': 3.667220231060474e-11,
'fear': 2.3939014472247897,
'happy': 1.2440780556642484e-05,
'sad': 87.49081939349405,
'surprise': 6.846103949403675e-05,
'neutral': 9.81130493418037},
'dominant_emotion': 'sad',
'race': {'asian': 7.334453304675418,
'indian': 3.1661530981155095,
'black': 85.50387534522267,
'white': 0.09932484836949994,
'middle eastern': 0.03912873741168454,
'latino hispanic': 3.8570622418559934},
'dominant_race': 'black'}]
让我们尝试分析另一张图片,这次是一个微笑的女孩的图片(image5
):
DeepFace.analyze(img_path = image5,
actions = ["age", "gender", "emotion", "race"])
结果相当准确:
[{'age': 26,
'region': {'x': 377, 'y': 140, 'w': 558, 'h': 558},
'gender': {'Woman': 99.66641068458557, 'Man': 0.3335847519338131},
'dominant_gender': 'Woman',
'emotion': {'angry': 1.31229280062393e-10,
'disgust': 2.1887400676449618e-26,
'fear': 3.8267408134914985e-22,
'happy': 99.99999403953517,
'sad': 2.6514247764822096e-14,
'surprise': 2.245952144279152e-17,
'neutral': 6.961274993922523e-06},
'dominant_emotion': 'happy',
'race': {'asian': 3.0054475969609733,
'indian': 4.536693711482729,
'black': 0.7851633247927815,
'white': 41.17996289491211,
'middle eastern': 23.323961892600284,
'latino hispanic': 27.168768902870287},
'dominant_race': 'white'}]
更改预训练模型
如果你想更改 DeepFace 使用的基础预训练模型,你可以在verify()
函数中使用model_name
参数指定你想使用的模型:
models = ["VGG-Face", "Facenet", "OpenFace", "DeepFace", "ArcFace"]
DeepFace.verify(image1,
image2,
model_name = models[1]) # change to Facenet
现在,所选模型的权重将被下载到你的计算机上。
represent()
函数也是如此:
embeddings = DeepFace.represent(img_path = image1,
model_name = models[1])
以及find()
函数:
DeepFace.find(img_path = image1,
db_path = "Faces",
model_name = models[1],
enforce_detection = False)[0]
如果你喜欢阅读我的文章,并且这些文章对你的职业/学习有所帮助,请考虑成为 Medium 会员。每月费用为 5 美元,这将为你提供对 Medium 上所有文章(包括我的文章)的无限制访问权限。如果你使用以下链接注册,我将获得一小部分佣金(对你没有额外费用)。你的支持意味着我可以花更多时间撰写像这样的文章。
[## 通过我的推荐链接加入 Medium - Wei-Meng Lee
阅读 Wei-Meng Lee 的每一个故事(以及 Medium 上成千上万其他作者的故事)。你的会员费直接支持…
weimenglee.medium.com](https://weimenglee.medium.com/membership?source=post_page-----5f8d1e43f2a6--------------------------------)
总结
使用 DeepFace 使得进行面部识别变得非常简单——无需明确的训练,你只需将面部图像放入一个特定的文件夹,DeepFace 就会完成识别特定面孔的繁重工作。DeepFace 的缺点是返回结果需要一点时间。如果你想使用 DeepFace 进行实时面部识别,它可能不是最佳选择。然而,仍然有很多应用可以使用 DeepFace 构建。如果你在当前项目中使用 DeepFace,请务必与我们分享你是如何使用 DeepFace 的,或你打算如何使用它。
使用 DuckDB 与 Polars
原文:
towardsdatascience.com/using-duckdb-with-polars-e15a865e48a3
了解如何使用 SQL 查询你的 Polars DataFrame
·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 4 月 14 日
--
Hans-Jurgen Mager 在 Unsplash 上的照片
在我之前关于数据分析的几篇文章中,我讨论了两个在业界逐渐受到关注的重要新兴库:
-
DuckDB — 你可以使用 SQL 语句在内存中查询数据集。
-
Polars — 与久负盛名的 Pandas 库相比,更高效的数据框库。
了解如何使用 SQL 进行数据分析
levelup.gitconnected.com ## 开始使用 Polars DataFrame 库
了解如何使用 Polars DataFrame 库操作表格数据(并替代 Pandas)
towardsdatascience.com
那么,结合这两个库的优势会怎样呢?
实际上,你可以通过 DuckDB 直接使用 SQL 语句查询 Polars 数据框。
那么,使用 SQL 查询 Polars 数据框有什么好处呢?尽管使用起来很方便,但操作 Polars 数据框仍然需要一些练习和相对陡峭的学习曲线。但由于大多数开发者已经熟悉 SQL,直接使用 SQL 操作数据框是否更加方便呢?采用这种方法,开发者可以兼得两全其美:
-
使用所有各种函数查询 Polars 数据框的能力,或者
-
在数据提取时,使用 SQL 进行更自然、更简单的操作
在这篇文章中,我将给你一些如何通过 DuckDB 使用 SQL 查询 Polars 数据框的示例。
安装 Polars 和 DuckDB
对于这篇文章,我正在使用 Jupyter Notebook。确保你已经通过以下命令安装了 Polars 和 DuckDB:
!pip install polars
!pip install duckdb
创建 Polars 数据框
要开始使用,让我们手动创建一个 Polars 数据框:
import polars as pl
df = pl.DataFrame(
{
'Model': ['iPhone X','iPhone XS','iPhone 12',
'iPhone 13','Samsung S11',
'Samsung S12','Mi A1','Mi A2'],
'Sales': [80,170,130,205,400,30,14,8],
'Company': ['Apple','Apple','Apple','Apple',
'Samsung','Samsung','Xiao Mi',
'Xiao Mi'],
})
df
数据框的样子如下:
所有图片由作者提供
比如,你现在想找出所有苹果公司销售量超过 80 的手机。你可以使用 Polars 中的 filter()
函数,如下所示:
df.filter(
(pl.col('Company') == 'Apple') &
(pl.col('Sales') > 80)
)
结果如下所示:
在 Polars 数据框上使用 SQL
现在让我们执行与前一节相同的查询,只不过这次我们将使用 DuckDB 和 SQL 语句。不过首先,让我们选择数据框中的所有行:
import duckdb
result = duckdb.sql('SELECT * FROM df')
result
你可以直接在 SQL 语句中引用
df
数据框。
使用 DuckDB 时,你可以通过 sql()
函数发出 SQL 语句。或者,query()
函数也可以使用:
result = duckdb.query('SELECT * FROM df')
result
变量是一个 duckdb.DuckDBPyRelation
对象。使用这个对象,你可以执行许多不同的任务,例如:
- 计算 销售额 列的平均值:
result.mean('Sales')
- 描述数据框:
result.describe()
- 对数据框中的列应用缩放函数:
result.apply("max", 'Sales,Company')
- 重新排序数据框:
result.order('Sales DESC')
但最简单的方式是直接使用 SQL 查询 Polars 数据框。
例如,如果你想获取所有销售量大于 80 的行,只需使用 sql()
函数和以下 SQL 语句:
duckdb.sql('SELECT * FROM df WHERE Sales >80').pl()
pl()
函数将duckdb.DuckDBPyRelation
对象转换为 Polars 数据框。如果你想将其转换为 Pandas 数据框,请使用df()
函数。
如果你想获取所有型号名称以“iPhone”开头的行,请使用以下 SQL 语句:
duckdb.sql("SELECT * FROM df WHERE Model LIKE 'iPhone%'").pl()
如果你想获取 Apple 和 Xiao Mi 的所有设备,请使用以下 SQL 语句:
duckdb.sql("SELECT * FROM df WHERE Company = 'Apple' OR Company ='Xiao Mi'").pl()
使用 DuckDB 与 Polars 数据框的真正强大之处在于你想从多个数据框中查询时。考虑以下来自 2015 年航班延误 数据集的三个 CSV 文件:
2015 年航班延误数据集 —
www.kaggle.com/datasets/usdot/flight-delays
。许可证 — CC0: 公开领域
-
flights.csv
-
airlines.csv
-
airports.csv
让我们使用 Polars 加载它们:
import polars as pl
df_flights = pl.scan_csv('flights.csv')
df_airlines = pl.scan_csv('airlines.csv')
df_airports = pl.scan_csv('airports.csv')
display(df_flights.collect().head())
display(df_airlines.collect().head())
display(df_airports.collect().head())
上述语句使用了惰性求值来加载这三个 CSV 文件。这确保了在优化所有查询之前,不会对数据框执行任何查询。
collect()
函数强制 Polars 将 CSV 文件加载到数据框中。
这里是 df_flights
、df_airlines
和 df_airports
数据框的样子:
假设你想统计一家航空公司延误的次数,同时显示每家航空公司的名称,这里是你可以使用的 SQL 语句,使用 df_airlines
和 df_flights
数据框:
duckdb.sql('''
SELECT
count(df_airlines.AIRLINE) as Count,
df_airlines.AIRLINE
FROM df_flights, df_airlines
WHERE df_airlines.IATA_CODE = df_flights.AIRLINE AND df_flights.ARRIVAL_DELAY > 0
GROUP BY df_airlines.AIRLINE
ORDER BY COUNT DESC
''')
结果如下:
如果你想统计每个州的机场数量并按降序排序,可以使用以下 SQL 语句:
duckdb.sql('''
SELECT STATE, Count(*) as AIRPORT_COUNT
FROM df_airports
GROUP BY STATE
ORDER BY AIRPORT_COUNT DESC
''')
最后,假设你想知道哪家航空公司有最高的平均延误时间。你可以使用以下 SQL 语句计算各种统计数据,如 最小到达延误、最大到达延误、平均到达延误 和 到达延误的标准差:
duckdb.sql('''
SELECT AIRLINE, MIN(ARRIVAL_DELAY), MAX(ARRIVAL_DELAY),
MEAN(ARRIVAL_DELAY), stddev(ARRIVAL_DELAY)
FROM df_flights
GROUP BY AIRLINE
ORDER BY MEAN(ARRIVAL_DELAY)
''')
根据平均到达延误,我们可以看到 AS 航空公司是延误最短的航空公司(由于值为负,这意味着大多数情况下它会更早到达!),而 NK 航空公司是延误最长的航空公司。想知道 AS 航空公司是什么吗?尝试使用你刚学到的知识!我将把它留作练习,答案在本文末尾。
如果你喜欢阅读我的文章并且它对你的职业/学习有所帮助,请考虑注册成为 Medium 会员。每月仅需 5 美元,你将获得对 Medium 上所有文章(包括我的文章)的无限制访问。如果你通过以下链接注册,我将获得一小部分佣金(对你没有额外费用)。你的支持意味着我可以花更多时间写更多类似的文章。
## 通过我的推荐链接加入 Medium - Wei-Meng Lee
阅读 Wei-Meng Lee 的每个故事(以及 Medium 上其他成千上万的作家的故事)。你的会员费直接支持……
总结
在这篇简短的文章中,我展示了如何将 DuckDB 和 Polars 结合使用来查询你的数据框。利用这两个库,你可以享受最好的两个世界——使用熟悉的查询语言(即 SQL)来查询高效的数据框。快用你自己的数据集试试看,并与我们分享它是如何帮助你的数据分析过程的。
问答题答案:
duckdb.sql("SELECT AIRLINE from df_airlines WHERE IATA_CODE = 'AS'")
使用枚举和 functools 升级你的 Pandas 数据管道
原文:
towardsdatascience.com/using-enums-and-functools-to-upgrade-your-pandas-data-pipelines-d51ca1418fe2
编程
通过两个逐步示例,看看如何更高效地编程以处理你的数据
·发布于Towards Data Science ·12 分钟阅读·2023 年 6 月 9 日
--
图片使用了我才华横溢的姐姐ohmintyartz的许可。
你可能在创建数据管道以处理原始数据时曾使用过 Pandas。编写代码来过滤、分组和执行数据计算只是构建数据管道和 ETL 过程的第一步。
在大规模数据处理时,除了这些,我们还应该编写功能性且易于阅读和维护的代码。
有许多方法可以改进你现有的数据管道,比如添加高效的日志记录、包括数据验证,甚至使用除 Pandas 之外的新库,如 PySpark 和 Polars。
此外,你还可以不同地结构化你用于处理数据的实际代码。这意味着不是为了提高管道的性能,而是专注于编写易于修改和迭代的代码。
在这篇文章中,我们将通过两个简单的例子来看看如何做到这一点,特别是通过使用枚举和functools。
随意在你选择的笔记本或 IDE 中跟随操作!你可以从 Kaggle这里下载数据集,CC0 1.0 通用(CC0 1.0)公共领域献身许可下免费使用。然后导入并运行以下内容,我们可以开始了!
import pandas as pd
from pandas import DataFrame
import numpy as np
from functools import reduce
from enum import Enum
FILE_PATH = "/Updated_sales.csv"
CHUNK_SIZE = 1000
def read_raw_data(file_path: str, chunk_size: int=1000) -> DataFrame:
csv_reader = pd.read_csv(file_path, chunksize=chunk_size)
processed_chunks = []
# append the processed chunk to the list
for chunk in csv_reader:
chunk = chunk.loc[chunk["Order ID"] != "Order ID"].dropna()
processed_chunks.append(chunk)
# concatenate the processed chunks into a single DataFrame
return pd.concat(processed_chunks, axis=0)
df = read_raw_data(file_path=FILE_PATH)
使用枚举更好地结构化你的数据管道
对枚举的快速介绍
你可能首先会想“什么是枚举”?
枚举,简称“enum”,是“一组与唯一值绑定的符号名称(成员)” (Python docs, 2023)。实际来说,这意味着你可以在一个主要的“类”下定义和使用一组相关的变量。
一个简单的例子是有一个枚举类“Color”,其中包含“Red”,“Green”和“Blue”等名称,每当你想引用特定颜色时都可以使用它们。
接下来,你可能会想,如果你可以直接在数据处理管道中调用所需的名称,定义一些变量在单独的枚举类中还有什么意义呢?
枚举有一些关键的好处:
-
定义枚举可以让你将相关的常量组织在一个(或多个)类中,这些类可以作为你在管道中调用的维度、度量和其他常量的真实来源;
-
使用枚举将帮助你避免在数据管道中传递无效值,前提是你正确地定义和维护了枚举类;
-
枚举允许用户使用标准化的数据点和常量集,这在多人基于一个主要数据源进行聚合或创建模型时非常有用(有助于避免在原始数据源中对同一列有多个定义或别名)。
听起来有点抽象,所以让我们看看在标准预处理管道中如何实际应用枚举的例子。
在你的数据处理管道中使用枚举
我们已经有了初始的 DataFrame,所以让我们开始创建一个函数,通过拆分购买地址来向数据中添加更多列。
def split_purchase_address(df_to_process: DataFrame) -> DataFrame:
df_address_split = df_to_process["Purchase Address"].str.split(",", n=3, expand=True)
df_address_split.columns = ["Street Name", "City", "State and Postal Code"]
df_state_postal_split = (
df_address_split["State and Postal Code"]
.str.strip()
.str.split(" ", n=2, expand=True)
)
df_state_postal_split.columns = ["State Code", "Postal Code"]
return pd.concat([df_to_process, df_address_split, df_state_postal_split], axis=1)
接下来,我们可以通过使用 Pandas 的原生pipe
方法将其应用于现有的表,例如,我们在 DataFrame 上调用 pipe 并将函数名作为参数传递。
processed_df = df.pipe(split_purchase_address)
接下来,你会看到我们拥有的数据仍然处于非常详细的层次,订单 ID 是表的主键。当我们想要对数据进行聚合以进行进一步分析时,可以使用 Pandas 中的groupby
方法来实现。
在 Pandas 中,你可能会看到一些代码用于在一组列上对数据进行分组,然后对其中一个维度(在这个例子中我们将使用订单 ID)进行聚合计数,看起来可能是这样的:
# groupby normally
grouped_df = (
processed_df
.groupby(
["Product", "Quantity Ordered", "Street Name", "City", "State Code", "Postal Code"]
)
["Order ID"]
.count()
.reset_index()
.sort_values("Order ID", ascending=False)
.rename({"Order ID": "Count of Order IDs"}, axis=1)
)
结果是一个新的 DataFrame,像这样:
在这个简单的例子中,按六列进行分组并不太困难,我们可以将这些列的列表直接传递给groupby
方法。然而,这有一些缺点:
-
如果我们处理的是一个更大的数据集,并且希望按 20 列进行分组怎么办?
-
如果最终用户提出了新的需求,我们需要调整具体的列进行分组怎么办?
-
如果底层表发生变化,列的名称或别名发生变化怎么办?
我们可以通过在枚举类中定义列来部分解决这些问题。特别是对于这种情况,我们可以在一个新的类SalesGroupByColumns
中定义这些与销售表相关的分组列,如下所示:
class SalesGroupByColumns(Enum):
PRODUCT = "Product"
QUANTITY_ORDERED = "Quantity Ordered"
STREET_NAME = "Street Name"
CITY = "City"
STATE_CODE = "State Code"
POSTAL_CODE = "Postal Code"
我们在这里做的最终只是将列定义为新的 Enum 类中的常量(这是从from enum import Enum
的导入中获得的)。
现在我们已经定义了这些新的枚举值,我们可以像这样访问枚举的各个成员:
SalesGroupByColumns.PRODUCT
SalesGroupByColumns.PRODUCT.value
只需调用枚举名称将返回枚举成员,并且在目标枚举上调用value
让我们直接访问枚举成员的字符串值。现在,要将所有枚举成员放入我们可以传递给groupby
的列表中,我们可以像这样使用列表推导式:
[column.value for column in SalesGroupByColumns]
在将它们放入列表后,我们可以将这个输出分配给一个变量,然后将这个变量传递给我们的groupby
方法,而不是直接传递原始的字符串列表:
# groupby adjusted
groupby_columns = [column.value for column in SalesGroupByColumns]
grouped_df = (
processed_df
.groupby(groupby_columns)
["Order ID"]
.count()
.reset_index()
.sort_values("Order ID", ascending=False)
.rename({"Order ID": "Count of Order IDs"}, axis=1)
)
grouped_df.head()
我们得到了与之前相同的表格,但代码看起来稍微更干净。对于长期维护来说,这种做法的好处可以在你处理管道时体现出来。
例如,你可能会发现你想添加几个新的列,比如说如果你还想进行更多的特征工程,创建一个房屋号码和产品类别列,然后添加到分组中。你可以这样更新你的枚举类:
# what's the benefit? adding new columns!
class SalesGroupByColumns(Enum):
PRODUCT = "Product"
QUANTITY_ORDERED = "Quantity Ordered"
STREET_NAME = "Street Name"
CITY = "City"
STATE_CODE = "State Code"
POSTAL_CODE = "Postal Code"
HOUSE_NUMBER = "House Number"
PRODUCT_CATEGORY = "Prouct Category"
# then you run the code same as before and it would still work
这样,你就不需要修改现有的处理代码,因为列表推导式会自动获取SalesGroupByColumns
类中的所有值,并将其应用于你的聚合逻辑。
一个好的提示是,所有这些仅在你确切知道你在枚举类中定义了什么并仅按预期使用它们时才会有效。如果你在这里进行更改,并且你在几个不同的表中抓取所有这些列进行分组,重要的是要确保那是你打算做的。
否则,你可以在单独的类中定义你需要用于特定表的枚举集,或者如果有意义的话,在单独的列列表中(这样你仍然避免将原始字符串列表传递给groupby
方法)。
在 Pandas 中使用枚举进行数据聚合
另一个例子是,假设我们有一个不同的情况,我们通过更改一些列的数据类型和创建一个新的总成本列来对数据进行一些额外的转换。我们可以像这样将其添加到之前的管道中:
def convert_numerical_column_types(df_to_process: DataFrame) -> DataFrame:
df_to_process["Quantity Ordered"] = df_to_process["Quantity Ordered"].astype(int)
df_to_process["Price Each"] = df_to_process["Price Each"].astype(float)
df_to_process["Order ID"] = df_to_process["Order ID"].astype(int)
return df_to_process
def calculate_total_order_cost(df_to_process: DataFrame) -> DataFrame:
df_to_process["Total Cost"] = df_to_process["Quantity Ordered"] * df_to_process["Price Each"]
return df_to_process
processed_df = (
df
.pipe(split_purchase_address)
.pipe(convert_numerical_column_types)
.pipe(calculate_total_order_cost)
)
现在我们的 DataFrame 在订单 ID 级别上已经转换过来,接下来我们在一组新列上进行另一个分组,但这次对几个不同的度量进行聚合:
# let's say we have a file now "SalesColumns.py"
# we can add to it
import numpy as np
class AddressColumns(Enum):
STREET_NAME = "Street Name"
CITY = "City"
STATE_CODE = "State Code"
POSTAL_CODE = "Postal Code"
class SalesMeasureColumns(Enum):
TOTAL_COST = "Total Cost"
QUANTITY_ORDERED = "Quantity Ordered"
# then separately we can do the groupby
groupby_columns = [column.value for column in AddressColumns]
grouped_df = (
processed_df
.groupby(groupby_columns)
.agg(
Total_Cost=(SalesMeasureColumns.TOTAL_COST.value, np.sum),
Total_Quantity_Ordered=(SalesMeasureColumns.QUANTITY_ORDERED.value, np.sum)
)
.reset_index()
.sort_values("Total_Cost", ascending=False)
)
这里有几个关键点需要注意:
-
我们定义了一组新的枚举类:
AddressColumns
和SalesMeasureColumns
。现在,对于一个不同的表格,如果我们想要专门按地址字段进行分组,我们可以定义groupby_columns
列表来包括这些列,然后将其传递给转换后的 DataFrame 的groupby
方法。 -
SalesMeasureColumns
类包括我们希望进行聚合的度量。将原始表格中的列名放入该类中意味着如果其他人也希望汇总成本和订购数量,他们会调用正确的列。
我们可以进一步在之前的管道中添加之前定义的链式管道和函数,并将此代码放入新函数中以收集列列表并聚合表格。这样,最终代码变得更容易阅读,并且随着时间的推移,可能更容易调试和记录。
对于聚合,总成本和订购数量在不同的表格、团队和最终用户中可能会有所不同。在 SalesMeasuresColumns
的枚举中定义它,意味着对于 Sales 表格和度量,所有用户都可以使用相同的定义对这些列进行聚合。
利用 functools 的 reduce 方法来简化你的数据过滤
现在,让我们来看看如何使用 functools 的 reduce 方法来改进数据过滤。
在 Pandas 中,过滤原始数据的常见方法是使用 loc
方法。例如,让我们写一些代码来过滤包含“North”的街道数据。
grouped_df.loc[grouped_df["Street Name"].str.contains("North")]
我们的输出 DataFrame 现在只包括根据过滤器筛选的列。一般来说,在对特定问题进行分析或调整数据集以适应机器学习模型时,你可能希望对 DataFrame 应用多个过滤器。你可以这样使用 loc
方法应用多个过滤器:
grouped_df.loc[
(grouped_df["Street Name"].str.contains("North")) &
(grouped_df["Postal Code"].str.contains("940")) &
(grouped_df["Total_Cost"] < 1000)
]
接下来,就像之前优化代码时使用枚举一样,如果我们想要添加、编辑或删除一些过滤器会怎么样?来自最终用户的需求变化或从一些探索性数据分析中获得的新见解可能意味着你需要随着时间的推移调整过滤数据的方式。
与其继续在 loc
方法中添加更多的代码行,我们可以在一个变量中定义过滤器列表,然后稍后将其传递给 loc
。
filter_conditions = [
grouped_df["Street Name"].str.contains("North"),
grouped_df["Postal Code"].str.contains("940"),
grouped_df["Total_Cost"] < 1000
]
然而,我们不能直接将这个列表传递给 loc
。loc
中的多个条件接受的格式是由 DataFrame 的布尔掩码组成,通过 &
运算符分隔。
# doesn't work -> you can't just pass a list into loc
df.loc[FILTER_CONDITIONS]
# format should look like this
df.loc[condition_1 & condition_2 & condition_3]
这就是 functools.reduce
方法发挥作用的地方。如何实现正确的格式可以在这里的示例中看到:
# functools reduce
reduce(lambda x, y: f"{x} & {y}", ["condition_1", "condition_2", "condition_3"])
在这个简单的示例中,我们只是将一堆字符串组合在一起,以输出稍后在 loc
方法中需要写的内容。
functools 中的reduce
方法允许你将函数和可迭代对象作为参数传递。reduce
方法然后将函数应用于可迭代对象中的元素,逐步执行。这意味着它会对可迭代对象中的元素集合进行顺序操作。
在我们的例子中,我们有一个匿名lambda
函数,它接受两个参数x
和y
,并使用 f-string 将它们结合起来,用&
操作符分隔。这意味着首先函数会输出condition_1 & condition_2
,然后将condition_3
添加到第一个累计值中,结果就是condition_1 & condition_2 & condition_3
。
现在,我们可以将这个 reduce 方法与函数和过滤条件传递给loc
方法,而不是逐个将所有原始过滤条件传递给loc
。
grouped_df.loc[reduce(lambda x, y: x & y, filter_conditions)]
应用所有过滤器后的最终 DataFrame 如下所示。如果你想要添加另一个过滤器,只需将其添加到现有的filter_conditions
变量中,其他代码保持不变。
这篇文章的重点是展示几种不同的代码结构方式,以提高代码的可读性和可维护性。虽然这些改变不一定会加快你的管道速度或改善内存使用,但考虑代码的易用性是很重要的。
这有两个好处:
-
如果多人同时处理相同的原始数据,将列名和过滤器定义在一个集中位置意味着每个人都可以参考一个唯一的真实来源,避免出现不同的名称和逻辑却最终指代同一事物的情况;
-
当处理包含变化需求的数据时,你只需在一个地方修复列名和逻辑即可加快开发时间。
希望你觉得这些改善数据管道的技巧有帮助。尝试一下这种数据处理结构,看看它是否适合你!
如果你喜欢我的内容,可以考虑关注我并通过下面的链接注册成为 Medium 会员。每月只需 5 美元,你将获得对 Medium 上所有内容的无限制访问。通过我的链接注册可以让我赚取少量佣金。如果你已经注册关注我,感谢你的支持!
[## 通过我的推荐链接加入 Medium — Byron Dolon
作为 Medium 会员,你的一部分会员费用会支付给你阅读的作者,而你可以完全访问每个故事……
byrondolon.medium.com](https://byrondolon.medium.com/membership?source=post_page-----d51ca1418fe2--------------------------------)
更多内容: - 3 种高效的方式通过子字符串过滤 Pandas DataFrame 列
- 5 条实用建议给有志于数据分析的新人 - 使用 Python 中的堆叠条形图改进数据可视化 - C**onditional Selection and Assignment With .loc in Pandas - 5(以及半个)行代码用以理解 Pandas 中的数据
使用从 BERT 嵌入中衍生的向量表示的傅里叶变换进行语义相似度评估
照片由 Igor Shabalin 提供,已获许可
探索通过评估 BERT 嵌入的不同表示来了解句子中词语的相互影响
·
关注 发布于 Towards Data Science ·5 min read·2023 年 1 月 14 日
--
BERT 嵌入在程序化提取文本意义方面提供了极大的机会。似乎我们(以及机器)理解文本所需的一切都隐藏在这些数字中。关键在于如何正确地操作这些数字。我在我最近的帖子中讨论了这一概念 发现 BERT 嵌入在不同层次中的趋势,用于语义上下文确定任务。本文继续讨论这一主题,探讨通过对 BERT 嵌入应用傅里叶变换得到的向量表示是否在 NLP 任务中的意义提取方面也有用。
假设
正如你可能从物理学中知道的那样,傅里叶变换使我们能够理解信号中的频率。表示单词嵌入的向量是否看起来像信号?也就是说,从傅里叶变换得出的频域知识在处理 BERT 嵌入以发现语义接近性时是否有用?简单来说,傅里叶变换在分析 BERT 嵌入时是否有意义?让我们来检查一下。
实验
本文剩余部分讨论的例子假设你已经按照 之前的帖子 中的示例定义了模型。我们还需要从该 之前的帖子 中讨论的样本句子中获得的表示。请注意,本文使用了一组与上述帖子中使用的样本稍有不同的样本。因此,在创建表示之前,你需要定义以下样本:
sents = []
sents.append(‘I had a very good apple.’)
sents.append(‘I had a very good orange.’)
sents.append(‘I had a very good adventure.’)
正如你所看到的,所有句子都有相同的直接宾语修饰词集合,而直接宾语在每个句子中都是不同的。我们实验的目的是检查直接宾语的语义接近性如何影响修饰词的接近程度。
之前的帖子中已经介绍了对样本句子进行分词和生成隐藏状态的过程。因此我们在这里不会重复这个过程。
让我们专注于每个句子中直接宾语(“apple/orange/adventure”)的修饰词(“a/very/good”)。为了在获取其嵌入时使用正确的索引,我们首先检查一下句子中的词元:
for i in range(len(sents)):
print(tokenizer.convert_ids_to_tokens(tokenized_text[i]))
[‘[CLS]’, ‘i’, ‘had’, ‘a’, ‘very’, ‘good’, ‘apple’, ‘.’, ‘[SEP]’]
[‘[CLS]’, ‘i’, ‘had’, ‘a’, ‘very’, ‘good’, ‘orange’, ‘.’, ‘[SEP]’]
[‘[CLS]’, ‘i’, ‘had’, ‘a’, ‘very’, ‘good’, ‘adventure’, ‘.’, ‘[SEP]’]
由于我们对直接宾语的修饰词(在这个具体的例子中,我们有三个修饰词:a, very, good)感兴趣,我们需要索引为 3,4,5 的词元。因此,我们需要将索引偏移 3。以下,我们将为每个修饰词在每个句子中获取上下文嵌入。我们将把修饰词嵌入保存在为每个句子定义的列表中:
l12_1 = []
l12_2 = []
l12_3 = []
for i in range(3):
l12_1.append(hidden_states[0][12][0][i+3][:10].numpy())
l12_2.append(hidden_states[1][12][0][i+3][:10].numpy())
l12_3.append(hidden_states[2][12][0][i+3][:10].numpy())
现在让我们来探讨不同句子中直接宾语的语义相似性如何影响各自修饰语的相似性。
from scipy import spatial
for i in range(3):
print(1 — spatial.distance.cosine(l12_1[i], l12_2[i]))
0.9003266096115112
0.9178041219711304
0.8865049481391907
在上述输出中,我们可以看到“Apple”和“Orange”修饰语的上下文嵌入显示了高度的相似性。这是可以理解的,因为直接宾语“Apple”和“Orange”本身非常接近。
虽然从以下结果来看,“Apple”和“Adventure”修饰语的表示并不太接近:
for i in range(3):
print(1 — spatial.distance.cosine(l12_1[i], l12_3[i]))
0.49141737818717957
0.7987119555473328
0.6404531598091125
“Orange”和“Adventure”修饰语也不应该太接近:
for i in range(3):
print(1 — spatial.distance.cosine(l12_2[i], l12_3[i]))
0.7402883768081665
0.8417230844497681
0.7215733528137207
现在让我们从 BERT 提供的嵌入中推导出更复杂的表示。首先,让我们获取每个句子中修饰语的初始嵌入:
l0_1 = []
l0_2 = []
l0_3 = []
for i in range(3):
l0_1.append(hidden_states[0][0][0][i+3][:10].numpy())
l0_2.append(hidden_states[1][0][0][i+3][:10].numpy())
l0_3.append(hidden_states[2][0][0][i+3][:10].numpy())
现在,我们可以,例如,通过将上下文嵌入(在第 12 层编码器中生成)除以相应的初始嵌入(如前面帖子中讨论的),得到一些新的嵌入表示,用于进一步分析。
import numpy as np
l0_12_1 = []
l0_12_2 = []
l0_12_3 = []
for i in range(3):
l0_12_1.append(np.log(l12_1[i]/l0_1[i]))
l0_12_2.append(np.log(l12_2[i]/l0_2[i]))
l0_12_3.append(np.log(l12_3[i]/l0_3[i]))
for i in range(3):
l0_12_1[i] = np.where(np.isnan(l0_12_1[i]), 0, l0_12_1[i])
l0_12_2[i] = np.where(np.isnan(l0_12_2[i]), 0, l0_12_2[i])
l0_12_3[i] = np.where(np.isnan(l0_12_3[i]), 0, l0_12_3[i])
为了分析目的,你可能还想创建另一组表示,计算上下文嵌入与非上下文(初始)嵌入之间的逐元素差异。
_l0_12_1 = []
_l0_12_2 = []
_l0_12_3 = []
for i in range(3):
_l0_12_1.append(l12_1[i]-l0_1[i])
_l0_12_2.append(l12_2[i]-l0_2[i])
_l0_12_3.append(l12_3[i]-l0_3[i])
在继续评估我们刚创建的表示之前,让我们使用傅里叶变换创建另一组表示,以便随后可以比较不同方法获得的所有表示。
fourierTransform_1=[]
fourierTransform_2=[]
fourierTransform_3=[]
for i in range(3):
fourierTransform_1.append(np.fft.fft(l0_12_1[i])/len(l0_12_1[i]))
fourierTransform_2.append(np.fft.fft(l0_12_2[i])/len(l0_12_2[i]))
fourierTransform_3.append(np.fft.fft(l0_12_3[i])/len(l0_12_3[i]))
现在我们可以比较每个修饰语的每对句子的获得表示。
print (sents[0])
print (sents[1])
print()
for i in range(3):
print(tokenizer.convert_ids_to_tokens(tokenized_text[0][i+3]))
print(‘diff’, 1 — spatial.distance.cosine(_l0_12_1[i], _l0_12_2[i]))
print(‘log_quotient’, 1 — spatial.distance.cosine(l0_12_1[i], l0_12_2[i]))
print(‘fourier’, 1 — spatial.distance.cosine(abs(fourierTransform_1[i]), abs(fourierTransform_2[i])))
print()
生成的输出应如下所示:
I had a very good apple.
I had a very good orange.
a
diff 0.8866338729858398
log_quotient 0.43184104561805725
fourier 0.9438706822278501
very
diff 0.9572229385375977
log_quotient 0.9539480209350586
fourier 0.9754009221726183
good
diff 0.8211167454719543
log_quotient 0.5680340528488159
fourier 0.7838190546462953
在上述实验中,我们期望看到前两个句子中的相同修饰语之间具有高度的相似性。实际上,我们可以看到差异和傅里叶变换方法在任务中表现良好。
以下实验的目的是确定那些被修饰的名词不太接近的修饰语的相似性。
print (sents[0])
print (sents[2])
print()
for i in range(3):
print(tokenizer.convert_ids_to_tokens(tokenized_text[0][i+3]))
print(‘diff’, 1 — spatial.distance.cosine(_l0_12_1[i], _l0_12_3[i]))
print(‘log_quotient’, 1 — spatial.distance.cosine(l0_12_1[i], l0_12_3[i]))
print(‘fourier’, 1 — spatial.distance.cosine(abs(fourierTransform_1[i]), abs(fourierTransform_3[i])))
print()
这是输出结果:
I had a very good apple.
I had a very good adventure.
a
diff 0.5641788840293884
log_quotient 0.5351020097732544
fourier 0.8501702469740261
very
diff 0.8958494067192078
log_quotient 0.5876994729042053
fourier 0.8582797441535993
good
diff 0.6836684346199036
log_quotient 0.18607155978679657
fourier 0.8857107252606878
上述输出显示,当评估那些相关名词不是非常接近的修饰语的相似性时,差异和对数商的效果最好。
print (sents[1])
print (sents[2])
print()
for i in range(3):
print(tokenizer.convert_ids_to_tokens(tokenized_text[0][i+3]))
print(‘diff’, 1 — spatial.distance.cosine(_l0_12_2[i], _l0_12_3[i]))
print(‘log_quotient’, 1 — spatial.distance.cosine(l0_12_2[i], l0_12_3[i]))
print(‘fourier’, 1 — spatial.distance.cosine(abs(fourierTransform_2[i]), abs(fourierTransform_3[i])))
print()
输出结果如下:
I had a very good orange.
I had a very good adventure.
a
diff 0.8232558369636536
log_quotient 0.7186723351478577
fourier 0.8378725099204362
very
diff 0.9369465708732605
log_quotient 0.6996179223060608
fourier 0.9164374584436726
good
diff 0.8077239990234375
log_quotient 0.5284199714660645
fourier 0.9069805698881434
我们再次看到,当评估那些相关名词不是非常接近的修饰语的相似性时,差异和对数商的效果最好。
结论
在分析 BERT 嵌入时,傅里叶变换是否有意义?根据本文所做的实验,我们可以得出结论,这种方法可以有效地与其他方法结合使用。
使用 GPT-3.5-Turbo 和 GPT-4 进行人道主义数据类别预测
·
关注 发表在数据科学之道 · 23 分钟阅读 · 2023 年 3 月 29 日
--
图像由稳定扩散创建,提示词为“预测猫”。
总结
在本文中,我探讨了使用 GPT-3.5-Turbo 和 GPT-4 对数据集进行分类,而不需要标记数据或模型训练,通过向模型提供数据摘录和类别定义。使用从令人惊叹的人道主义数据交换(HDX)找到的一小部分已分类的“数据网格”数据集,GPT-4 的零-shot 提示在预测类别时达到了 96%的准确率,而在预测类别和子类别时达到了 89%的准确率。GPT-4 在相同提示下的表现优于 GPT-3.5-Turbo,类别准确率为 96%对 66%。尤其有用的是,模型能够提供其预测的推理,这有助于识别改进过程。这只是由于成本限制而涉及少量记录的快速分析,但它显示了使用大型语言模型进行数据质量检查和总结的一些前景。由于提示中允许的最大令牌数量影响数据摘录中可以包含的数据量,以及性能和成本挑战——特别是如果你是一个小型非营利组织!——在商业生成 AI 的早期阶段存在局限性。
人道主义数据交换(HDX)平台有一个很棒的功能叫做HDX 数据网格,它提供了按国家划分的六个关键危机类别的高质量数据覆盖概述,查看这里了解乍得的例子。进入网格的数据集会经过 HDX 团队一系列严格的测试以确定覆盖范围和质量,其中第一个测试是确定数据集是否在批准的类别中。
我在想,也许大型语言模型(LLMs)可能是一个有效的方法来应用数据质量和分类规则,在那些可能没有标记训练数据的情况下。这也很方便,以人类可读的文本形式提供规则,非技术团队可以轻松维护,并直接使用这些规则以消除对特征工程和模型管理的需求。
哦,我最近也获得了 GPT-4 的早期访问权限,想要试一试!🙂……所以决定也进行一些分析,比较 GPT-3.5-Turbo 的表现。
数据集是否在已批准的类别中?
查看《2023 年人道主义数据现状 附录 B》,其中概述了在评估数据是否具有足够质量和覆盖范围时使用的标准和类别……
确定数据集是否应包含在数据网格中的第一步是检查数据集是否符合附录 A 中定义的主题要求。不相关的数据集将被自动排除。
附录 A 中的类别是……
HDX 数据网格数据集的接受数据类别(见 HDX 年度报告,附录 A [1])
我们可以编写分类器将这些类别分配给我们的数据集,但我们只知道已批准的 HDX 数据网格数据集的子集的类别。如果仅通过提示就能对我们的数据进行分类而不需要手动标记,那将是太棒了。这是一个零-shot 任务[2],这是大语言模型的一个惊人特性,即可以在没有专门为任务训练或提供示例的情况下进行分类。
为单个表预测数据集类别
让我们读取类别数据,并使用它生成定义每一类的提示文本……
hdx_data_categories_file = './data/Data Completeness Definitions - version_1.csv'
dg_categories = pd.read_csv(hdx_data_categories_file)
dg_categories = dg_categories[
["Category", "Subcategory", "Definition", "Datagrid recipe category"]
]
dg_categories["prompt_text"] = dg_categories.apply(
lambda x: f"- Category '{x['Category']} : {x['Subcategory']}' is defined as: {x['Definition']}",
axis=1,
)
category_prompt_text = dg_categories["prompt_text"].to_string(index=False, header=False)
display(category_prompt_text)
这给出……
- Category \'Affected People : Internally Displaced Persons\' is defined as: Tabular data of the number of displaced people by location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available.\n
- Category \'Affected People : Refugees and Persons of Concern\' is defined as: Tabular data of the number of refugees and persons of concern either in the country or originating from the country disaggregated by their current location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available or if the locations\' coordinates are defined in the tabular data.\n
- Category \'Affected People : Returnees\' is defined as: Tabular data of the number of displaced people who have returned.\n
- Category \'Affected People : Humanitarian Needs\' is defined as: Tabular data of the number of people in need of humanitarian assistance by location and humanitarian cluster/sector.\n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: List of organisations working on humanitarian issues, by humanitarian cluster/sector and disaggregated by administrative division.\n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: \n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: Note: An exception for the subnational rule is made for the IATI dataset which, if available, should always be included as an incomplete dataset.\n
- Category \'Coordination & Context : Funding\' is defined as: Tabular data listing the amount of funding provided by humanitarian cluster/sector.\n
- Category \'Coordination & Context : Funding\' is defined as: \n
- Category \'Coordination & Context : Funding\' is defined as: Note: An exception for the subnational rule is made for the FTS dataset which, if available, should always be included as a complete dataset.\n
- Category \'Coordination & Context : Conflict Events\' is defined as: Vector data or tabular data with coordinates describing the location, date, and type of conflict event.\n
- Category \'Coordination & Context : Humanitarian Access\' is defined as: Tabular or vector data describing the location of natural hazards, permissions, active fighting, or other access constraints that impact the delivery of humanitarian interventions.\n
- Category \'Coordination & Context : Climate Impact\' is defined as: Tabular or vector data containing current and historical impacts of climate events relating to floods, droughts and storms. The data should specify the location of the event, date of the event, and contain at least one indicator of impact such as spatial extent of event, disruption to affected populations, destroyed infrastructure, and/or affected vegetation.\n
- Category \'Food Security & Nutrition : Food Security\' is defined as: Vector data representing the IPC/CH acute food insecurity phase classification or tabular data representing population or percentage of population by IPC/CH phase and administrative division.\n
- Category \'Food Security & Nutrition : Acute Malnutrition\' is defined as: Tabular data specifying the global acute malnutrition (GAM) or severe acute malnutrition (SAM) rate\xa0 by administrative division.\n
- Category \'Food Security & Nutrition : Food Prices\' is defined as: Time series prices for common food commodities at a set of locations.\n- Category \'Geography & Infrastructure : Administrative Divisions\' is defined as: Vector geographic data describing the sub-national administrative divisions of a location, usually a country, including the names and unique identifiers, usually p-codes, of each administrative division. To be considered "complete", and included here, the humanitarian community working in the location has to have endorsed a preferred set of administrative boundaries as the Common Operational Dataset (COD).\n
- Category \'Geography & Infrastructure : Populated Places\' is defined as: Vector data or tabular data with coordinates representing the location of populated places (cities, towns, villages).\n
- Category \'Geography & Infrastructure : Roads\' is defined as: Geographic data describing the location of roads with some indication of the importance of each road segment in the transportation network. The data should exclude or indicate roads that are not usable by typical four-wheel-drive vehicles (footpaths, etc.).\n
- Category \'Geography & Infrastructure : Airports\' is defined as: Geographic data representing all operational airports including a name or other unique identifier and an indication of what types of aircraft can use each.\n
- Category \'Health & Education : Health Facilities\' is defined as: Vector data or tabular data with coordinates representing health facilities with some indication of the type of facility (clinic, hospital, etc.).\n
- Category \'Health & Education : Education Facilities\' is defined as: Vector data or tabular data with coordinates representing education facilities with some indication of the type of facility (school, university, etc.).\n
- Category \'Population & Socio-economy : Baseline Population\' is defined as: Total population disaggregated age and sex categories, aggregated by administrative division.\n
- Category \'Population & Socio-economy : Poverty Rate\' is defined as: Population living under a defined poverty threshold, aggregated by administrative division and represented as a percentage of total population or as an absolute number.'
这里有一个与农业相关的测试文件,这是一个不受支持的类别,并且不出现在 HDX 的数据网格中……
filename = "./data/number-of-acreage-under-irrigation.xlsx"
df = pd.read_excel(filename, sheet_name="Sheet1")
df = df.fillna("")
display(df)
一个数据集表的摘录,该表不属于 HDX 支持的类别之一
在上述内容中,我故意避免了对表格进行解析以整理内容(有关更多信息,请参见这里)。相反,我们将原始表格扔给 GPT,看看它的表现如何。
作为提示的 CSV 字符串表示,表格如下所示……
csv_as_str = df[0:20].to_csv(index=False)
print(csv_as_str)
Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Unnamed: 8,Unnamed: 9,Unnamed: 10,Unnamed: 11
Table 3: Number of acreage under irrigation,,,,,,,,,,,
,,OVERALL,,Sub county,,,,,,,
,,,,Chepalungu,,,,Bomet Central,,,
,,,,Male,,Female,,Male,,Female,
,,N,%,N,%,N,%,N,%,N,%
What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%
对于提示,我们将类别定义合并为一个聊天提示,并将一些指令和正在分析的表格合并为第二个……
prompts = []
prompts.append(f"Here is a list of HDX data categories with their definition: \n\n {category_prompt_text} \n\n")
prompts.append(f"Does the following table from file {filename} fall into one of the categories provided, if not say no. "\
f"If it does, which category and explain why? \n\n {csv_as_str} \n\n")
所以提示 1……
Here is a list of HDX data categories with their definition:
- Category \'Affected People : Internally Displaced Persons\' is defined as: Tabular data of the number of displaced people by location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available.\n
- Category \'Affected People : Refugees and Persons of Concern\' is defined as: Tabular data of the number of refugees and persons of concern either in the country or originating from the country disaggregated by their current location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available or if the locations\' coordinates are defined in the tabular data.\n
- Category \'Affected People : Returnees\' is defined as: Tabular data of the number of displaced people who have returned.\n
- Category \'Affected People : Humanitarian Needs\' is defined as: Tabular data of the number of people in need of humanitarian assistance by location and humanitarian cluster/sector.\n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: List of organisations working on humanitarian issues, by humanitarian cluster/sector and disaggregated by administrative division.\n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: \n
- Category \'Coordination & Context : 3W - Who is doing what where\' is defined as: Note: An exception for the subnational rule is made for the IATI dataset which, if available, should always be included as an incomplete dataset.\n
- Category \'Coordination & Context : Funding\' is defined as: Tabular data listing the amount of funding provided by humanitarian cluster/sector.\n
- Category \'Coordination & Context : Funding\' is defined as: \n
- Category \'Coordination & Context : Funding\' is defined as: Note: An exception for the subnational rule is made for the FTS dataset which, if available, should always be included as a complete dataset.\n
- Category \'Coordination & Context : Conflict Events\' is defined as: Vector data or tabular data with coordinates describing the location, date, and type of conflict event.\n
- Category \'Coordination & Context : Humanitarian Access\' is defined as: Tabular or vector data describing the location of natural hazards, permissions, active fighting, or other access constraints that impact the delivery of humanitarian interventions.\n
- Category \'Coordination & Context : Climate Impact\' is defined as: Tabular or vector data containing current and historical impacts of climate events relating to floods, droughts and storms. The data should specify the location of the event, date of the event, and contain at least one indicator of impact such as spatial extent of event, disruption to affected populations, destroyed infrastructure, and/or affected vegetation.\n
- Category \'Food Security & Nutrition : Food Security\' is defined as: Vector data representing the IPC/CH acute food insecurity phase classification or tabular data representing population or percentage of population by IPC/CH phase and administrative division.\n
- Category \'Food Security & Nutrition : Acute Malnutrition\' is defined as: Tabular data specifying the global acute malnutrition (GAM) or severe acute malnutrition (SAM) rate\xa0 by administrative division.\n
- Category \'Food Security & Nutrition : Food Prices\' is defined as: Time series prices for common food commodities at a set of locations.\n- Category \'Geography & Infrastructure : Administrative Divisions\' is defined as: Vector geographic data describing the sub-national administrative divisions of a location, usually a country, including the names and unique identifiers, usually p-codes, of each administrative division. To be considered "complete", and included here, the humanitarian community working in the location has to have endorsed a preferred set of administrative boundaries as the Common Operational Dataset (COD).\n
- Category \'Geography & Infrastructure : Populated Places\' is defined as: Vector data or tabular data with coordinates representing the location of populated places (cities, towns, villages).\n
- Category \'Geography & Infrastructure : Roads\' is defined as: Geographic data describing the location of roads with some indication of the importance of each road segment in the transportation network. The data should exclude or indicate roads that are not usable by typical four-wheel-drive vehicles (footpaths, etc.).\n
- Category \'Geography & Infrastructure : Airports\' is defined as: Geographic data representing all operational airports including a name or other unique identifier and an indication of what types of aircraft can use each.\n
- Category \'Health & Education : Health Facilities\' is defined as: Vector data or tabular data with coordinates representing health facilities with some indication of the type of facility (clinic, hospital, etc.).\n
- Category \'Health & Education : Education Facilities\' is defined as: Vector data or tabular data with coordinates representing education facilities with some indication of the type of facility (school, university, etc.).\n
- Category \'Population & Socio-economy : Baseline Population\' is defined as: Total population disaggregated age and sex categories, aggregated by administrative division.\n
- Category \'Population & Socio-economy : Poverty Rate\' is defined as: Population living under a defined poverty threshold, aggregated by administrative division and represented as a percentage of total population or as an absolute number.'
然后提示 2……
Does the following table from file ./data/number-of-acreage-under-irrigation.xlsx fall into one of the categories provided, if not say no. If it does, which category and explain why?
Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Unnamed: 8,Unnamed: 9,Unnamed: 10,Unnamed: 11\nTable 3: Number of acreage under irrigation,,,,,,,,,,,
,,OVERALL,,Sub county,,,,,,,
,,,,Chepalungu,,,,Bomet Central,,,
,,,,Male,,Female,,Male,,Female,
,,N,%,N,%,N,%,N,%,N,%
What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%
让我们尝试使用GPT-3.5-turbo和GPT-4……
def prompt_model(prompts, temperature=0.0, model="gpt-4"):
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for prompt in prompts:
messages.append({"role": "user", "content": prompt})
response = ai.ChatCompletion.create(
model=model, temperature=temperature, messages=messages
)
return response["choices"][0]["message"]["content"]
prompts = []
prompts.append(
f"Here is a list of HDX data categories with their definition: \n\n {category_prompt_text} \n\n"
)
prompts.append(
f"Does the following table from file {filename} fall into one of the categories provided, if not say no. "
f"If it does, which category and explain why? \n\n {csv_as_str} \n\n"
)
for model in ["gpt-3.5-turbo", "gpt-4"]:
response = prompt_model(prompts, temperature=0.0, model=model)
print(f"\n{model} Model response: \n\n{response}")
我们得到……
gpt-3.5-turbo Model response:
No, the table does not fall into any of the categories provided.
The table is about the number and percentage of people who own land of
different sizes that are currently under irrigation.
It does not provide any information about the number of affected people,
coordination and context, food security and nutrition, geography and
infrastructure, health and education, or population and socio-economy.
gpt-4 Model response:
The table from the file ./data/number-of-acreage-under-irrigation.xlsx
does not fall into any of the provided HDX data categories.
The table provides information about the number of acreage under
irrigation in different sub-counties and is not related to any of the
categories mentioned, such as affected people, coordination & context,
food security & nutrition, geography & infrastructure, health & education,
or population & socio-economy.
GPT-3.5-turbo和GPT-4都完美地工作,并识别出我们的表格不属于所需类别之一(它与农业相关)。我也喜欢这种推理,至少在这个例子中完全正确。
我们用一个在受支持类别中的表进行尝试,查德的食品价格,如在查德 HDX 数据网格上找到的。这个文件的 CSV 字符串,取前 20 行,如下所示……
date,admin1,admin2,market,latitude,longitude,category,commodity,unit,priceflag,pricetype,currency,price,usdprice
#date,#adm1+name,#adm2+name,#loc+market+name,#geo+lat,#geo+lon,#item+type,#item+name,#item+unit,#item+price+flag,#item+price+type,#currency,#value,#value+usd
2003-10-15,Barh El Gazal,Barh El Gazel Sud,Moussoro,13.640841,16.490069,cereals and tubers,Maize,KG,actual,Retail,XAF,134.0,0.2377
2003-10-15,Barh El Gazal,Barh El Gazel Sud,Moussoro,13.640841,16.490069,cereals and tubers,Millet,KG,actual,Retail,XAF,147.0,0.2608
2003-10-15,Lac,Mamdi,Bol,13.5,14.683333,cereals and tubers,Maize,KG,actual,Retail,XAF,81.0,0.1437
2003-10-15,Lac,Mamdi,Bol,13.5,14.683333,cereals and tubers,Maize (white),KG,actual,Retail,XAF,81.0,0.1437
2003-10-15,Logone Occidental,Lac Wey,Moundou,8.5666667,16.0833333,cereals and tubers,Millet,KG,actual,Retail,XAF,95.0,0.1685
2003-10-15,Logone Occidental,Lac Wey,Moundou,8.5666667,16.0833333,cereals and tubers,Sorghum,KG,actual,Retail,XAF,62.0,0.11
2003-10-15,Logone Occidental,Lac Wey,Moundou,8.5666667,16.0833333,cereals and tubers,Sorghum (red),KG,actual,Retail,XAF,62.0,0.11
2003-10-15,Moyen Chari,Barh-K h,Sarh,9.1429,18.3923,cereals and tubers,Millet,KG,actual,Retail,XAF,100.0,0.1774
2003-10-15,Moyen Chari,Barh-K h,Sarh,9.1429,18.3923,cereals and tubers,Sorghum,KG,actual,Retail,XAF,90.0,0.1597
2003-10-15,Moyen Chari,Barh-K h,Sarh,9.1429,18.3923,cereals and tubers,Sorghum (red),KG,actual,Retail,XAF,90.0,0.1597
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Maize,KG,actual,Retail,XAF,132.0,0.2342
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Maize (white),KG,actual,Retail,XAF,132.0,0.2342
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Millet,KG,actual,Retail,XAF,110.0,0.1952
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Rice (imported),KG,actual,Retail,XAF,396.0,0.7026
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Rice (local),KG,actual,Retail,XAF,297.0,0.5269
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Sorghum,KG,actual,Retail,XAF,100.0,0.1774
2003-10-15,Ndjaména,Ndjamena,Ndjamena,12.11,15.04,cereals and tubers,Sorghum (red),KG,actual,Retail,XAF,100.0,0.1774
2003-10-15,Ouaddai,Ouara,Abeche,13.8166667,20.8166667,cereals and tubers,Millet,KG,actual,Retail,XAF,155.0,0.275
2003-10-15,Ouaddai,Ouara,Abeche,13.8166667,20.8166667,cereals and tubers,Sorghum,KG,actual,Retail,XAF,97.0,0.1721
使用相同格式的提示,我们得到……
gpt-3.5-turbo Model response:
Yes, the table falls into the category 'Food Security & Nutrition : Food Prices'.
This is because the table contains time series prices for common food
commodities at different locations.
gpt-4 Model response:
Yes, the table falls into the category "Food Security & Nutrition : Food Prices".
This is because the table contains time series prices for common
food commodities (such as Maize, Millet, and Sorghum) at various
locations (markets) with their respective coordinates (latitude and longitude).
The data also includes information on the date, administrative divisions,
and currency.
所以再说一次,两种模型都是正确的。该数据集的类别确实是“食品安全与营养:食品价格”。
好的,对于使用单个表进行的一些快速示例,看起来不错。那么,基于多个表的内容来识别类别呢?
使用来自多个表的摘录预测数据集类别
在 HDX 中,一个数据集可以有多个“资源”(文件),而对于 Excel 中的数据,这些文件可能在工作表中包含多个表。因此,只查看数据集中的一个表可能无法完全了解情况,我们需要根据多个表做出决策。这一点很重要,因为在数据集中的所有表中可能会有关于数据集、字段查找等的文档标签,而这些标签本身不足以推断数据集中所有数据的类别。
在 ChatGPT API 推出之前,这会由于令牌限制而变得困难。然而,ChatGPT 允许我们指定多个提示,并且有增加的令牌限制。如我们所见,这仍然是一个限制因素,但比以前的模型有所改进。
本分析的样本数据——在笔记本仓库中提供——是从 HDX 中提取的,由…
-
遍历数据集
-
对于每个数据集,遍历文件
-
对于每个表格文件,下载它
-
对于文件中的每个标签,创建一个表格摘录(前 20 行)以 CSV 格式
注意:我没有包含这段代码以避免 HDX 上的流量过多,但如果对这段代码感兴趣,可以在 Medium 上给我留言。
所以每个数据集都有一个这样的字段…
[
{
"filename":"<DATASET NAME>/<FILE NAME 1>",
"format": "EXCEL",
"sheet": "<SHEET 1>",
"table_excerpt": "<FIRST 20 ROWS OF TABLE IN CSV FORMAT>"
},
{
"filename":"<DATASET NAME>/<FILE NAME 1>",
"format": "EXCEL",
"sheet": "<SHEET 2>",
"table_excerpt": "<FIRST 20 ROWS OF TABLE IN CSV FORMAT>"
},
{
"filename":"<DATASET NAME>/<FILE NAME 2>",
"format": "CSV",
"sheet": "",
"table_excerpt": "<FIRST 20 ROWS OF TABLE IN CSV FORMAT>"
},
{
"filename":"<DATASET NAME>/<FILE NAME 3>",
"format": "EXCEL",
"sheet": "<SHEET 1>",
"table_excerpt": "<FIRST 20 ROWS OF TABLE IN CSV FORMAT>"
},
... etc
]
对于每个数据集,这种结构允许我们为每个表生成多个提示…
def predict(data_excerpts, temperature=0.0):
results = []
for index, row in data_excerpts.iterrows():
dataset = row["name"]
title = row["title"]
print(
f"\n========================================= {dataset} =============================================\n"
)
prompts = []
# Start the prompt by defining the categories we want to assign
prompts.append(
f"Here is a list of HDX data categories with their definition: \n\n {category_prompt_text} \n\n"
)
prompts.append(
f"Here are excerpts from all the tables in this dataset: {title} ...\n\n"
)
# Build multiple prompts for each table excerpt for this dataset
tables = row["table_excerpts"]
for table in tables:
table = json.loads(table)
csv_as_str = table["table_excerpt"]
sheet = table["sheet"]
type = table["type"]
filename = table["filename"]
print(f"DATA > {filename} / {sheet}")
prompts.append(
f"Type {type} sheet {sheet} from file {filename} Table excerpt: \n\n {csv_as_str} \n\n"
)
# Finish up with our request
prompts.append(
"Does the dataset fall into exactly one of the categories mentioned above, if not say no. "
"If it does, add a pipe charatcter '|' before and after the top category and sub-category category and explain why it was chosen step-by-step.\n\n"
"What is the second most likely category if you had to pick one (adding a ^ character either side)? \n\n"
)
actual_category = row["datagrid_category"]
d = {
"dataset_name": dataset,
"filename": filename,
"prompts": prompts,
"actual_category": actual_category,
}
# Send our prompt array to two models
for model in ["gpt-3.5-turbo", "gpt-4"]:
# for model in ['gpt-3.5-turbo']:
# GPT-4 is in test and can fail sometimes
try:
print(f"\nCalling model {model}")
response = prompt_model(prompts, temperature=temperature, model=model)
if "|" in response:
predicted_category = response.split("|")[1].strip()
else:
predicted_category = response
print(f"\n{model} Model response: \n\n{response}")
match = actual_category == predicted_category
d[f"{model}_response"] = response
d[f"{model}_predicted_category"] = predicted_category
d[f"{model}_match"] = match
print(
f"******* RESULT: || {match} || prediced {predicted_category}, actual {actual_category} *******"
)
except Exception as e:
print(e)
results.append(d)
results = pd.DataFrame(results)
return results
这会生成像这样的提示…
提示 1 — 定义类别
Here is a list of HDX data categories with their definition:
- Category \'Affected People : Internally Displaced Persons\' is defined as: Tabular data of the number of displaced people by location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available.\n
- Category \'Affected People : Refugees and Persons of Concern\' is defined as: Tabular data of the number of refugees and persons of concern either in the country or originating from the country disaggregated by their current location. Locations can be administrative divisions or other locations (such as camps) if an additional dataset defining those locations is also available or if the locations\' coordinates are defined in the tabular data.\n
- Category \'Affected People : Returnees\' is defined as: Tabular data of the number of displaced people who have returned.\n
... etc
提示 2 — 指定数据集名称并介绍表格摘录
Here are excerpts from all the tables in this colombia-health-facilities-2021: ...
提示 3、4 等 — 提供表格摘录
Type XLSX sheet servicios from file ./data/prompts/colombia-health-facilities-2021/fservicios.xlsx Table excerpt: \
,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
0,depa_nombre,muni_nombre,habi_codigo_habilitacion,codigo_habilitacion,numero_sede,sede_nombre,nivel,grse_codigo,grse_nombre,serv_codigo,serv_nombre,ambulatorio,hospitalario,unidad_movil,domiciliario,otras_extramural,centro_referencia,institucion_remisora,complejidad_baja,complejidad_media,complejidad_alta
1,Amazonas,LETICIA,9100100019,9100100019,01,E.S.E. HOSPITAL SAN RAFAEL,2,7,Apoyo Diagnóstico y Complementación Terapéutica,706,LABORATORIO CLÍNICO,SI,SI,NO,NO,NO,NO,NO,NO,SI,NO
2,Amazonas,LETICIA,9100100019,9100100019,01,E.S.E. HOSPITAL SAN RAFAEL,2,7,Apoyo Diagnóstico y Complementación Terapéutica,712,TOMA DE MUESTRAS DE LABORATORIO CLÍNICO,SI,SI,SI,NO,NO,NO,NO,SI,NO,NO\n3,Amazonas,LETICIA,9100100019,9100100019,01,E.S.E. HOSPITAL SAN RAFAEL,2,7,Apoyo Diagnóstico y Complementación Terapéutica,714,SERVICIO FARMACÉUTICO,SI,SI,NO,NO,NO,NO,NO,SI,NO,NO
... etc
Type XLSX sheet capacidad from file ./data/prompts/colombia-health-facilities-2021/fcapacidad.xlsx Table excerpt:
,0,1,2,3,4,5,6,7,8,9,10\n0,depa_nombre,muni_nombre,codigo_habilitacion,numero_sede,sede_nombre,nivel,grupo_capacidad,coca_nombre,cantidad,modalidad,modelo
1,Amazonas,EL ENCANTO,9126300019,11,CENTRO DE SALUD SAN RAFAEL - E.S.E HOSPITAL SAN RAFAEL DE LETICIA,2,CAMAS,Pediátrica,1,,0
2,Amazonas,EL ENCANTO,9126300019,11,CENTRO DE SALUD SAN RAFAEL - E.S.E HOSPITAL SAN RAFAEL DE LETICIA,2,CAMAS,Adultos,3,,0
... etc
最终提示 — 我们请求对数据进行分类
Does the dataset fall into exactly one of the categories mentioned above, if not say no.
If it does, add a pipe charatcter '|' before and after the top category and sub-category category and explain why it was chosen step-by-step.
What is the second most likely category if you had to pick one (adding a ^ character either side)?
你会注意到在最后的提示中,我们要求有点多:
-
我们要求模型指明数据不符合我们的类别,以便我们捕捉负面情况,模型不会尝试为每个数据集分配一个类别。有些将不符合批准的类别
-
请求类别“完全匹配”。如果没有这个要求,GPT-3.5-Turbo 可能会随意构造新的类别!
-
如果模型确实识别出一个类别,将其用‘|’括起来,以便更容易解析
-
我们要求模型提供其推理过程,因为这已被证明可以改善结果[3]。了解类别决策的原因也有助于突出虚假信息的情况
-
最后,为了后续讨论,我们还请求第二个最可能的类别
此外,如果你仔细查看预测函数中的代码,我在这项研究中使用了温度为 0.0。温度控制输出的随机程度,由于我们希望结果既准确又具体,而不是描述量子物理的文本,所以我将其设置为尽可能低。
生成我们的预测…
output_folder = "./data/prompts"
data_excerpts = pd.read_pickle(f"{output_folder}/datasets_excerpts.pkl")
data_excerpts = data_excerpts[data_excerpts["is_datagrid"] == True]
data_excerpts = data_excerpts.sample(min(150, data_excerpts.shape[0]), random_state=42)
results = predict(data_excerpts, temperature=0.0)
results.to_excel(f"{output_folder}/results.xlsx")
我们做得怎么样?
def output_prediction_metrics(results, prediction_field="predicted_post_processed", actual_field="actual_category"):
"""
Prints out model performance report if provided results.
Parameters
----------
results : list
Where each element has fields as defined by ...
prediction_field : str
Field name of element with prediction
actual_field : str
Field name of element with actual value
"""
y_test = []
y_pred = []
for index, r in results.iterrows():
if actual_field not in r:
print("Provided results do not contain expected values.")
sys.exit()
y_pred.append(r[prediction_field])
y_test.append(r[actual_field])
print(f"Results for {prediction_field}, {len(results)} predictions ...\n")
print(f"Accuracy: {round(accuracy_score(y_test, y_pred),2)}")
print(
f"Precision: {round(precision_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
)
print(
f"Recall: {round(recall_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
)
print(
f"F1: {round(f1_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
)
return
results.fillna("", inplace=True)
print("\ngpt-3.5-turbo ...")
output_prediction_metrics(results, prediction_field="gpt-3.5-turbo_predicted_category", actual_field="actual_category")
print("\ngpt-4 ...")
output_prediction_metrics(results, prediction_field="gpt-4_predicted_category", actual_field="actual_category")
注意:虽然我们提供了 150 个数据集进行预测,但 GPT-4 的 API 时常超时,且未重试调用以节省成本。这对于处于早期预览阶段的 GPT-4 是完全可以预期的。有些提示也超出了 GPT-3.5-Turbo 的令牌长度。因此,以下结果适用于 GPT-3.5-turbo 和 GPT-4 做出的 53 个预测。
比如,仅预测类别,如“Coordination & Context”,当完整类别和子类别为“Coordination & Context : Humanitarian Access”时……
Results for **gpt-3.5**-turbo_predicted_category_1, 53 predictions ...
Accuracy: 0.66
Precision: 0.86
Recall: 0.66
F1: 0.68
Results for **gpt-4**_predicted_category_1, 53 predictions ...
Accuracy: 0.96
Precision: 0.97
Recall: 0.96
F1: 0.96
GPT-4 几乎总是能够识别正确的类别(96%准确率),在相同提示下表现显著优于 GPT-3.5-turbo(66%准确率)。
对于同时预测整个类别和子类别……
Results for **gpt-3.5**-turbo_predicted_category, 53 predictions ...
Accuracy: 0.57
Precision: 0.73
Recall: 0.57
F1: 0.60
Results for **gpt-4**_predicted_category, 53 predictions ...
Accuracy: 0.89
Precision: 0.92
Recall: 0.89
F1: 0.89
再次强调,GPT-4 比 GPT-3.5 表现显著优越。89%的准确率实际上相当不错,鉴于……
- 我们仅仅提供了一组文本规则,没有标记数据、训练分类器或提供任何示例。
实际上,如果我们查看那些预测失败的示例……
df = results.loc[results["gpt-4_match"] == False]
for index, row in df.iterrows():
response = row["gpt-4_response"]
predicted_second_category = response.split("^")[1].strip()
print(f"Dataset: {row['dataset_name']}")
# print(f"Dataset: {row['filename']}")
print("")
print(f"Actual: {row['actual_category']}")
print(f"Predicted category: {row['gpt-4_predicted_category']}")
print(f"Predicted second category: {predicted_second_category}\n")
print(
f"Secondary category matched: {predicted_second_category == row['actual_category']}"
)
print("=====================================================")
我们得到……
Dataset: mozambique-attacks-on-aid-operations-education-health-and-protection
Actual: Coordination & Context : Humanitarian Access
Predicted category: Coordination & Context : Conflict Events
Predicted second category: Health & Education : Health Facilities
Secondary category matched: False
=====================================================
Dataset: iraq-violence-against-civilians-and-vital-civilian-facilities
Actual: Coordination & Context : Humanitarian Access
Predicted category: Coordination & Context : Conflict Events
Predicted second category: Affected People : Humanitarian Needs
Secondary category matched: False
=====================================================
Dataset: south-sudan-access-incidents
Actual: Coordination & Context : Conflict Events
Predicted category: Coordination & Context : Humanitarian Access
Predicted second category: Coordination & Context : Conflict Events
Secondary category matched: **True**
=====================================================
Dataset: somalia-displacement-idps-returnees-baseline-assessment-iom-dtm
Actual: Affected People : Returnees
Predicted category: Affected People : Internally Displaced Persons
Predicted second category: Affected People : Returnees
Secondary category matched: **True**
=====================================================
Dataset: ukraine-border-crossings
Actual: Coordination & Context : Humanitarian Access
Predicted category: Geography & Infrastructure : Populated Places
Predicted second category: Coordination & Context : Humanitarian Access
Secondary category matched: **True**
=====================================================
Dataset: northeast-nigeria-displacement-for-borno-adamawa-and-yobe-states-bay-state
Actual: Affected People : Returnees
Predicted category: Affected People : Internally Displaced Persons
Predicted second category: Affected People : Returnees
Secondary category matched: **True**
=====================================================
Dataset: somalia-acute-malnutrition-burden-and-prevalence
Actual: Food Security & Nutrition : Acute Malnutrition
Predicted category: Affected People : Acute Malnutrition
Predicted second category: Food Security & Nutrition : Food Security
Secondary category matched: False
=====================================================
Dataset: colombia-people-in-need-pin-del-cluster-en-seguridad-alimentaria-y-nutricion-san-sp
Actual: Food Security & Nutrition : Food Security
Predicted category: Affected People : Humanitarian Needs
Predicted second category: Coordination & Context : 3W - Who is doing what where
Secondary category matched: False
=====================================================
Dataset: sind-safeguarding-healthcare-monthly-news-briefs-dataset
Actual: Coordination & Context : Humanitarian Access
Predicted category: Coordination & Context : Conflict Events
Predicted second category: Affected People : Humanitarian Needs
Secondary category matched: False
=====================================================
有几件事引起了注意。像‘mozambique-attacks-on-aid-operations-education-health-and-protection’这样的数据集,包含了与医疗保健和攻击相关的数据文件。
因此,假设每个数据集只有一个类别可能不是解决问题的最佳方式,数据集在类别之间被重复使用。
在 GPT-4 错误的约一半的案例中,它预测的第二名类别是正确的。查看这些案例中的一个模型输出,乌克兰边界穿越……
Yes, the dataset falls into exactly one of the categories mentioned above.
|Geography & Infrastructure : Populated Places|
The dataset contains information about border crossings in Ukraine,
including their names in English and Ukrainian, the country they connect to,
and their latitude and longitude coordinates. This information is related to
populated places (border crossings) and their geographic locations, which is why the "Geography & Infrastructure : Populated Places" category is the most appropriate.
^Coordination & Context : Humanitarian Access^
The second most likely category would be "Coordination & Context : Humanitarian Access"
border crossings can be considered as points of access between countries, and
understanding their locations could be relevant in the context of humanitarian
interventions. However, this category is not as fitting as the first one
since the dataset does not specifically describe access constraints or
permissions related to humanitarian interventions.
很酷的一点是它解释了为什么没有选择‘Coordination & Context : Humanitarian Access’,因为‘……它并不特别关注访问限制’。这是类别定义……
Coordination & Context : Humanitarian Access:表格或矢量数据,描述自然灾害、许可、激烈战斗或其他影响人道干预交付的访问限制的位置。
因此,GPT-4 似乎严格遵循类别规则。HDX 团队应用的分类有一些更细致的区别,其中跨境数据集与人道主义访问有很合理的关系。因此,也许提高模型在这种情况下预测的一个方法是向类别定义中添加额外的文本,指明边界穿越可能与人道主义访问相关。
这里的关键是 GPT-4 表现非常出色,而少数不正确的预测是由于我对问题的框定不当(数据集可以有多个类别),以及定义类别的文本可能存在的问题。
结论
该技术看起来非常有前途。我们能够获得一些良好的结果,无需设置标签、训练模型或在提示中提供示例。此外,像 GPT-4 这样的模型的数据总结能力确实令人印象深刻,帮助调试模型预测,也可能是提供快速数据概览的不错方法。
然而,存在一些警示:
-
由于成本和 GPT-4 仍处于早期预览阶段,这项研究所使用的数据量非常有限。未来的研究当然需要使用更多的数据。
-
目前提示长度是一个限制因素,上述研究只包括了少于 4 个表的数据集,以避免在提示表格摘录时超出 token 限制。HDX 数据集可能包含比这更多的表格,在某些情况下,拥有更大的表格摘录可能会更有价值。像 OpenAI 这样的供应商似乎在逐步增加 token 限制,因此随着时间的推移,这可能会变得不那么成为问题。
-
由于早期预览的原因,GPT-4 模型的性能非常慢,每个提示完成需要 20 秒。
-
问题的框架并不理想,例如,假设一个数据集只能有一个类别。虽然足以展示大型语言模型在评估数据质量和总结方面的潜力,但未来稍微不同的方法可能会产生更好的结果。例如,为 HDX 平台上的数据集预测每个国家的顶级数据集候选者。
能够用自然语言指定数据测试和数据问题仍然很酷!
参考文献
[1] OCHA, 2023 年开放人道数据状态
[2] Brown 等人, 语言模型是少样本学习者(2020 年)。
[3] Kojima 等人, 大型语言模型是零样本推理者。
Schopf 等人, 评估无监督文本分类:零样本和基于相似性的 approaches
这个分析的代码可以在这个笔记本中找到。
使用 GPT-4 的视觉功能作为艺术评论家
原文:
towardsdatascience.com/using-gpt-4-with-vision-as-an-art-critic-ec91080ba334
OpenAI 最新模型如何提供对视觉艺术的见解
·发表于 Towards Data Science ·14 分钟阅读·2023 年 11 月 2 日
--
AI 艺术评论家的极简主义画作,图像由 AI 图像创建程序 Midjourney 创建,并由作者编辑
大约两周前,OpenAI 开始推出其具备视觉功能的 GPT-4 模型(GPT-4V)[1]。作为 ChatGPT Plus 的订阅者,每月费用为 20 美元,我获得了新模型的早期访问权限。系统可在这里获取:chat.openai.com。
使用 GPT-4V,我可以提交一个包含最多四张图片的文本提示。系统会分析提示和图片,并作出相应的回应。由于我一直在写有关 AI 和 ML 创意用途的文章,我的第一个直觉是看看 GPT-4V 是否能够有效地作为艺术评论家。
这是我的第一次实验。我让系统评论了一幅我在 2002 年制作的新的英格兰风景丙烯画,这幅画是通过参考图像、一些自定义软件和投影仪完成的。我拍了一张画作的照片,并将其发送给 GPT-4V 进行评论。
ChatGPT 用户界面 (左侧),秋天的沃伦公园 (右侧),截图和画作由作者提供
这是我的提示,随后是 GPT-4V 的评论。
RobG*: 根据这幅画的照片,写一篇带有建设性批评的评论。
秋天的沃伦公园
罗伯特·A·冈萨尔维斯
画布板上的丙烯画
24” x 32”*
GPT-4V: “秋天的沃伦公园”由罗伯特·A·冈萨尔维斯创作,是对大自然美丽的迷人描绘。艺术家选择了充满活力的秋季色彩,捕捉了季节的精髓,使场景充满了炽热的橙色和深邃的棕色。使用的点彩技法赋予了作品梦幻般的、近乎超凡脱俗的质量,让观众感到仿佛被旋转的落叶包围。
建设性批评:
1. 尽管点彩画创造了迷人的质感,但如果点的大小能有更多变化,可能会为秋叶增添更多的深度和现实感。
2. 虽然天空以其冷色调的蓝色呈现得很美,但如果能有更平滑的过渡或渐变,会使其感觉更自然。
3. 签名“Robert Arthur 11/02”虽然重要,但显得有些突出。如果能更微妙地融入画作中,可能会更具美感。
总体而言,这幅作品是对秋天的绝妙表现,吸引了观众的注意力,并将他们带到了沃伦公园的中心。
第一部分相当令人赞赏。我注意到 ChatGPT 通常提供积极的回应。它指出了我画作的一些客观特征,如点彩风格和秋季色彩。我喜欢它说画作给观众带来了“被卷入旋转的落叶风暴中的感觉”时的热情。
沃伦公园秋季细节,秋叶 (左), 天空 (中), 和签名 (右),作者提供的图像
在我的提示中,我要求一些“建设性批评”,系统确实提供了一些。关于点的大小的评论是公平的,但建议天空使用平滑渐变似乎不太对劲,因为点彩画通常仅由颜色点构成。然而,我对系统能够读取我的签名和日期感到印象深刻,因为我选择了几乎没有对比度的颜色。是的,我的签名在画作中确实不寻常地大。
在本文中,我将提供有关 GPT-4V 模型的一些背景信息,并查看它在审视其他形式的视觉艺术(包括多媒体装置和 AI 生成艺术)时的表现如何。
GPT-4V 模型
2023 年 3 月,OpenAI 推出了 GPT-4,其中非视觉变体提供给了 ChatGPT Plus 用户。以下是作者们对视觉功能的评价。
GPT-4 接受由图像和文本组成的提示,这种设置与纯文本设置类似,让用户指定任何视觉或语言任务。具体来说,模型根据由任意交织的文本和图像组成的输入生成文本输出。在各种领域——包括文本和照片、图表或屏幕截图的文档——GPT-4 展现了与纯文本输入类似的能力。 - OpenAI [1]
模型大小和架构
与 OpenAI 发布之前的模型不同的是,它没有透露有关 GPT-4 模型的大小和架构的很多信息。然而,有泄漏的报告显示 GPT-4 约有 1.8 万亿个参数。它使用了一个专家混合(MoE)系统,在模型中有 16 个专家,每个专家约有 1110 亿个参数。作为参考,早期的 GPT-3 模型使用了 1750 亿个参数 [3]。
训练过程
OpenAI 没有透露他们用于训练 GPT-4 模型的数据详情。在他们的研究网站上,他们表示 GPT-4 基础模型是通过公共和许可数据语料库进行训练的,涵盖了各种推理能力、陈述和意识形态。为了确保模型的响应与用户意图一致,他们通过人类反馈强化学习(RLHF)进行了微调。
实验结果
为了测试这些模型,OpenAI 进行了各种基准测试,包括模拟最初为人类设计的考试。他们没有对这些考试进行特定的训练。模型在训练期间见到了考试中少量的问题;对于每个考试,他们运行了一个去除这些问题的变体,并报告了较低的分数[1]。以下是一些结果。
GPT 模型的测试结果,数据来自OpenAI,图表由作者提供
你可以看到 GPT-4 模型在大多数测试中显著优于 GPT-3.5 模型。有趣的是,GPT-4V 在一些测试中,如 GRE 定量测试,表现优于标准 GPT-4。从他们的论文中,我了解到 OpenAI 的研究人员尽可能客观地转录了测试中的任何图像和图表,以供视觉模型和非视觉模型使用[1]。然而,他们没有提供为什么具有视觉的 GPT-4 在测试中表现优于标准 GPT-4 模型的任何见解。我猜测 GPT-4V 在回答那些具有图像描述的“文字”问题时表现更好。
模型的局限性
作者描述了 GPT-4 和 GPT-4V 的不足之处。他们报告说,尽管这些模型有所进步,但与其前身一样,仍存在一些弱点:它们并不总是可靠,可能会产生“幻觉”,具有有限的上下文窗口,并且无法从经验中学习。这些局限性,特别是在关键情况下,使用时需要谨慎。模型的能力和限制引入了新的安全问题,强调了由于其潜在的社会影响而需要进一步研究。论文探讨了包括偏见、虚假信息等潜在风险,并概述了为最小化这些风险而采取的措施,如对抗性测试和安全管道[1]。
GPT-4V 的初步发布
起初,视觉版本 GPT-4V 仅通过Be My Eyes这一由丹麦初创公司创建的应用程序提供。该应用程序通过将盲人或视力低下者与志愿者连接,来帮助他们完成诸如产品识别和机场导航等任务。利用 GPT-4V 的功能,该初创公司在应用程序中推出了一个“虚拟志愿者”功能,旨在匹配人类助手的理解水平[2]。
OpenAI 的数据收集政策
OpenAI 修改了其数据收集政策。默认情况下,你的提示和回应可以被 OpenAI 用于训练未来的模型。他们之前使用了一个在线表单,用户可以选择退出数据收集。根据他们的新政策,数据收集现在与聊天记录绑定。如果你关闭聊天记录,他们将不会使用你的提示和回应来训练他们的模型。你可以在这里阅读他们的新政策。
接下来,我将展示更多使用 GPT-4V 作为艺术评论家的实验。
ChatGPT 用户界面,作者截图
GPT-4V 作为艺术评论家
GPT-4V 可以作为艺术评论家吗?我在上面的截图中问了这个问题,它回答说,“绝对可以!”并解释了为什么它可以。让我们看看它在几个例子中的表现。
安娜·克里斯蒂娜·戈朗松和罗布·贡萨尔维斯的《Swarming》,作者提供的图片
装置艺术
第一个例子是对一个名为“Swarming”的互动装置的评论,这个装置是我和我的朋友及艺术家安娜·克里斯蒂娜·戈朗松共同构建的。该装置展示了毛毡和染色的形状,作为“boids”的家,这些数字生物似乎从这些形状中飞出,通过投影在空中飞翔,最终又回到形状中。boids 可以通过微软 Kinect 设备被观众干扰和移动,该设备正在观察观众。
我将三张图片和一个装置描述发送给 GPT-4V,获取评论。以下是互动的过程。
*RobG: 根据这些照片和互动装置的描述,写一篇带有建设性批评的评论。
[description]*
GPT-4V: 安娜·克里斯蒂娜·戈朗松和罗布·贡萨尔维斯的《Swarming》评论
“Swarming” 是对将触感艺术与现代技术的沉浸式能力相结合的可能性的一次雄心勃勃的证明。由安娜·克里斯蒂娜·戈朗松设计的色彩鲜艳的毛毡雕塑为展览增添了奇幻的质感。这些有机形态与它们的动画对照形成了对日常生活的超现实逃离,展现了一个艺术真的栩栩如生的世界。
四种互动模式——睡眠、侦察、集群和互动——经过深思熟虑地设计,以为访客提供分层的体验。观众的存在本身就可以改变装置的行为,弥合被动观看和主动参与之间的差距。使用微软 Kinect 接口来跟踪访客的动作并触发视觉和听觉反馈,这种方法既创新又引人入胜。这种即时反馈循环使访客能够真正融入艺术作品之中。
然而,还有一些领域可能需要进一步完善:
环境声音:静默模式,特别是睡眠和侦察,可能利用细微的环境声音。这将增强体验,确保在这些较安静的时刻气氛不显得过于单调或空旷。
互动清晰度:提供一个简单的指南或指示,说明访客如何将展览从一种模式过渡到另一种模式可能会有所帮助。这可以确保访客能够充分探索“Swarming”提供的互动深度。
照明:考虑到视频投影仪在这个装置中的重要性,应特别关注展览室的环境光照。适当的照明可以增强投影的可见性,并确保感应雕塑得到适当的照明。
动画中的多样性:虽然“boids”为展览增添了动态元素,但引入多样的图案或行为可以进一步丰富观众的体验。这可以通过多样化 boids 的颜色或大小来实现。
总之,“Swarming”是传统艺术与前沿技术的引人入胜的融合,提供了一种既创新又怀旧的多感官体验。通过一些调整,这个装置有潜力进一步吸引观众,使每次访问都成为独特的发现之旅。
GPT-4V 对这件作品的评价非常积极。它以“雄心勃勃”、“异想天开”、“创新”和“引人入胜”等词语开始了对装置的高度评价。几乎就像我付钱让它写这篇评论一样。哦,等等。我确实这么做了。🙂
但它确实提出了一些合理的建设性批评。第一点,关于在睡眠和侦察模式中发出声音,并未考虑艺术坐席者的理智。关于为观众提供指示和光照条件的建议是有效的。最后一点,关于在动画中增加更多多样性也是很好的。但建议改变 boids 的大小听起来很像建议在我开头示例中改变点彩画点的大小。嗯,让我们看看这个趋势是否在其他示例中继续。
Dream Studio(右)、DALL-E 3(中)、和 Midjourney(左)的用户界面,作者截图
AI 生成艺术
为了测试系统的比较艺术批评能力,我将相同的提示输入到三个文本生成图像模型中,并请求 GPT-4V 对生成的图像进行比较。我使用的图像生成模型是 Stability AI 最新提供的Dream Studio,OpenAI 的DALL-E 3和Midjourney。我为这些服务都拥有付费账户。您可以在上面的截图中看到用户界面。
在这三种图像生成系统中,DALL-E 3 有一个有趣的特点;它使用 ChatGPT 作为界面。当我要求它使用提示“秋季树林的印象派画作”生成画作时,它首先将其扩展为更长的提示:“印象派风格的油画,描绘了秋季的密林。金色的树叶铺满了地面,而树木则装饰着生动的橙色、红色和黄色。斑驳的阳光透过树冠,投射出闪烁的光斑到森林地面上。”这是一个很酷的功能,所以我用了它。如果我输入的是“创建秋季树林的印象派画作。使用这个确切的提示。”它将跳过提示扩展步骤。
以下是生成“秋季树林的印象派画作”的结果。
“秋季树林的印象派画作,” 在 Dream Studio 生成 (左), DALL-E 3 (中), 和 Midjourney (右),作者提供图像
这是我发送给 GPT-4V 的提示以及三幅图像。
RobG:分析这些使用提示“秋季树林的印象派画作”创作的三幅图像。提供对比反馈,将图像称为左、中、右。
GPT-4V:当然!让我们深入分析这三幅秋季树林的印象派画作。
左侧图像:左侧图像使用了柔和的色彩调色板,主要展示了柔和的橙色、黄色和冷灰色。树木用细长的树干绘制,质地光滑,使整个场景呈现出一种精致、空灵的感觉。前景有一条明显的小径,覆盖着落叶,引导观众的视线深入树林。画作中有一种宁静和安静的感觉,就像能听到风在树叶间轻轻沙沙作响。
中间图像: 中间图像在色彩选择上更为大胆和生动,呈现了深金色、火红色和绿色的点缀。光影的对比非常强烈,特别是阳光如何透过树冠,照射到森林地面。与左图不同,这个场景更为密集,树木的集中度显得更高。对森林地面的关注,尤其是阳光照射下的散落叶子,营造了一种神奇、近乎梦幻的氛围。
右图: 右图是三幅图像中最具动感和生动的,色彩调色板十分抢眼,展示了明亮的橙色、红色和对比鲜明的蓝色。图中有一个看似溪流或水坑的倒影,双倍增强了秋天色彩的影响,并创造了树木的镜像效果。这幅画感觉比其他画作更具抽象性,笔触更加大胆,反射效果的运用使其似乎是在雨后场景中观看。画作强调了运动感,可能是由于水面波纹效果和落叶的方向性。
对比反馈: 虽然这三幅画作都美丽地捕捉了秋季树林的精髓,但它们以不同的方式展现。左边的画作提供了平静而柔和的表现,专注于树林的宁静。相比之下,中间的画作将观众沉浸在阳光照射下的光辉森林中,散发着温暖和迷人。右图因其大胆的颜色和抽象性而脱颖而出,利用反射和生动的色调来传达季节的动态特征。每幅画都有自己独特的魅力,使它们各自都很迷人。
哇!这是对三幅画作的极佳而全面的分析。再次强调,所有评论都非常积极。我喜欢它如何从描述第二幅图像开始比较这些图像,“与左图不同,这个场景更为密集…” 它还早早地把第三幅图像标榜为赢家,“右图是三幅图像中最具动感和生动的…” 但它以积极的语调结束,每幅画作都得到了赞扬,“每幅画都有自己独特的魅力…”
寻找一些具体的建设性批评,我问了这个问题。
RobG: 简要描述每幅画作可以如何改进。
GPT-4V: 左图: 为了增强其魅力,艺术家可以在颜色上引入更多的对比度,以赋予场景更多的深度。添加一些更明亮、更鲜艳的秋季色彩也可能帮助某些元素突出,确保图像不会显得过于单调或褪色。
中间图像:虽然光影的表现非常出色,但艺术家可以进一步定义树木,给它们增加更多的纹理或树皮图案的变化。这将打破单一性,并提供额外的兴趣点。
右侧图片:这幅画的动态特性是它的强项,但如果有一个更清晰的焦点会更好。无论是特定的树木、倒影还是其他元素,给观众的眼睛一个休息的地方,可能会使场景更具平衡感。艺术家还可以细化水中的倒影,确保它们与所映照的树木更自然地对齐。
这里是再次提供的三张图片供参考。
来自 Dream Studio 生成的画作 (左), DALL-E 3 (中), 和 Midjourney (右), 作者提供的图片
我同意这三条建议。左侧图片可以增加对比度和色彩饱和度。中间的图片需要在树木上增加一些定义。而右侧的图片缺少一个中央的焦点对象。
结论与下一步
在这次探索中,我发现 GPT-4V 作为艺术评论家的能力既有前景又富有洞察力。通过将其分析能力与各种形式的视觉艺术进行对比,很明显,该模型能够提供积极的反馈和建设性的批评,具有相当的深度。
从我画作的反馈到更复杂的互动装置和 AI 生成的艺术作品,GPT-4V 展现了对艺术的深刻理解和令人惊讶的敏锐眼光。虽然该模型的评价显著积极,这可能反映了 OpenAI 的设计决定,以避免过于负面的回应,但在被要求时,它毫不犹豫地提供了改进建议。
然而,必须记住的是,虽然 GPT-4V 可以提供分析,但艺术的实际价值往往在于观众的主观体验。像 GPT-4V 这样的 AI 工具可以提供反馈,但它们不应替代人类的触感,后者在艺术解释中带来个人的视角、情感和体验。
未来,探索 GPT-4V 对更抽象艺术形式的反应或者深入进行历史艺术时期的比较分析将是非常有趣的。例如,比较两幅法国印象派画作,并观察它是否能识别它们,将是很有趣的。此外,随着模型的不断发展和改进,还有可能对艺术、设计和其他视觉媒介提供更深入的见解。
总之,GPT-4V 为寻求反馈的艺术家提供了一个独特且有价值的工具。但仍需以平衡的视角来对待其批评,记住艺术不仅仅关乎技巧或视觉吸引力,更与人类经验密切相关。
致谢
我感谢 Jennifer Lim 和 Oliver Strimpel 审阅这篇文章并提供反馈。
参考文献
[1] OpenAI, GPT-4 技术报告 (2023)
[2] OpenAI, Be My Eyes — Be My Eyes 使用 GPT-4 改变视觉可及性 (2023)
[3] T. Brown 等人,语言模型是少样本学习者(2020)
附言
粉丝纠正:从投球手的投球区到本垒板不是 90 英尺!,图片作者
尽管 GPT-4V 喜欢我那幅“秋天的沃伦公园”画作,但康纳·奥布莱恩并没有那么友好。你可以在 2013 年的这段“粉丝纠正”视频中听到他对我画作的评价,从 01:36 开始。
使用 Python 和 Kafka
原文:
towardsdatascience.com/using-kafka-with-python-54dc20717cf7
了解如何使用 Python 创建生产者和消费者,并绘制动态散点图
·发表于 Towards Data Science ·7 分钟阅读·2023 年 3 月 20 日
--
图片由 Markus Winkler 提供,来源于 Unsplash
在我之前关于 Kafka 的文章中,我介绍了 Kafka 用于数据流处理的使用。我还展示了如何启动 Kafka 代理服务,并演示了如何使用 Kafka 生产者控制台应用程序发送消息以及如何使用 Kafka 消费者控制台应用程序接收消息。
在这篇文章中,我将向你展示如何使用 Python 利用 Kafka。具体来说,我将:
-
使用 Python 向 Kafka 代理服务发送消息
-
使用 Python 从 Kafka 代理服务接收消息
-
构建一个动态图表应用程序,用于绘制和更新散点图,随时接收来自代理服务的新数据
使用 Python 和 Kafka
至少有三个 Python 库可供 Python 开发者与 Kafka 代理服务接口。它们是:
-
Kafka-Python
-
PyKafka
-
Confluent Kafka Python
对于本文,我将使用 Confluent Kafka Python 包。
要安装 Confluent Kafka Python 包,请使用 pip
命令:
!pip install confluent-kafka
生成消息
首先,让我们先从生产者开始。生产者是向 Kafka 代理服务发送消息的组件。以下代码片段指定了要连接的 Kafka 代理服务器:
from confluent_kafka import Producer
import socket
conf = {
'bootstrap.servers': "localhost:9092",
'client.id': socket.gethostname()
}
producer = Producer(conf)
你需要确保你的 Kafka 代理服务正在运行。有关如何启动服务的详细信息,请参阅
towardsdatascience.com/using-apache-kafka-for-data-streaming-9199699623fa
。
了解如何安装和使用 Kafka 发送和接收消息
towardsdatascience.com](/using-apache-kafka-for-data-streaming-9199699623fa?source=post_page-----54dc20717cf7--------------------------------)
要发送消息,你可以使用Producer
对象的produce()
函数。你可以传递四个参数给它:
-
要发送的主题
-
消息负载的关键字
-
消息负载
-
轮询生产者以了解消息是否成功传递(或未传递)时要调用的回调函数。
以下代码片段将消息发送到 Kafka 代理服务:
def acked(err, msg):
if err is not None:
print("Failed to deliver message: %s: %s" % (str(msg), str(err)))
else:
print("Message produced: %s" % (str(msg)))
producer.produce("SomeTopic", key="key1", value="Hello", callback=acked)
producer.poll(1) # Maximum time (1s) to block while waiting for events
poll()
函数返回处理的事件数量(回调服务的数量)。
让我们发送另一条具有不同关键字的消息:
producer.produce("SomeTopic", key="key2", value="World", callback=acked)
producer.poll(1)
消费消息
消息发送完毕后,你现在可以处理消费者。以下代码片段连接到 Kafka 代理服务:
from confluent_kafka import Consumer
conf = {
'bootstrap.servers': "localhost:9092",
'group.id': "1",
'auto.offset.reset': 'latest'
}
consumer = Consumer(conf)
group.id
指示你属于哪个消费者组。如果有两个消费者分配了相同的组 ID 到同一个主题,它们将共享从同一主题读取的工作。
auto.offset.reset
指示…指定当从主题分区消费时没有初始偏移量时,消费者应如何行为。
我将在另一篇文章中讨论偏移量。
为了消费消息,让我们定义一个名为consume()
的函数。它接受消费者和要订阅的主题:
from confluent_kafka import KafkaError, KafkaException
def consume(consumer, topics):
try:
consumer.subscribe(topics)
# use this as a way to stop the loop
t = threading.currentThread()
while getattr(t, "run", True):
msg = consumer.poll(timeout=5.0)
if msg is None:
continue
if msg.error():
if msg.error().code() == KafkaError._PARTITION_EOF:
# End of partition event
sys.stderr.write('%% %s [%d] reached end at offset %d\n' %
(msg.topic(), msg.partition(), msg.offset()))
elif msg.error():
raise KafkaException(msg.error())
else:
key = msg.key().decode("utf-8")
data = msg.value().decode("utf-8")
print(key, data)
finally:
# Close down consumer to commit final offsets.
consumer.close()
在这个函数中,你首先订阅了你想要监听的主题。然后,你使用threading.currentThread()
函数检查当前线程上是否设置了名为run
的属性,该线程用于运行此函数。这用于控制是否继续等待下一条消息或退出函数。
我们将使用无限循环来不断轮询 Kafka 代理服务。timeout
参数允许你设置阻塞调用的时间,直到代理服务返回消息。如果你想更频繁地轮询代理,请将timeout
设置为较低的值,例如0.5
秒。
如果返回了消息,你可以提取其关键字和值,然后打印出来。最后,如果无限循环被终止,你将关闭消费者。
要运行consume()
函数,我们将使用threading
包:
import threading
thread = threading.Thread(target=consume,
args=(consumer, ["SomeTopic"]))
thread.start()
start()
函数将consume()
函数作为独立线程运行,以免冻结你的 Jupyter Notebook。
你现在应该能看到由生产者发送的两个传入消息:
key1 Hello
key2 World
要终止消费者,只需将thread
的run
属性设置为False
,consume()
函数将停止运行:
thread.run = False
绘制图表
现在你已经能够使用 Confluent Kafka Python 包来生产和消费消息,是时候利用这项新知识做些有用的事情了!
让我们使用生产者模拟一个 IOT 设备向 Kafka 代理服务发送传感器数据,并在消费者端读取这些数据并绘制图表。随着新数据的接收,图表将动态更新。所有这些将在 Jupyter Notebook 中直接运行。
发送传感器数据
让我们定义一个名为 send_message()
的函数,它接受四个参数:
-
topic — 消息的主题
-
datetime — 收集的传感器数据的日期时间
-
temp1 — 传感器 1 的温度读数
-
temp2 — 传感器 2 的温度读数
这四个参数将用于发送消息:
def send_message(topic, datetime, temp1, temp2):
producer.produce(topic, key="key", value=f"{datetime},{temp1},{temp2}", callback=acked)
producer.poll(1)
我们还将定义一个名为 update()
的函数,它每 2 秒调用一次 send_message()
函数,传入一些随机值:
import threading
import datetime
import random
def update():
threading.Timer(2.0, update).start() # call update() every 5s
send_message('SensorData',
datetime.datetime.utcnow(), # datetime in UTC
random.uniform(20, 39), # temperature1
random.uniform(10, 29)) # temperature2
update()
消费消息并绘制图表
对于消费者,我们来创建 Consumer
对象:
from confluent_kafka import Consumer
conf = {
'bootstrap.servers': "localhost:9092",
'group.id': "1",
'auto.offset.reset': 'latest'
}
consumer = Consumer(conf)
然后我们将利用 Plotly 向 FigureWidget
中添加两个散点图:
FigureWidget
是一个图形库,可以在 Jupyter Notebook 中显示图表。
from confluent_kafka import KafkaError, KafkaException
import time
import plotly.graph_objects as go
import numpy as np
# initialize a plot
fig = go.FigureWidget()
# add two scatter plots
fig.add_scatter(fill='tozeroy')
fig.add_scatter(fill='tozeroy')
我们现在可以从 Kafka broker 服务中轮询消息:
def consume(consumer, topics):
counter = 0
x = [] # datetime
y1 = [] # first temp
y2 = [] # second temp
n = 12 # the number of points to display on the plot
try:
consumer.subscribe(topics)
t = threading.currentThread()
while getattr(t, "run", True):
msg = consumer.poll(timeout=2.0)
if msg is None:
continue
if msg.error():
if msg.error().code() == KafkaError._PARTITION_EOF:
# End of partition event
sys.stderr.write('%% %s [%d] reached end at offset %d\n' %
(msg.topic(), msg.partition(), msg.offset()))
elif msg.error():
raise KafkaException(msg.error())
else:
data = msg.value().decode("utf-8")
x.append(data.split(',')[0]) # datetime
y1.append(float(data.split(',')[1])) # first temp
y2.append(float(data.split(',')[2])) # second temp
#---display the last n points---
# first scatter plot
fig.data[0].x = x[-n:] # datetime
fig.data[0].y = y1[-n:] # temp
# second scatter plot
fig.data[1].x = x[-n:] # datetime
fig.data[1].y = y2[-n:] # temp
finally:
# Close down consumer to commit final offsets.
consumer.close()
display(fig)
让我们分析一下上面的代码片段。当从 broker 获取到一条消息时,它会被拆分为三部分 —— datetime
、temp1
和 temp2
。它们分别被添加到 x
、y1
和 y2
列表中:
data = msg.value().decode("utf-8")
x.append(data.split(',')[0]) # datetime
y1.append(float(data.split(',')[1])) # first temp
y2.append(float(data.split(',')[2])) # second temp
随着时间的推移,列表会包含大量数据。因此,我们只希望绘制最近的 n 个点(在此示例中设置为 12)。要更新第一个散点图,设置 fig.data[0].x
和 fig.data[0].y
属性。对于第二个散点图,设置 fig.data[1].x
和 fig.data[1].y
属性:
#---display the last n points---
# first scatter plot
fig.data[0].x = x[-n:] # datetime
fig.data[0].y = y1[-n:] # temp
# second scatter plot
fig.data[1].x = x[-n:] # datetime
fig.data[1].y = y2[-n:] # temp
就这样!每当接收到新消息时,散点图将会自动更新!
目前,当你运行代码时,会显示一个空的图:
图片由作者提供
你现在可以准备开启一个线程来运行 consume()
函数:
import threading
thread = threading.Thread(target=consume, args=(consumer, ["SensorData"]))
thread.start()
过一段时间后,你应该会看到两个散点图更新:
图片由作者提供
这是更新图表的视频:
视频由作者提供
如常,若要暂停 consume()
函数,将 thread
的 run
属性设置为 False
:
thread.run = False
如果你喜欢阅读我的文章并且它对你的职业/学习有所帮助,请考虑注册成为 Medium 会员。每月 $5,你可以无限制访问 Medium 上的所有文章(包括我的文章)。如果你使用以下链接注册,我将获得少量佣金(对你没有额外费用)。你的支持意味着我可以花更多时间写出像这样的文章。
阅读 Wei-Meng Lee 的每个故事(以及 Medium 上其他成千上万位作家的故事)。你的会员费用直接支持…
weimenglee.medium.com](https://weimenglee.medium.com/membership?source=post_page-----54dc20717cf7--------------------------------)
总结
本文展示了一个你可以使用 Kafka 构建的有用应用。特别是,我向你展示了如何使用 Python 和 Plotly 构建一个动态图表应用。Kafka 的实时流式特性使其成为需要低延迟数据更新的应用(如物联网应用)的理想选择。敬请关注下一篇 Kafka 文章!
使用大型语言模型作为推荐系统
最近研究的综述及自定义实现
·
关注 发布于 Towards Data Science · 8 分钟阅读 · 2023 年 4 月 10 日
--
近年来,大型语言模型(LLMs)在数据科学界和新闻界引起了极大的关注。自 2017 年变压器架构问世以来,我们见证了这些模型在自然语言任务复杂性方面的指数级进步,包括分类、意图与情感提取,以及生成与人类相似的文本。
从应用角度来看,将 LLMs 与各种现有技术结合使用,来弥补它们的缺陷(我最喜欢的之一是GPT + Wolfram Alpha 组合 用于处理数学和符号推理问题),其可能性似乎无穷无尽。
但令我惊讶的是,LLMs 也可以作为推荐系统单独使用,而无需任何额外的特征工程或其他常见推荐系统中的手动过程。这种能力主要归功于 LLMs 的预训练方式及其操作方式。
目录
-
LLMs 和变换器如何工作的回顾
-
LLMs 作为推荐系统
-
使用自定义数据实现/复制 P5
-
复制尝试 2 — 阿拉伯语
-
LLMs 作为推荐系统的优点和缺点
-
最终思考
-
代码
LLMs 和变换器如何工作的回顾
语言模型是概率模型,试图映射一系列标记(短语、句子等)的发生概率。它们在各种文本上进行训练,并相应地推导概率分布。对于它们可以处理的各种任务(总结、问答等),它们通过条件概率迭代地选择最可能的标记/词汇来继续提示。请参见下面的示例:
基于上下文的后续标记的概率示例(作者提供的图片)
LLMs 是经过大量文本训练的语言模型,这些模型具有庞大的架构,并利用了大量计算资源。它们通常由变换器架构驱动,这种架构在 Google 的著名 2017 年论文“Attention Is All You Need”中被引入。这种架构利用了“自注意力”机制,使得模型能够在预训练过程中学习不同标记之间的关系。
在对足够大的文本集合进行预训练后,相似的词汇将具有相似的嵌入(例如:“King”,“Monarch”),而不同的词汇将具有更大的差异。此外,利用这些嵌入,我们将看到词汇之间的代数映射,使模型能够更有效地确定序列的正确下一个标记。
- 就像经典的“King — Man + Woman = Queen” 示例一样
自注意力嵌入的附加好处是,它们会根据周围的词汇而有所变化,使其更贴合于特定上下文中的含义。
斯坦福大学的克里斯托弗·曼宁博士提供了一个关于 LLMs 如何工作的高层次概述。
LLMs 作为推荐系统
2022 年,Rutger’s University 的研究人员发布了论文《推荐作为语言处理(RLP):统一预训练、个性化提示与预测范式(P5)》(Geng et. al)。论文中介绍了一种“灵活统一的文本到文本范式”,将多个推荐任务结合在一个系统中:P5。该系统能够通过自然语言序列执行以下操作:
-
顺序推荐
-
评分预测
-
解释生成
-
评论总结
-
直接推荐
让我们来看看论文中的一个顺序推荐任务示例。
Input: "I find the purchase history list of user_15466:
4110 -> 4467 -> 4468 -> 4472
I wonder what is the next item to recommend to the user. Can you help
me decide?"
Output: "1581"
研究人员为用户和每个项目分配了唯一的 ID。使用包含数千名用户(及其购买历史)和唯一项目的训练集,LLM 能够学习到某些项目彼此相似以及某些用户对某些项目有倾向(由于自注意力机制的特性)。在所有这些购买序列的预训练过程中,模型本质上经历了一种协同过滤的形式。它查看用户购买了相同的项目以及哪些项目往往一起购买。结合 LLM 生成上下文嵌入的能力,我们突然拥有了一个非常强大的推荐系统。
在上述示例中,尽管我们不知道每个 ID 对应的项目是什么,但我们可以推断出项“1581”被选择的原因是其他用户购买了它,并且与“user_15466”已经购买的任何项一起购买了。
关于 P5 的架构,它“利用预训练的 T5 检查点作为骨干”(Geng et. al)。
- T5 是另一个 LLM,谷歌几年前发布的。它被设计来处理多种类型的序列到序列任务,因此作为这种系统的起点是合理的。
使用自定义数据实现/复制 P5
我对这篇论文印象深刻,想看看是否可以在较小的规模上复制顺序推荐能力。我决定利用 Hugging Face 的开源 T5 模型(T5-large),并制作了自己的自定义数据集来微调它以产生推荐。
我制作的数据集包含了超过 100 个运动设备购买的示例以及下一个要购买的项目。例如:
Input: “Soccer Goal Post, Soccer Ball, Soccer Cleats, Goalie Gloves”
Target Output: “Soccer Jersey”
当然,为了使其更为强健,我决定使用更具体的提示。提示如下:
输入: “购买的物品: {足球门框, 足球, 足球鞋, 守门员手套} — 推荐候选: {足球球衣, 篮球球衣, 橄榄球球衣, 棒球球衣, 网球衬衫, 曲棍球球衣, 篮球, 橄榄球, 棒球, 网球, 曲棍球, 篮球鞋, 橄榄球鞋, 棒球鞋, 网球鞋, 曲棍球头盔, 篮球臂套, 橄榄球护肩, 棒球帽, 网球拍, 曲棍球滑冰鞋, 篮球架, 橄榄球头盔, 棒球棒, 曲棍球棍, 足球标志, 篮球短裤, 棒球手套, 曲棍球护具, 足球护腿板, 足球短裤} — 推荐: ”
目标输出: “足球球衣”
上面你可以看到用户目前购买的物品规格,随后是尚未购买的推荐候选列表(这是整个库存)。
在使用 Hugging Face 的 Trainer API (Seq2SeqTrainer 训练了约 10 个周期) 对 T5 模型进行微调后,我获得了一些令人惊讶的好结果!一些示例评估:
输入: “购买的物品: {足球球衣, 足球门框, 足球鞋, 守门员手套} — 推荐候选: {篮球球衣, 橄榄球球衣, 棒球球衣, 网球衬衫, 曲棍球球衣, 足球, 篮球, 橄榄球, 棒球, 网球, 曲棍球, 篮球鞋, 橄榄球鞋, 棒球鞋, 网球鞋, 曲棍球头盔, 篮球臂套, 橄榄球护肩, 棒球帽, 网球拍, 曲棍球滑冰鞋, 篮球架, 橄榄球头盔, 棒球棒, 曲棍球棍, 足球标志, 篮球短裤, 棒球手套, 曲棍球护具, 足球护腿板, 足球短裤} — 推荐: ”
模型输出: “足球”
输入: “购买的物品: {篮球球衣, 篮球, 篮球臂套} — 推荐候选: {足球球衣, 橄榄球球衣, 棒球球衣, 网球衬衫, 曲棍球球衣, 足球, 橄榄球, 棒球, 网球, 曲棍球, 足球鞋, 橄榄球鞋, 棒球鞋, 网球鞋, 曲棍球头盔, 守门员手套, 橄榄球护肩, 棒球帽, 网球拍, 曲棍球滑冰鞋, 足球门框, 篮球架, 橄榄球头盔, 棒球棒, 曲棍球棍, 足球标志, 篮球短裤, 棒球手套, 曲棍球护具, 足球护腿板, 足球短裤} — 推荐: ”
模型输出: “篮球鞋”
这当然是主观的,因为推荐不一定是二元的成功/失败,但输出与迄今为止的购买项相似,确实令人印象深刻。
复制尝试 2 — 阿拉伯语
接下来,我想看看是否可以为阿拉伯语做这个,所以我翻译了我的数据集,并寻找了一些能够处理阿拉伯语文本的公开可用 T5 模型(AraT5, MT5等)。在尝试了十几个我在Hugging Face Hub上找到的变体后,我遗憾地发现无法得到令人满意的结果。
模型(经过微调后)会推荐相同的 1 或 2 件商品,不管购买历史如何——通常是“كرة القدم”,即“soccer ball”(不过,也许它知道阿拉伯语使用者喜欢足球,并且总是寻找新的足球)。即使尝试了这些模型的更大版本,如 MT5-xl,我也得到了相同的结果。这很可能是由于这些 LLM 在英语以外的语言上的数据稀缺。
对于我的最后一次尝试,我决定尝试将 Google Translate API 与我的英语微调 T5 模型结合使用。过程如下:
- 将阿拉伯语输入 → 翻译成英语 → 输入到经过英语微调的模型中 → 获取模型的英语预测 → 翻译回阿拉伯语
不幸的是,这仍然没能帮助多少,因为翻译器会犯一些错误(例如:“كرة القدم”,我们用来代替“soccer”的词直接翻译成“foot ball”),这让模型出错,导致始终推荐相同的 1-2 件商品。
LLM 作为推荐系统的优点和陷阱
这种技术最突出的优点在于它作为独立系统实施的简便性。由于上述 LLM 和预训练技术的性质,我们可以绕过繁重的手动特征工程——模型应该能够自然地学习表示和关系。此外,我们可以在一定程度上绕过新商品的冷启动问题——商品的名称/描述可以被提取并自然地与用户已经购买/选择的现有商品相关联。
然而,这种方法有一些陷阱(不要急着丢弃你当前的推荐系统!),主要是由于对推荐内容的控制缺乏。
-
因为没有对用户在检查后购买商品的不同操作/事件进行加权,所以我们完全依赖 LLM 预测的最可能的下一个 token(s)进行推荐。我们无法考虑用户书签、查看了一段时间、放入购物车等操作。
-
此外,对于这些 LLM,我们确实面临大多数推荐基于相似性的风险(即与目前购买的商品语义相似的商品),尽管我认为通过大量用户购买历史数据,我们可以通过这种方法模拟的“协同过滤”方法来改善这个问题。
-
最后,由于大型语言模型(LLMs)理论上可以生成任何文本,输出可能是一个与库存中的条目不完全匹配的字符串(尽管我认为这种情况发生的可能性较低)。
最终想法
根据 P5 论文的结果以及我尝试在 T5 模型上进行一些微调和提示的结果,我推测这种技术可以应用于许多语言模型。使用更强大的序列到序列模型可能会有显著帮助,尤其是当微调数据足够大且提示技术得到完善时。
然而,我不会建议(没有双关的意思)任何人单独使用这种方法。我建议将其与其他推荐系统技术结合使用,这样我们可以避免上述提到的陷阱,同时获得好处。如何实现这一点——我不确定,但我认为通过一些创造力,这种 LLM 技术可以以有益的方式集成(也许通过提取基于嵌入的特征用于协同过滤,或者与“双塔”架构结合使用,可能性是无限的)。
代码
-
我实现的 T5 推荐系统(Github 仓库)
-
我在 Hugging Face 上的微调 T5 模型(Hugging Face Hub)
参考文献
[1] S. Geng, S. Liu, Z. Fu, Y. Ge, Y. Zhang, 推荐作为语言处理(RLP):统一的预训练、个性化提示与预测范式(P5)(2023),第 16 届 ACM 推荐系统会议
使用 LazyPredict 评估 ML 算法
原文:
towardsdatascience.com/using-lazypredict-for-evaluating-ml-algorithms-f912a43eef2
使用 LazyPredict 库自动选择最佳机器学习算法的过程
·发表于 Towards Data Science ·10 分钟阅读·2023 年 3 月 27 日
--
由 Victoriano Izquierdo 摄影,来自 Unsplash
评估机器学习算法是数据科学家常做的任务。虽然数据科学家需要了解不同类型的机器学习算法以应对不同类型的问题,但至关重要的是,他/她需要在自己的数据集上应用这些算法。只有这样,他/她才能更好地了解该使用哪种算法来训练模型,以及如何进行超参数调整。然而,选择合适的算法是一个耗时且令人疲惫的过程。理想情况下,应该有一个自动化的过程,只需提供数据,理想的机器学习算法将会为你选择。
这个问题的答案是 LazyPredict。 LazyPredict 是一个 Python 库,帮助你部分自动化选择最佳算法来训练你的数据集的过程。通过提供你的数据,LazyPredict 将使用超过 60 种机器学习算法来训练模型。最终结果将呈现给你。之后,你可以选择表现最好的机器学习算法进一步训练或优化你的数据集。
手动选择机器学习 (ML) 模型
要欣赏 LazyPredict 的美妙,了解通常如何手动完成这些事情总是有益的。因此,在本节中,我将使用糖尿病数据集作为示例,看看我们如何使用它来评估几种 ML 算法并选择与之最匹配的理想算法。为了简化起见,我们将使用以下 ML 算法:
-
逻辑回归
-
K-最近邻 (KNN)
-
支持向量机
糖尿病数据集:
www.kaggle.com/datasets/mathchi/diabetes-data-set
。许可: CC0: 公开领域
加载数据
第一步是将 diabetes.csv 文件加载到 Pandas DataFrame 中,然后打印出其详细信息:
import numpy as np
import pandas as pd
df = pd.read_csv('diabetes.csv')
df.info()
特别是,数据框中没有NaN
值:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 768 entries, 0 to 767
Data columns (total 9 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Pregnancies 768 non-null int64
1 Glucose 768 non-null int64
2 BloodPressure 768 non-null int64
3 SkinThickness 768 non-null int64
4 Insulin 768 non-null int64
5 BMI 768 non-null float64
6 DiabetesPedigreeFunction 768 non-null float64
7 Age 768 non-null int64
8 Outcome 768 non-null int64
dtypes: float64(2), int64(7)
memory usage: 54.1 KB
让我们查看数据框本身:
df
注意到一些列有 0 值,例如Pregnancies、SkinThickness、Insulin和Outcome:
所有图片均由作者提供
清理数据
由于数据框中没有NaN
值,现在检查一下具体哪些列有 0 值:
#---check for 0s---
print(df.eq(0).sum())
从下面的输出中,你可以看到只有DiabetesPedigreeFunction和Age列没有 0 值:
Pregnancies 111
Glucose 5
BloodPressure 35
SkinThickness 227
Insulin 374
BMI 11
DiabetesPedigreeFunction 0
Age 0
Outcome 500
dtype: int64
对于其他具有 0 值的列,只有Pregnancies和Outcome列可以有 0 值——Pregnancies的 0 表示患者以前从未怀孕,而Outcome的 0 表示患者不是糖尿病患者。对于其他列,值为 0 是不合逻辑的——0 皮肤厚度,真的吗?
所以现在让我们替换这些列中的 0 值,以便它们具有更有意义的值。第一步是用NaN
替换 0:
df[['Glucose','BloodPressure','SkinThickness',
'Insulin','BMI']] = \
df[['Glucose','BloodPressure','SkinThickness',
'Insulin','BMI']].replace(0,np.NaN)
df
然后,用每列的均值替换NaN
值:
df.fillna(df.mean(), inplace = True) # replace the rest of the NaNs with the mean
你现在可以验证所有列中除了Pregnancies和Outcome外都没有 0 值:
print(df.eq(0).sum())
Pregnancies 111
Glucose 0
BloodPressure 0
SkinThickness 0
Insulin 0
BMI 0
DiabetesPedigreeFunction 0
Age 0
Outcome 500
dtype: int64
检查特征之间的相关性
尽管数据集中有多个特征,但并非所有特征都对结果有贡献。因此,计算每列的相关因子是有用的:
corr = df.corr()
print(corr)
这是结果:
Pregnancies Glucose BloodPressure SkinThickness \
Pregnancies 1.00 0.13 0.21 0.08
Glucose 0.13 1.00 0.22 0.19
BloodPressure 0.21 0.22 1.00 0.19
SkinThickness 0.08 0.19 0.19 1.00
Insulin 0.06 0.42 0.07 0.16
BMI 0.02 0.23 0.28 0.54
DiabetesPedigreeFunction -0.03 0.14 -0.00 0.10
Age 0.54 0.27 0.32 0.13
Outcome 0.22 0.49 0.17 0.22
Insulin BMI DiabetesPedigreeFunction Age Outcome
Pregnancies 0.06 0.02 -0.03 0.54 0.22
Glucose 0.42 0.23 0.14 0.27 0.49
BloodPressure 0.07 0.28 -0.00 0.32 0.17
SkinThickness 0.16 0.54 0.10 0.13 0.22
Insulin 1.00 0.17 0.10 0.14 0.21
BMI 0.17 1.00 0.15 0.03 0.31
DiabetesPedigreeFunction 0.10 0.15 1.00 0.03 0.17
Age 0.14 0.03 0.03 1.00 0.24
Outcome 0.21 0.31 0.17 0.24 1.00
绘制特征之间的相关性
使用热图可视化特征之间的相关性,使理解这些数字变得更加容易:
%matplotlib inline
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 10))
cax = ax.matshow(corr,cmap='coolwarm', vmin=-1, vmax=1)
fig.colorbar(cax)
ticks = np.arange(0,len(df.columns),1)
ax.set_xticks(ticks)
ax.set_xticklabels(df.columns)
plt.xticks(rotation = 90)
ax.set_yticklabels(df.columns)
ax.set_yticks(ticks)
#---print the correlation factor---
for i in range(df.shape[1]):
for j in range(9):
text = ax.text(j, i, round(corr.iloc[i][j],2),
ha="center", va="center", color="w")
plt.show()
这是相关因子的热图。我们感兴趣的是查看哪些特征与Outcome高度相关(无论是正相关还是负相关)。因此,我们将查看Outcome列,并关注那些深红色(正相关)和深蓝色(负相关;此案例中没有)的单元格:
你还可以通过编程找到最相关的特征:
#---get the top four features that has the highest correlation---
print(df.corr().nlargest(4, 'Outcome').index)
#---print the top 4 correlation values---
print(df.corr().nlargest(4, 'Outcome').values[:,8])
你可以看到,Outcome最相关的前三个特征是Glucose、BMI和Age:
Index(['Outcome', 'Glucose', 'BMI', 'Age'], dtype='object')
[1\. 0.49292767 0.31192439 0.23835598]
评估机器学习算法
数据清理完成后,下一步是选择不同的算法来使用你的数据训练模型。
使用逻辑回归
首先使用逻辑回归来训练模型。我们将使用交叉验证来评分模型:
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
#---features---
X = df[['Glucose','BMI','Age']]
#---label---
y = df.iloc[:,8]
log_regress = linear_model.LogisticRegression()
log_regress_score = cross_val_score(log_regress, X, y, cv=10, scoring='accuracy').mean()
print(log_regress_score)
对于逻辑回归,我获得了 0.7669856459330144 的分数。我将把结果添加到列表中,以便稍后可以与其他模型进行比较:
result = []
result.append(log_regress_score)
使用 K-最近邻
接下来,我们将使用 K-最近邻算法来训练模型:
from sklearn.neighbors import KNeighborsClassifier
#---empty list that will hold cv (cross-validates) scores---
cv_scores = []
#---number of folds---
folds = 10
#---creating odd list of K for KNN---
ks = list(range(1,int(len(X) * ((folds - 1)/folds)), 2))
#---perform k-fold cross validation---
for k in ks:
knn = KNeighborsClassifier(n_neighbors=k)
score = cross_val_score(knn, X, y, cv=folds, scoring='accuracy').mean()
cv_scores.append(score)
#---get the maximum score---
knn_score = max(cv_scores)
#---find the optimal k that gives the highest score---
optimal_k = ks[cv_scores.index(knn_score)]
print(f"The optimal number of neighbors is {optimal_k}")
print(knn_score)
result.append(knn_score)
我们将尝试不同的 K 值,并逐一打分。我们将选择最高的分数并打印出最佳的 K 值。结果如下:
The optimal number of neighbors is 19
0.7721462747778537
使用支持向量机
我们要使用的最终算法是支持向量机(SVM)。我们将尝试两种类型的核函数。首先尝试linear
核函数:
from sklearn import svm
linear_svm = svm.SVC(kernel='linear')
linear_svm_score = cross_val_score(linear_svm, X, y,
cv=10, scoring='accuracy').mean()
print(linear_svm_score)
result.append(linear_svm_score)
接下来是rbf
(径向基函数)核:
rbf = svm.SVC(kernel='rbf')
rbf_score = cross_val_score(rbf, X, y, cv=10, scoring='accuracy').mean()
print(rbf_score)
result.append(rbf_score)
选择最佳性能算法
现在我们已经用不同的算法训练了数据集,我们可以汇总所有结果并进行比较:
algorithms = ["Logistic Regression", "K Nearest Neighbors", "SVM Linear Kernel", "SVM RBF Kernel"]
cv_mean = pd.DataFrame(result,index = algorithms)
cv_mean.columns=["Accuracy"]
cv_mean.sort_values(by="Accuracy",ascending=False)
从下图可以看出,KNN 是赢家,但其他模型也不落后。
根据这个结果,你现在知道 KNN 是这个特定数据集的最佳算法。
使用 LazyPredict 进行分类问题
尽管我们从之前的结果中知道 KNN 在使用的四种算法中表现最好,但结论并非绝对。你可能会发现还有其他更适合你的数据集的算法。这时你可以使用 LazyPredict 自动训练你的数据集,利用不同的算法。
LazyPredict 支持回归和分类算法。
关于数据集,我将使用我之前清理过的数据。你可以使用从 CSV 文件中读取的原始数据,LazyPredict 会自动预处理你的数据——它将用均值(对于数值列)和常量值(对于分类列)替换缺失值。然后它会标准化数值列并编码分类列。
然而,自己进行数据预处理总是更好,因为你最了解自己的数据(正如我们数据集中某些列的 0 值不可接受所示)。
因此,我将从清理后的数据框df
中提取前 8 列的特征,将第九列作为标签:
#---features---
X = df.iloc[:,:8]
#---label---
y = df.iloc[:,8]
接下来,安装 LazyPredict:
!pip install lazypredict
对于分类问题,导入LazyClassifier
类。你还需要导入其他所需模块:
import lazypredict
# for classification problem
from lazypredict.Supervised import LazyClassifier
# split dataset into training and testing sets
from sklearn.model_selection import train_test_split
初始化LazyClassifier
类,特别是将predictions
参数设置为True
:
clf = LazyClassifier(verbose=0, ignore_warnings=True,
custom_metric = None, predictions=True)
将数据集拆分为 80%的训练集和 20%的测试集:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state = 42)
你现在可以使用clf
分类器来拟合(训练)你的数据,并使用各种分类算法进行预测:
scores, predictions = clf.fit(X_train, X_test, y_train, y_test)
scores
scores
变量是一个数据框,显示了各种机器学习模型及其各自的指标,如准确率、ROC AUC、F1 分数等:
你可以看到结果与我们最初的测试非常接近,其中 K-Nearest Neighbor 表现得相当不错(在这种情况下是第二好的模型)。当然,现在我们知道ExtraTreesClassifier算法效果更好。
predictions
变量是一个数据框,包含了每个模型的预测值:
predictions
使用 LazyPredict 解决回归问题
在结束这篇文章之前,让我们使用LazyPredict
处理一个回归问题。这一次,我们将利用sk-learn
库中附带的 Boston 数据集。对于回归问题,使用LazyRegressor
类:
from lazypredict.Supervised import LazyRegressor
from sklearn.model_selection import train_test_split
from sklearn import datasets
# load the Boston dataset
data = datasets.load_boston()
X, y = data.data, data.target
# split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state=42)
clf = LazyRegressor(predictions=True)
# fit the data using different algorithms
models, predictions = clf.fit(X_train, X_test, y_train, y_test)
models
评估结果如下:
以及每个算法的预测值:
predictions
如果你喜欢阅读我的文章并且这些文章对你的职业/学习有所帮助,请考虑成为 Medium 会员。会员费用为每月$5,可以无限制访问 Medium 上的所有文章(包括我的)。如果你通过以下链接注册,我将获得一小部分佣金(不会增加你的额外费用)。你的支持意味着我可以投入更多时间撰写类似的文章。
[## 使用我的推荐链接加入 Medium - Wei-Meng Lee
阅读 Wei-Meng Lee 的每一个故事(以及 Medium 上的其他成千上万的作家)。你的会员费用直接支持…
总结
这篇文章展示了如何使用LazyPredict库简化机器学习算法选择的过程。一旦你确定了理想的算法,应该通过超参数调整进一步优化你的模型。如果你想快速了解这个话题,请查看我之前的文章:
## 使用 GridSearchCV 调整机器学习模型的超参数
学习如何使用 GridSearchCV 函数优化你的机器学习模型
[towardsdatascience.com
使用 LLM 来评估 LLM
·
关注 发布于 Towards Data Science ·7 分钟阅读·2023 年 11 月 10 日
--
图片由 OpenAI 的 DALL-E 3 生成。
你可以要求 ChatGPT 以百万种不同的方式进行工作:作为你的营养师、语言导师、医生等。不奇怪我们看到很多基于 OpenAI API 的演示和产品发布。但虽然让 LLM 以某种方式运作很容易,确保它们表现良好并准确完成给定任务则是完全不同的故事。
问题在于我们关注的许多标准都是极其主观的。回答是否准确?回答是否连贯?有没有虚假信息?建立可量化的评估指标很困难。通常,你需要人工判断,但让人类检查大量 LLM 输出是非常昂贵的。
此外,LLM 具有许多可以调整的参数,如提示、温度、上下文等。你可以在特定数据集上微调模型以适应你的使用案例。通过提示工程,即使是让模型深呼吸 [1] 或让你的请求更具情感色彩 [2] 也能改善性能。有很大的调整和实验空间,但在你更改某些内容后,你需要能够判断系统整体上是变得更好还是更差。
由于人工劳动缓慢且昂贵,强烈的动力驱使我们寻找这些更主观标准的自动化指标。一种有趣的方法,正在获得越来越多的关注,是使用 LLM 来评估 LLM 的输出。毕竟,如果 ChatGPT 能生成一个良好的、连贯的回应,它是否也能判断给定的文本是否连贯呢?这打开了一整盒潜在的偏见、技术和机会,让我们深入探讨一下。
LLM 偏见
如果你对使用 LLM 构建指标和评估器有负面的直觉反应,你的担忧是有根据的。这可能只是延续现有偏见的糟糕方式。
例如,在稍后我们将详细讨论的 G-Eval 论文中,研究人员展示了他们基于 LLM 的评估给予 GPT-3.5 摘要的分数高于人类编写的摘要,即使人类评审更喜欢人类编写的摘要。
另一项研究题为 “大型语言模型不是公平的评估者” [3],发现当被要求选择两个选项中哪个更好时,选项呈现的顺序存在显著偏差。例如,GPT-4 通常更喜欢第一个给出的选项,而 ChatGPT 更喜欢第二个选项。你可以通过将问题的顺序颠倒过来提问,看看 LLM 在回答中的一致性如何。他们随后开发了通过用不同顺序的选项多次运行 LLM 来缓解这种偏差的技术。
评估评估器
到头来,我们想知道 LLM 是否能像人类评估员一样好或相似地表现。我们仍然可以将其作为科学问题来处理:
-
设置评估标准。
-
让人类和 LLM 根据标准进行评估。
-
计算人类和 LLM 评估之间的相关性。
通过这种方式,我们可以了解 LLM 与人类评估员的相似程度。
确实,已经有几个类似的研究表明,对于某些任务,LLM 的表现比传统的评价指标要好得多。值得注意的是,我们不需要完美的相关性。如果我们在多个示例上进行评估,即使评估不完美,我们仍然可以得到新系统表现更好还是更差的某种概念。我们还可以使用 LLM 评估器来标记令人担忧的边缘案例,以便人类评估员进一步检查。
让我们来看看一些最近提出的依赖于 LLM 的核心指标和评估器。
G-Eval
G-Eval [4] 的工作方式是首先概述评估标准,然后简单地要求模型给出评分。它可以用于摘要和对话生成任务,例如。
G-Eval 包含以下组件:
-
提示。 定义了评估任务及其标准。
-
中间指令。 概述了评估的中间指令。它们实际上要求 LLM 生成这些步骤。
-
评分函数。 我们不会直接采纳 LLM 的评分,而是查看底层的标记概率以获得最终得分。因此,如果你要求在 1 到 5 之间评分,我们不会仅仅采用 LLM 给出的数字(例如“3”),而是会查看每个等级的概率并计算加权得分。这是因为研究人员发现通常一个数字主导了评估(例如主要输出 3),即使你要求 LLM 给出一个小数值,它仍然倾向于返回整数。
G-Eval 提供的提示用于在 1 到 5 的范围内计算连贯性。你可以在原始论文中找到更多示例。
研究发现 G-Eval 在显著程度上超越了传统的基于参考的指标,如 BLEU 和 ROUGE,这些指标与人工判断的相关性相对较低。表面上看,这似乎非常简单,因为我们只是要求 LLM 执行评估。我们也可以尝试将任务分解为更小的组件。
FactScore
FactScore(原子性得分中的事实精确度)[5] 是一个用于事实精确度的指标。这里的两个关键概念是将原子事实视为一个单元,并基于特定的知识来源来评估可信度。
在评估中,你将生成内容分解为小的“原子”事实(例如“他出生在纽约”),然后检查每个事实是否得到给定真实知识来源的支持。最终得分是通过将被支持的事实数量除以总事实数量来计算的。
在论文中,研究人员让 LLM 生成人物传记,然后使用关于他们的维基百科文章作为真实来源。LLM 进行与人类相同程序的错误率低于 2%。
FactScore 用于生成布里奇特·莫伊纳汉的传记。请参阅原始论文。
RAGAS
现在,让我们来看看一些检索增强生成(RAG)的指标。使用 RAG 时,你首先在外部知识库中检索相关的上下文,然后让 LLM 根据这些事实回答问题。
RAGAS(检索增强生成评估)[6]是一个用于评估 RAG 的新框架。它不是一个单一的指标,而是一系列指标的集合。论文中提出的三个指标是忠实度、回答相关性和上下文相关性。这些指标完美地展示了如何将评估分解为 LLMs 的简单任务。
忠实度衡量回答在给定上下文中的扎实程度。这与 FactScore 非常相似,你首先将生成的内容分解为一系列陈述,然后询问 LLM 这些陈述是否得到给定上下文的支持。得分是支持的陈述数除以所有陈述的总数。对于忠实度,研究人员发现与人工标注者有很高的相关性。
回答相关性试图捕捉回答是否针对实际问题的概念。你可以先要求 LLM 根据答案生成问题。对于每个生成的问题,你可以计算生成的问题与原始问题之间的相似度(通过创建嵌入并使用余弦相似度)。通过这样做n次并计算相似度得分的平均值,你将得到最终的回答相关性值。
上下文相关性指的是提供的上下文的相关性。也就是说,提供的上下文仅包含回答问题所需的信息。在理想情况下,我们只向 LLM 提供回答问题所需的正确信息。上下文相关性是通过要求 LLM 提取上下文中与答案相关的句子来计算的。然后只需将相关句子的数量除以总句子数以获得最终得分。
你可以在这里找到更多的指标和解释(以及开源的 GitHub 仓库)。
关键点是我们可以将评估转化为更小的子问题。我们不是询问整个文本是否得到上下文支持,而是询问一个小的具体事实是否得到上下文支持。我们不是直接给出答案是否相关的数字,而是要求 LLM 为给定的答案想出一个问题。
结论
评估 LLMs 是一个极具趣味的研究课题,随着越来越多的系统进入生产阶段并应用于更多安全关键的环境,这一话题将受到越来越多的关注。
我们还可以使用这些指标来监控 LLMs 在生产中的表现,以便发现输出质量是否开始下降。特别是在错误成本高的应用场景中,如医疗保健,开发保护措施和系统以捕捉和减少错误将是至关重要的。
尽管使用 LLM 作为评估器时肯定存在偏见和问题,我们仍应以研究问题的态度保持开放的心态。当然,人类仍将参与评估过程,但在某些场景中,自动化指标可以部分评估性能。
这些指标不必总是完美的;它们只需要足够好以正确指导产品的开发。
特别感谢 Daniel Raff 和 Yevhen Petyak 的反馈和建议。
最初发表于 Medplexity substack。
-
杨成润等人。大语言模型作为优化器。arXiv,2023 年 9 月 6 日。arXiv.org,
doi.org/10.48550/arXiv.2309.03409.
-
李成等人。大语言模型理解并能通过情感刺激进行增强。arXiv,2023 年 11 月 5 日。arXiv.org,
doi.org/10.48550/arXiv.2307.11760.
-
王培毅等人。大语言模型并非公平的评估器。arXiv,2023 年 8 月 30 日。arXiv.org,
doi.org/10.48550/arXiv.2305.17926.
-
刘洋等人。G-Eval: 使用 GPT-4 进行更好人类对齐的 NLG 评估。arXiv,2023 年 5 月 23 日。arXiv.org,
doi.org/10.48550/arXiv.2303.16634.
-
闵世温等人。FActScore: 长篇文本生成中细粒度原子级事实精确度评估。arXiv,2023 年 10 月 11 日。arXiv.org,
doi.org/10.48550/arXiv.2305.14251.
-
Shahul Es 等人。RAGAS: 自动化评估检索增强生成。1,arXiv,2023 年 9 月 26 日。arXiv.org,
doi.org/10.48550/arXiv.2309.15217.
使用机器学习创建自定义色彩调色板
原文:
towardsdatascience.com/using-machine-learning-to-create-custom-color-palettes-acb4eeaa06aa
深入了解 Streamlit 的本月应用
·发表在Towards Data Science ·12 分钟阅读·2023 年 2 月 16 日
--
照片由Ricardo Gomez Angel提供,Unsplash
介绍
我们都喜欢接触新的数据集、探索它,并从中学习。但原始数字本身并不是优秀的讲述者。我们原始的大脑天生对线条、形状和颜色敏感。这就是为什么数字需要被可视化才能讲述一个好的故事。
数据可视化的色彩调色板可能会决定你的数据故事的成败。虽然为你的数据可视化找到完美的颜色组合可能是一个严格且耗时的任务,但你不必完全独自完成。你可以从历史上最伟大的画家和艺术家那里获得灵感。
作者提供的图片
从零开始创建色彩调色板通常是可视化工程师和设计师的专业领域,他们使用色彩理论来组合适合各种可视化目的的和谐调色板(例如,单色调色板用于顺序或渐变图,或互补调色板用于分类图表)。像[matplotlib](https://matplotlib.org/)
或[plotly](https://plotly.com)
这样的可视化包已经提供了令人惊叹的调色板集合。
虽然数据分析师和数据科学家可以使用预制的颜色调色板,但在某些情况下,我们可能想要自己创建自定义调色板。例如,当你想制作一个与公司颜色主题一致的可视化时。例如,你可能想创建一个与喜欢的公司的 logo 主题匹配的柱状图。有一个可以自动为你完成这项工作的应用程序会非常好,对吧?
图片由作者提供
图片由作者提供
我花费了太多时间试图为我的可视化找出完美的颜色组合,因此我决定让这个过程对我自己来说更简单一些。
我使用Streamlit 🎈构建了一个 Web 应用程序,可以从任何输入图像中推断颜色调色板:无论是画作、电影海报、摇滚专辑封面,还是圣诞家庭照片,应有尽有!虽然应用程序可能不会立即提供完美的颜色调色板,但至少会给你一个很好的起点。
那么,让我们看看我是如何构建这个应用程序的。
构建颜色调色板推断应用程序
图片由作者提供
为了创建一个颜色调色板推断工具,我们需要几个组件。
在这篇文章中,我将教你:
-
如何构建一个图像加载组件
-
如何构建一个带滑块的图像增强组件
-
如何对像素进行聚类并使用组平均值来制作调色板
-
如何使用颜色选择器小部件来显示和修改调色板
想要亲自试用这个应用程序吗?你可以在这里查看它,并在这里查看源代码。
1. 如何构建一个图像加载组件
我们的应用程序需要的第一个元素是输入我们想要转换为颜色调色板的图像的方法。我们可以通过三种方式做到这一点:
-
从现有的图像或艺术品库中加载。
-
使用
st.file_uploader()
将新的图像文件上传到应用程序中。 -
从 URL 下载一张新图像到应用程序中。
我们可以使用st.tabs()
构建三个不同的选项卡,让我们可以无缝切换这三种输入模式。
使用st.tabs()
构建的图像输入源。
由于 Streamlit 运行应用程序的方式是从上到下,后来的输入模式会获得更高的优先级并覆盖之前加载器的输出。例如,由于你将输入模式按 Gallery → File Uploader → Image URL 的顺序排列,如果你将图像保存到变量img
中,则画廊加载的任何内容都会被文件上传器和 URL 下载器的输出覆盖。
当用户使用层级较低的加载器加载图像时,你可以为用户添加弹出警告,如果层级较高的加载器已经加载了图像。
这就是代码的样子:
# define three tabs for the three loading methods
gallery_tab, upload_tab, url_tab = st.tabs(["Gallery", "Upload", "Image URL"])
with gallery_tab:
...
# raise a warning if file uploader or URL downloader have already loaded an image
if st.session_state.get("file_uploader") is not None:
st.warning("To use the Gallery, remove the uploaded image first.")
if st.session_state.get("image_url") not in ["", None]:
st.warning("To use the Gallery, remove the image URL first.")
img = ...
with upload_tab:
img = ...
# raise a warning if the URL downloader has already loaded an image
if st.session_state.get("image_url") not in ["", None]:
st.warning("To use the file uploader, remove the image URL first.")
with url_tab:
img = ...
因此,如果你尝试从画廊加载一个预先存在的图像,但在 URL 下载器中已经存在一个链接,你将需要先删除那个链接。这可能不是最优雅的解决方案,但它有效!
现在,继续实现每一部分的加载方法。
画廊视图
对于画廊视图,我们可以简单地将一些图像保存在公共存储库中,并直接在应用程序中加载这些图像(我在这里使用 GitHub,但你也可以使用 AWS S3 或 Google Cloud Storage)。st.selectbox
包含我保存的艺术品名称,因此用户可以通过从下拉菜单中选择来加载它们。实现代码如下。
import streamlit as st
from PIL import Image
with gallery_tab:
options = list(gallery_dict.keys())
file_name = st.selectbox("Select Art",
options=options,
index=options.index("Mona Lisa (Leonardo da Vinci)")
)
img_file = gallery_dict[file_name]
if st.session_state.get("file_uploader") is not None:
st.warning("To use the Gallery, remove the uploaded image first.")
if st.session_state.get("image_url") not in ["", None]:
st.warning("To use the Gallery, remove the image URL first.")
img = Image.open(img_file)
gallery_dict
是一个字典,包含文件名和图像文件路径作为键值对,PIL.Image.open()
用于加载这些文件。结果保存在一个名为 img
的变量中。
文件上传器
实现文件上传器非常简单,因为已经有一个 Streamlit 小部件。它叫做(你能猜到吗?)st.file_uploader()
!
实现代码如下:
with upload_tab:
img_file = st.file_uploader("Upload Art", key="file_uploader")
if file is not None:
try:
img = Image.open(img_file)
except:
st.error("The file you uploaded does not seem to be a valid image. Try uploading a png or jpg file.")
if st.session_state.get("image_url") not in ["", None]:
st.warning("To use the file uploader, remove the image URL first.")
这个小部件允许你上传一个文件,然后你可以将其传递给 PIL.Image.open()
来加载。如果文件实际上不是图像文件,或者格式与 PIL.Image
期望的不一致,这一步可能会失败。为了防止这个问题发生,我们可以将加载部分放入一个 try/except
块中。
⚠️ 我在这里使用这个代码块作为万用块,以避免在加载文件时出现各种意外错误。然而,通常不推荐在没有实际指定要绕过的Exception
类型的情况下使用except
,尤其是因为你可能不知道的致命错误会默默地通过这个代码块,使得调试代码变得困难。
URL 下载器
说实话!虽然上传文件是用户加载他们希望推断颜色调色板的自定义图像的好方法,但这不是最简单的方法。用户需要找到图像(如果他们还没有的话),在本地下载,然后上传到应用程序中。听起来简单,但在实际操作中却很麻烦。
为了消除这个障碍,我们可以在应用程序中添加一个 URL 下载器,这样用户可以简单地复制图像链接(例如,从 Google 搜索结果中),然后直接粘贴到应用程序中。为了实现这一点,我们需要 requests
模块,它获取 URL 的内容,以及 io.BytesIO
函数,它使得内容能够被 PIL.Image.open()
理解。实现起来简单且与文件上传器的实现类似。
import requests
from io import BytesIO
with url_tab:
url_text = st.empty()
url = url_text.text_input("Image URL", key="image_url")
if url != "":
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content))
except:
st.error("The URL does not seem to be valid.")
2. 如何使用滑块构建图像增强组件
现在我们已经上传了图像,我们准备推断颜色调色板,对吗?其实不完全是。
你加载到应用程序中的原始图像可能未经过颜色推断的优化。颜色可能过于暗淡,画布上的亮度或对比度可能不足。这就是为什么你需要首先进行一些图像调整。
为了将增强功能应用于图像,我们可以使用 PIL.ImageEnhance
。API 非常简单。例如,如果你想将图像(加载到 img
中)的颜色增强因子设置为 2.5,你可以运行:
img = ImageEnhance.Color(img)
img = img.enhance(2.5)
只需将 Color
替换为 Shapness
、Contrast
或 Brightness
,即可分别调整这些图像属性。我们可以创建四个不同的滑块,将每个属性的值分配给这些滑块,然后编写单独的代码块,依次将增强应用于图像。但我们是优秀的程序员,我们尽量保持代码优雅和 DRY(Don’t Repeat Yourself)。所以让我们以更实用的方式来处理这个问题。
我们可以定义一个字典,其中包含我们希望应用的所有增强功能作为键,值表示我们希望在应用程序中分配给这些增强功能的滑块范围和步长。
enhancement_range = {
# "enhancement_type": [min, max, step_size]
"Color": [0., 5., 0.2],
"Sharpness": [0., 3., 0.2],
"Contrast": [0.5, 1.5, 0.1],
"Brightness": [0.5, 1.5, 0.1]
}
enhancement_categories = enhancement_range.keys()
# put adjustment sliders inside an expander
enh_expander = st.sidebar.expander("Image Enhancements", expanded=False)
# create a reset button that resets all enhancements to default value (1.0)
with enh_expander:
if st.button("reset"):
for cat in enhancement_categories:
if f"{cat}_enhancement" in st.session_state:
st.session_state[f"{cat}_enhancement"] = 1.0
# create sliders for each enhancement category using the dictionary values (min, max, step_size)
enhancement_factor_dict = {
cat: enh_expander.slider(f"{cat} Enhancement",
value=1.,
min_value=enhancement_range[cat][0],
max_value=enhancement_range[cat][1],
step=enhancement_range[cat][2],
key=f"{cat}_enhancement")
for cat in enhancement_categories
}
使用这种方法,如果我们想更改增强类型或值范围,我们只需更改原始字典。
现在我们已经在侧边栏上放置了滑块,剩下的就是使用 ImageEnhance
将这些值应用于图像。
from PIL import ImageEnhance
for cat in enhancement_categories:
# apply the enhancement class to the image
# e.g. for cat='Color' this would be the same as
# img = ImageEnhance.Color(img)
img = getattr(ImageEnhance, cat)(img)
# apply the enhencement value from the corresponding st.slider
img = img.enhance(enhancement_factor_dict[cat])
显示图像
现在我们已经加载了图像,剩下的就是使用 st.image()
将其显示在应用程序中:
with st.expander("🖼 Artwork", expanded=True):
st.image(img, use_column_width=True)
和 瞧!
3. 如何聚类像素并使用组平均值制作调色板
最后,进入有趣的部分!使用机器学习推断色彩调色板。
这里的想法非常简单。一幅图像是像素的集合,每个像素都有三个值:R、G、B。这些值基本上告诉你每个像素包含多少红色、绿色和蓝色。为了推断色彩调色板,像素在画布上的实际位置基本上是无关紧要的。重要的是它在 (R, G, B) 坐标空间中的位置。因此,在继续之前,让我们将图像分解到这个新的坐标系统中,并摆脱像素在图像上的实际位置。
r, g, b = np.array(img).reshape(-1, 3).T
df_rgb = pd.DataFrame({"R": r, "G": g, "B": b}).sample(n=sample_size)
如果我们查看 RGB 坐标中的像素,我们可以将相互接近的像素分组,并使用像素的平均值来表示每个组——实际上是我们调色板中的一种颜色。例如,如果我们想从蒙娜丽莎的增强图像中构建一个 5 色调色板,我们需要首先查看 RGB 空间中像素的分布(这里通过 PCA 算法投影到二维):
原始像素颜色在 (R, G, B) 空间中的二维 PCA 分解
然后我们选择 5 个不同的簇,并将每个簇的平均值分配给调色板中的一个槽。
使用 K-means 算法将原始采样像素分配到 5 个不同的簇
显然,我们不需要手动完成这个工作。有一种方便的机器学习算法叫做 K-means 聚类,它可以一举完成这个任务。我们唯一需要提供的参数是聚类的数量,也就是我们的调色板大小。这是使用 sklearn.cluster.KMeans
实现的样子。
from sklearn.cluster import KMeans
palette_size = st.sidebar.number_input("palette size",
min_value=1,
max_value=20,
value=5,
step=1,
help="Number of colors to infer from the image.")
model = KMeans(n_clusters=palette_size)
clusters = model.fit_predict(df_rgb)
palette = model.cluster_centers_.astype(int).tolist()
就这样!我们现在有了我们的调色板,只需要将其返回给应用中的用户。
ℹ️ 尽管我们在这里使用了流行的 R、G、B 像素分解,但值得注意的是,这不是分解颜色的唯一方法。我们本可以在 HSV(色调、饱和度、明度)空间中进行像素聚类,这将以不同的方式分布像素,从而得到不同的调色板。
4. 如何使用色彩选择器控件
不知为何,我感觉我决定构建这个应用的潜意识原因是为了使用 Streamlit 的神奇 st.color_picker()
控件!这个控件的好处在于,你可以同时展示从图像中推断出的调色板颜色,并且如果愿意的话,还可以更改它们。这非常完美,因为正如我之前提到的,从应用中得到的颜色可能并不是针对你特定用例的 100% 完美,而只是一个很好的起点。所以你可能需要稍微调整一下,在实际用于你的可视化之前做一些修饰。
我们不希望调色板在应用程序中占据页面的一半,因此让我们把它们分别放在不同的列中。
columns = st.columns(palette_size)
for i, col in enumerate(columns):
with col:
st.session_state[f"col_{i}"]= \\
st.color_picker(label=str(i),
value=palette[i],
key=f"pal_{i}")
value=palette[i],
key=f"pal_{i}")
美极了!
最后一件要做的事情是为用户提供一个可以在日常工作中使用的实用产品。如果数据分析师或数据科学家使用这个应用来推断一个色彩调色板,他们很可能会在 matplotlib
或 plotly
中使用这个调色板来应用于他们的可视化中。为什么不提供一个代码片段来处理这个问题,这样他们就不必将每一个十六进制颜色代码从调色板中复制粘贴到他们的编码环境中。
图片由作者提供
感谢 Streamlit 的 st.code()
控件,可以一键复制整个代码块!
总结
就这样!现在你有了一个功能齐全的调色板推断应用,可以帮助你弄清楚列奥纳多·达·芬奇是用什么颜色让蒙娜丽莎的脸上带上微笑的!
在这篇文章中,我们介绍了很多 Streamlit 的功能,特别是像 st.image
、st.tabs
、st.file_uploader
和 st.color_picker
这样的控件。不过为了简洁起见,我们没有涵盖所有内容,所以请随时跳转到 源代码 仓库,自己查看所有细节。你可以看到,我在多个应用运行间保持一致性时,依赖了大量的 st.session_state
。
我们学会了
-
构建一个图像加载组件(加载现有图像、上传或 URL 下载)。
-
构建一个带滑块的图像增强组件。
-
使用 K-means 将像素聚类,并使用组平均值来构建调色板。
-
使用颜色选择器小部件将调色板返回给最终用户。
我很高兴能带你走过所有这些步骤,希望你喜欢阅读/浏览这篇文章,并希望你学到了一些东西。
你可以亲自查看这个应用,这里。我很想听听你的想法、问题、评论和反馈!通过LinkedIn或我的网站与我联系。
这篇文章与 Ksenia Anske 及团队 Streamlit 在 Snowflake 的合作下编写。请查看这篇文章在 Streamlit 博客上发布的最终版本 这里。
除非另有说明,所有图片均由作者提供。
使用 MLflow 和 ATOM 跟踪所有机器学习实验,而无需额外的代码
开始仅通过更改一个参数来存储模型、参数、管道、数据和图表
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 3 月 13 日
--
照片由 Hans Reniers 提供,来源于 Unsplash
介绍
MLflow Tracking 组件是一个 API 和 UI,用于记录参数、代码版本、指标和输出文件,在运行机器学习实验时用于后续结果的可视化。
在本故事中,我们将解释如何使用ATOM库轻松跟踪你的模型、参数、管道、数据和图表。ATOM 是一个开源 Python 包,旨在帮助数据科学家探索机器学习管道。
注意:本故事侧重于使用 ATOM 的实验跟踪功能。讲解库的基础知识不在本故事范围内。如果你想要一个温和的库介绍,请阅读这篇其他故事。
功能概述
通过在atom的构造函数中为[experiment](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-experiment)
参数分配名称来开始跟踪实验。每个模型都使用单独的运行进行跟踪。当未配置后端时,数据会存储在./mlruns
本地。要配置后端,请在初始化atom之前,在笔记本或 IDE 中使用mlflow.set_tracking_uri。这不会影响当前活动的运行(如果存在),但会对后续运行生效。在终端运行mlflow ui
以打开 MLflow 的跟踪 UI,并在localhost:5000
查看。
注意: 在Databricks上使用 ATOM 时,实验名称应包含存储的完整路径,例如/Users/username@domain.com/experiment_name
。
跟踪以下元素:
标签
运行将自动标记为模型的全名,分支
模型训练来源及拟合模型所需时间。
通过ht_params
参数添加其他自定义标签,例如atom.run(["LR", "RF"], ht_params={"tags": {"tag1": 1}})
。
参数
初始化时使用的所有参数都会被跟踪。传递给 fit 方法的额外参数不会被跟踪。
模型
模型的估算器被存储为工件。可以通过atom的[log_model](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-log_model)
属性来关闭此选项,例如atom.log_model = False
。
超参数调整
如果进行超参数调整,每个试验都会作为模型主运行中的嵌套运行进行跟踪。可以通过atom的[log_ht](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-log_ht)
属性来关闭此选项,例如atom.log_ht = False
。
指标
所有指标结果都会被追踪,不仅是在训练过程中,还包括在稍后调用 evaluate 方法时。在训练验证期间计算的指标也会被存储。
数据集
用于拟合和评估模型的训练和测试集可以作为.csv
文件存储到运行的工件中。可以使用atom的[log_data](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-log_data)
属性开启此选项,例如atom.log_data = True
。
管道
模型的管道(由 export_pipeline 方法返回)可以作为一个工件进行存储。可以使用atom的[log_pipeline](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-log_pipeline)
属性开启此选项,例如atom.log_pipeline = True
。
图表
默认情况下,图表作为.html
工件存储在所有与图表中显示的模型对应的运行中。如果指定了filename
参数,则会以该名称存储,否则使用方法的名称。可以使用atom的[log_plots](https://tvdboom.github.io/ATOM/v5.1/API/ATOM/atomclassifier/#atomclassifier-log_plots)
属性关闭此选项,例如atom.log_plots = False
。
示例
了解功能的最简单方法是通过示例。我们按照通常的方式初始化atom,并指定experiment
参数。这里提供的名称是使用的 mlflow 实验的名称。如果没有该名称的实验,则会创建一个新的。
from atom import ATOMClassifier
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
atom = ATOMClassifier(X, y, experiment="breast_cancer")
让我们指定一下我们也想记录数据和管道。这两个选项默认是关闭的。
atom.log_data = True
atom.log_pipeline = True
按照通常的方式训练模型。所有选择的指标都会被记录。
atom.run(models=["LR", "RF"], metric=["f1", "accuracy", "precision"])
现在在终端中运行mlflow ui
打开 UI。每个模型都有其单独的运行。
每次运行,atom会存储:
- 模型的参数:
- 训练集和测试集上的指标分数:
- 预定义标签和自定义标签(如果指定):
- 工件,包括估计器和管道的 pickle 文件,以及训练和测试集的 csv 文件。
额外的指标(使用 evaluate 方法计算)和显示模型的图表也会被添加到运行中。
atom.evaluate()
atom.plot_gains()
超参数调优
在运行 超参数调优 时,研究的每个试验都会作为嵌套运行添加到主运行中。
atom.run(models="LGB", n_trials=10)
嵌套运行的名称为 <model_name> — <trial_number>
。它们的指标分数不是在训练集或测试集上,而是在验证集上,验证集是用于验证该特定试验的训练集子集(其余部分用于拟合估计器)。测试集在超参数调整期间不会使用,以避免数据泄露。
注意: 数据和管道不会存储在嵌套运行中。
训练中的验证
一些 模型 允许 训练中的验证。这意味着在每轮训练后(对于线性模型是一轮迭代,对于提升树模型是增加的一棵树),估计器在训练集和测试集上使用 第一个指标 进行评估。验证分数存储在 evals
指标中,也在 mlflow 中跟踪。
注意: evals
指标在嵌套运行中不会计算。
DAGsHub 集成
ATOM 内置了与 DAGsHub 的集成,这是一种基于开源工具的数据科学优化平台,面向开源社区。要将你的 mlflow 实验存储在 DAGsHub 仓库中,请在 experiment
参数中输入 dagshub:<experiment_name>
(而不仅仅是实验名称)。如果仓库不存在,将创建一个新的公共仓库。
一个简约的示例如下:
from atom import ATOMClassifier
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
atom = ATOMClassifier(X, y, experiment="dagshub:breast_cancer")
atom.run(models=["LR", "RF"])
注意: 如果在运行 atom 的构造函数时已登录到你的 DAGsHub 账户,浏览器页面会自动打开以授予访问权限。如果没有,请阅读 这里 了解如何设置你的 DAGsHub 凭据。
结论
我们展示了使用 ATOM 库跟踪机器学习实验的简便性。只需最少的代码更改,现在可以存储每个训练模型的估计器、参数、管道、数据和图表。
有关 ATOM 的更多信息,请查看软件包的 文档。对于错误或功能请求,请随时在 GitHub 上提出问题或发邮件给我。
相关故事:
-
towardsdatascience.com/atom-a-python-package-for-fast-exploration-of-machine-learning-pipelines-653956a16e7b
-
towardsdatascience.com/how-to-test-multiple-machine-learning-pipelines-with-just-a-few-lines-of-python-1a16cb4686d
-
towardsdatascience.com/from-raw-data-to-web-app-deployment-with-atom-and-streamlit-d8df381aa19f
-
towardsdatascience.com/exploration-of-deep-learning-pipelines-made-easy-e1cf649892bc
-
towardsdatascience.com/deep-feature-synthesis-vs-genetic-feature-generation-6ba4d05a6ca5
-
towardsdatascience.com/from-raw-text-to-model-prediction-in-under-30-lines-of-python-32133d853407
-
towardsdatascience.com/how-to-make-40-interactive-plots-to-analyze-your-machine-learning-pipeline-ee718afd7bc2
-
towardsdatascience.com/machine-learning-on-multioutput-datasets-a-quick-guide-ebeba81b97d1
参考资料:
- 所有图表和图片(除特色图片外)均由作者创建。
使用多任务和集成学习预测阿尔茨海默病的认知功能
个人故事与数据科学
认识到我在机器学习领域的影响力,源于认知科学和我的第一篇科学论文的发表
·发布于 Towards Data Science ·11 分钟阅读·2023 年 6 月 9 日
--
图片由 Alina Grubnyak 拍摄,来源于 Unsplash
在我之前的一篇 文章 中,我详细描述了我从认知科学转向机器学习的经历以及对我的困惑症。 在那篇文章中,我提到:
“一个想法开始慢慢展开——也许,我的背景 [在认知科学] 提供了比我最初预期的更为坚实的基础。”
在这篇文章中,我将分享一个具体的例子,展示我的认知科学背景如何使我能够 1) 为在神经科学领域对我个人具有重要意义的疾病开发创新的建模方法,以及 2) 构建通常在传统讨论中被忽视的独特联系。
通过这次经历,我意识到深度学习领域虽然潜力巨大,但仍处于初期阶段,这提醒了我它对传统和非传统背景的个人都提供了包容性的机会。
大脑网络实验室
完成本科学位后,我心中萦绕着一种感觉,即虽然有不错的理论基础,但缺乏有效应用这些工具的实践理解。我设想了一个理想的场景,在神经科学或心理健康领域应用这些工具。
因此,当有机会申请Wicklow 医学中的人工智能研究计划时,我感到非常激动,这意味着我将致力于在我的研究生项目中进行一个以“利用人工智能模型推动医疗研究,包括肿瘤学、心脏病学和神经学等领域”的实践。
图 I. 由 cottonbro studio 在 Pexels 上拍摄(来源)
在这个项目中,我被接受了🎉,最终在 UCSF 的脑网络实验室工作。该实验室的重点是:
“通过将计算工具应用于神经影像数据,理解健康和疾病脑部的机制。”
我迫不及待地等待着有关实践的更多细节;兴奋之情在我内心涌动,期待开始弥合理论与实践之间的差距。最后,当我收到任务时,期待已久的时刻终于到来了。
预测阿尔茨海默病患者的认知评分
任务: 预测阿尔茨海默病患者的认知评分
问题: 我完全没有经验于计算机视觉。
我的第一次印象是,“我怎么可能对此做出贡献?”自然,恐惧和冒名顶替综合症的声音想要显现出来。此外,我在深度学习方面的经验有限——几乎可以忽略不计——而且,职位由于任务复杂性有被取消的风险。
然而,我确实拥有的是全面理解阿尔茨海默病的动力,对受影响者的真诚同情,以及利用我们现有的计算工具做出贡献的热情。此外,理论与实践之间的脱节激发了我做出贡献和成长的决心。
生活方式与阿尔茨海默病的相关性
我第二次印象是,仅仅依赖 MRI 数据来预测认知评分似乎有些奇怪,因为已知的人口统计学、遗传学和生活方式与阿尔茨海默病的相关性。
图 II. 作者拍摄。引用见文章末尾。
在研究中,“肠道微生物群、衰老、现代生活方式与阿尔茨海默病之间的联系”,作者强调“已经在阿尔茨海默病患者中报道了肠道微生物群的显著……变化……肠道微生物群对负面的外部生活方式因素,如饮食、睡眠剥夺、昼夜节律干扰、慢性噪音和久坐行为非常敏感,这些也被认为是发病的重要风险因素。
然而,我知道我需要集中精力完成任务;尽管如此,好奇心仍然驱使我去探询是否有可能获得人口统计学和临床数据的访问权限,以防机会出现…… 搓手
我们稍后会再讨论这个问题。
测量认知功能
ADAS Cog-11 分数
与此同时,我继续在仅使用 MRI 数据的基础上进行建模方法的头脑风暴。然而,你可能会想,我们如何定义‘认知功能’?换句话说,我们的预测目标是什么?我们的目标标签是什么?
我们正在使用ADAS-Cog-11来测量认知功能——这是一个用于评估阿尔茨海默病患者记忆、语言和实践能力退化的指标。根据维基百科,"它是最广泛使用的认知量表之一……并被认为是评估抗痴呆治疗的‘金标准’"。
ADAS-Cog-11 分数来源于以下十一项认知任务:
图 III. 由作者基于这里的信息拍摄的照片
你可以在这里找到任务的详细总结。
简单的卷积神经网络
为了开始实验,我们在 MRI 数据上训练了一个基准卷积神经网络(CNN)模型,以预测认知评分。
这个简单的 CNN 包含以下层:第一卷积层、池化层、第二卷积层、池化层、一个 2 层全连接神经网络,以及一个回归层。
结果并不显著,交叉验证的 R2 范围从 0.33 到 0.52,测试 R2 为 0.47。尽管如此,考虑到对这个简单模型的期望较低,我们还是很高兴建立了一个起始基准,今后可以在此基础上进行改进。
多任务学习与捕捉大脑结构上下文
下一步是研究痴呆症的结构性预测因素。这导致了一些文献,突出了灰质和白质体积与痴呆症严重程度之间的独立关系。
根据Stout et al,
“定量磁共振方法提供了有力证据,表明皮质灰质体积(可能反映萎缩)和异常白质体积与可能的阿尔茨海默病痴呆的严重程度独立相关:较低的灰质和较高的异常白质体积与更严重的痴呆相关。”
从中,一个理论浮现出来:如果模型能够捕捉灰质和白质体积的信息,它应该会提升预测能力。
那我们该如何做到这一点呢?
简短的回答是:✨ 多任务学习 ✨
“多任务学习是机器学习的一个子领域,其中多个学习任务同时解决,同时利用任务之间的共性和差异 […] 利用相关任务训练信号中包含的领域信息 […] 每个任务学到的东西可以帮助其他任务学得更好”
— 多任务学习,2021 年 7 月 6 日。在 维基百科
直觉:如果我们的模型预测认知评分,同时学习将输入 MRI 扫描分割为白质、灰质和脑脊液,这些相关任务将利用共享的领域信息,并提升每个单独任务的表现。
U-Net 架构
对于这个模型,我们使用了 U-Net 架构:
“一个依赖于强大的数据增强使用的架构,[…] 包括一个收缩路径来捕捉上下文和一个对称的扩展路径以实现精确定位 […] 这样的网络可以从非常少的图像中端到端训练,并超越了之前最好的方法(一个滑动窗口卷积网络)” — Ronneberger et al., 2015
图 IV. U-Net 架构。许可:Apache 2.0 (来源)
这是一个具有吸引力的架构来进行实验,因为我们能够使用非常少的图像来实现比以前的方法更好的性能,只要我们应用数据增强技术。
医学成像中的数据增强技术是什么?
数据增强是对输入图像应用随机化更改(例如:平移、旋转、翻转、拉伸、剪切等)的过程,以增加变异性。这个过程通过对输入数据进行小的位移来使我们的模型更好地泛化,只要修改后的图像仍然在可能的输入范围内。
这些技术还作为解决标记医学成像稀缺挑战的宝贵解决方案。获取足够大的医学成像数据集是一个突出的问题,主要有两个原因:1)医学扫描的人工标注极其耗时,2)临床数据的共享受到越来越严格的患者隐私法的限制。
图 V. 数据增强。照片由作者提供。
不同器官和模态的性能
实验各种数据增强技术以找出最适合特定任务的技术是常见做法。然而,特别是在医疗影像中,指导探索过程的一种方式是了解不同器官/结构、模态和任务组合的合适——也即最有效的——增强技术。在这里,“合适”指的是确保增强的数据包含输入空间内有效的示例。
例如,弹性变形通常适用于具有固有弹性或可变形性的器官。脑组织就是一个很好的例子,因为大脑具有显著的神经可塑性,能够在经历、学习和环境变化时进行结构和功能上的变化。
圣地亚哥·拉蒙·卡哈尔,这位被誉为现代神经科学之父的人曾宣称:
“任何人都可以,如果他愿意的话,成为自己大脑的雕塑家”
— 圣地亚哥·拉蒙·卡哈尔
然而,骨骼的可变形性有限,血管是刚性结构,因此应用弹性变形可能无法准确表示现实的变化或保持解剖结构的完整性。
此外,缩放(放大图像以强调特定区域)通常适合 X 光图像,因为它们通常涵盖了更广泛的视野。然而,对于已经具有较窄视野并集中于特定兴趣区域的 MRI,缩放可能会无意中排除重要的背景信息,从而使其作为 MRI 数据增强技术不太适用。
关于医疗影像数据增强技术的详细文献综述可以在这里找到。作者强调了这一点:
“根据输入的性质和视觉任务的不同,不同的数据增强策略可能表现不同。因此,可以想象医疗影像需要特定的增强策略,以生成合理的数据样本并有效地对深度神经网络进行正则化。”
多任务学习架构可视化
我在下面提供了多任务模型的可视化表示。在 U-Net 架构的最底部,我们加入了一个回归块。在这个块中,所有像素被展平,然后通过一个线性层,最终输出表示认知分数。
图 VI. 多任务,U-Net 架构。照片由作者提供。
U-Net 和多任务学习性能
总结一下,最初我们使用了一个简单的 CNN 基线模型,得到的交叉验证 R2 值范围从 0.33 到 0.52,测试 R2 值为 0.47。
然而,实施多任务模型后,我们见证了显著的改进:
-
交叉验证 R2:提升至 0.41 到 0.69 的范围
-
测试 R2:提升至 0.57
还值得一提的是,我们为分割任务开发了一个基线模型,使用了 U-Net 架构,准确率为 93.7%。
值得注意的是,分割任务的性能也得到了提升,使其准确率从 93.7% 提升到 97.27% 🎉
图 VII. 分割和回归结果。作者拍摄
多任务模型似乎在脑部分割的当前最先进水平中表现更佳,如下所示:
图 VIII. 作者拍摄。文章来源可以在 这里 查找。
那么人口统计学和遗传特征呢?
那么我之前承诺回到的人口统计学和遗传特征呢?
嗯,我们获得了数据!干杯
第一步是仅在表格特征上训练另一个基线模型。我们最终选择了具有泊松损失函数的基于直方图的梯度提升回归模型(HGB Regressor),结果出乎意料地不错。
-
交叉验证 R2:范围从 0.56 到 0.63
-
测试 R2:0.51
将表格数据整合到一个集成多任务模型中
在最终实验中,我们集思广益,寻找有效地将这些特征整合到我们的模型中的方法。
表格数据和成像数据的整合是一项挑战,因为成像数据(每个体素都被视为一个输入)与表格数据集(每个特征表示为一个单一值)之间的输入数量差异很大。将所有特征等同对待可能会降低人口统计学和遗传风险因素在整体模型中的重要性。
为了解决这个问题,我们为表格数据开发了一个单独的模型,使用 HGB 回归器。然后,我们应用了加权平均集成方法,将 HGB 回归器和多任务模型的预测结果结合起来。对每个模型分配的权重基于其性能和可靠性,给更准确或更有信心的模型分配更高的权重。这种集成技术通过分配适当的权重有效地优化了每个模型的贡献。
下面,你可以看到这种集成方法的可视化图。
图 IX. 多任务学习和集成方法。作者拍摄。
最终模型性能
那么这个集成的多任务方法与之前的实验相比如何呢?
鼓点
回归任务性能:
-
交叉验证 R2:0.73–0.78
-
测试 R2:0.67
分割任务性能:
- 准确率:98.12%
与多任务学习中观察到的显著性能提升类似,通过集成方法纳入人口统计特征和遗传风险因素,不仅显著提升了回归任务的性能,还进一步增强了分割任务的表现。这清晰地展示了利用多个数据源并利用其协同潜力的力量。
结束语
有机会深入研究神经科学与机器学习的交汇点,特别是在阿尔茨海默病方面,并意识到我的背景使我能够将不同领域的概念联系起来,这种经历是颠覆性的。自从体验到将当前神经科学研究与机器学习概念对接所带来的显著改进以来,我对整合多样数据源和以领域驱动方式应用模型架构的力量有了更深刻的认识。
我希望这项研究能够激励那些:
-
从认知科学转向机器学习
-
对获取和应用神经系统疾病技术充满热情
-
对医学领域的计算机视觉感兴趣
如有任何问题,请随时联系,我希望这对你来说与对我一样令人激动!
图 II 使用的文章:
使用 OpenAI 和 Python 提升你的简历:一步一步的指南
这是一个成功故事的经历,只需几个小时的工作
·
关注 发布于 Towards Data Science ·10 min read·2023 年 2 月 24 日
--
图片由作者使用 Midjourney 制作
几天前,我下班回家,开始玩视频游戏。
玩视频游戏对我来说是一个很好的消遣,因为我非常糟糕,我根本不专注于游戏;我只是开始思考。我让我的思绪随意飘荡。通常,我会想到我所喜欢的东西,比如人工智能。
所以,当我在 PS4 上进行任务时(当然,我又输了),我在想许多 NLP 任务,比如标记分类、下一个词预测(文本生成)、情感分析、文本分类以及其他许多任务,现在基本上可以在几秒钟内解决。
比如说,ChatGPT 是一个非常出色的聊天机器人,经过训练可以以对话的方式回答任何领域的问题。它可以总结文本、回答问题、写代码、做模仿、写歌曲、写食谱……
当前的趋势是构建更小、更具可扩展性且开源的代码,这些代码可以用来替代 ChatGPT 并且拥有免费的 API,所以重点并不仅仅是ChatGPT。重点在于这些所谓的“大型语言模型”,它们在大量文本上进行训练,凭借如此强大的计算能力,它们可能超越了你在笔记本电脑上可以建立的所有小型方法。
我的意思是,如果你需要总结一段文本(那不是医学上未知的、超级困难的文本)或理解一条评论是好是坏,没有理由自己开发模型,因为 OpenAI 的免费模型可能在几秒钟内就能做到这一点。
所以我在想,“什么任务需要大量工作,但现在我们可以在几秒钟内完成?”例如,写简历。
当然,没有人可以从零开始为我们写简历,因为他们不了解我们的职业生涯,但在线上有一些服务,你可以利用 AI 来改进你的简历。这些服务通常不是免费的,我认为现在它们已经过时了,因为 OpenAI 确实是免费的,而且可能比所有其他模型做得更好,除非那些模型是 Meta、Google 或微软的模型。😅
所以我决定使用 Streamlit 构建一个网页应用,大家可以上传自己的简历,人工智能(正如我们将看到的,更具体地说是 OpenAI)将在几秒钟内改进他们的简历。
在我看来,这款应用的工作方式如下:
作者提供的图片
很简单,对吧?
几个小时的工作,这就是结果:
现在,这看起来像是一个已经运营了几个月的初创公司,但事实上,现在只需要我花费几个小时,一些基本的 Python 技巧,加上 Open AI 的 GPT-3 魔力,就能完成这些!
我是一个极客,如果有人在没有代码的情况下向我展示这些,我是不会相信的,所以让我在下一章中详细解释一下我做了什么!
0. 一些考虑事项
在我们开始之前,有一些事情需要考虑,以使讨论更加全面。
一个公平的假设:
首先,HR 和招聘人员的世界因多种原因而充满挑战。例如,求职市场是动态的,并且始终在不断发展。AI 使用训练模型。这意味着一旦模型被训练,其获得的信息和性能与可能已经过时的数据集有关。
因此,我会说这个模型更多地作为语法检查和文本美化器工作,而不是一个真正审查你的简历并找到改进方法的专家。
一个安全问题:
提交个人数据到软件中是否安全?
嗯……是的,也不是。
在这篇文章中,我并没有邀请你提交除工作经验之外的任何东西,这些内容并不是什么秘密,因为任何拥有 LinkedIn 页面的人都能看到。如果你对此仍然感到不安,记住始终可以选择在本地运行我的代码,这样就不会使用任何网络应用程序,你可以将简历的所有信息保留给自己。
我不推荐在 AI 简历改进工具中添加个人信息,如地址、电话号码或电子邮件。
一个伦理问题:
在使用 AI 来改进你的简历时,有一些一般性的伦理考虑需要记住。
-
构建一个没有偏见的数据集是一个非常棘手的概念,因为我们所有人都可能在某种程度上有偏见:我们能做的唯一事情就是尽量构建一个尽可能少偏见的数据集。这也适用于这种情况。人工智能在招聘过程中的盲目和不受控制的使用,无论是招聘还是构建简历,都极具风险,因为机器学习决策算法在所有决策阶段和简历的各个部分都可能会出现偏见错误。(阅读更多 这里)
-
重要的是要对 AI 在简历中的使用保持透明。如果你使用 AI 生成内容或优化格式,我们需要确保在你的简历中披露这一点。这是我们在工作中应该做的事情。有时由于这些技术与我们的生活如此紧密,容易忘记正确披露,但这仍然需要提醒。还有一些工具可以用来判断文本是否由 AI 编写(阅读更多 这里)
-
最后,记住你只是在使用语言模型。模型真正做的只是以一种花哨的方式基于全球数十亿文本预测“下一个词”。你对自己了解胜过计算机,因此深入挖掘,提升你的优点,给自己一些肯定,然后再使用语言模型来改进你的简历😊
既然我们都达成一致了,那就开始代码吧 💻
1. Github!
首先,这里没有什么秘密!一切都是公开的,且在 Github 上! 👇
我现在将描述所有必要的内容,以生成结果。
[## GitHub - PieroPaialungaAI/AI_CV_improver: 使用人工智能改进你的 CV]
你好 😃 感谢你来到这里!这是一个 Python 项目,将帮助你使用人工智能改进 CV…
github.com](https://github.com/PieroPaialungaAI/AI_CV_improver?source=post_page-----e2c1a359e194--------------------------------)
1.1. Constants.py
让我们从简单的事情开始
constants.py 是一个包含……常量的文件。
它获取我们模板简历的关键,OpenAI 模型的温度,以及我们用来改进简历的提示。这就是他们所谓的“提示工程”。
注意!!!你需要使用你的 OpenAI 密钥来更改 OPENAIKEY。它不是公开的,你不应该共享,因此我将其命名为 fake_key。你可以在这里获取密钥
openai.com/api/
你可以通过更改 constants.py 文件来调整提示工程。难道这不是很简单吗?🙃
1.2 utils.py
utils.py 是一个帮助我们从 .txt 文件中提取内容并提取总结部分的文件。只是一个诚实的家伙,做他被支付的工作。
1.3 cv_parser.py
cv_parser.py 做的事情确实类似于 utils.py,我实际上不确定是否应该将代码拆分为两个 .py 文件。它只是一个工具箱;它处理模板和结果文件(分别是过程的开始和结束),并将工作经验解析成文本部分。它有点像工具所做的工作,但更相关于任务……我可能应该将它们放在同一个文件中,但我喜欢有序 😂
1.4 ai_improver.py
ai_improver.py 处理实际的 AI 部分。它通过使用 OpenAI 密钥连接到 OpenAI 来改进简历总结和列出的每一项工作经验。它还使用了我们在 constants.py 文件中构建并放置的所有提示。
1.5 app.py
这是我们用来运行应用程序的内容。我们用以下命令运行它
streamlit run app.py
它做的就是……一切。
这里是所有路径连接的地方。脚本执行以下步骤:
-
它从上传的文件中获取数据,并使用 utils.py 和 cv_parser.py 解析这些数据。
-
它使用ai_improver.py来处理 AI 部分。ai_improver.py 还使用 constants.py 文件。
-
它允许你通过下载按钮下载 AI 改进的简历。
1.6 main.py
main.py 的功能与 app.py 完全一样,但在本地运行(没有流媒体)。
使用以下命令运行它
python main.py
2. 结果!
让我们好好看看我得到的结果。
我的输入总结如下:
“我是一名物理学家,已在数据科学和深度学习方面发展了强大的技能。拥有复杂系统物理学和大数据的硕士学位(成绩:110/110 荣誉)。目前在辛辛那提大学航空航天工程与工程力学系担任研究助理,同时攻读博士学位。我的研究目标是寻找使用机器学习和数据科学优化工程实验的方法。为《Towards Data Science》撰写关于流行机器学习应用和技术解决现实问题的技术文章(每月超过 6 万次浏览)。《青少年人工智能》课程的创作者和教师。图像处理 I 和信号处理 I 课程的开发者。”
人工智能模型将其输出为一个对招聘人员更具吸引力和趣味性的总结:
我拥有广泛的物理学背景,拥有复杂系统物理学和大数据的硕士学位(成绩:110/110 荣誉)。我还在数据科学和深度学习方面发展了强大的技能。我目前在辛辛那提大学航空航天工程与工程力学系担任研究助理,同时攻读博士学位。我的研究重点是寻找使用机器学习和数据科学优化工程实验的方法。此外,我还是《Towards Data Science》的技术写作者,撰写关于流行的机器学习应用和解决现实世界问题的技术文章(每月超过 6 万次浏览)。我还是《青少年人工智能》课程的创作者和教师,并且是《图像处理 I》和《信号处理 I》课程的开发者。我在这些领域的经验和知识使我培养了一套在行业中非常受欢迎的独特技能。
我不打算撒谎……这对我来说确实听起来更好。
这无疑更流畅(这很可预测,因为我不是以英语为母语的人),听起来也更激动人心。不过,它略长一些,这可能是个问题,因为我们希望在简历中节省空间。
让我们来看一下某段工作经历的回顾。例如,其中之一是:
使用机器学习创建代理模型并提高机械工程实验的效率。实践传统和深度机器学习的研究。
而人工智能将其改正为:
利用机器学习创建代理模型并优化机械工程实验。拥有传统和深度机器学习的经验,并且有成功研究的可靠记录。
…有时我感觉这有点过于大胆了。😂
我仍然会说它确实提高了文本的质量。在这种情况下,它也保持了简洁,这也是一个优点!
3. 一些要点!
在这篇文章中,我分享了使用 OpenAI 的 GPT-3 模型来创建简历改进网络应用程序的经验。以下是我的经验的关键要点:
-
Open AI 模型是一个可以执行各种 NLP 任务的工具,包括令牌分类、文本生成、情感分析和文本分类。它非常有用,以至于许多以前的模型现在基本上已经过时。
-
这个应用的概念是下载模板, 填写你的简历,上传回去,然后让 AI 为你分析。
-
创建一个网页应用来提升你的简历非常简单。只需几个小时的代码和一些基础的 Python 编程。
-
这让我想到,AI 驱动应用的可能性是巨大的,随着像 GPT-3 和 Streamlit 这样的工具变得更加普及,即使是不懂编程(或编程能力不强)的人员,也可以在短短几小时内创造出令人印象深刻的结果。
最后,这次经历再次展示了 AI 的巨大力量及其改变我们工作和生活方式的潜力。但也许我们早已知道这一点🙃
4. 结论
如果你喜欢这篇文章,想了解更多机器学习的内容,或者只是想问我一些问题,你可以:
在 Linkedin 上关注我,我会在那里发布我的所有故事。
订阅我的 通讯。它会让你了解最新故事,并给你机会发信息给我,获取所有的修正或疑问。
成为 推荐会员,这样你就不会有“每月最大故事数量”的限制,你可以阅读我(以及数千位机器学习和数据科学顶级作者)关于最新技术的所有内容。
使用 OpenCLIP 进行图像搜索和自动字幕生成
原文:
towardsdatascience.com/using-openclip-for-image-search-and-automatic-captioning-fa1cbbd48ce4
如何利用更多数据和新的机器学习训练技术来改善图像和文本嵌入,以应用于各种场景
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 3 月 7 日
--
“用于在大型库中查找图片的高科技计算机系统,”我使用 AI 图像创建程序 Midjourney 创建,并由作者编辑
自 2021 年推出以来,我一直在使用和撰写关于 OpenAI 的 CLIP 系统的文章 [1]。它包含可以用于各种跨模态比较的图像和文本编码模型,例如使用文本查询快速找到库中最匹配的图像。
2022 年 12 月,一个名为 LAION 的独立研究小组发布了一篇名为《对比语言-图像学习的可重复扩展定律》的论文 [2],描述了他们如何首先重新实现并训练了一个类似于 CLIP 的模型,然后通过使用更大的数据集和新的机器学习技术来改进系统。他们称他们的新模型为 OpenCLIP。
在这篇文章中,我将提供有关原始 CLIP 的一些背景信息,描述 LAION 如何改进该模型,并展示我使用 国会图书馆 Flickr 照片流 的图像进行的两个系统实验的一些结果。我还实现了一种称为“嵌入算术”的酷技术,这项技术来自 Meta AI,用于通过图像和文本提示搜索照片 [3]。
我将在文章末尾展示一个使用 LAION 的 OpenCLIP 模型变体自动生成图片字幕的演示。例如,我使用 Midjourney 创建了这篇文章的标题图片,文本提示为“用于在大型库中查找图片的高科技计算机系统”。当我将该图像输入 OpenCLIP 时,它生成了标题为“一个女孩站在图书馆里看电视”的字幕。还不错!
OpenAI 的 CLIP
在 2021 年,OpenAI 发布了一篇名为“从自然语言监督中学习可转移的视觉模型”的论文,描述了他们的新系统,称为对比语言-图像预训练(CLIP)。该系统包括两个 AI 模型,一个文本编码器和一个图像编码器,训练以生成称为嵌入的数字数组。他们发布了源代码和预训练模型。
CLIP 的系统组件, OpenAI 绘制的图示
OpenAI 使用 4 亿对图像/文本对训练编码器,目标是使编码器生成相似的嵌入,当文本和图像相似时。该系统可用于涉及文本和图像的各种任务,如检索,即用文本搜索图像,以及分类,即自动将图像分配到类别中。你可以在我的使用该系统的文章中了解更多关于 CLIP 的信息。
LAION 的 OpenCLIP
LAION 是一个独立研究者小组,提供数据集、工具和 AI 模型。他们重新实现了 OpenAI 的 CLIP 模型,并在他们的 20 亿对图像/文本对的数据集上进一步训练以提高性能。在他们的论文[2]中,他们讨论了如何通过使用 Google 发明的浮点数格式 bfloat16 解决在更多数据训练时遇到的问题[4]。
对于更大规模的模型……我们观察到训练过程中损失值的激增,对性能产生了不利影响。我们通过从 float16 的混合精度切换到 bfloat16 的混合精度解决了这个问题。我们推测 bfloat16 解决了这个问题,因为较大的模型通常显示更大的激活值……使 bfloat16 在其更宽的动态范围下更为合适。—— Mehdi Cherti 等,来自 LAION
他们的论文展示了他们的 OpenCLIP 系统在图像搜索等任务上优于 OpenAI 的 CLIP。下图中,OpenCLIP 用橙色表示,CLIP 用蓝色表示,其中数值越小越好。
使用 CLIP(蓝色)和 OpenCLIP 进行图像检索的结果(数值越小越好), LAION 绘制的图表
图表中有很多信息。我会看看能否将其拆解。横轴表示用于训练的计算量,以 Giga 乘加操作 (GMACs) 为单位。纵轴显示了结果的准确度,定义为 100 - Recall @ 5 针对 Flicker 30K 数据集。例如,如果系统搜索了五张猫的照片,只有四张包含猫,则 Recall@5 值为 80%。但从 100 中减去它将得到 20%,值越低越好。你跟上了吗?图表中的形状代表各种数据集,右侧的键显示了不同的大小。蓝线显示了 CLIP 在 CLIP-WIT 数据集上的表现,经过多种配置训练。橙线显示了最佳 OpenCLIP 模型在 LAION 数据集上的表现,使用了各种配置。底线:OpenCLIP 比 CLIP 更好。
如上图所示,LAION 计算了训练和性能之间关系的方程,带有指数成分。他们讨论了使用这种“缩放定律”来预测需要多少额外的训练以进一步提高他们的模型[2]。
在下一部分,我将展示如何构建和运行一些测试,以展示这两个系统在美国国会图书馆图像上的表现。
比较 OpenCLIP 和 CLIP
我使用了国会图书馆 Flickr 照片流中的照片进行测试。图书馆发布了超过四万张带有说明的图片供人们浏览和评论。请注意,所有图片均标记为“无已知版权限制”,因此可以用于任何目的。
这里是一些数据集样本,图像上方有说明。
来自国会图书馆 Flickr 照片流的样本,来自 LOC 的图像
通过这些样本,你可以对数据集中图像的类型有个了解,包括绘画、旧照片、新照片等。
为了测试这些系统,我将六个说明和图像通过 CLIP 和 OpenCLIP,并计算了余弦相似度,这是一种衡量文本和图像嵌入之间接近度的指标。请注意,结果范围大致从 0.0 到 0.4,较低的数字表示不匹配,较高的数字表示匹配。
图像横向展示在顶部,相应的说明纵向列在左侧。你可以看到,右侧的 OpenCLIP 结果在对角线块(较亮的黄色)上的匹配分数更高,而在非匹配分数(较暗的蓝色)上则更低,相比之下,左侧的 CLIP 结果较低。这意味着,如果你使用这些系统搜索带有文本的图像,使用 OpenCLIP 会得到比 CLIP 更好的结果。
使用 OpenCLIP 探索 LOC 图像
为了探索国会图书馆的 Flickr 照片流,我创建了一个Google Colab,下载了所有 4 万张图像,并通过 OpenCLIP 图像编码器进行文本搜索。
使用 OpenCLIP 搜索 LOC 图像的组件,图表由作者绘制,图像来自 LOC
我开始使用Flickr API将所有 4 万张照片下载到本地文件夹。接下来,我将图片发送到 OpenCLIP 图像编码器以创建图像嵌入。编码器之前由 LAION 使用 20 亿张带有说明的图像进行训练。然后,我输入了文本查询,例如“gone fishing”,并通过文本编码器创建了一个嵌入。我计算了文本嵌入与 40K 图像嵌入之间的余弦相似度,以找到最佳匹配。最后,我对数组进行排序并显示了搜索的前六张图像。
这里是使用 OpenCLIP 和数据集结果的一些示例搜索。
“boat vacation”
我输入了一个搜索短语并点击了“运行”按钮,以查看得分前六的结果。
“boat vacation”的 OpenCLIP 搜索结果,截图由作者提供,图像来自 LOC
果然,它找到了几艘假期中的船只。注意这些匹配的得分相对较低(0.259 到 0.281)。这些得分较低的原因可能是由于使用了有些抽象的词语“vacation”。接下来,我尝试了一些更具体的内容。
“建造飞机发动机”
在这里,我尝试使用了更具体的搜索短语。
“building an airplane engine”的 OpenCLIP 搜索结果,截图由作者提供,图像来自 LOC
好的,这次搜索的得分要高得多(0.302 到 0.326)。最佳结果展示了一张人们建造飞机发动机的漂亮照片。接下来,我尝试了一些有趣的东西。
“mini golf”
数据集中有很多具有美国风情的图像,所以我检查了一下是否有迷你高尔夫球场的图像。
果然,答案是“是的!”注意这些图像的得分较高(0.378 到 0.395)。最佳结果是一个经典的风车洞,风车叶片上写着“MINI GOLF”两次。在描述一种酷炫的新方法来优化图像搜索之后,我会重新审视这个搜索。
嵌入算术
2022 年 10 月,Meta AI 发布了一篇标题引人注目的论文,《多模态查询的嵌入算术用于图像检索》,其中“多模态”指的是其他形式的媒体,如文本[8]。
概念是这样的:如果你在数据集中找到一张图像,并且想要找到另一张保留部分特质但改变其他特质的图像,你可以通过结合图像和文本元素来构建查询。运行新的查询应该能够找到你所寻找的内容,前提是数据集中存在这样的图像。
这里是论文中的一个视觉示例。
使用嵌入算术进行图像检索,图片由 Meta AI 提供
这开始于顶部的猫图像,该图像被编码成嵌入 E(I)。然后,“cat”和“dog”这两个词通过文本编码器分别得到 E(W1) 和 E(W2)。E(W2) 和 E(W1) 之间的差值被加到猫图像的嵌入中,从而找到类似的狗图像所在的位置。从图像数据库中检索得到的结果显示了一个接近的匹配,如底部所示。匹配通过将原始标题中的“dog”替换为“cat”来评估,得到“狗坐在草地上。” 将转换后的标题文本嵌入与狗图像的嵌入进行比较,以查看是否匹配。
论文讨论了如何使用缩放因子 λ 来调整从文本提示中进行的修改量。这里是产生新嵌入 x 的方程。
论文讨论了如何在 1.0 到 1.5 之间的缩放因子对于许多搜索效果良好。
我在我的 Colab 中实现了这种形式的嵌入数学。这里是一些结果,从基于迷你高尔夫风车图像的修改搜索开始。
迷你高尔夫风车图像 + 1.5(“约塞米蒂·萨姆” - “风车”)
对于这次搜索,我从迷你高尔夫风车的肖像照片开始,添加了“约塞米蒂·萨姆”这一短语,并删除了“风车”。我使用了 1.5 的缩放因子。
迷你高尔夫风车图像 + 1.5(“约塞米蒂·萨姆” - “风车”),作者截图,图片来自 LOC
原始图像在左上角,最佳匹配在其旁边,得分很高,为 0.407。顶部的结果与起始图像非常相似,只是显示了约塞米蒂·萨姆而不是风车。接下来是一些路边餐馆的图像。
甜甜圈店图像 + 1.2(“汉堡包” - “甜甜圈”)
对于下一个测试,我从搜索“甜甜圈店”开始,选择了一张名为甜甜圈洞的有趣地方的图像。接下来,我使用了“汉堡包”作为正向提示,“甜甜圈”作为负向提示。我使用了 1.2 的缩放因子。以下是结果。
甜甜圈店图像 + 1.2(“汉堡包” - “甜甜圈”),作者截图,图片来自 LOC
哇,它找到了一家经典的麦当劳餐厅,其中金色拱门与起始图像中的巨型甜甜圈完美对齐。注意到顶部匹配的非常高的分数 0.533。我的最后一次搜索涉及一些名人的旧照片。
亚伯拉罕·林肯图像 + 1.1(“奥斯卡·王尔德” - “亚伯拉罕·林肯”)
对于我的最终测试,我首先搜索了“亚伯拉罕·林肯”,并选择了一张他坐在椅子上的著名图像。我使用系统检查数据集中是否有类似的奥斯卡·王尔德图像。我为这次测试使用了 1.1 的缩放因子。
亚伯拉罕·林肯图像 + 1.1(“奥斯卡·王尔德” - “亚伯拉罕·林肯”),作者截图,LOC 提供的图像
果然,它找到了一个色调为褐色的奥斯卡·王尔德坐在木椅上的图像。尽管姿势不同,但最高匹配的得分是我见过的最高,为 0.675。高分可能是因为著名人物的名字和面孔之间的强相关性超越了其他因素。接下来,我将展示我如何使用 OpenCLIP 生成图像字幕。
使用 CoCa 和 OpenCLIP 创建字幕
在他们 2022 年的论文《CoCa: 对比字幕生成模型是图像-文本基础模型》中,[5] 作者展示了如何训练类似于 OpenAI 的 CLIP 的模型来自动从图像生成字幕。
在这项工作中,我们展示了对比字幕生成模型(CoCa),这是一个新的图像-文本基础模型系列,它包含了现有的视觉预训练范式,并结合了自然语言监督。CoCa 在来自各种数据源的图像-文本对上进行了单阶段预训练,能够高效地结合对比和字幕生成目标于一个编码器-解码器模型中。 — 来自 Google 的 Jiahui Yu 和 Zirui Wang
独立研究员 Phil Wang,外号 lucidrains,将 CoCa 模型适配到了 OpenCLIP 上。结果非常出色。
例如,这里有六张图像,附有来自 LOC 的原始字幕。
以下是由 CoCa/OpenCLIP 生成的带字幕的图像:
尽管字幕缺少一些细节,如具体的个人(谁)和地点(哪里),但系统在描述图像的视觉内容(什么)方面做得非常出色。你可以在我的 Colab 这里 查看这一点。
社会影响
在互联网上训练的大型 AI 模型可能会表现出文化偏见,如果不加以缓解,可能会对社会造成伤害。OpenCLIP 模型的作者在他们的论文 [2] 中表达了他们的担忧。
我们的工作涉及研究大规模预训练模型的功能和属性。将这些模型公开可能有正面和负面的影响,就像任何具有通用功能的研究工具一样。…… 基于大规模预训练通用模型的技术可能被滥用,而民主机构的任务是制定涉及这些模型的敏感应用的规则。模型的开放发布也为广泛的研究社区提供了研究这些模型安全性方面的机会,以预防恶意方滥用技术,从而进行共同的透明努力。 — Mehdi Cherti 等人,来自 LAION
他们的透明政策允许其他研究人员评估和缓解他们模型的使用。
讨论与下一步
OpenCLIP 系统在大规模数据集中搜索图像时表现良好。新的数学嵌入技术提供了专家工具,帮助人们找到完美的镜头。CoCa/OpenCLIP 模型在为图像创建描述性标题方面做得很好。
一个改进的领域是看看这些系统是否可以微调以找到或创建个人照片的标题。与 OpenAI 不同,LAION 发布了他们模型的训练代码。尽管他们的代码是为大规模训练设计的,但如果能够调整为仅使用例如你叔叔 Bob 的十张照片来微调模型,那将会很有帮助。
源代码
本项目的源代码可在 GitHub 上获得。
知识共享 署名-相同方式共享
致谢
我想感谢 Jennifer Lim 在这个项目上的帮助。
参考文献
[1] A. Radford 等人, CLIP, 从自然语言监督中学习可转移的视觉模型 (2021)
[2] M. Cherti 等人, OpenCLIP, 对比语言-图像学习的可重复性缩放法则 (2022)
[3] G. Couairon 等人, 多模态查询的嵌入算术用于图像检索 (2022)
[4] S. Wang 和 P. Kanwar, BFloat16: 云 TPUs 高性能的秘密 (2019)
[5] J. Yu, CoCa: 对比式标题生成器是图像-文本基础模型 (2022)
使用 Plotly 3D 表面图可视化地质表面
原文:
towardsdatascience.com/using-plotly-3d-surface-plots-to-visualise-geological-surfaces-8829c06a5c9a
使用 Python 数据可视化库可视化地下结构
·发布于 Towards Data Science ·9 分钟阅读·2023 年 6 月 21 日
--
Hugin 组的 3D 表面图。图像由作者提供。
在地球科学中,全面了解地下的地质表面至关重要。了解构造的确切位置及其几何形状,可以更容易地识别潜在的石油和天然气勘探新目标以及潜在的碳捕集和储存地点。我们可以使用多种方法来完善这些表面,从地震数据到测井推导的构造顶部。通常,这些技术会相互结合以完善最终表面。
本文继续我之前的文章,后者关注于通过区域内的测井数据来推断和可视化地理空间变异。在这篇文章中,我们将探讨如何使用交互式 Plotly 图表创建 3D 表面。
由于建模地质表面是一个复杂的过程,通常涉及多个迭代和完善,本文演示了如何使用 Python 简单可视化这些数据的示例。
为了了解如何利用 Plotly 可视化我们在一个区域内的地质构造顶部,我们将使用两组数据。
第一组数据来自 28 个测井孔交点,这些数据用于克里金插值以生成低分辨率表面。相比之下,第二组数据来自解释的地震数据,提供了更高分辨率的表面。
两组数据均来自 Equinor Volve 数据集,详细信息见本文底部。
你还可以通过以下链接查看这一小系列中的其他文章:
-
利用 pykrige 和 matplotlib 进行地质变化的空间可视化
-
使用 Plotly Express 可视化 3D 线图中的井路径
导入库与数据
在对数据进行任何操作之前,我们首先需要导入所需的库。这些是:
-
pandas — 用于读取我们的数据,数据格式为
csv
-
matplotlib 用于创建我们的可视化
-
pykrige 用于进行克里金计算
-
numpy 用于一些数值计算
-
plotly.graph_objects 用于 3D 可视化表面
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from pykrige import OrdinaryKriging
import numpy as np
接下来,我们可以使用 pd.read_csv()
加载数据。
由于这些数据包含 Volve 领域所有井眼的地质表面信息,我们可以使用 query()
提取我们需要的地层数据。在这种情况下,我们将关注 Hugin 地层。
df = pd.read_csv('Data/Volve/Well_picks_Volve_v1 copy.csv')
df_hugin = df.query('SURFACE == "Hugin Fm. VOLVE Top"')
df_hugin
当我们运行上述代码时,我们会得到以下表格。你可能会注意到一些井眼多次遇到 Hugin 地层。这可能是由于井眼或地层几何形状导致井眼多次穿透地层。
Pandas 数据框包含关于 Volve 领域 Hugin 地层的井位信息。图片来自作者。
外推 TVDSS 生成地质表面
在我之前的文章中,我专注于如何使用一种称为克里金的过程来“填补测量点之间的空白”。我们不会在这篇文章中详细介绍这个过程;然而,你可以查看 这篇文章以获取更多信息。
一旦数据加载完毕,我们可以通过调用 pykrige 的 OrdinaryKriging
方法来运行克里金过程。
在此调用中,我们传入我们的 x
和 y
数据,表示井眼在地下遇到地层的位置的东向和北向坐标。
由于我们想要生成 Hugin 地层的表面,我们需要使用 TVDSS(真实垂直深度水下)测量。这真实反映了表面在所选基准面下的深度。
OK = OrdinaryKriging(x=df_hugin['Easting'],
y=df_hugin['Northing'],
z=df_hugin['TVDSS'],
variogram_model='linear',
verbose=True, enable_plotting=True)
Hugin 地层的普通克里金结果。图片来自作者。
一旦模型生成,我们可以将其应用于覆盖井眼/穿透点整个范围的两个数组。
这使我们能够生成一个值的网格,然后将其传递到我们上面生成的 OrdinaryKriging
对象中。
gridx = np.arange(433986, 438607, 50, dtype='float64')
gridy = np.arange(6477539, 6479393, 50,dtype='float64')
zstar, ss = OK.execute('grid', gridx, gridy)
最后,我们可以使用 matplotlib 的 imshow
方法生成一个简单的 2D 地图视图。
fig, ax = plt.subplots(figsize=(15,5))
# Create a 2D image plot of the data in 'zstar'
# The 'extent' parameter sets the bounds of the image in data coordinates
# 'origin' parameter sets the part of the image that corresponds to the origin of the axes
image = ax.imshow(zstar, extent=(433986, 438607, 6477539, 6479393),
origin='lower')
# Set the labels for the x-axis and y-axis
ax.set_xlabel('X Location (m)', fontsize=14, fontweight='bold')
ax.set_ylabel('Y Location (m)', fontsize=14, fontweight='bold')
# Add contours
contours = ax.contour(gridx, gridy, zstar, colors='black')
colorbar = fig.colorbar(image)
colorbar.set_label('DTC (us/ft)', fontsize=14, fontweight='bold')
# Display the plot
plt.show()
应用克里金插值后的 Hugin 组的 2D 表面和等高线图。图片由作者提供。
使用 Plotly 创建简单的 3D 表面图
要将我们的 2D 表面转换为 3D,我们需要使用 Plotly。我们可以使用 matplotlib 来完成这一点;然而,根据我的经验,使用 Plotly 生成 3D 可视化更容易、更顺畅且更具交互性。
在下面的代码中,我们首先创建坐标网格。为此,我们使用 numpy 的 linspace
函数。这个函数将创建一个在指定范围内均匀分布的数字集合。对于我们的数据集和示例,范围从 xgrid_extent
和 ygrid_extent
的最小值到最大值。
在此范围内使用的总值数量将等于我们上面创建的 zstar
网格中 x 和 y 元素的数量。
一旦我们的网格形成,我们就会调用 Plotly。
首先,我们创建图形对象,然后使用 fig.add_trace
将 3D 表面图添加到图形中。
添加后,我们需要调整图形的布局,以便添加轴标签,定义宽度和高度,并设置一些内边距。
xgrid_extent = [433986, 438607]
ygrid_extent = [6477539, 6479393]
x = np.linspace(xgrid_extent[0], xgrid_extent[1], zstar.shape[1])
y = np.linspace(ygrid_extent[0], ygrid_extent[1], zstar.shape[0])
fig = go.Figure()
fig.add_trace(go.Surface(z=zstar, x=x, y=y))
fig.update_layout(scene = dict(
xaxis_title='X Location',
yaxis_title='Y Location',
zaxis_title='Depth'),
width=1000,
height=800,
margin=dict(r=20, l=10, b=10, t=10))
fig.show()
当我们运行上述代码时,我们会得到一个交互式图,显示基于多次遇到的钻井井眼的 Hugin 组的地质表面。
使用 Plotly 生成的 Hugin 组的 3D 表面图。图片由作者提供。
使用 Plotly 进行完整解释的表面视图
Volve 数据集中有许多完全解释的表面,这些表面是根据地质解释生成的,包括地震数据。
这些数据包含场地中数据点的 x
和 y
位置,以及我们的 TVDSS 数据(z
)。
提供在 Volve 数据门户上的数据是 .dat 文件的形式,不过,通过在文本编辑器中稍作处理,可以轻松地将其转换为 CSV 文件并使用 pandas 加载。
hugin_formation_surface = pd.read_csv('Data/Volve/Hugin_Fm_Top+ST10010ZC11_Near_190314_adj2_2760_EasyDC+STAT+DEPTH.csv')
Hugin 组的 X、Y 和 Z 位置。图片由作者提供。
数据加载完成后,我们可以通过将 x、y 和 z 数据提取到变量中来简化操作。
x = hugin_formation_surface['x']
y = hugin_formation_surface['y']
z = hugin_formation_surface['z']
然后,我们需要在 x 和 y 数据位置之间创建一个规则间隔的网格。这可以使用 numpy 的 meshgrid 完成。
xi = np.linspace(x.min(), x.max(), 100)
yi = np.linspace(y.min(), y.max(), 100)
xi, yi = np.meshgrid(xi, yi)
有几种方法可以在一系列数据点之间进行插值。选择的方法将取决于数据的形式(规则采样数据点与不规则采样数据点)、数据大小和计算能力。
如果我们有像这里这样的大型数据集,使用一些方法如径向基函数的计算成本会更高。
在这个示例中,我们将使用 scipy 中的 LinearNDInterpolator 方法来构建我们的模型,然后将其应用于我们的 z(TVDSS)变量。
为了在点之间进行插值,我们需要将xi
、yi
重塑为 1D 数组,因为LinearNDInterpolator
期望 1D 数组。
xir = xi.ravel()
yir = yi.ravel()
interp = LinearNDInterpolator((x, y), z)
zi = interp(xir, yir)
一旦计算完成,我们可以继续使用Plotly图形对象创建我们的 3D 表面。
fig = go.Figure()
fig.add_trace(go.Surface(z=zi, x=xi, y=yi, colorscale='Viridis'))
fig.update_layout(scene = dict(
xaxis_title='Easting (m)',
yaxis_title='Northing (m)',
zaxis_title='Depth',
zaxis=dict(autorange='reversed')),
width=1000,
height=800,
margin=dict(r=20, l=10, b=10, t=10))
fig.update_traces(contours_z=dict(show=True,
usecolormap=True,
project_z=True,
highlightcolor="white"))
fig.show()
当我们运行上述代码时,会得到 Hugin 地层的以下 3D 表面图。
Hugin 地层的完全解释地质表面。图像由作者提供。
当我们将这个图与从井眼测量生成的图进行比较时,我们可以明显看到整体形状上的相似之处,尤其是中间的山谷。然而,地震导出的表面提供了比井导出的地层顶部更多的细节。
Hugin 地层表面的比较:左侧是从稀疏的井测量生成的表面,右侧是从地震数据生成的表面。图像由作者提供。
摘要
在这个简短的教程中,我们展示了如何使用Plotly的 3D 表面图生成地质表面的互动 3D 可视化。通过井测量得出的地层顶部,我们可以生成非常基础的 3D 表面。这是因为测量仅限于交叉了 Hugin 地层的井眼,这意味着我们得到的是低分辨率的表面。
相比之下,如果我们有更详细的测量点,比如来自地震导出的地层面,我们可以生成对地层的更真实的印象。
两种方法都是有效的,但请注意,当从仅通过井测量得到的地层顶部进行外推时,我们可能无法生成该区域地层的完整图像。
使用的数据集
本教程中使用的数据是 Equinor 在 2018 年发布的 Volve 数据集的一个子集。数据集的完整详细信息,包括许可证,可以在以下链接找到。
Equinor 已发布 Volve 油田 2008-2016 年的完整数据集。点击此处下载以供学习、研究等使用…
Volve 数据许可证基于 CC BY 4.0 许可证。许可证协议的完整详细信息可以在此处找到:
感谢阅读。在你离开之前,你一定要订阅我的内容,并将我的文章发送到你的收件箱。 你可以在这里做到这一点!
其次,你可以通过注册会员获得完整的 Medium 体验,并支持成千上万的其他作者和我。这只需每月$5,你就可以完全访问所有精彩的 Medium 文章,并有机会通过你的写作赚取收入。
如果你通过 我的链接注册, 你将直接用你的费用的一部分支持我,而不会额外增加你的费用。如果你这样做,非常感谢你的支持。
使用 Plotly Express 旭日图探索地质数据
原文:
towardsdatascience.com/using-plotly-express-sunburst-charts-to-explore-geological-data-841f179d08c8
用 Python 轻松快速地理解你的地质层级数据
·发表于数据科学前沿 ·阅读时间 5 分钟·2023 年 7 月 27 日
--
在互动式 Plotly Express 旭日图上展示的地质层级。图像由作者提供。
数据可视化在地球科学和数据科学领域发挥着至关重要的作用。它可以让我们深入了解地下情况,理解地质结构和层级关系。地下通常被细分为不同的类别,从最广泛的地质时间范围,如时代、纪、世,到岩性差异,如砂岩、石灰岩和页岩。
在处理地质层级数据时,数据可以以多种方式进行可视化。这包括传统的地质时间尺度图表和表格以及互动式旭日图。
旭日图(Sunburst charts)可以用独特的方式展示数据,是可视化地质层级数据等层级数据的好方法。它们通过使用多层同心圆环图来实现,这些图表可以根据使用的工具进行完全交互,帮助从最高层级钻取到最低层级。
为了演示这些图表,我们将使用Plotly Express,一个高级数据可视化 Python 库,从挪威大陆架上的一个井中获取一些数据,并可视化地质层级以及每个地层的岩性组成。我们还将看到如何在创建图表之前准备井中的数据。
导入库并加载数据
首先,我们需要两个库:pandas 用于加载和处理数据,plotly_express 用于创建可视化。
import pandas as pd
import plotly_express as px
接下来,我们将从 CSV 文件加载数据。有关使用的数据的详细信息可以在文章底部找到。
如果你有 LAS 文件,可以快速使用加载 LAS 文件功能,通过LASIO库来加载文件,然后将数据转换为pandas数据框。
df = pd.read_csv('Data/Xeek_Well_15-9-15.csv')
清理和准备数据
现在我们已经加载了数据,我们需要做一点清理。
首先,我们将删除 Formation 列中任何缺失信息的行,以简化问题。请注意,这样做可能会影响后续岩石类型和地层的计数,因此任何缺失数据应彻底检查和理解。
删除数据的替代方法是用占位符值(如“未知”)替换缺失数据。
data_cleaned = df.dropna(subset=['FORMATION'])
当我们查看数据框的前五行时,我们得到如下内容。
删除包含缺失值的行后,数据的前 5 行。图片由作者提供。
下一步是将数据结构化,以便在 sunburst 图上正确绘制。
为此,我们需要进行以下操作
-
按以下顺序对数据进行分组:Group、Formation 和 Lith。
-
计算每个 Group、Formation 和 Lith 的唯一组合的出现次数
grouped_data = data_cleaned.groupby(['GROUP', 'FORMATION', 'LITH']).size().reset_index(name='count')
当我们查看新的grouped_data
数据框的头部时,会得到每个地质层的岩石类型概述,这些概述随后与 Group 相关联。
分组和计数每个地质层岩石类型后的前五行数据框。图片由作者提供。
使用 Plotly Express 创建 Sunburst 图
现在数据已经是正确格式,我们可以最终使用 Plotly Express 创建 Sunburst 图。
我们要做的就是调用px.sunburst
并传入一些参数。
首先,我们传入分组数据框,然后为path
参数指定我们希望显示的环的顺序。在我们的情况下,我们将按 Group 到 Formation 再到 Lithology 的顺序。
接下来,我们将values
参数设置为计数列。这将允许我们控制每个岩石类型的切片大小。岩石类型的数量越多,切片就会越宽。
然后,我们设置图形的width
和height
,以及title
。
最后,我们希望岩石类型的颜色是不同的;然而,我们也可以将其设置为数据框中的任何其他列或自定义颜色刻度。
fig = px.sunburst(grouped_data,
path=["GROUP", "FORMATION", "LITH"],
values='count',
title="Geological Distribution within well 15/9-15",
width=800,
height=800,
color='LITH')
fig.show()
当我们运行上述代码时,我们会得到以下图形,这为每个地质组和地层以及不同岩石类型的分布和贡献提供了出色的概述。
太阳辐射图显示了与地层相关的不同地质群、地层和岩性。图像来源:作者。
使用 Plotly 的一个优点是它默认是互动的。我们不需要像使用 matplotlib 时那样添加额外的 Python 代码来使图表具有互动性。
在下面的示例中,我们可以深入探讨我们的数据,了解特定地层的岩性组成。
动画太阳辐射图显示了与地层相关的不同地质群、地层和岩性。图像来源:作者。
摘要
使用 Plotly 创建的太阳辐射图是展示地质信息的绝佳方式。它们提供了出色的可视化效果,并允许你通过互动深入数据,以了解地质层次结构中每一层的组成。
通过一些预处理,容易将现有的地质信息从井中提取并转换为适合绘图的格式。此外,在处理互动图形时,你对数据的感知要比研究静态图形要好得多。而且,它们也更有趣。
本教程中使用的数据集
训练数据集用于 Xeek 和 FORCE 2020 举办的机器学习竞赛 (Bormann et al., 2020)。此数据集受 Creative Commons Attribution 4.0 International 许可证保护。
完整的数据集可以通过以下链接访问:doi.org/10.5281/zenodo.4351155
。
感谢阅读。在你离开之前,你一定要订阅我的内容,以便将我的文章发送到你的邮箱。 你可以在这里订阅!
其次,你可以通过注册会员获得完整的 Medium 体验,并支持成千上万的其他作家和我。只需每月$5,你就可以全面访问所有精彩的 Medium 文章,并有机会通过写作赚钱。
如果你通过 我的链接注册, 你将直接用你的一部分费用支持我,这不会增加你的费用。如果你这样做,非常感谢你的支持。
使用 Polars 插件通过 Rust 实现 14 倍速度提升
在本地 Polars 库之外实现高速度
·
关注 发表在 Towards Data Science ·8 min read·2023 年 11 月 9 日
--
由 DALL-E 3 生成
介绍
Polars 正因其速度、内存效率和美观的 API 而风靡全球。如果你想了解它有多强大,不妨看看 DuckDB Benchmarks。而且这些测试甚至还没有使用 Polars 的最新版本。
尽管 Polars 可以做很多惊人的事情,但它传统上并不比 Pandas 更适合做你可能想做的所有计算。有几个例外情况 Polars 并未超越。不过,随着最近 Rust 的 Polars 插件系统的发布,这种情况可能会发生变化。
Polars 插件
什么是 polars 插件?它只是创建自己 Polars 表达式的一种方法,使用原生 Rust 并通过自定义命名空间将其暴露给表达式。它允许你利用 Rust 的速度,将其应用到你的 Polars DataFrame,以便以利用 Polars 提供的速度和内置工具的方式执行计算。
让我们看一些具体的例子。
顺序计算
Polars 在某些功能上似乎有所欠缺,例如那些需要知道 DataFrame 之前值的操作。性质上是顺序的计算在原生 Polars 表达式中并不总是特别容易或高效。让我们看看一个具体的例子。
我们有以下算法来计算给定运行中的数字数组的累计值,该运行定义为具有相同符号的一组数字。例如:
┌───────┬───────────┐
│ value ┆ run_value │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═══════╪═══════════╡
│ 1 ┆ 1 │ # First run starts here
│ 2 ┆ 3 │
│ 3 ┆ 6 │
│ -1 ┆ -1 │ # Run resets here
│ -2 ┆ -3 │
│ 1 ┆ 1 │ # Run resets here
└───────┴───────────┘
我们希望计算一个列的累计和,每当值的符号从正数变为负数或从负数变为正数时,累计和会重置。
让我们从一个用 pandas 编写的基准版本开始。
def calculate_runs_pd(s: pd.Series) -> pd.Series:
out = []
is_positive = True
current_value = 0.0
for value in s:
if value > 0:
if is_positive:
current_value += value
else:
current_value = value
is_positive = True
else:
if is_positive:
current_value = value
is_positive = False
else:
current_value += value
out.append(current_value)
return pd.Series(out)
我们遍历一个系列,计算每个位置的当前运行值,并返回一个新的 Pandas Series。
基准测试
在继续之前,我们将设置几个基准。我们将使用 pytest-benchmark 和 pytest-memray 测量执行速度和内存消耗。我们将设置问题,使得我们有一个实体列、一个时间列和一个特征列。目标是计算数据中每个实体在时间上的运行值。我们将把实体和时间戳的数量都设为 1,000,使我们拥有一个包含 1,000,000 行的 DataFrame。
当我们使用 Pandas 的 groupby apply 功能对我们的基准进行测试时,我们得到以下结果:
Pandas Apply Pytest-Benchmark(作者提供的图片)
Pandas 应用的 Memray 输出(作者提供的图片)
Polars 原始实现
好的,现在我们有了基准测试。接下来,让我们看看如何在 Polars 中实现相同的功能。我们将从一个非常相似的版本开始,该版本将通过映射函数应用到 Polars GroupBy 对象上。
def calculate_runs_pl_apply(s: pl.Series) -> pl.DataFrame:
out = []
is_positive = True
current_value = 0.0
for value in s:
if value is None:
pass
elif value > 0:
if is_positive:
current_value += value
else:
current_value = value
is_positive = True
else:
if is_positive:
current_value = value
is_positive = False
else:
current_value += value
out.append(current_value)
return pl.DataFrame(pl.Series("run", out))
现在让我们看看这与我们原来的 Pandas 基准测试的比较结果。
Pandas Apply 与 Polars Apply Pytest-Benchmark(作者提供的图片)
Polars 应用的 Memray 输出(作者提供的图片)
好吧,这样做效果并不好。不过这不应该让人感到惊讶。Polars 的开发者已经明确表示,Pandas 中非常常见的 groupby apply 方法在 Polars 中并不是一种高效的计算方式。这一点在这里得到了体现。速度和内存消耗都比我们原来的 Pandas 实现要差。
Polars 表达式实现
现在让我们把这个函数写成原生的 Polars 表达式。这是与 Polars 一起工作时的首选和优化方式。算法会看起来有些不同。但这是我想出的计算相同输出的方案。
def calculate_runs_pl_native(df: pl.LazyFrame, col: str, by: str) -> pl.LazyFrame:
return (
df.with_columns((pl.col(col) > 0).alias("__is_positive"))
.with_columns(
(pl.col("__is_positive") != pl.col("__is_positive").shift(1))
.over(by)
.fill_null(False)
.alias("__change_sides")
)
.with_columns(pl.col("__change_sides").cumsum().over(by).alias("__run_groups"))
.with_columns(pl.col(col).cumsum().over(by, "__run_groups").alias("runs"))
.select(~cs.starts_with("__"))
)
对我们正在做的事情的简要说明:
-
查找所有特征值为正的行
-
查找所有
__is_positive
列与前一行不同的行。 -
对
__change_sides
进行累加,以标记每个不同的运行 -
对每个不同的运行中的值进行累加
现在我们有了原生的 Polars 函数。让我们再次进行基准测试。
Pandas Apply 与 Polars Apply 与 Polars Native Pytest-Benchmark(作者提供的图像)
Polars Native 的 Memray 输出(作者提供的图像)
不幸的是,我们没有看到函数执行速度的提升。这很可能是因为我们需要做的 over
语句数量来计算运行值。然而,我们确实看到了预期中的内存减少。可能还有更好的方法使用 Polars 表达式来实现这一点,但我现在不打算担心这个问题。
Polars 插件
现在让我们来看看新的 Polars 插件。如果你想要设置这些插件的教程,可以查看 这里的文档。 在这里,我主要展示插件的具体实现。首先,我们将用 Rust 编写我们的算法。
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
#[polars_expr(output_type=Float64)]
fn calculate_runs(inputs: &[Series]) -> PolarsResult<Series> {
let values = inputs[0].f64()?;
let mut run_values: Vec<f64> = Vec::with_capacity(values.len());
let mut current_run_value = 0.0;
let mut run_is_positive = true;
for value in values {
match value {
None => {
run_values.push(current_run_value);
}
Some(value) => {
if value > 0.0 {
if run_is_positive {
current_run_value += value;
} else {
current_run_value = value;
run_is_positive = true;
}
} else if run_is_positive {
current_run_value = value;
run_is_positive = false;
} else {
current_run_value += value;
}
run_values.push(current_run_value);
}
}
}
Ok(Series::from_vec("runs", run_values))
}
你会注意到这看起来与我们在 Python 中编写的算法非常相似。我们在这里没有做任何花哨的 Rust 魔法!我们使用 Polars 提供的宏来标注输出类型,仅此而已。然后我们可以将新的函数注册为表达式。
from polars import selectors as cs
from polars.utils.udfs import _get_shared_lib_location
lib = _get_shared_lib_location(__file__)
@pl.api.register_expr_namespace("runs")
class RunNamespace:
def __init__(self, expr: pl.Expr):
self._expr = expr
def calculate_runs(
self,
) -> pl.Expr:
return self._expr.register_plugin(
lib=lib,
symbol="calculate_runs",
is_elementwise=False,
cast_to_supertypes=True,
)
然后我们可以这样运行它:
from polars_extentsion import RunNamespace
df.select(
pl.col(feat_col).runs.calculate_runs().over(entity_col).alias("run_value")
).collect()
好了,现在让我们查看结果吧!
所有实现 Pytest-Benchmark(作者提供的图像)
Polars 插件的内存输出(作者提供的图像)
这才像样!我们获得了 14 倍的速度提升,内存分配从 ~57MiB 降低到 ~8MiB。
何时使用 Polars 插件
现在我已经展示了使用插件的强大功能,我们来谈谈不应使用插件的情况。我可能不使用插件的一些原因(每个都有其自身的注意事项):
-
如果您可以轻松使用原生 Polars 表达式编写一个非常快速的计算版本。 Polars 的开发者非常聪明。我不会赌自己写的函数比他们编写的快很多。Polars 的工具已经存在。利用它们擅长的部分吧!
-
如果您的计算没有自然的并行化。例如,如果我们没有在多个实体上运行上述问题,我们的加速可能会显著减少。我们既受益于 Rust 的速度,也受益于 Polars 自然能够同时对多个组应用 Rust 函数的能力。
-
如果您不需要顶级的速度或内存性能。 许多人会同意,编写 Rust 要比编写 Python 更困难且耗时。所以,如果您不介意您的函数运行时间从 200 毫秒变成 2 秒,您可能不需要使用插件。
牢记上述内容,这里有一些我感觉有时会让我倾向于使用插件的要求:
-
速度和内存非常重要。 我最近在一个 Polars 插件中重写了很多数据管道的功能,因为我们在 Polars 和其他工具之间来回切换,内存分配变得太大了。在我们想要的数据量下,难以在我们希望的基础设施上运行管道。插件使我们能够在更短的时间和更小的机器上运行相同的管道。
-
您有一个独特的使用案例。 Polars 提供了很多内置函数。但它是一个广泛适用的通用工具集。有时,这个工具集并不特别适用于您试图解决的问题。在这种情况下,插件可能正是您需要的。我遇到的两个最常见的例子是更复杂的数学计算,例如应用横截面线性回归,或顺序(基于行的)计算,就像我们在这里展示的那样。
新的插件系统是对 Polars 已经支持的所有列式计算的完美补充。通过这个补充,Polars 允许其功能有美妙的扩展性。除了编写自己的插件之外,还要留意一些酷炫的 Polars 插件包,它们可以扩展您的功能而无需自己编写插件!
Polars 发展迅速并引起了广泛关注。查看这个项目,开始使用它,留意他们将发布的其他精彩功能,也许在此期间还可以学习一些 Rust!
数据科学中的概率词使用
将模糊的反馈转化为机器学习的具体概率
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 2 月 7 日
--
图片由 Christina @ wocintechchat.com 提供,来自 Unsplash
在开始新的数据科学模型时,你必须评估提供给你的数据。通常,新的数据科学项目始于一个数据集,并与主题专家或其他联系人进行联系。主题专家为数据集的意义提供额外的背景。这包括数据集中的异常值或例外情况,以及主题专家认为的“正常”或“异常”情况,或“总是”或“从不”发生的情况。但如果“从不”并不意味着“从不”,“总是”也不意味着“总是”呢?这些词被称为“概率词”,包含其他地方无法找到的关键信息。本文讨论了如何使用这些概率词来了解你的数据并改进你的模型。
“从不”并不意味着“从不”
定义概率词
概率词是表达不确定性或概率的词语。它们包括“也许”,“可能”,“大概”,“很可能”,“不太可能”,“可能”,“不可能”等词语。这些词用来表示某人对事件的信念或信心的隐性分布。
每个人,无论是有意识还是无意识地,都在脑海中为这些概率词分配一个概率。然而,这些词的确切概率常常被个人对词语的解释所掩盖,从而与数据的关联变差。例如,“通常”是指 40%还是 80%的时间?这一定义在不同人和情况之间可能有所不同。一旦提取出这些概率的解释,就可以将其融入建模开发过程中。
人们心中的概率范围可以发生剧烈变化。
概率词与以往的工作
两项涉及概率词的显著研究包括 1993 年美国中央情报局(CIA)谢尔曼·肯特的《估计概率的词汇》(Words of Estimative Probability (cia.gov))。第二篇文章是由安德烈·莫博辛和迈克尔·J·莫博辛于 2018 年在《哈佛商业评论》中撰写的《如果你说某事“很可能”,人们认为它有多可能?》
肯特撰写的研究旨在解决在情报行业中,人们常用不具体的陈述来描述事件发生的可能性。在文章中,他利用了一组报告样本来建立单词与概率之间的映射,将数字赋予人们回应的不确定性。原始表格的输出范围从 0%的不可能到 100%的确定性,中间有一个“可能性的一般区域”。“可能性的一般区域”包含 7 种概率短语。这些词按确定性递增的顺序为:“不可能”,“几乎不可能”,“可能性不大”,“变化不定”,“可能”,“几乎确定”和“确定”。
后来,安德烈和迈克尔·莫博辛进行了跟进研究,进行了更新的调查,包含了更多不同的词汇。他们的目标是增加研究参与者的数量,并扩展到情报界之外。他们在网络上调查用户,将词汇与其解释的概率相关联。作者还试图识别其他背景方面的差异,例如性别和将英语作为第二语言的人。研究的一个教训是,人们应该使用概率来解释数据,而不是使用概率词,以避免在分享数据见解时的误解。此外,人们应使用明确的方法来收集概率。
本节仅为文章的简要概述,我强烈推荐完整阅读这些文章。但问题仍然存在——这对数据科学意味着什么?
“通常”到“罕见”的范围因人而异。
在数据科学中使用概率词汇
将模糊的词汇含义与具体示例联系起来,是扩展数据集知识和增加额外知识的极好方法。这些额外的信息可以增加模型可用的信息,提高模型性能。
了解数据背景
除了传统的数据探索技术,还有许多不同的方法可以深入了解数据集。这可以通过各种方法完成,通常涉及与受访者的讨论。这个人可以是主题专家、内容审核团队、会计师、用户,或与数据集、行业或问题相关的人。
与受访者交谈时,准备识别他们的概率词汇。首先询问数据的一般统计行为,如相关性,并将每个相关性视为假设。这个假设由受访者根据他们的经验来证明或反驳。目标是倾听人们在表达行动意见时所使用的限定词。你可以利用这些概率词汇来识别他们经验中的正常、异常和异常值。这也可以用于双重检查你的数据集是否与任务一致。例如,数据集中或相关人员中是否存在你不知道的偏差?受访者认为异常的事情在数据中是否经常发生?
在识别正常数据、异常数据和离群值时,脑海中或画出分布图会有所帮助。通过寻找概率词汇,我们试图识别受访者的事件样本在分布中的位置。然而,必须确保我们从受访者的角度收集数据。在数据中看似正常的事件在业务流程中可能非常意外或不寻常——这种知识在建模时是金子般的宝贵。
关于利率变化的隐式分布示例。
例如,设想你是一个数据科学家,任务是开发一个模型来预测美联储是否会提高利率。在学习利率、美联储的动作以及市场反应时,获取专家的观点至关重要,这能让我们了解他们认为哪些因素会影响这些决定。比如,我们询问一位基金经理他们认为利率会如何变化,他们说:“可能会以较慢的速度上涨。”在这种情况下,重要的是要求交易员用概率术语(比例、百分比等)来量化“可能”对他们来说意味着什么。通过要求交易员用百分比术语表达他们的理解,我们可以开始建立对以下内容的理解:
-
“可能”对他们来说意味着什么,以及实现这一点需要什么经济背景
-
什么会导致几乎肯定的正面或负面利率变动
-
他们认为美联储在这种情况下通常会做什么
-
一个不寻常的 负面利率变动会是什么样子
-
一个不寻常的 正面利率变动会是什么样子
通过要求交易员详细描述每个回应的背景来跟进。澄清交易员提到的每个词的概率可能性,并寻找口头解释的统计见解。这些见解包括多重共线性、次要效应以及其他影响模型性能但不在训练数据集中的来源。为了减少结果的偏差,尽量采访几个人。
有了这些信息,你对可能导致利率变化的原因会有更多了解,并能构建出交易员的信念。数据甚至可以用来建立一个网络抓取的情感模型,将外部情感转化为业务情感。
增强你的数据集
使用概率词汇,可以增强数据集以包含见解。例如,你可以添加一个分类列,标记“异常”情况。你可以使用这些数据预测更大数据集上的反馈,将其作为机器学习模型的原始输入,并量化这些“人类”信息对数据的价值。
一个包含反馈信息的示例数据集。
使用数据来预测更大数据集上的反馈被称为“弱学习”。在这种情况下,构建了一个模型,该模型利用反馈样本来预测其余数据集上的反馈内容。这意味着反馈样本可以扩展到覆盖整个数据集。然后,这些扩展的反馈可以作为另一个模型的输入或用于探索性数据分析。这种方法的好处是可以将数据样本扩展到覆盖大量数据。然而,这也带来了准确性的成本。由于模型是在小样本上训练的,模型可能会有更高的偏差或表现不完全符合受访者的实际情况。
如果你需要即时反馈以作为模型的一部分进行预测,那么可以使用“弱学习”来构建一个“即时反馈”系统。随着模型进行在线预测,“弱学习”模型接收原始数据,预测反馈内容,然后将原始数据和预测的反馈传递给主在线模型。这允许你构建一个功能齐全的模型,而无需人类时刻参与。
收集这些数据的另一个好处是,它可以用来量化采访提供的信息比原始数据更多。可以通过构建有反馈和没有反馈的模型来评估这一点。在训练完两个模型后,比较模型分数之间的差异,这将给出反馈的相对价值。如果你的模型在使用反馈数据后比仅使用原始数据训练的模型提高了 15%,那么这证明采访提高了模型的性能。如果这 15%的提升可以与业务影响挂钩,这可以帮助证明采访成本并给出反馈的美元价值。例如,如果一个预测模型的性能提高了 15%,而这转化为 20 万美元的价值,那么反馈的价值就是 20 万美元。
构建你自己的概率调查
现在概率词汇已经介绍完毕,这些知识可以用于创建你的概率调查。
要开始,可以从之前提到的原始研究中的常见概率词汇中获得灵感。可以添加你自己和你组织中常用的概率词汇。花一周时间识别你会议中使用的常见词汇并保持一个常见项目的清单也可能会有帮助。我在职业背景中使用的一个例子是“不确定性”。如果我正在为我的组织编写一个概率调查,我会想在词汇列表中包括像“极大不确定性”和“几乎确定”这样的条目以获取反馈。记住,你总是可以稍后添加词汇并收集更多反馈,所以不需要追求完美。
一旦收集了词汇列表,需要创建一个架构来从人们或其他数据源收集词汇。如果你已经有数据源,可以使用你喜欢的方法将数据输入到你的流程中。如果你正在收集来自组织内部人员的反馈,建立一个简单的调查架构以便收集信息是有益的。这可以包括像 Google Forms、Microsoft Forms 和 Streamlit 这样的工具。我的常用选择是 Streamlit,因为它设置迅速,使用 Python 构建,并且可以在本地 PC 上快速运行,或在其网站上运行。
你也可以收集每个人提供反馈时的基本元数据。像公司级别、部门和工作年限这样的元数据可能对细分不同部门如何使用概率词汇很有帮助。一旦收集了大量反馈,你可以收集数据并分析每个回应的分布。从这些分布中,你将能够回答如下问题:
-
哪些词具有相同的概率含义?
-
定义相似的词是否具有不同的概率?
-
例如,“确定性”和“绝对确定”可能分别与 70% 和 95% 的平均概率相关。
-
一个词的概率分布是什么?
-
这些词的分布特征是什么?(均值、中位数、众数、标准差等)
在这项分析之后,你可以利用这些信息帮助将面试、反馈会话和探索性数据分析中收集的非确定性词汇连接起来。这些数据可以用于数据管道中,以分类陈述中的概率、识别独特情况并改进模型结果。
除非另有说明,否则所有图片均由作者提供。
使用倾向评分匹配来构建领先指标
关于构建产品激活指标的简短指南
·
关注 发布于 Towards Data Science · 6 分钟阅读 · 2023 年 3 月 4 日
--
在上一篇文章中,我讨论了“输入 > 输出 > 结果”框架,以及“输出”是核心部分,但并不一定容易定义——仅仅因为你希望它受到输入的影响,但同时,你需要与结果有因果联系。
用户激活指标属于这种类型的指标。“激活”是由 Dave McClure(著名的 AAARRR 框架——意识、获取、激活、保留、推荐、收入)设计的海盗指标框架的第三个阶段,它通常定义为用户克服了第一组障碍,开始使用你的产品,从中获得了一些价值,并且现在更有可能在长期内被保留。
Some examples of product activation metric:
Loom: Sharing a loom¹
Zappier: Setting a zap¹
Zoom: Completing a zoom meeting within 7d of signup¹
Slack: Sending 2,000+ team messages in the first 30 days²
Dropbox: Uploading 1 file in 1 folder on 1 device within 1 hour²
HubSpot: Using 5 features within 60 days²
¹2022 product benchmark from Open View
https://openviewpartners.com/2022-product-benchmarks/
²Stage 2 Capital: the science of scaling:
https://www.stage2.capital/science-of-scaling
测量激活是重要的,因为它帮助你了解你的产品对新用户的吸引力如何,以及你是否有效地让他们成为“活跃”用户。这是通向用户忠诚的第一步——在这个阶段你可以知道用户是否可能长期留在你的产品中。如果激活率较低,可能表示产品或入驻流程存在问题,可能需要做出调整以改善用户体验并提高激活率。
定义产品激活为何困难
-
你希望激活能成为保留的良好预测指标,但同时,你希望它足够简单,因为这应该是用户遵循的简单第一步。
-
基本上,你要寻找用户能够采取的最小动作,这个动作可以展示产品的价值,但你希望这个小动作与保留有因果关系(无论你如何定义它)。
-
与任何“领先”指标一样,因果关系部分(“做某个动作导致长期保留”)是困难的。你通常从观察数据开始,传统的数据分析可能无法提供完整的图景,因为它可能忽略了影响激活/保留的混杂因素。
使用队列分析和倾向得分匹配来调查因果关系
通过队列分析,你可以开始建立对哪些用户行为可能是激活指标的良好候选的直觉。
其目的是:
-
根据用户注册你的产品的位置对他们进行分组。
-
根据用户是否达到了保留阶段将他们分开。
-
寻找那些在达成保留阶段的用户中做得非常多的动作,而在未达成的用户中做得较少的动作。
假设你运营一个健身应用。你开始创建每月的队列,并且注意到,70%在注册后的第一周内上传至少一次锻炼的用户在一年后仍然活跃于应用中,而如果他们没有上传,则只有 40%。这可以成为激活指标的初步想法。
这里的前提条件是你要弄清楚要研究哪个行动。在上述示例中,你必须想到查看谁跟踪了他们的锻炼。这是定量与定性相结合的地方,也是你的‘用户洞察’/常识发挥作用的地方。或者,如果你想寻求其他领域专家的帮助,那就需要你的网络技能。
一些建议:
-
你可能只想提出几个潜在的行动想法,不必过多地研究它们,因为俗话说:“如果你折磨数据足够久,它会承认任何事”(罗纳德·H·科斯)。你选择的行动越多,你发现一些东西的可能性就越大,但你也面临较高的假阳性风险。因此,坚持有意义且不过于牵强的东西可以作为一个好的经验法则。
-
你可能想要采用一种原则性的方法,只寻找你认为能够移动的事物。如果你提出的方案过于复杂或小众,你可能无法实施,这将使整个过程失去意义。
倾向得分匹配可以验证或否定你之前的直觉
一旦你确定了潜在的激活信号,下一步是确保它们的准确性。这时倾向得分匹配就能派上用场——以了解你之前找到的相关性是否实际上是因果关系。虽然这不是唯一的解决方案,并且确实需要对你的用户有一些了解(这可能并非总是可能的),但它相对容易实施,并且能让你对结果更有信心(直到进一步通过更稳健的方法如 A/B 测试进行三角验证)。
倾向得分匹配的理念如下:
-
为了找到采取行动与留存之间的因果联系,理想情况下,你可以复制采取行动的用户,并让克隆体不采取行动——以比较结果。
-
由于目前还无法做到(或许?),下一步最佳的方法是查看你的数据,找到与你采取行动的用户非常相似(几乎相同)的用户——但他们没有采取行动。
倾向得分匹配是一种方法学,可以帮助你找到这些非常相似的用户并进行配对。具体来说,它包括:
-
训练一个模型来预测你的用户采取你定义的行动的可能性(他们的倾向)。
-
基于之前发现的可能性来匹配用户(匹配部分)
(注意:你有多种方法可以完成这两个步骤,关于如何选择模型、如何选择合适的变量、选择什么匹配算法等方面,网上有一些很好的指南——有关更多信息,请参见 “倾向得分匹配的实施实践指南”)
再次以我们的健身应用程序为例:
-
你已经发现,在注册后的第一周上传至少一个锻炼计划的用户,仍然在一年后继续使用应用的比例为 70%,而如果他们没有上传,则为 40%。
-
你训练一个模型来预测用户在注册后一周内上传锻炼计划的可能性——你发现,通过大型健身网站的推荐链接下载应用的用户,其可能性非常高。
-
你根据可能性对用户进行排序,然后开始进行简单的 1:1 匹配(按可能性排序的第一个采取行动的用户与按可能性排序的第一个未采取行动的用户匹配,依此类推)。
-
匹配后,你会发现差异大大缩小,但这仍然对你考虑它作为潜在的激活指标很重要!
定义产品激活的万灵药?
阶段分析 + 倾向评分匹配可以帮助你隔离特定行动对用户行为的影响,这对于定义准确的激活指标至关重要。
但这种方法并不是万灵药——这套方法论有一堆假设,你需要对其进行微调/验证,以确保它适用于你的使用场景。
特别是,PSM 的有效性将高度依赖于你能预测自我选择的准确程度。如果你缺少关键特征,而未观察到的特征带来的偏差很大——那么 PSM 的估计可能会非常偏颇,并且不够有用。
综上所述——即使以不完美的方式使用这种方法,也能帮助你采用更加数据驱动的方法来选择指标,让你开始‘关注的内容’,直到你进入 A/B 测试阶段,并更好地理解推动长期成功的因素。
希望你喜欢阅读这篇文章!你有任何想分享的技巧吗?请在评论区告诉大家!
如果你想了解更多关于我的内容,以下是你可能喜欢的几篇文章:
增加你结果的信心,建立更强的个人品牌
towardsdatascience.com](/7-tips-to-avoid-public-embarrassment-as-a-data-analyst-caec8f701e42?source=post_page-----3e656dccbaf9--------------------------------) [## 如何构建成功的仪表盘]
来自某个曾经构建过几个不成功的案例的清单
towardsdatascience.com](/how-to-build-a-successful-dashboard-359c8cb0f610?source=post_page-----3e656dccbaf9--------------------------------) [## 如何选择数据项目进行工作]
如果你对如何使用时间有一个合理的方法,你可以优化你所创造的价值。
使用 Python 解决工程中最常见的问题之一
创建一个用于工作点分析的通用框架
·发布于 Towards Data Science ·17 分钟阅读·2023 年 1 月 2 日
--
在工程中,某些问题经常出现。本文的重点是讨论一种在我日常工作中经常遇到的问题,我决定分享如何使用 Python 来解决它。我们讨论的是什么类型的问题?就是解决系统工作点的问题!在深入复杂的代码之前,让我们用一个简单的例子来说明我的意思。
我们希望解决下图所示简单电路的工作点。这可以通过重新排列欧姆定律 (V=IR) 来实现,从而根据已知的输入电压和电阻隔离电流。
图片由作者提供
简单,对吧?不幸的是,大多数现实世界的问题都没有这么简单。例如,如果我告诉你,电阻器加热时,其电阻值发生变化,这实际上使电阻成为电流的函数。我们会得到如下形式的方程:
如果不知道实际的电阻函数形式,我们不能仅仅通过代数方法来求解电流。此外,如果方程很复杂,无法单独隔离电流怎么办?或者,假如电阻以电流的形式给出,并且是离散的表格数据 — 那么我们甚至没有代数表达式来操作以尝试求解电流。那我们应该如何确定电路中的电流呢?我们需要一种更通用的方法来解决这个问题。
解决这类问题的一般方法是将其表示为根查找问题。这实际上非常简单——我们只需将方程的右侧从左侧中减去,使得方程等于零。这样得到如下:
通过这样做,我们重新提出了问题。我们不再直接求解电流值,而是尝试找到一个电流值,将其输入方程的左侧,使其结果为零。我们为什么这样制定问题?因为存在大量的数值算法(如二分法、牛顿法等)来解决这种精确的问题!而且大多数算法不在乎方程的左侧有多复杂——它甚至不需要有封闭的代数形式(即,它可以由插值的离散数据、数值评估的积分或任意复杂度的函数组成)。只要我们能将问题表示为 f(x)=0,我们(几乎)总能找到问题的解决方案(如果问题陈述改变,代码也可以轻松修改/扩展——不需要重新做代数)。
本文接下来将通过一个例子来讲解如何将根查找方法应用于一个稍微复杂一点的实际问题,重点是 Python 中的代码结构和组织技巧。尽管这个问题(确定管道/泵系统中的水流量)有些领域特定,但所使用的方法和编码技术完全通用,适用于所有工程领域。考虑到这一点,我会尽量保持问题的物理建模方面处于高层次,使得不论技术背景如何,文章的主要学习目标仍然清晰明了。
作为旁注,我目前的领域“专长”在电机控制和功率电子学领域,与泵/管道应用相去甚远。我已经有好几年没有涉及这一主题,但认为它会是一个有趣的例子。我确信有许多人比我更有资格讨论泵/管道建模的具体细节,但我写这篇文章的意图在于方法论——而不是如何解决泵/管道问题。无论如何,我欢迎那些对该领域更了解的人的评论或改进建议!
问题
我们想要将水从一个水箱转移到另一个水箱。我们已经有一个泵和一些管道可以用来连接两个水箱,想要估算转移所有水所需的时间。每个水箱的体积是已知的,因此,如果我们可以估算水在两个水箱之间的流量,我们就可以估算转移过程需要多长时间。完整的设备如下所示。
图片由作者提供
这个特定问题(可以归类为“内部流动”问题)在机械工程领域中非常清楚。对于那些不太熟悉的人,或者需要快速复习的人,我们通常通过伯努利方程来解决这些问题(如下所示)。
伯努利方程本质上是一个能量守恒的陈述,它告诉我们流体粒子在沿流线(即如果流体中丢下一个假想粒子其会沿着的流动路径)移动时,能量如何在不同的能量机制之间转换。方程的左边表示流体粒子在任意第一个位置(位置 1)处的总单位重能量,是重力势能项、动能项和压力项的总和。当流体在系统中流动时,能量必须守恒,因此在流线上的任意第二个点(位置 2)处的总能量(由方程右边表示)必须等于位置 1 的总能量。
上述形式的伯努利方程被称为“扬程”形式,因为每个项的单位是长度/高度。这对我们的直觉很方便,因为我们基本上是在将每个项的能量等同于具有给定扬程高度的流体柱的重力势能。然而,伯努利方程的一个主要限制是它假设系统中没有损失(这并不是一个很好的假设)。为克服这一限制,我们可以在方程中补充两个附加项,如下所示:
Hp(Q)和 Hl(Q)项分别表示泵对系统增加的扬程和由于实际世界效应(如摩擦、粘度等)在系统中损失的扬程。注意这两个项都是系统流体流量 Q 的函数。(作为上述描述扬程解释的一个有趣结果,泵的扬程告诉你泵理论上能将流体推高多少)。我们将稍后更详细地检查泵和损失项,但在此之前,让我们简化上述方程以解决我们的具体问题。
再次查看上面的系统,我们将方便地选择两个位置用于伯努利方程,使得大多数项相互抵消。我们可以通过选择位置 1 和 2 分别位于每个水箱的自由水面来做到这一点,此时压力是常数且等于大气压力(P1=P2),速度大约是常数且为零(V1=V2=0)。我们还将假设在我们分析系统的瞬间两个水箱中的水高度相同,即 Z1=Z2。经过代数简化后,我们发现几乎所有项都抵消了,剩下的就是泵产生的扬程必须等于由于非理想性导致的系统损失的扬程。换句话说,泵在弥补系统中的任何能量损失。
在下图中可以定性地看到这种情况。泵产生的扬程随着流量的增加而减少,而管道系统中的损失随着流量的增加而增加。两条曲线交点(泵扬程 = 损失扬程)决定了系统的工作点(流量)。
作者提供的图片
在我们可以跳入代码之前的最后一步是将问题表示为根寻找问题。将方程右侧减去左侧,我们得到我们正在寻找的根求解问题。也就是说,我们将问题表述为:找到流量(Q),使得下面方程的左侧等于零。此时,泵扬程将等于系统的扬程损失。
代码
为了避免丢失我们正在做的大局,我不会解释代码的每一个细节(假设你已经具备合理的 Python 背景)。相反,我会集中精力确保叙述和代码结构清晰,并在需要时提供更多细节。与往常一样,如果有任何不清楚的地方,请随时提问。
设置
我们将从导入所有必要的模块开始。稍后将明显每个模块的使用方式,但值得注意的是,关键的导入语句来自 scipy。这些是针对当前问题的特定函数。代码块还设置了一些默认绘图设置(根据个人口味),创建了一个文件夹来保存生成的图形,并定义了一些单位转换常数,使得后续代码中的操作更加简便。
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#these are the key libraries for solving the problem
from scipy.interpolate import interp1d
from scipy.optimize import root_scalar
#set plotting defaults
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 12
figsize = (6.4,4)
#make folder to save plots to
plots_folder = Path('plots')
plots_folder.mkdir(exist_ok=True)
#define conversion constants for ease of use later
INCHES_TO_METERS = 25.4/1000
FEET_TO_METERS = 12*INCHES_TO_METERS
GALLONS_TO_M3 = 0.0037854118 #convert gallons to m³
接下来,我们将创建一个 Python dataclass
,它实际上充当了一个存储流体属性(密度、粘度和重力)的容器。默认情况下,这些属性设置为水的属性。请注意,虽然这不是完全必要的,但 Python 数据类非常方便。如果你对它们不熟悉,我强烈推荐你查看这个视频。
@dataclass
class Fluid():
#fluid defaults to water properties
rho: float = 997 #kg/m³
mu: float = 0.0007972 #N-s/m² = kg/m-s
g: float = 9.81 #m/s²
管道模型
下一步是建模管道的摩擦损失(上述扩展的 Bernoulli 方程中的 Hl(Q)项)。这通常使用如下的 Darcy-Weisbach 方程,其中 f 是摩擦因子(稍后会详细介绍),v 是流速,g 是重力,而 L 和 D 分别是管道的长度和直径。
不幸的是,摩擦因子(f)不是恒定的,而是依赖于流速、流体属性和管道尺寸。存在各种计算 f 的模型,但我们将使用下面的 Haaland 方程。
在这个方程中,epsilon 是管道的表面粗糙度(可以在工程教材的表格中找到),Re 是著名的雷诺数,如下所示。
最后,我们可以注意到每单位时间的扫过体积,或者称为体积流量(Q),等于管道的横截面积(A)乘以流速(v)。因此,给定管道中的流量,我们可以计算出管道中的相应流速:
希望所有这些方程没有让你忽视大局——我们只是查看了计算管道摩擦损失的一个特定模型。给定流量和管道尺寸,首先计算出相应的流速,然后通过上述方程计算管道的摩擦损失。这正是Pipe
类(如下所示)所实现的。
初始化方法存储了管道的尺寸(假设都以米为单位)和流体属性。A
方法计算管道的横截面积(对于不熟悉@property
装饰器的人,这篇文章解释得非常好)。Q_to_v
方法将每分钟加仑数(gpm)的流量转换为米/秒的流速。friction_factor
方法评估了上述提到的 Haaland 方程,而head_loss
和 head_loss_feet
分别评估了管道的摩擦损失(以米和英尺为单位)(使用 Darcy-Weisbach 方程)。
class Pipe():
def __init__(self, L, D, epsilon, fluid: Fluid):
#pipe dimensions are all assumed to be in meters
self.L = L
self.D = D
self.epsilon= epsilon
#fluid properties
self.fluid = fluid
@property
def A(self):
"""computes cross-sectional area of pipe in m²"""
return np.pi*(self.D/2)**2 #area in m²
def Q_to_v(self, gpm):
"""Converts gpm to fluid speed in pipe in m/s"""
Q = gpm*GALLONS_TO_M3/60 #flow rate in m³/s
return Q/self.A #flow velocity in m/s
def friction_factor(self, gpm):
"""computes Darcy friction factor, given flow rate in gpm
This method uses Haaland's equation, wich is an explicit approximation
of the well-known, but implicit Colebrook equation
"""
#first get flow velocity from flow rate and pipe dimensions
v = self.Q_to_v(gpm)
#compute Reynold's number
Re = self.fluid.rho*v*self.D/self.fluid.mu
#compute relative roughness
e_over_d = self.epsilon/self.D
#use Haaland's equation
f = (-1.8*np.log10((e_over_d/3.7)**1.11 + 6.9/Re))**-2
return f
def head_loss(self, gpm):
"""computes head loss in meters, given flow rate in gpm"""
#get flow velocity
v = self.Q_to_v(gpm)
#get Darcy friction factor
f = self.friction_factor(gpm)
#compute head loss in meters
hl = 0.5*f*(self.L/self.D)*v**2/self.fluid.g
return hl
def head_loss_feet(self, gpm):
"""computes head loss in feet, given flow rate in gpm"""
hl_meters = self.head_loss(gpm)
return hl_meters/FEET_TO_METERS
让我们看看管道类的实际应用。首先,我们可以创建一个水的Fluid
对象和一个长度为 100 英尺、直径为 1.25 英寸的Pipe
对象。
#create fluid object for water
water = Fluid()
#create pipe segment with water flowing in it
pipe = Pipe(L=100*FEET_TO_METERS,
D=1.25*INCHES_TO_METERS,
epsilon=0.00006*INCHES_TO_METERS,
fluid=water)
接下来,我们将绘制流量(以 gpm 为单位)与摩擦损失的关系曲线。当我们利用面向对象编程时,下面的代码变得如此简洁和易读,这难道不令人惊叹吗?
gpm_arr = np.linspace(1,30,100)
hl = [pipe.head_loss_feet(gpm) for gpm in gpm_arr]
fig, ax = plt.subplots(figsize=figsize)
ax.plot(gpm_arr, hl)
ax.set_xlabel('Flow Rate [gpm]')
ax.set_ylabel('Head Loss [ft]')
fig.tight_layout()
fig.savefig(plots_folder/'pipe_loss_curve.png')
图片来源:作者
泵模型
我们已经有了管道头损失的工作模型——现在我们需要一个泵产生的头部模型,Hp(Q)。我相信有解析模型可以用来确定泵的行为,但我们将假设我们已经有了一个具体的泵——即我在网上找到的随机的、分数马力的泵:
大多数泵都有一个数据表,其中包含描述泵行为的泵曲线。对于我们的泵,下面是泵曲线(请注意,出于版权原因,这只是制造商提供的图示的重建版——原版可以在这里找到)。
图片来源:作者
目前我们有了描绘泵行为的图像,但还没有一个可以实际用来确定它在系统中表现的数学模型。这个问题经常出现,我的解决方法是 1) 数字化数据,然后 2) 使用插值方案包装离散数据,以生成一个连续函数。让我来说明一下。
第一步)有许多工具可以将图像数据数字化——我个人最喜欢的是免费的在线工具WebPlotDigitizer。你可以加载感兴趣的图像,校准坐标轴,然后提取所需的数据曲线(可以手动提取,也可以使用自动提取工具)。数据可以导出为 .csv 文件。
图片来源:作者
第二步)现在我们已经得到了数字化的数据,我们只需要用某种插值器进行包装——这正是下面Pipe
类所做的。初始化方法接收 .csv 文件名,存储文件名,将数据加载到 pandas DataFrame 中,存储在 data
属性中,然后将数据传递给 scipy 的 interp1d
函数。interp1d
函数生成一个新的函数,默认使用线性插值将离散数据点转换为连续函数(interp1d
函数的完整文档可以在这里找到)。新生成的插值函数随后存储在 _interp
属性中以供后续访问。Pipe
类还包含一个 bounds
方法,该方法返回包含泵曲线数据中流量的最小/最大值的列表(这将在根查找算法中使用),以及一个 head_gain_feet
方法,该方法接受流量值(单位为 gpm),并调用由 interp1d
生成的底层插值函数。
class Pump():
def __init__(self, file):
#store file name
self.file = file
#read data into pandas dataframe and assign column names
self.data = pd.read_csv(file, names=['gpm', 'head [ft]']).set_index('gpm')
#create continuous interpolation function
self._interp = interp1d(self.data.index.to_numpy(), self.data['head [ft]'].to_numpy())
@property
def bounds(self):
"""returns min and max flow rates in pump curve data"""
return [self.data.index.min(), self.data.index.max()]
def head_gain_feet(self, gpm):
"""return head (in feet) produced by the pump at a given flow rate"""
return self._interp(gpm)
我们可以创建一个Pump
对象并查看我们读取的原始数据。
pump = Pump('pump_data.csv')
pump.data.head()
图片来源:作者
我们还可以将泵曲线数据与管道损失曲线一起绘制,以直观地查看系统将在哪里操作。
head_loss = [pipe.head_loss_feet(gpm) for gpm in pump.data.index]
fig, ax = plt.subplots(figsize=figsize)
ax.plot(pump.data, label='Pump Curve')
ax.plot(pump.data.index, head_loss, label='Pipe Head Loss')
ax.set_xlabel("Flow Rate [gpm]")
ax.set_ylabel("Head [ft]")
ax.legend(frameon=True, facecolor='w', framealpha=1, loc=6)
fig.tight_layout()
fig.savefig(plots_folder/'pump_curve_with_losses.png')
图片来源:作者
系统模型
我们终于建立了能够解决泵/管道系统操作点的基础设施。最后一步是创建一个System
类,该类接收一个Pipe
和Pump
对象,并执行根求解操作。正如下面的代码所示,System
类接收并存储一个Pipe
和Pump
对象。然后,它利用这两个对象创建一个residual
方法,该方法计算泵头和管道头损失之间的差值。这个residual
方法随后在get_operating_point
方法中被使用,以实际求解系统的操作点。该方法包装了 scipy 的root_scalar
函数,它作为各种根求解算法的接口。我们将让root_scalar
函数选择它认为最适合的算法,但为了帮助它,我们将指定一个我们知道根在其间的区间。在我们的案例中,这个区间是泵曲线数据的上限和下限流量。关于root_scalar
函数的完整文档可以在这里找到。
提示:将
Pipe
和Pump
对象注入System
类(与在实例化时让系统类创建Pipe
和Pump
对象相对)被称为“依赖注入”。这通常被认为是一种良好的编码实践,因为它使代码更具模块化、可扩展性,更易于调试/测试。
class System():
def __init__(self, pipe: Pipe, pump: Pump):
self.pipe = pipe
self.pump = pump
def residual(self, gpm):
"""
Computes the difference between the head produced by the pump
and the head loss in the pipe. At steady state, the pump head and
head loss will be equal and thus the residual function will go to zero
"""
return self.pump.head_gain_feet(gpm) - self.pipe.head_loss_feet(gpm)
def get_operating_point(self):
"""solve for the flow rate where the residual function equals zero.
i.e. the pump head equals the pipe head loss"""
return root_scalar(self.residual, bracket=self.pump.bounds)
让我们创建一个System
并运行get_operating_point
方法,以观察我们的劳动成果。正如代码输出所示,get_operating_point
方法仅返回root_scalar
函数的输出,这是一个RootResults
对象。这个对象本质上只是一个容器,存储各种属性,其中最重要的是root
属性,因为它包含我们问题的解决方案。
sys = System(pipe, pump)
res = sys.get_operating_point()
res
我们可以再次绘制相同的泵和头损失曲线,这次在计算出的稳态操作点处添加一条垂直线。
head_loss = [pipe.head_loss_feet(gpm) for gpm in pump.data.index]
fig, ax = plt.subplots(figsize=figsize)
ax.plot(pump.data, label='Pump Curve')
ax.plot(pump.data.index, head_loss, label='Pipe Head Loss')
#plot vertical line at operating point
ax.axvline(res.root, color='k', ls='--', lw=1)
ax.legend(frameon=True, facecolor='white', framealpha=1, loc=6)
ax.set_xlabel("Flow Rate [gpm]")
ax.set_ylabel("Head [ft]")
ax.set_title(f'Operating Point = {res.root:.1f} gpm')
fig.tight_layout()
fig.savefig(plots_folder/'intersection_solution.png')
图片来源:作者
完成了!我们已经程序化地确定了系统的操作点。因为我们使用的是一种比较通用的编码框架,我们可以轻松尝试使用不同的泵或管道进行相同的分析!我们甚至可以扩展我们的代码以包括多个泵,或各种管道配件/管道分支。
设计探索
作为一个小例子,突出我们设置代码方式的好处,我们将进行设计探索。使用相同的泵,我们希望了解管道长度对系统体积流量的影响。为此,我们只需遍历一个管道长度数组(从 100 到 1000 英尺),更新存储在System
中的Pipe
对象的长度属性,然后重新计算系统的工作点,将结果追加到列表中。最后,我们将水流量绘制为管道长度的函数。
#sweep pipe length from 100 to 1000 feet
lengths_feet = np.linspace(100, 1000, 1000)
lengths_meters = lengths_feet*FEET_TO_METERS
flow_rates = []
for l in lengths_meters:
#update pipe length
sys.pipe.L = l
#compute new flow rate solution
res = sys.get_operating_point()
#append solution to flow rates list
flow_rates.append(res.root)
#plot results
fig, ax = plt.subplots(figsize=figsize)
ax.plot(lengths_feet, flow_rates)
ax.set_xlabel("Pipe Length [ft]")
ax.set_ylabel("Flow Rate [gpm]")
# ax.set_ylim(bottom=0)
fig.tight_layout()
fig.savefig(plots_folder/'flow_vs_pipe_length.png')
图片由作者提供
几行代码让我们能够深入了解系统的行为。如果这是一个设计问题,这些见解可能会驱动关键的设计决策。
结论
本文虽然主要集中在一个特定领域的示例问题上,但突出了我经常使用的工作流程的几个方面。操作点分析的问题在工程和科学中经常出现,尽管有许多方法可以解决这个问题,但有些方法比其他方法更强大、可扩展和灵活。本文中使用的方法论(问题表述和代码结构原则)对我帮助极大,希望其他人也能受到启发,采用类似的工作流程!
随时留下任何评论或问题,或者在 LinkedIn 上与我联系——我非常乐意澄清任何不确定的点。最后,我鼓励你自己尝试代码(甚至将其用作自己工作流程的起始模板)——这篇文章的 Jupyter Notebook 可以在我的Github上找到。
尼古拉斯·赫门威
-
如果你喜欢这个,请在 Medium 上关注我
-
考虑订阅 电子邮件更新
-
有兴趣合作吗?让我们 在 LinkedIn 上联系
使用量子退火进行 scikit-learn 特征选择
对于具有大量特征的数据集,使用量子处理进行 scikit-learn 模型的特征选择
·
关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 4 月 10 日
--
特征选择是机器学习中的一个广泛话题。正确实施时,它可以帮助减少过拟合、提高可解释性、降低计算负担等。用于特征选择的技术有很多种,它们的共同点在于它们会查看特征集,并尝试将那些能够带来良好结果(准确的预测、可解释的模型等)的特征与那些不利于此目标的特征分开。
特别困难的是特征数量非常大的情况。对所有特征组合进行穷举探索通常计算开销很大。像 R 中的regsubsets()
函数那样的逐步搜索算法可能会通过添加/删除特征并比较结果来尝试推断出有前途的特征组合。但最终,当特征数量很大时,搜索成功的程度与寻找最佳特征组合所花费的努力之间仍然存在权衡。
一般来说,当问题的最佳解涉及到大量组合的搜索时,量子退火可能值得研究。我将展示一个使用 D-Wave 最近发布的 scikit-learn 插件进行特征选择的例子,该数据集具有数百个特征。
D-Wave 和 scikit-learn
请记住,这并不是通用的门模型量子计算。这是一种算法,本质上类似于模拟退火,其中有一个目标函数,并且使用类似于模拟退火的方法来找到一个最小化目标的值组合。不同的是,这里的退火并非模拟的,而是通过编程使实际系统的物理能量与目标函数相匹配。系统的能量会降低直到达到最低点(退火),然后解就是系统的状态,这个状态会被读取并返回给用户。
D-Wave 构建了量子退火器,可以有效地解决许多组合复杂度很高的问题。只要你能将问题简化为二次二进制模型(BQM),或者带约束的 BQM(CQM),或者上述模型的某些离散推广(DQM),问题就可以提交给量子解算器。更多细节请参见文档。
理论上,你可以将特征选择算法表述为 BQM,其中特征的存在是值为 1 的二进制变量,而特征的缺失是值为 0 的变量,但这需要一些工作。D-Wave 提供了一个可以直接插入到 scikit-learn 管道中的scikit-learn 插件,简化了这一过程。
本文将首先展示显式地表述 BQM 的整个过程,然后是 CQM,将模型发送到量子解算器,然后解析结果。这是解决量子退火器上优化过程的一般方法,并且可以用于许多问题。
但如果你只是想选择数据集中最好的特征,一个简单的SelectFromQuadraticModel()
方法调用就足够了。这将整个算法压缩成一行代码。我们将在最后展示这一点。
问题与方法
考虑从 OpenML 下载的场景识别数据集,最初由Boutell, M., Luo, J., Shen, X., Brown, C. (2004)创建,并由其作者在 CC BY 许可证下分发。这是一个二分类数据集,具有一个依赖变量(Urban),并有近 300 个特征。响应变量的值(二元:城市或非城市)需要基于特征的组合进行预测。一个简单的分类模型如RandomForestClassifier()
应该表现良好,前提是特征选择得当。
>>> dataset.get_data()[0]
attr1 attr2 attr3 attr4 attr5 attr6 attr7 attr8 ... attr293 attr294 Beach Sunset FallFoliage Field Mountain Urban
0 0.646467 0.666435 0.685047 0.699053 0.652746 0.407864 0.150309 0.535193 ... 0.014025 0.029709 1 0 0 0 1 0
1 0.770156 0.767255 0.761053 0.745630 0.742231 0.688086 0.708416 0.757351 ... 0.082672 0.036320 1 0 0 0 0 1
2 0.793984 0.772096 0.761820 0.762213 0.740569 0.734361 0.722677 0.849128 ... 0.112506 0.083924 1 0 0 0 0 0
3 0.938563 0.949260 0.955621 0.966743 0.968649 0.869619 0.696925 0.953460 ... 0.049780 0.090959 1 0 0 0 0 0
4 0.512130 0.524684 0.520020 0.504467 0.471209 0.417654 0.364292 0.562266 ... 0.164270 0.184290 1 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2402 0.875782 0.901653 0.926227 0.721366 0.795826 0.867642 0.794125 0.899067 ... 0.254413 0.134350 0 0 0 0 0 1
2403 0.657706 0.669877 0.692338 0.713920 0.727374 0.750354 0.684372 0.718770 ... 0.048747 0.041638 0 0 0 0 0 1
2404 0.952281 0.944987 0.905556 0.836604 0.875916 0.957034 0.953938 0.967956 ... 0.017547 0.019734 0 0 0 0 0 1
2405 0.883990 0.899004 0.901019 0.904298 0.846402 0.858145 0.851362 0.852472 ... 0.226332 0.223070 0 0 0 0 0 1
2406 0.974915 0.866425 0.818144 0.936140 0.938583 0.935087 0.930597 1.000000 ... 0.025059 0.004033 0 0 0 0 0 1
[2407 rows x 300 columns]
一种选择特征的方法,描述在Milne, Rounds, 和 Goddard (2018)的论文中,适合量子退火机,并可以总结如下:
将数据集拆分为特征矩阵和响应向量——即 scikit-learn 中熟悉的(X, y)元组。使用这些,构造一个相关矩阵 C,其中 Cii 表示特征与响应的相关性,Cij 是特征之间的相互相关性。对于与响应高度相关的特征,我们希望尽可能多地捕获。对于彼此相关的特征,我们希望尽可能少地捕获,但不影响与响应强相关的特征。显然,我们希望在所有这些标准之间取得平衡。
设 Xi 为二元变量,指示特征 i 是否被选择到最终数据集中。如果选择了特征 i,则 Xi = 1,否则 Xi = 0。问题变成了等同于寻找极值:
目标函数
第一个项的总和表示来自特征的个别贡献——我们称之为线性项。第二个项的总和可以说包含了二次交互项。alpha 是一个偏置系数,它控制目标函数中我们允许的特征之间的交互量;其值需要探索以找到最佳结果。找到最小化目标函数的 Xi 集合等同于特征选择。
目标函数实际上是一个二次二元模型——BQM。它是二元的,因为 Xi 可以是 0 或 1。它是二次的,因为最高阶的项是二次交互项。这可以很容易地在量子退火机上解决。我们需要应用的唯一约束是,等于 1 的变量 Xi 的数量不能超过我们愿意接受的特征总数。
用困难的方法进行特征选择
让我们来解决这个问题。下面的代码块导入了我们需要的所有库,下载了数据集,并实例化了一个分类模型。
import numpy as np
import openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
import plotly.express as px
from plotly.subplots import make_subplots
import dimod
from dwave.system import LeapHybridCQMSampler
dataset = openml.datasets.get_dataset(312)
X, y, categorical_indicator, attribute_names = dataset.get_data(
target=dataset.default_target_attribute, dataset_format="dataframe"
)
X = X.astype(float)
y = y.values.astype(int)
clf = RandomForestClassifier()
我们关注两件事:
-
每个特征的“相关性”,即它与响应的相关强度
-
特征的“冗余性”,即相关矩阵中的非对角项(二次交互项的权重)
让我们构建几个函数来可视化相关性和冗余性,将它们应用于没有特征选择的整个数据集,然后在整个数据集上进行交叉验证(5 折)模型的性能。
def evaluate_model(m, X, y, indices=None):
if not indices is None:
X_filtered = X.iloc[:, indices]
else:
X_filtered = X
acc = np.mean(cross_val_score(clf, X_filtered, y, cv=5))
return acc
def show_relevance_redundancy(X, y, indices=None, title=""):
if not indices is None:
X = X.iloc[:, indices].copy()
y = y
fig = make_subplots(
rows=1,
cols=2,
column_widths=[0.68, 0.32],
column_titles=["relevance", "redundancy"],
)
trace_rel = px.bar(np.array([abs(np.corrcoef(x, y)[0, 1]) for x in X.values.T]))
trace_red = px.imshow(abs(np.corrcoef(X.values, rowvar=False)))
fig.add_trace(trace_rel.data[0], row=1, col=1)
fig.add_trace(trace_red.data[0], row=1, col=2)
fig.update_layout(width=1200, height=480, title=title)
fig.show()
show_relevance_redundancy(
X,
y,
None,
f"full dataset: acc={evaluate_model(clf, X, y, None):.4f}",
)
结果如下:
基线性能
模型的准确性为 0.9082。特征在相关性方面变化显著。有一些冗余的区域,特征之间似乎存在强相关性。
为了最小化目标函数,让我们创建一个带有约束的二次模型,并将其发送到量子退火器。代码将在下面解释。
k = 30
# Pearson correlation
correlation_matrix = abs(np.corrcoef(np.hstack((X, y[:, np.newaxis])), rowvar=False))
# fix the alpha parameter from the Milne paper
# to account for the numerous quadratic terms that are possible
beta = 0.5
alpha = 2 * beta * (k - 1) / (1 - beta + 2 * beta * (k - 1))
# generate weights for linear and quadratic terms, per Milne algorithm
Rxy = correlation_matrix[:, -1]
Q = correlation_matrix[:-1, :-1] * (1 - alpha)
np.fill_diagonal(Q, -Rxy * alpha)
# create binary quadratic model from the linear and quadratic weights
bqm = dimod.BinaryQuadraticModel(Q, "BINARY")
# create constrained quadratic model
cqm = dimod.ConstrainedQuadraticModel()
# the objective function of the CQM is the same as BQM
cqm.set_objective(bqm)
# constraint: limit the number of features to k
cqm.add_constraint_from_iterable(
((i, 1) for i in range(len(cqm.variables))), "==", rhs=k
)
# the sampler that will be used is the hybrid sampler in the DWave cloud
sampler = LeapHybridCQMSampler()
# solve the problem
sampleset = sampler.sample_cqm(cqm)
这里发生了很多事情。让我们解释每一步。
我们将特征数量限制为 k=30。这是模型的主要限制。
我们稍微偏离了论文中描述的目标函数。我们没有直接定义 alpha,而是使用了一个等效参数 beta,它具有相同的作用。然后我们以一种保持交互项贡献受控的方式定义 alpha——如果特征数量极大,这将确保交互项不会压倒目标函数中的线性项。
我们以一种可以直接构建 BQM 的方式塑造相关矩阵。我们将 BQM 限制在不超过 k=30 个特征的条件下,因此我们获得一个受限的二次模型 (CQM)。我们将 CQM 发送到量子退火器并收集结果。
应该注意的是,量子部分代码的运行时间约为 10 秒。对于许多问题,这是一个典型的基准时长,即使变量数量很大。运行时间往往随着问题复杂性的增加而增加,但这种增加可能不会像你从经典求解器中预期的那样急剧。
我们还没有完成。D-Wave 硬件通常会对解决方案空间进行采样,并返回大量可能的可行解决方案。这是因为硬件的工作方式:一个退火事件运行迅速且容易,因此重复进行退火是有意义的,这里自动发生了。如果生成足够多的样本,其中一些将是最佳解决方案。
所以我们需要对解决方案进行排序并挑选最佳方案。“最佳”意味着——它在目标函数中具有最低的值,D-Wave 称之为“能量”,因为它确实是量子处理器的物理能量。
# postprocess results
feasible = sampleset.filter(lambda s: s.is_feasible)
if feasible:
best = feasible.first
else:
assert len(cqm.constraints) == 1
best = sorted(
sampleset.data(),
key=lambda x: (list(cqm.violations(x.sample).values())[0], x.energy),
)[0]
assert list(best.sample.keys()) == sorted(best.sample.keys())
is_selected = np.array([bool(val) for val in best.sample.values()])
features = np.array([i for i, val in enumerate(is_selected) if val])
best_energy = best.energy
features
是一个包含量子退火器选定特征索引的数组。这是特征选择过程的结果。显然,它的长度将是 k=30。
让我们在特征选择后测量模型的准确性:
show_relevance_redundancy(
X,
y,
features,
f"explicit optimization: acc={evaluate_model(clf, X, y, features):.4f}",
)
显式优化
特征选择后的模型准确度为 0.9381。我们获得了确切数量的特征。大多数特征具有很高的相关性。特征之间的相关值通常较低。模型表现更好,可能更容易解释,并且选择正确特征所花费的时间并不长。
但是这个过程很长且繁琐,如果你只想选择一些特征的话。幸运的是,现在有一种更简单的方法。
轻松的特征选择方法
如果你安装了 D-Wave scikit-learn 插件,你只需执行以下操作:
X_new = SelectFromQuadraticModel(num_features=30, alpha=0.5).fit_transform(X, y)
这就是全部了。后台代码创建了二次模型,对其进行约束,发送到量子求解器,获取结果,再进行解析,最后选择最佳结果以 NumPy 数组格式返回,这是 scikit-learn 所期望的格式。运行时间大致相同。但让我们看看结果是什么样的:
X_new_df = pd.DataFrame(data=X_new, columns=list(range(X_new.shape[1])))
show_relevance_redundancy(
X_new_df,
y,
None,
f"plugin optimization: acc={evaluate_model(clf, X_new_df, y, None):.4f}",
)
插件优化
模型的准确度为 0.9369,本质上与我们通过显式优化获得的性能相同(在各种随机组件的典型波动范围内)。
选择的特征集略有不同。这可能是由于我的手动过程和库中的自动化实现之间的微小差异造成的。
在任何一种情况下,我们都将模型的性能从良好提升到了卓越,通过一种不进行启发式猜测的算法,而是考虑了所有特征的总和。
待探索的进一步主题
alpha 参数会使特征选择算法偏向于减少冗余但可能质量较低(alpha=0),或者偏向于提高质量但代价是增加冗余(alpha=1)。最佳值取决于你要解决的问题。
SelectFromQuadraticModel()
方法有一个名为 time_limit
的参数,正如名称所示:它控制求解器上的最大时间。你为所用的时间付费,因此这里的高值可能会很昂贵。另一方面,如果量子退火似乎未能收敛到足够好的解决方案,这里较高的值可能值得探索。这里展示的问题对量子处理器而言相当简单,所以我们仅使用了默认值。
请记住,你的计算机与云端量子求解器之间的数据交换可能并非简单,因此在时间敏感的应用程序(如金融)中,连接的延迟可能非常重要。
参考资料
D-Wave 网络研讨会介绍了 scikit-learn 插件。www.youtube.com/watch?v=VHEpe00AXPI
Milne, A., Rounds, M., & Goddard, P. (2018). 使用量子退火器进行信用评分和分类的最佳特征选择。1QBit.com。1QBit.com 上的白皮书
Boutell, M., Luo, J., Shen, X., Brown, C. (2004). 学习多标签场景分类。《模式识别》期刊,ScienceDirect。文章链接
入门指南:D-Wave 求解器
D-Wave scikit-learn 插件:D-Wave scikit-learn 插件
D-Wave 示例:CQM 特征选择:D-Wave 示例:CQM 特征选择
OpenML 场景识别数据集:OpenML 场景识别数据集
为本文创作的代码和工件:为本文创作的代码和工件
感谢罗斯-霍尔曼科技学院的马修·布特尔博士授权我在创作共用许可证下访问场景特征数据集。
所有图像均由作者创作。
使用 React 构建互动界面以展示令人兴奋的数据集
原文:
towardsdatascience.com/using-react-to-build-interactive-interfaces-to-exciting-dataset-c01691a5fc38
数据教程
使用网页开发创建更动态的数据可视化体验
·发表在 Towards Data Science ·5 分钟阅读·2023 年 9 月 19 日
--
界面演示视频
除了我作为一家小型机器学习公司的首席执行官的全职工作之外,我的爱好是创建美丽的数据可视化。
我通常使用 Matplotlib,但这次我想创造一个更互动的体验。
由于我喜欢网页开发和设计,我决定为世界银行的 人口估计和预测 数据集创建一个 React 应用程序。
这是一个令人着迷的数据集,你可以查看 1960 年至 2022 年所有国家和地区的人口金字塔,包括对 2050 年的预测。它的许可证是 Creative Commons Attribution 4.0。
这是一个非常适合互动界面的数据集,人们可以快速更改年份和地区。
在这个故事中,我将分享我的工作见解和所学到的经验。
如果你想测试这个解决方案,你可以在这里找到:datawonder.io/population-pyramids
让我们开始吧。
第一部分:准备数据
我想创建一个简单而快速的后台,将数据提供给前端,而不进行任何耗时的预处理。
相反,我的想法是提前处理所有数据,并在应用程序启动时将其全部加载到内存中。
世界银行的数据总是有一组指标,而我需要的指标具有以下格式:
人口年龄
,
有 17 个年龄组,从 0–4 岁到 80 岁及以上。每个指标都有一个单独的列表示每一年,如下方的 pandas 数据框所示。
因为我确切知道我需要数据的哪些部分,不想在后台进行任何过滤或其他操作。
相反,我决定将 DataFrame 转换为以下格式的 JSON 文件。
data = {
<Area1> = {
"1960": {
"total": X + Y ... + A + B + ...
"male": {
"0-4": X,
"4-9": Y,
...
}
"female": {
"0-4": A,
"4-9": B,
...
}
}
...
}
...
}
这就是准备数据的全部内容。
第二部分:创建后端
后端是这个 web 应用程序中最简单的部分,只提供用户界面的 JSON 数据。
当应用程序启动时,我将预处理的 JSON 加载到内存中作为 Python 字典。
import json
from typing import List
from fastapi import FastAPI, Query
app = FastAPI()
population_data_file = open("./data/data.json")
population_data = json.load(population_data_file)
population_data_file.close()
我决定使用 FastAPI,因为它是一个可行的选项,并且使创建端点变得容易。
当用户发送查询时,我可以立即访问数据。
@app.get("/population")
def get_data(areas: List[str] = Query(None)):
populations = []
for area in areas:
if area not in population_data.keys():
return "Unknown area", 400
populations.append({"area": area, "data": population_data[area]})
return {"populations": populations}
我还有一个端点来列出所有可用的国家和地区。
@app.get("/areas")
def get_countries():
return list(population_data.keys())
这就是我后端所有的代码。
第三部分:创建用户界面
构建前端是这个项目中最耗时的部分。我不能详细讲解每一行代码。
我的目标是创建一个简约的应用程序,它既响应迅速又易于使用,因为许多数据接口过于复杂。
我使用了两个主要的 JavaScript 库来让我的生活更轻松:Recharts 和 Chakra。
我不会详细讲解 Recharts,因为它们有很棒的文档来描述你可以做的所有事情。
实质上,它们提供了一个叫做 ResponsiveContainer
的 React 组件,作为你图表的基础。
然后你可以定义像BarChart
、YAxis
、XAxis
、Tooltip
和Legend
这样的东西来创建完美的图表。
结构如下;所有组件都有几个属性来确保样式正确。
<ResponsiveContainer>
<BarChart>
<CartesianGrid />
<YAxis />
<XAxis />
<Tooltip />
<Legend />
<Bar />
<Bar />
</BarChart>
</ResponsiveContainer>
例如,我将我的 BarChart
组件设置为垂直布局,并将条形图堆叠在“sign”上,以使它们朝相反方向延展。
<BarChart
stackOffset="sign"
layout="vertical"
data={
prepare_data(
population.data[year.toString()],
percentage
)
}
>
<Bar dataKey="Male" stackId="a" fill="#57BED6" z={100} />
<Bar dataKey="Female" stackId="a" fill="#EC6237" z={10} />
</BarChart>
prepare_data()
函数将一个区域和年份的人口数据转换为对象列表。
const prepare_data = (population, percentage) => {
return AGES.map((age) => ({
name: age,
Male:
(population["male"][age] * -1) /
(percentage ? 0.01 * population["total"] : 1),
Female:
population["female"][age] / (percentage ? 0.01 * population["total"] : 1),
}));
};
下面是每个图表的样子。
我为其他组件设置了额外的属性来样式化轴标签和网格等。
第四部分:部署应用程序
我决定使用 Digital Ocean 的应用平台作为我的部署方式,但这可能不是最便宜的选择。
他们的最小后端服务比我需要的要好得多,但我计划随着时间的推移添加额外的端点以支持其他数据集。
部署前端
如果你使用create-react-app
创建前端,你可以将 Digital Ocean 指向那个仓库,它会在你将代码推送到 master 时构建应用程序。
你无需创建 Docker 容器或其他类似的东西。这对于像我这样在 DevOps 上比较懒的人来说非常棒。
它对于这个应用程序也很完美,因为它不需要任何花哨的东西。
结论
静态数据可视化很棒,但有时你会想要一种更互动的体验,让用户自己探索模式。
在这种情况下,构建简单的 React web 应用程序来使用开源库创建图表和图形是出奇的简单。
在本教程中,我分享了一些关于如何创建一个用户界面的见解,该界面允许用户查看不同国家和地区的年龄分布。
我希望你能从中获得一些灵感,自己做一个类似的项目。
感谢你的阅读,下次见!😊
使用无服务器函数来管理和监控基于云的训练实验
一种简单的常规方法,可以帮你节省大量金钱
·
查看 发表在 Towards Data Science · 11 min 阅读 · 2023 年 12 月 17 日
--
图片由 Ziyou Zhang 在 Unsplash 提供
这篇博客文章是与我的同事Shay Margalit共同撰写的。它总结了他关于如何利用AWS Lambda函数来提高对Amazon SageMaker训练服务的使用控制和成本的研究。感兴趣吗?请继续阅读 😃.
我们很幸运(或者说非常不幸——这要看你问谁)能亲历这场被许多人预期会改变我们所知世界的人工智能革命。得益于硬件开发的进步和对大量数据的获取,这场革命很可能会影响我们日常生活的许多方面——虽然具体如何,没有人能确定。为了满足对人工智能日益增长的需求,基础的机器学习模型规模迅速扩大,训练这些模型所需的资源也在增加。总而言之,要在人工智能开发领域保持相关性,就需要对昂贵的重型设备进行大规模投资。
基于云的托管训练服务,例如 Amazon SageMaker、Google Vertex AI 和 Microsoft Azure ML,通过使开发者能够在他们原本负担不起的机器上进行训练,降低了人工智能开发的门槛。虽然这些服务减少了人工智能的前期成本,并使你仅为训练时间付费,但变量成本的潜在累积需要对如何使用训练服务以及它们如何贡献于整体训练费用进行仔细规划。然而,不可避免的是,事情并不总是按计划进行。借用一句古老的意第绪谚语“开发者计划,编程之神笑”。当涉及到高风险时,比如训练人工智能模型——一个错误的实验可能导致数百或数千美元的计算时间浪费,因此明智之举是建立多个防线。
防线第一步——鼓励健康的发展习惯
第一防线应关注机器学习算法工程师的开发实践。以下是一些你可能会考虑的指导原则:
-
鼓励对用于训练的硬件资源进行适当且成本优化的使用(例如,参见这里)。
-
及早识别并终止失败的实验。
-
通过定期分析和优化运行时性能来提高性价比(例如,参见这里)。
尽管制定和调整像上面这些 AI 开发原则可能会提高您的生产力并减少浪费,但它们并不能完全防护所有可能的失败。例如,专用的故障检测运行时进程可能无法解决培训实验停滞的情况(例如,由于培训应用程序进程中的死锁),而培训任务仍然保持活动状态,直到它被主动停止或超时。
第二道防线 — 部署跨项目的保护措施
在这篇文章中,我们提议建立第二道防线,监控项目(或组织)中的所有培训活动,验证其是否符合预定的规则,并在发现异常培训实验时采取适当的措施。一种方法是使用专用的无服务器函数,这些函数会在培训任务的不同阶段触发,并编程以评估任务的状态,并根据需要停止或重新启动(可能会更改任务设置)任务。在接下来的章节中,我们将演示如何使用 AWS Lambda 作为对抗异常 Amazon SageMaker 培训实验的第二道防线的几个示例。
免责声明
尽管我们选择了 Amazon SageMaker 和 AWS Lambda 进行演示,但本文内容同样适用于其他服务,并且可以为它们实现类似的功能。请不要将我们选择这些服务解读为对其使用的推荐。云端培训有多种选项,每种选项都有其自身的优缺点。最适合您的选择将极大地取决于您项目的具体细节。
虽然我们将分享一些无服务器代码的 Python 示例,但我们不会详细讲解如何创建和部署它们作为 AWS Lambda 函数。与 AWS Lambda 交互的方式有很多种。我们建议读者参阅官方 AWS 文档以了解更多信息。
以下示例是为演示目的而创建的。它们可能需要修改以适应您项目的具体需求。在调整我们提议的解决方案之前,请确保充分理解所有代码细节及相关服务费用。重要的是,我们将分享的代码没有经过严格测试。任何涉及创建和调用多个 Lambda 函数以及Amazon CloudWatch 警报(如本文所述)的解决方案都需要进行适当的验证,以防止冗余/孤立工件的积累。
我们强烈建议您将本文的详细信息与最新的 AWS Lambda 文档和支持库的最新版本进行核对。
强制执行开发者合规性
虽然云治理对于成功和高效地使用云服务通常是至关重要的,但其实施有时可能具有挑战性。例如:Amazon SageMaker 包含一个用于向训练作业添加标签的 API。这些标签可以用来包含与 SageMaker 作业相关的元数据,比如训练项目的名称、开发阶段、当前试验的目标、运行作业的开发组或用户的名称等等。这些元数据可以用来收集统计信息,比如每个项目或组的开发成本。在下面的代码块中,我们展示了如何将多个标签应用于 SageMaker 训练作业:
from sagemaker.pytorch import PyTorch
tags = [{'Key': 'username', 'Value': 'johndoe'},
{'Key': 'model_name', 'Value': 'mnist'},
{'Key': 'training_phase', 'Value': 'finetune'},
{'Key': 'description', 'Value': 'fine tune final linear layer'}]
# define the training job with tags
estimator = PyTorch(
entry_point='train.py',
framework_version='2.1.0',
role='<arn role>',
py_version='py310',
job_name='demo',
instance_type='ml.g5.xlarge',
instance_count=1,
tags=tags
)
# deploy the job to the cloud
estimator.fit()
自然,这些标签只有在我们能够强制执行其应用时才有用。这就是AWS Lambda的作用。使用Amazon EventBridge,我们可以监控 SageMaker 训练作业状态的变化,并注册一个在每次变化时被触发的函数。在下面的代码块中,我们提出了一个 Python 例程,每次作业启动时都会验证特定的 SageMaker 标签。如果缺少某个标签,作业将被自动终止。事件的结构可以在这里找到。请注意使用(更详细的)SecondaryStatus字段来轮询训练作业的状态(而不是TrainingJobStatus)。
import boto3
def stop_training_job(training_job_name):
sm_client = boto3.client("sagemaker")
response = sm_client.stop_training_job(TrainingJobName=training_job_name)
assert response['ResponseMetadata']['HTTPStatusCode'] == 200
# TODO - optionally send an email notification
def enforce_required_tags(training_job_name, event):
event_tags = event['detail']['Tags']
if 'model_name' not in event_tags:
stop_training_job(training_job_name)
# define lambda handler
def sagemaker_event_handler(event, _):
job_name = event['detail']['TrainingJobName']
job_secondary_status = event['detail']['SecondaryStatus']
if job_secondary_status == 'Starting':
enforce_required_tags(job_name, event)
AWS 提供了多种创建 Lambda 函数的方法。有关详细信息,请参见AWS Lambda文档。创建后,请确保将该函数设置为EventBridge 规则的目标。
同样的功能可以用来强制执行额外的开发规则,这些规则旨在控制成本,比如:可以使用的实例类型、每个作业的最大实例数量、作业的最大运行时间等等。
停止停滞的实验
想象一下以下场景:你计划了一个大型的云端训练作业,将在八台每小时 $30 的 ML 计算实例上运行三天。为了完成这个任务,你已确保了 $17,280 的预算(8 个实例 x 每小时 $30 x 24 小时 x 3 天)。你在出发去度假三天的周末前启动了训练作业。当你从假期回来时,发现作业开始后一小时内,训练过程停滞,导致昂贵的设备实际上在三天内完全闲置。你不仅浪费了 $17,280(祝你好运向老板解释这一点),而且你的开发进度现在也被推迟了三天!!
保护自己免受这种情况的一种方法是监控底层训练作业资源的使用情况。例如,如果你的训练实例的 GPU 使用率在较长时间内低于某个阈值,这很可能是一个信号,表明出现了问题,需要立即停止训练作业。
我们将通过定义一个 Amazon CloudWatch 警报 来实现这一点,该警报监控每个 SageMaker 作业的一个训练实例的 GPU 使用情况,并在警报触发时调用一个 AWS Lambda 函数来终止该作业。设置这一功能需要三个组件:一个 Amazon CloudWatch 警报(每个训练作业一个),一个 AWS Lambda 函数,以及一个 Amazon Simple Notification Service (SNS) 主题,用于将 Lambda 函数与 CloudWatch 警报连接起来。
首先,我们创建一个 SNS 主题。这可以通过 Amazon SNS 控制台完成,也可以通过 Python 实现,如下所示:
import boto3
sns_client = boto3.client('sns')
# Create a SNS notification topic.
topic = sns_client.create_topic(Name="SageMakerTrainingJobIdleTopic")
topic_arn = topic.arn
print(f"Created SNS topic arn: {topic_arn}")
接下来,我们扩展了上面定义的 sagemaker_event_handler 函数,每次启动训练作业时创建一个独特的警报。我们编程设置警报以测量五分钟周期内的平均 GPU 使用率,并在连续三次测量低于 1% 时提醒我们的 SNS 主题。训练作业完成时,警报将被删除。
def create_training_alarm(job_name):
topic_arn = '<sns topic arn>'
SAMPLE_PERIOD_SECONDS = 60 * 5 # 5 minutes
SAMPLE_POINTS_LIMIT = 3
GPU_UTIL_THRESHOLD_PERCENTAGE = 1
cloudwatch_client = boto3.client('cloudwatch')
# A new sample is generated each SAMPLE_PERIOD_SECONDS seconds.
# The alarm will set off it there will be more than SAMPLE_POINTS_LIMIT
# below the limit.
response = cloudwatch_client.put_metric_alarm(
AlarmName=job_name + 'GPUUtil',
AlarmActions=topic_arn,
MetricName='GPUUtilization',
Namespace='/aws/sagemaker/TrainingJobs',
Statistic='Average',
Dimensions=[{
"Name": "Host",
"Value": job_name+"/algo-1"
}],
Period=SAMPLE_PERIOD_SECONDS,
EvaluationPeriods=SAMPLE_POINTS_LIMIT,
DatapointsToAlarm=SAMPLE_POINTS_LIMIT,
Threshold=GPU_UTIL_THRESHOLD_PERCENTAGE,
ComparisonOperator='LessThanOrEqualToThreshold',
TreatMissingData='notBreaching'
)
assert response['ResponseMetadata']['HTTPStatusCode'] == 200
def delete_training_alarm(job_name):
cloudwatch_client = boto3.client('cloudwatch')
response = cloudwatch_client.delete_alarms(
AlarmNames=[job_name+'GPUUtil'])
def sagemaker_event_handler(event, _):
job_name = event['detail']['TrainingJobName']
job_secondary_status = event['detail']['SecondaryStatus']
if job_secondary_status == 'Starting':
enforce_required_tags(job_name, event)
elif job_secondary_status == 'Training':
create_training_alarm(job_name)
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:
delete_training_alarm(job_name)
最后,我们定义一个第二个 Python AWS Lambda 函数,该函数 解析从 SNS 主题接收到的消息,并终止与警报关联的训练作业。
import boto3, json
def lambda_sns_handler(event, context):
data = json.loads(event['Records'][0]['Sns']['Message'])
alarm_name = data['AlarmName']
training_job_name = alarm_name.replace('GPUUtil', '')
stop_training_job(training_job_name)
AWS 提供了多种机制来订阅 Lambda 函数到 SNS 主题,包括 AWS 控制台、AWS CLI 和 AWS 无服务器应用程序模型 (AWS SAM)。
我们描述的解决方案总结在下图中:
AWS 架构图(作者提供)
请注意,相同的架构也可以用于强制执行你机器学习训练项目的最低 GPU 利用率。GPU 通常是训练基础设施中最昂贵的资源,你的目标应是最大化所有训练工作负载的利用率。通过规定最低利用率(例如 80%),你可以确保所有开发者适当优化他们的工作负载。
确保开发的连续性
在我们之前的示例中,我们演示了如何识别并停止一个停滞的实验。在我们描述的大型训练任务场景中,这有助于节省大量资金,但并没有解决开发的三天延迟。显然,如果停滞的源头在你的代码中,那么推迟恢复训练直到问题解决是有意义的。然而,我们经常遇到的训练中断并不是由于我们的代码造成的,而是由于服务环境中的偶发故障。在这种情况下,你的优先任务可能是确保训练的连续性,而不是等待有人手动恢复训练任务(使用最新的训练检查点)。在下面的代码块中,我们使用boto3的create_training_job API 来扩展我们的sagemaker_event_handler函数,以(简单地)恢复任何运行至少两小时后失败的训练任务。
import boto3, datetime
def clone_job(training_name, disable_spot=False):
# get description
client = boto3.client('sagemaker')
desc = client.describe_training_job(TrainingJobName=training_name)
# update the training name
new_training_name = training_name + 'clone'
use_spots = (not disable_spot) and desc["EnableManagedSpotTraining"]
if disable_spot:
desc["StoppingCondition"].pop("MaxWaitTimeInSeconds", None)
client.create_training_job(
TrainingJobName=new_training_name,
HyperParameters=desc["HyperParameters"],
AlgorithmSpecification=desc["AlgorithmSpecification"],
RoleArn=desc["RoleArn"],
OutputDataConfig=desc["OutputDataConfig"],
ResourceConfig=desc["ResourceConfig"],
StoppingCondition=desc["StoppingCondition"],
EnableNetworkIsolation=desc["EnableNetworkIsolation"],
EnableInterContainerTrafficEncryption=desc[
"EnableInterContainerTrafficEncryption"
],
EnableManagedSpotTraining=use_spots,
Tags=client.list_tags(ResourceArn=desc['TrainingJobArn'])
)
def sagemaker_event_handler(event, _):
TRAIN_TIME_THRESHOLD = 2 * 60 * 60: # 2 hours
job_name = event['detail']['TrainingJobName']
job_secondary_status = event['detail']['SecondaryStatus']
if job_secondary_status == 'Starting':
enforce_required_tags(job_name, event)
elif job_secondary_status == 'Training':
create_training_alarm(job_name)
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:
delete_training_alarm(job_name)
if job_secondary_status == 'Failed':
start_time = datetime.datetime.utcfromtimestamp(
event['detail']['CreationTime']/1000)
end_time = datetime.datetime.utcfromtimestamp(
event['detail']['TrainingEndTime']/1000)
training_time_seconds = (end_time - start_time).seconds
if training_time_seconds >= TRAIN_TIME_THRESHOLD:
clone_job(job_name)
上述函数会自动恢复任何在两小时后失败的任务。一个更实用的解决方案可能会尝试诊断错误类型,以确定恢复任务是否合适。一种方法是解析失败描述消息和/或与失败任务相关的 CloudWatch 日志。
高级 Spot 实例利用
Amazon SageMaker 的一个吸引人的特点是其对托管抢占式训练的支持。Amazon EC2 Spot 实例允许你利用未使用的 EC2 容量,并以折扣价格购买。这些实例的一个问题是它们在使用过程中可能会被收回(“中断”)。因此,Spot 实例应仅用于容错工作负载。SageMaker 通过代表你识别 Spot 中断,并在新的 Spot 实例可用时自动重新启动作业,使利用 Spot 实例变得简单。尽管托管抢占式 实例可以用来降低训练成本,但有时这一策略可能会适得其反。例如,当 Spot 容量不足时,你的训练作业可能会在开始之前超时。或者,作业可能会经历频繁的中断,这使得它无法取得任何有意义的进展。这两种情况都可能干扰开发并降低生产力。这些类型的情况可以使用 AWS Lambda 进行监控和处理。在下面的代码块中,我们扩展了我们的sagemaker_event_handler 函数,以识别一个被中断超过三次的训练作业,并用一个禁用了托管抢占式训练的克隆作业替换它。
def sagemaker_event_handler(event, _):
TRAIN_TIME_THRESHOLD = 2 * 60 * 60: # 2 hours
MIN_ITERRUPTS = 3
job_name = event['detail']['TrainingJobName']
job_secondary_status = event['detail']['SecondaryStatus']
if job_secondary_status == 'Starting':
enforce_required_tags(job_name, event)
elif job_secondary_status == 'Training':
create_training_alarm(job_name)
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:
delete_training_alarm(job_name)
if job_secondary_status == 'Failed':
start_time = datetime.datetime.utcfromtimestamp(
event['detail']['CreationTime']/1000)
end_time = datetime.datetime.utcfromtimestamp(
event['detail']['TrainingEndTime']/1000)
training_time_seconds = (end_time - start_time).seconds
if training_time_seconds >= TRAIN_TIME_THRESHOLD:
clone_job(job_name)
if job_secondary_status == 'Interrupted':
transitions = event['detail']["SecondaryStatusTransitions"]
interrupts = [e for e in transitions if e["Status"] == "Interrupted"]
num_interrupts = len(interrupts)
if num_interrupts > MIN_ITERRUPTS:
stop_training_job(job_name)
clone_job(job_name, disable_spot=True)
上述实现仅根据训练作业的中断次数确定了 Spot 使用策略。一个更复杂的解决方案可能还会考虑其他作业(使用相同实例类型)、中断发生的时间段、活跃训练时间的数量和/或由于 Spot 实例容量不足而超时的最近作业数量。
总结
有效的 AI 模型开发需要定义一个创意且详细的训练基础设施架构,以最小化成本并最大化生产力。在这篇文章中,我们展示了如何使用无服务器的 AWS Lambda 函数来增强 Amazon SageMaker 的托管训练服务,以解决训练过程中可能出现的一些常见问题。当然,具体应用这些技术的方式将很大程度上取决于你项目的具体情况。
如果有任何问题、意见或修正,请随时联系。务必查看我们关于 DL 训练优化的其他文章。
使用 SHAP 调试 PyTorch 图像回归模型
原文:
towardsdatascience.com/using-shap-to-debug-a-pytorch-image-regression-model-4b562ddef30d
使用 DeepShap 来理解和改进驱动自动驾驶汽车的模型
·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 1 月 10 日
--
(来源:作者)
自动驾驶汽车让我感到恐惧。巨大的金属块在空中飞驰,如果出现问题没有人可以阻止它们。为了降低风险,仅仅评估这些“怪物”背后的模型是不够的。我们还需要了解它们是如何进行预测的。这是为了避免任何可能导致意外的边缘情况。
好吧,我们的应用程序并没有那么重要。我们将调试用于驱动迷你自动驾驶汽车的模型(最坏的情况可能是一个瘀伤的脚踝)。尽管如此,IML 方法仍然有用。我们将看看它们如何甚至能提高模型的性能。
具体来说,我们将:
-
使用 PyTorch 和图像数据以及连续目标变量来微调 ResNet-18
-
使用 MSE 和散点图来评估模型
-
使用 DeepSHAP 解释模型
-
通过更好的数据收集来纠正模型
-
讨论图像增强如何进一步改善模型
在这个过程中,我们将讨论一些关键的 Python 代码。你还可以在GitHub上找到完整的项目。
如果你是 SHAP 的新手,请查看下面的视频。如果你想要更多内容,可以查看我的SHAP 课程。 如果你注册我的新闻通讯,你可以免费访问 😃
Python 包
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import random
from PIL import Image
import cv2
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import shap
from sklearn.metrics import mean_squared_error
数据集
我们从在一个房间里收集数据开始这个项目(这将对我们造成困扰)。如前所述,我们使用图像来驱动自动化汽车。你可以在Kaggle上找到这些图像的例子。这些图像的尺寸均为 224 x 224 像素。
我们用下面的代码显示了其中一个。注意图像名称(第 2 行)。前两个数字是 224 x 224 帧中的 x 和 y 坐标。在图 1中,你可以看到我们使用绿色圆圈显示了这些坐标(第 8 行)。
#Load example image
name = "32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg"
x = int(name.split("_")[0])
y = int(name.split("_")[1])
img = Image.open("../data/room_1/" + name)
img = np.array(img)
cv2.circle(img, (x, y), 8, (0, 255, 0), 3)
plt.imshow(img)
图 1:赛道输入图像示例(来源:作者)
这些坐标是目标变量。模型使用图像作为输入来预测它们。然后,这一预测用于引导汽车。在这种情况下,你可以看到汽车正在接近左转。理想的方向是朝向绿色圆圈所给的坐标。
训练 PyTorch 模型
我想专注于 SHAP,因此我们不会深入讨论建模代码。如果你有任何问题,随时在评论中提问。
我们首先创建ImageDataset类。这个类用于加载我们的图像数据和目标变量。它通过paths来完成这一任务。值得指出的是目标变量的缩放方式——x和y的范围都在-1和1之间。
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):
self.transform = transform
self.paths = paths
def __getitem__(self, idx):
"""Get image and target (x, y) coordinates"""
# Read image
path = self.paths[idx]
image = cv2.imread(path, cv2.IMREAD_COLOR)
image = Image.fromarray(image)
# Transform image
image = self.transform(image)
# Get target
target = self.get_target(path)
target = torch.Tensor(target)
return image, target
def get_target(self,path):
"""Get the target (x, y) coordinates from path"""
name = os.path.basename(path)
items = name.split('_')
x = items[0]
y = items[1]
# Scale between -1 and 1
x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right
y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom
return [x, y]
def __len__(self):
return len(self.paths)
实际上,当模型部署时,仅使用 x 预测来引导汽车。由于缩放,x 预测的符号将决定汽车的方向。当x < 0时,汽车应该左转。同样,当x > 0时,汽车应该右转。x 值越大,转弯越急。
我们使用 ImageDataset 类创建训练和验证数据加载器。这是通过对所有来自房间 1 的图像路径进行随机80/20拆分来完成的。最终,我们在训练集和验证集中分别有1,217和305张图像。
TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
paths = glob.glob('../data/room_1/*')
# Shuffle the paths
random.shuffle(paths)
# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)
# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())
注意valid_loader的batch_size。我们使用验证数据集的长度(即 305)。这使我们可以在一次迭代中加载所有验证数据。如果你处理的是更大的数据集,可能需要使用较小的批量大小。
我们加载一个预训练的 ResNet18 模型(第 5 行)。通过设置model.fc,我们更新了最后一层(第 6 行)。这是一个从 512 个节点到我们 2 个目标变量节点的全连接层。我们将使用 Adam 优化器来微调这个模型(第 9 行)。
output_dim = 2 # x, y
device = torch.device('mps') # or 'cuda' if you have a GPU
# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters())
我已经使用 GPU 训练了模型(第 2 行)。你仍然可以在 CPU 上运行代码。微调不像从头开始训练那样计算密集!
最后,我们有了模型训练代码。我们使用 MSE 作为损失函数训练 10 个周期。我们的最终模型是验证集上 MSE 最低的模型。
name = "direction_model_1" # Change this to save a new model
# Train the model
min_loss = np.inf
for epoch in range(10):
model = model.train()
for images, target in iter(train_loader):
images = images.to(device)
target = target.to(device)
# Zero gradients of parameters
optimizer.zero_grad()
# Execute model to get outputs
output = model(images)
# Calculate loss
loss = torch.nn.functional.mse_loss(output, target)
# Run backpropogation to accumulate gradients
loss.backward()
# Update model parameters
optimizer.step()
# Calculate validation loss
model = model.eval()
images, target = next(iter(valid_loader))
images = images.to(device)
target = target.to(device)
output = model(images)
valid_loss = torch.nn.functional.mse_loss(output, target)
print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))
if valid_loss < min_loss:
print("Saving model")
torch.save(model, '../models/{}.pth'.format(name))
min_loss = valid_loss
评估指标
此时,我们想了解我们的模型表现如何。我们查看 MSE 和实际与预测 x 值的散点图。暂时忽略 y,因为它不会影响汽车的方向。
训练和验证集
图 2提供了训练和验证集的这些指标。对角线红线表示完美预测。对于x < 0和x > 0,此线周围有类似的变异。换句话说,模型能够以类似的准确性预测左转和右转。在训练和验证集上的类似表现也表明模型没有过拟合。
图 2:模型在训练和验证集上的评估(来源:作者)
要创建上述图,我们使用model_evaluation函数。注意,数据加载器应创建为在第一次迭代中加载所有数据。
def model_evaluation(loaders,labels,save_path = None):
"""Evaluate direction models with mse and scatter plots
loaders: list of data loaders
labels: list of labels for plot title"""
n = len(loaders)
fig, axs = plt.subplots(1, n, figsize=(7*n, 6))
# Evalution metrics
for i, loader in enumerate(loaders):
# Load all data
images, target = next(iter(loader))
images = images.to(device)
target = target.to(device)
output=model(images)
# Get x predictions
x_pred=output.detach().cpu().numpy()[:,0]
x_target=target.cpu().numpy()[:,0]
# Calculate MSE
mse = mean_squared_error(x_target, x_pred)
# Plot predcitons
axs[i].scatter(x_target,x_pred)
axs[i].plot([-1, 1],
[-1, 1],
color='r',
linestyle='-',
linewidth=2)
axs[i].set_ylabel('Predicted x', size =15)
axs[i].set_xlabel('Actual x', size =15)
axs[i].set_title("{0} MSE: {1:.4f}".format(labels[i], mse),size = 18)
if save_path != None:
fig.savefig(save_path)
使用下面的函数可以看出我们的意思。我们创建了一个新的train_loader,将批量大小设置为训练数据集的长度。加载保存的模型也很重要(第 2 行)。否则,你将使用上一个纪元训练的模型。
# Load saved model
model = torch.load('../models/direction_model_1.pth')
model.eval()
model.to(device)
# Create new loader for all data
train_loader = DataLoader(train_data, batch_size=train_data.__len__())
# Evaluate model on training and validation set
loaders = [train_loader,valid_loader]
labels = ["Train","Validation"]
# Evaluate on training and validation set
model_evaluation(loaders,labels)
移动到新位置
结果看起来不错!我们预计汽车会表现良好,它确实如此。直到我们将它移动到新位置:
图 3:模型在新位置表现不佳(来源:作者)
我们从新位置(房间 2 和房间 3)收集了一些数据。对这些图像进行评估时,你会发现我们的模型表现得不如预期。这很奇怪!汽车在完全相同的轨道上,为什么房间会有影响呢?
图 3:模型在房间 2 和房间 3 上的评估(来源:作者)
使用 SHAP 调试模型
我们寻求 SHAP 的答案。它可以用来理解哪些像素对给定的预测重要。我们首先加载保存的模型(第 2 行)。SHAP 尚未实现 GPU 支持,所以我们将设备设置为 CPU(第 5-6 行)。
# Load saved model
model = torch.load('../models/direction_model_1.pth')
# Use CPU
device = torch.device('cpu')
model = model.to(device)
计算 SHAP 值时,我们需要获取一些背景图像。SHAP 将在计算值时对这些图像进行集成。我们使用batch_size为 100 的图像。这应该能给我们合理的近似值。增加图像数量将提高近似精度,但也会增加计算时间。
#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)
我们通过将模型和背景图像传入DeepExplainer函数来创建一个解释器对象。这个函数有效地为神经网络近似 SHAP 值。作为替代,你可以用GradientExplainer函数替换它。
#Create SHAP explainer
explainer = shap.DeepExplainer(model, background)
我们加载 2 张示例图像——一张右转和一张左转(第 2 行),并进行变换(第 6 行)。这很重要,因为图像应该与训练模型时使用的格式相同。然后,我们计算这些图像预测的 SHAP 值(第 10 行)。
# Load test images of right and left turn
paths = glob.glob('../data/room_1/*')
test_images = [Image.open(paths[0]), Image.open(paths[3])]
test_images = np.array(test_images)
test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)
# Get SHAP values
shap_values = explainer.shap_values(test_input)
最后,我们可以使用image_plot函数来显示 SHAP 值。但我们首先需要重新构造这些值。SHAP 值的返回维度是:
( #targets, #images, #channels, #width, #height)
我们使用转置函数,所以我们有维度:
(#targets, #images, #width, #height, #channels)
请注意,我们也将原始图像传入了image_plot函数。由于变换,test_input图像可能会显得很奇怪。
# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])
shap.image_plot(shap_numpy, test_numpy,show=False)
你可以在图 4中看到结果。第一列给出原始图像。第二列和第三列分别是 x 和 y 预测的 SHAP 值。蓝色像素降低了预测值。相比之下,红色像素增加了预测值。换句话说,对于 x 预测,红色像素导致了更尖锐的右转。
图 4:左转和右转的示例 shap 值(来源:作者)
现在我们有了进展。重要的结果是模型正在使用背景像素。你可以在图 5 中看到这一点,我们对右转的 x 预测进行了放大。换句话说,背景对预测很重要。这解释了表现不佳的原因!当我们转到一个新房间时,背景发生了变化,我们的预测变得不可靠。
图 5:右转 x 预测的 shap 值(来源:作者)
模型对房间 1 的数据过拟合。每张图像中都有相同的对象和背景。因此,模型将这些与左转和右转关联起来。由于训练和验证图像中都有相同的背景,我们在评估中无法识别出这一点。
图 6:对训练数据的过拟合(来源:作者)
改进模型
我们希望我们的模型在所有条件下表现良好。为此,我们期望它只使用轨迹上的像素。那么,让我们讨论一些提高模型鲁棒性的方法。
收集新数据
最好的解决方案是简单地收集更多的数据。我们已经有了一些来自房间 2 和 3 的数据。按照相同的过程,我们使用来自所有 3 个房间的数据训练一个新模型。查看图 7,它现在在新房间的图像上表现更好。
图 7:在房间 2 和 3 上评估新模型(来源:作者)
希望通过在多个房间的数据上训练,我们可以打破转弯与背景之间的关联。现在,左转和右转上存在不同的对象,但轨迹保持不变。模型应该学会轨迹才是预测的重要因素。
我们可以通过查看新模型的 SHAP 值来确认这一点。这些值对应于我们在图 4中看到的相同转弯。现在,背景像素的权重较少。好吧,虽然不完美,但我们在进步。
图 8:在所有 3 个房间上训练的模型的 shap 值(来源:作者)
我们可以继续收集数据。我们收集的数据地点越多,我们的模型就会越强大。然而,数据收集可能是耗时的(而且无聊!)。相反,我们可以考虑数据增强。
数据增强
数据增强是指我们使用代码系统地或随机地改变图像。这使我们能够人为地引入噪声并增加数据集的大小。
例如,我们可以通过翻转图像在垂直轴上来将数据集的大小翻倍。我们之所以可以这样做,是因为我们的轨道是对称的。如图 9 所示,删除也可能是一个有用的方法。这涉及到包含那些对象或整个背景被去除的图像。
图 9:使用删除进行图像增强的示例(来源:作者)
在构建强大的模型时,您还应该考虑诸如光照条件和图像质量等因素。我们可以通过颜色抖动或添加噪声来模拟这些因素。如果您想了解所有这些方法,请查看下面的文章。
使用 Python 进行数据增强,包括翻转、调整亮度、颜色抖动和随机噪声
towardsdatascience.com
在上述文章中,我们还讨论了为什么很难判断这些增强是否使模型更强大。我们可以在许多环境中部署模型,但这很耗时。幸运的是,SHAP 可以作为一种替代方案。与数据收集一样,它可以帮助我们了解增强如何改变模型的预测方式。
希望您喜欢这篇文章!您可以通过成为我的推荐会员 😃 来支持我。
[## 通过我的推荐链接加入 Medium — Conor O’Sullivan
作为 Medium 会员,您的部分会员费会分配给您阅读的作者,同时您可以完全访问每一个故事……
conorosullyds.medium.com](https://conorosullyds.medium.com/membership?source=post_page-----4b562ddef30d--------------------------------)
| Twitter | YouTube | Newsletter — 免费注册以获得Python SHAP 课程
数据集
JatRacer 图像(CC0: 公共领域) www.kaggle.com/datasets/conorsully1/jatracer-images
参考资料
SHAP,PyTorch 深度解释器 MNIST 示例 shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_classification/PyTorch%20Deep%20Explainer%20MNIST%20example.html
使用斜率图表简化你的数据可视化
原文:
towardsdatascience.com/using-slope-charts-to-simplify-your-data-visualization-be1f0eaf1f0f
数据可视化,数据讲述
通过使用斜率图表来简化你繁杂的图表:一个 Python Altair 教程
·发布于 Towards Data Science ·阅读时间 5 分钟·2023 年 12 月 8 日
--
作者提供的图片
我们可能会绘制图表以包含尽可能多的概念在我们的可视化中。因此,我们的图表可能会很难阅读且具有干扰性。因此,在绘制任何内容之前,请坐下来计划你想要传达的信息。然后,查看你的数据,决定什么是有效的必要内容。将其余部分排除在你的可视化之外。
在本教程中,我们将看看如何使用斜率图表来简化一个繁杂的趋势线。如果你是数据分析师,你可能会因为使用斜率图表看到显著的信息损失而惊慌。但我向你保证,在某些情况下,这确实是值得的。
让我们看看斜率图表可以应用的场景。
你可以在何时使用斜率图表
斜率图表是一种只显示第一个和最后一个点的折线图,如下面的图示。
作者提供的图片
当你只想了解数据趋势的斜率时,斜率图表特别有用。因此,斜率图表有助于简化趋势线。 例如,你可以使用斜率图表来查看产品销售在一段时间内是增加还是减少。假设你有许多趋势线需要在同一图表中表示,而你只对每条趋势线的第一个和最后一个值感兴趣。你可以通过使用斜率图表来简化图表。
让我们实现一个实际的例子,看看如何在 Python 数据可视化库 Altair 中实现斜率图表。
示例:游客到达情况
考虑旅游住宿设施的到达数据集,这是由 Eurostat 发布的开放数据。假设你想将葡萄牙的游客到达情况与其他五个国家进行比较:德国、法国、意大利、英国和西班牙。
从将数据集导入为 Pandas DataFrame 开始,并使用melt()
函数来解压数据集,即将数据行转换为列。
import altair as alt
import pandas as pd
from ydata_profiling import ProfileReport
df = pd.read_csv('tourist_arrivals_countries.csv', parse_dates=['Date'])
df = pd.melt(df, id_vars='Date', value_name='Tourist Arrivals', var_name='Country')
接下来,使用 Python Altair 绘制基本图表。要了解 Altair 的介绍,请阅读这篇介绍文章由Soner Yıldırım提供。
chart = alt.Chart(df).mark_line().encode(
x = 'Date:T', #A
y = 'Tourist Arrivals:Q', #B
color=alt.Color('Country:N') #C
)
chart.save('raw-chart.html')
下图显示了结果图表:
图片由作者提供
图表难以阅读。因此,我们通过以下方法改进它:
-
去除缺失的年份: 过滤掉 1994 年之前和 2018 年之后的所有年份
-
按年分组并计算总和: 不是显示所有月份的细节,而是仅显示年度值。
以下代码片段展示了如何进行操作:
# extract year from date
df.loc[:, 'Year'] = df['Date'].dt.year
# filter out years before 1994 and after 2018
df = df[(df['Year'] >= 1994) & (df['Year'] <= 2018)]
# group by year and country
df = df.groupby(['Year', 'Country'])['Tourist Arrivals'].sum().reset_index()
现在,再次绘制图表:
chart = alt.Chart(df).mark_line().encode(
x = 'Year:O',
y = 'Tourist Arrivals:Q',
color=alt.Color('Country:N')
)
chart.save(‘chart.html’)
下图显示了结果图表:
图片由作者提供
从图表中,我们注意到所有国家的趋势线都有所上升。因此,如果我们想评估趋势线在最后一年(2018 年)和第一年(1984 年)之间的变化,我们可以排除所有中间年份(1985 年至 2017 年),仅显示第一个和最后一个值。
此外,为了更清楚地展示差距,我们可以计算 2018 年与 1984 年之间的百分比增长,如下代码所示:
# select only 1994 and 2018
df = df[(df['Year'] == 1994) | (df['Year'] == 2018)]
# add a new column containing the difference for each country between the number of tourist arrivals in the year and 1994
for country in df['Country'].unique():
current = df[df['Country'] == country]['Tourist Arrivals']
base = df[(df['Country'] == country) & (df['Year'] == 1994)]['Tourist Arrivals'].values[0]
df.loc[df['Country'] == country, 'PI'] = (current - base)/ base*100
现在,让我们构建图表。我们还通过在strokeWidth
通道中使用条件来突出显示与葡萄牙相关的趋势线。
base = alt.Chart(df).encode(
x = alt.X('Year:O', title=''),
y = alt.Y('Tourist Arrivals:Q', title='Tourist Arrivals'),
color=alt.Color('Country:N',
scale=alt.Scale(scheme='set2'),
legend=None),
strokeWidth=alt.condition(alt.datum.Country == 'PT', alt.value(7), alt.value(0.5))
).properties(
title='Tourist Arrivals in Portugal (1994-2018)',
width=600,
)
chart = base.mark_line()
chart.save('chart.html')
下图显示了结果图表:
图片由作者提供
只需包括一个要点:与每个图表相关的标签。让我们将它们作为文本注释添加,如下代码片段所示:
annotation = base.mark_text(
dx=10,
align='left',
fontSize=12
).encode(
# format the text to show only 2 decimal places and add a percentage sign
text='label:N'
).transform_filter(
alt.datum.Year == 2018
).transform_calculate(
label='datum.Country + "(Increase:" + format(datum.PI, ".2f") + "%)"'
)
chart = (chart + annotation)
使用transform_calculate()
来正确格式化标签。下图显示了结果图表:
图片由作者提供
你可以继续优化图表,例如旋转 x 轴标签。不管怎样,斜率图已经完成,非常清晰 😃
总结
恭喜你!你刚刚学习了如何在 Python Altair 中实现斜率图!
使用斜率图仅显示趋势线的第一个和最后一个值。
你可以在我的书《数据讲故事与生成 AI:使用 Python 和 Altair》的GitHub 仓库中找到描述场景的完整代码。此外,这个例子是书中一个练习的解决方案。敬请关注,以实现更多示例 😃
如果你已经读到这里,今天的内容我觉得非常满意。谢谢,下次见 😃
附加资源
[书籍] 数据讲故事与生成 AI:使用 Python 和 Altair
你可能也会感兴趣…
这篇文章将展示如何使用 Python Altair 库构建地理地图。
在 Python Altair 中构建地理地图的 3 种方法
对于数据科学家来说,数据可视化是一项基本技能。它有助于快速理解数据中的模式和关联,否则可能会被忽视。地理地图是一个很好的方式来 继续阅读…
你可能不知道的 3 种时间序列可视化方法
在这篇文章中,我们将描述三种时间序列可视化的替代方法:
-
日历热力图
-
箱形图
-
循环图 继续阅读…
使用符号回归为 Elo 著名评分系统增加不确定性
并创建一个意外有用的评分算法
·发表于Towards Data Science ·阅读时间 12 分钟·2023 年 7 月 6 日
--
照片由JESHOOTS.COM提供,来自Unsplash
一个通用评分系统
Elo 评分系统在一些领域变得非常著名。最著名的例子可能是它自 1960 年代以来作为国际象棋评分的基础。此外,网站 538 成功地对其进行了修改,用于他们大部分知名的体育评分。较少公开的是,许多视频游戏开发者在他们的匹配系统中使用了 Elo 系统的变种。如果你正在阅读这篇文章,我会假设你对该系统有一定的了解。它为何在如此多的领域被广泛使用?我认为是因为它的计算扩展性、多功能性和简单性。然而,它也有一些缺点。本文将解决一个非常关键的问题,同时保持上述优点。
符号回归
尽管大型语言模型目前获得了所有的关注(双关含义),但也有其他令人兴奋的模型在被单独开发,具有非常不同的用途。符号回归通常适合发现封闭形式的解析规则,而不是处理像图像分类或音频翻译这样的深度学习任务。例如,如果你想重新发现牛顿的冷却定律,你可以构建一个资源密集型的密集神经网络。虽然这样做在数据足够的情况下效果良好,但它无法推广到未见过的情况。然而,符号回归将是正确的工具。它可以用有限的数据找到精确的公式,因此不仅可以推广,还能节省相当多的计算资源。Cranmer 等人撰写的我最喜欢的论文之一对此进行了更深入的探讨,甚至发展了一个以前未发现的暗物质过密度方程。
问题
经典的 Elo 评分将每个评分视为同样确定。这通常对大规模评分系统而言是一个不准确的假设。简单来说,新进入评分系统的玩家应该几乎总是以比那些已存在一段时间的玩家更大的方差进行建模。同样,那些评分系统长时间未曾见过的玩家可能已经有所改进或退步。假设你和你的四个密友经常在马里奥卡丁车中互相竞技。现在,我们可以使用简化的假设,你们只进行一对一的比赛。虽然还没有进行任何比赛,但假设你们有以下基本真实的 Elo 评分:
-
爱丽丝:1900
-
鲍勃:1700
-
切尔西:1500
-
德米特里:1300
-
伊夫琳:1100
在经典的 Elo 系统中,预期的胜率完全由评分差异决定:
其中 WP_A 是玩家 A 的胜率,R_B 和 R_A 是每个玩家的评分
因为这是一个模拟,我们不必担心当前的状态、良好的睡眠或血液酒精含量。相反,我们真实的 Elo 评分通过上述方程对应真实的胜率。这将意味着在模拟中每个玩家对抗另一名玩家时,以下是真实的胜率百分比:
你决定使用 Elo 系统来确定每个人的水平,因为目前你不知道真实评分。Elo 系统所需的唯一参数是 K 值。K 值是每场比赛中玩家之间“下注”的评分点数,最终根据赛前获胜概率奖励给胜者。较高的 K 值用于更快响应近期结果,而较低的 K 值用于更信任现有评分。在大多数使用情况下,你会看到 K 值在 10 到 40 之间。高于 40,评分变得非常嘈杂。低于 10,评分几乎不会变化。以下是 5,000 次模拟中,你平均需要多长时间才能确定每个人的水平,使用 K 值为 12、24 和 36,经过大约 100 场比赛:
如你所见,即使使用相当高的 K 值(36),我们的系统仍需 100 场比赛才能收敛。随着 K 值的增加,一致性变得更差。例如,当 K=12 时,我们对绿色玩家 Chelsea 的评分近乎真实评分 1500。然而,使用较高的 K 值时,我们对 Chelsea 评分的估计会在 1400 到 1600 之间漂移。
一个选择是从高 K 值开始,以提高收敛速度,然后随着评分的准确性提高而减少 K 值。这是一种在更多数据到来时减少不确定性的巧妙方法,但结果显示,它的表现不如更有原则的方法。
其他系统
过去曾有很多尝试解决这个问题,特别是 1995 年的Glicko 系统和 2005 年的TrueSkill 系统。Glicko 依赖于一个封闭形式的更新方程,而 TrueSkill 使用贝叶斯因子图来更新评分。TrueSkill 具有处理多人比赛场景或团队场景的巨大优势,并且通常收敛更快。TrueSkill 通常需要额外的计算,尤其是当每场比赛/比赛/锦标赛的玩家数量较多时。它也比 Glicko 对参数变化更敏感。使用哪种系统的决定当然取决于应用(但通常最终选择 TrueSkill)。让我们看看 Glicko、TrueSkill、Elo 以及带有衰减 K 的 Elo 如何尝试收敛到 Alice 的评分:
如你所见,Glicko 和 TrueSkill 在平均情况下收敛快得多,比 Elo 和带有 K 值衰减的 Elo 更快。从根本上说,更快的收敛是因为它们融入了不确定性度量。TrueSkill 在我的模拟中,实际超出真实评分,这显示了该系统对参数选择的敏感性。如果我调整参数,它将优于 Glicko,并且不会超出真实评分。
符号回归与我自己的系统
在我了解到其他人如何解决这个问题之前,我决定自己解决这个问题。我很高兴没有看到他们的解决方案,因为这可能会使我放弃创造自己的方法。在这个过程中,我发现了一个很好的使用案例,就是常常被低估的符号回归。
入门数学
与任何经典模型一样,我们必须从假设开始。就本文而言,可以安全地假设玩家和团队的技能在大多数情况下是正态分布的。根据我的经验,这对于几乎任何竞争都是接近真实的。国际象棋评级被发现略微更适合逻辑分布,但我现在想避免过于详细。
假设我们有一个正态分布的 10,000 名马里奥卡丁车玩家的样本。我们来放大其中的两个玩家。其中一个是“蓝色玩家”,另一个是“红色玩家”。
我会说,在这种情况下,我的表现比较符合红色玩家(从差到更差)。
在这种情况下,我从这个分布中随机抽取了两个玩家。在经典的 Elo 公式中,比赛完全可以通过这两个玩家之间的评分差异来描述。经典 Elo 没有试图考虑我们可能对蓝色玩家的数据比红色玩家少。为了澄清,我们还没有使用这些分布来预测单场比赛的结果。这些分布仅仅代表了整体的技能水平。即使它们完全不重叠,红色玩家击败蓝色玩家的概率仍然很小。
比如说,我们只见过蓝色玩家一次,他或她表现得相当不错。可能是因为在正确的时间拾取了正确的道具箱,这很难说。另一方面,我们对红色玩家的表现很确定,虽然具体有多差还有些不确定。在这种新的 Elo 形式中,我们承认蓝色玩家实际上可能比红色玩家差(因为分布有些有意义地重叠)。
让我们把红色分布叫做 R,蓝色分布叫做 B,差异分布叫做 Z:
我保证,数学方程不多,而且只有简单的方程
由于这些是独立分布,B 和 R 的协方差为零。我们的方程此时变得非常简单。换句话说,经典 Elo 只关心均值 Z,约为 1.85。在这种情况下,我们还跟踪 Z 的标准差,约为 0.84。这样,当我们记录一个观察值时,可以以贝叶斯方式更新分布。
注意上面的 x 轴具有相同的刻度。我只是通过 sigmoid 函数将技能差异转化为获胜概率。
上面,我可视化了当我们从技能差异分布中运行 Elo 10,000 次时的情况。经典 Elo 会给蓝色玩家一个 86.4%的胜率,用黄色星星表示。我们的新 Elo 模型将根据从分布中抽样的结果给出许多不同的概率。有趣的是,由于一个尾部场景使胜率急剧下降,而另一个尾部场景达到了上限,因此我们分布的平均胜率实际上比仅仅一个点估计要低得多——83.7%。我只是猜测,但这也许有助于解释为什么国际象棋联合会发现弱者往往比 Elo 预测的赢得更多。
还值得快速提及的是,一些游戏比其他游戏更具运气成分。例如,你可以在 Uno、卡坦岛、尤克或者 Five Crowns 中变得更出色,但这些游戏中涉及的运气因素足够多,以至于一个差的玩家仍然可能击败排名更高的玩家。另一方面,如果我与 Magnus Carlson 下棋,我 100 场 100 次都会输,因此涉及的运气成分较少。为了建模不可减少的运气因素,需要额外的复杂性(我现在暂时避免)。
符号回归
现在,我们如何利用胜负结果来更新这些分布的参数呢?我们需要将其作为一个二元结果变量,以保持 Elo 系统的通用性——我们希望能够将这个系统应用于几乎所有的游戏(一个好的练习是创建一个基于连续结果更新的系统)。然而,据我所知,没有简单的闭式解来更新带有二元结果的正态分布的参数,因此会出现一些不完美之处。一方面,我可以花费几年时间通过大学课程来磨练我的数学技能,以至于知道如何坐下来解决这个问题。另一方面,我可以作弊,把它放入符号回归模型中,并额外写一篇 Medium 文章。我选择了后者,尽管我非常尊重那些选择前者的人。
Beta 分布非常适合建模二元结果的概率,我们已经通过对正态分布均值的 sigmoid 变换得到了这样一个分布的均值(上面的紫色图表/黄色星星)。问题是是否存在一个合理的公式将正态分布的方差转化为 Beta 分布方差的等效形式。为了回答这个问题,我随机生成了几千个具有已知均值和方差的正态技能差异分布,然后将这些随机生成的分布中的每个点通过 sigmoid 函数转换为其 0 到 1 之间的等效概率。以下是一些示例,其中我调整了上面紫色分布的方差,保持相同的均值:
随着方差的增加,已知的信息减少,我们对蓝色玩家比红色玩家更优秀的信心也减少。黄色是方差最高的分布。
我在 Beta 旁边加了一个问号,因为没有保证通过这样做生成 Beta 分布。实际上,我确实没有。然而,仅仅通过目测这些分布,我猜测我可以找到相应的 Beta 参数来很好地描述这些分布。让我们使用最大似然估计经验性地找到上述分布的参数,以拟合最佳的 Beta 参数。
最小化负对数似然,我得出了以下结果:
Beta 参数拟合远非完美,但在方差较低或概率不极端的情况下似乎表现非常好。在生成了数千个这些后,我简单地使用PySR 包拟合了一个符号回归模型。我给它 Beta 均值和正态方差,看看它是否能预测 Beta 样本量(对应于方差)。然后我将得到 Beta 分布模型的所有参数。拟合模型后,PySR 给出了迄今为止最高的评分(基于简洁性和性能)的方程:
其中 mu_b 是 Beta 均值,sigma 平方是正态方差
我对这个方程的简单程度感到惊讶,对于这样一个简单的方程,结果居然如此好。虽然它比上面更远离完美,但你可以看到它对数据的拟合相当好:
这就是符号回归的美妙之处。我们不需要花费数分钟或数小时来优化成千上万的数据点,而是编写一个一行的函数,可以立即给出类似的结果。
将这个关键方程转化为实际的评级系统需要几个额外的步骤,这些步骤不是文章的重点。如果你感兴趣,这涉及使用 Beta 分布的贝叶斯更新,然后使用代数返回到正态分布。我还通过实验发现,由于比赛是成对比较(我们从两个玩家那里获取信息),实际中我们的有效样本量远远高于更新典型的 Beta 分布(如教科书中的掷硬币示例)。对于较高的样本量的调整只涉及添加一个平方项。以下是我用于实现它的所有代码:
当然,此时我仍然不确定我的实现有多好。它会比 TrueSkill 和 Glicko 差得多吗?我很可能只是浪费了大量时间。所以,我对 Alice 进行了之前相同的收敛测试:
(这次我花时间拟合 TrueSkill)
令我惊讶的是,它表现非常好。由于只需要少量基本操作,它可以每秒计算数千个评分,因此具有极高的可扩展性。该系统还具有无需大量参数调整的优点,同时保持了 Elo 系统的多样性和简洁性。当然,我想进一步推进,看看它在哪里崩溃:
当概率接近 0 或 1 时,预测结果与最大似然估计不匹配。因此,最终由于 Beta 分布根本上不是一个完美的拟合,这里存在一些劣势。我相信有一个更复杂的方程可以帮助解决这个问题,或者某种修正方法。也许一位统计学家会告诉我,有比 Beta 分布更适合建模这些结果的分布。然而,我对我创建的系统感到非常惊喜。没有任何系统能做到完美,总是有需要考虑的缺陷。当然,有几种方法可以将这些想法融入实际应用中——如果你感兴趣,可以关注我,我会继续写相关内容。我希望这样的成功能激励读者尝试符号回归以适应自己的使用案例。这里有一个链接到我创建所有这些可视化的笔记本。这是一个库的一部分,我希望在其中实现所有这些评分系统的向量化版本(尽管这可能还需要几个月时间)。
参考文献:
[1] Miles Cranmer, 用于科学的可解释机器学习:PySR 和 SymbolicRegression.jl (2023)
[2] Mark Glickman, 国际象棋等级评估综合指南 (1995)
[3] Ralf Herbrich, Tom Minka, Thore Graepel, TrueSkill(TM): 贝叶斯技能评分系统 (2007),神经信息处理系统第 20 卷
除非另有说明,否则所有图像均由作者提供。
使用 SQL 中的 HAVING 和 DISTINCT 子句
原文:
towardsdatascience.com/using-the-having-and-distinct-clauses-in-sql-d9e3be67b4be
你应该知道的两个重要 SQL 子句
·发表于 Towards Data Science ·4 分钟阅读·2023 年 1 月 25 日
--
SQL 是从数据库中提取数据的强大工具——无论是从一个表还是多个表中。
话虽如此,在有效分析数据时,有些子句特别重要。
本文讨论的两个子句是 HAVING 和 DISTINCT 子句。
为什么我们需要 HAVING 子句?
HAVING 子句的目的是在使用 GROUP BY 函数时充当 WHERE 子句的等效物。
如果你经常使用 SQL,你会知道 GROUP BY 子句对于在表中聚合值是非常重要的,例如,获得特定数据组的平均值,计算特定组的最大值或最小值——以及众多其他功能。
假设数据库中存在以下表格:
来源:作者使用 PostgreSQL 创建的表。表格显示在 pgAdmin4 中。
我们可以看到表格包含:
-
由字母表示的品牌
-
每个品牌的产品由数字 ID 表示
-
每种产品的价格
让我们假设我们希望确定表中每个品牌的平均产品价格,但仅在每个品牌在表中存在多个条目的情况下,即我们可以看到品牌 B 和 D 只有一个条目。因此,我们不希望这些条目包含在分析中。
我们如何使用 HAVING 子句计算在表中有多个条目的品牌的平均价格?
以下是该子句:
select brand, avg(price) from brand_table group by brand having count(brand)>1 order by brand;
执行此子句时,SQL 返回如下结果:
来源:表格由作者使用 PostgreSQL 创建。表格在 pgAdmin4 中显示。
正如我们所见,SQL 只返回了品牌 A、C 和 E 的平均价格。由于品牌 B 和 D 在表中只有一个条目,因此这些品牌没有被包括在内。
如果我们仅选择使用 GROUP BY 语句而不使用 HAVING 子句,SQL 将包括所有品牌:
来源:表格由作者使用 PostgreSQL 创建。表格在 pgAdmin4 中显示。
然而,B 和 D 的上述价格并不特别重要——因为计算只有一个产品的平均价格没有意义。因此,我们使用 HAVING 子句来仅显示多个条目的品牌。
DISTINCT 子句的目的
最简单来说,DISTINCT 子句的目的是选择表中唯一的条目而没有重复。
在上表中,我们可以看到每个品牌条目有多个产品。然而,如果我们只是想在表中显示每个品牌,我们可以如下使用 DISTINCT:
select distinct(brand) from brand_table order by brand;
来源:数据由作者使用 PostgreSQL 生成。表格在 pgAdmin 4 中显示。
在上表中,你会注意到品牌 E 的两个产品价格都是 10。
当选择品牌和价格而不使用 DISTINCT 子句时,我们可以看到这两个条目都被显示:
select brand, price from brand_table order by brand, price;
来源:数据由作者使用 PostgreSQL 生成。表格在 pgAdmin 4 中显示。
然而,当我们包含 DISTINCT 子句时——我们看到表格仅显示了品牌 E 的一个条目,其中价格为 10——这正是我们想要的:
来源:数据由作者使用 PostgreSQL 生成。表格在 pgAdmin 4 中显示。
从这个角度来看,使用 DISTINCT 可以在必要时获取特定组中的唯一值——在这个案例中是品牌。即使价格相同的产品 ID 可能是唯一的——但在按品牌分析时这并不相关。
结论
在本文中,你已经看到了:
-
如何在 GROUP BY 语句中将 HAVING 用作 WHERE 等效项
-
使用 DISTINCT 来返回特定类别中的唯一值
非常感谢阅读,如有任何问题或反馈,我们将不胜感激!
免责声明:本文是“按原样”编写的,不提供任何担保。它的目的是提供数据科学概念的概述,不应被解读为专业建议。本文中的发现和解释仅代表作者的观点,不代表或与本文中提到的任何第三方相关联。作者与本文中提到的任何第三方没有关系。
在 Python 中使用 Tqdm 与 Asyncio
原文:
towardsdatascience.com/using-tqdm-with-asyncio-in-python-5c0f6e747d55
PYTHON CONCURRENCY
一种有效监控并发任务进度的方法
·发表于 Towards Data Science ·6 分钟阅读·2023 年 5 月 2 日
--
图片由 Jungwoo Hong 提供,来自 Unsplash
介绍
让我烦恼的事
对数据科学家来说,使用 Python 进行并发编程以提高效率并不罕见。观察后台各种子进程或并发线程,以保持计算或 IO 绑定任务的有序总是令人满意的。
但让我困扰的一件事是,当我在后台并发处理数百或数千个文件或执行数百个进程时,我总是担心是否有几个任务会悄悄挂起,从而使整个代码永远无法完成。我也很难知道代码目前的执行状态。
最糟糕的部分是,当我看着空白屏幕时,很难判断我的代码还需要多长时间才能执行完成或预计时间。这对我组织工作进度的能力非常有害。
因此,我想要一种方法来让我知道代码执行到了哪里。
过去是怎么做的
更传统的方法是在线程之间共享一个内存区域,在这个内存区域中放置一个计数器,让这个计数器在任务完成时加 1,然后使用一个线程不断打印这个计数器的值。
这不是一个好的解决方案:一方面,我需要在现有业务逻辑中添加计数代码,这违反了“低耦合,高内聚”的原则。另一方面,由于线程安全问题,我不得不非常小心锁机制,这会导致不必要的性能问题。
tqdm 就是解决方案
tqdm 使用进度条来指示任务的进度。图片由作者提供
有一天,我发现了 tqdm 库,它使用进度条来可视化我的代码进度。我可以用进度条来可视化我的 asyncio 任务的完成情况和预计完成时间吗?
我继续研究,最终成功了。然后我与大家分享这种方法,以便每个程序员都有机会监控他们的并发任务进度。让我们开始吧。
Python 中 asyncio 的背景
在我们开始之前,我希望你对 Python asyncio 有一些背景了解。我的文章描述了一些 asyncio 常见 API 的用法,这将帮助我们更好地理解 tqdm 的设计:
asyncio.gather、asyncio.as_completed 和 asyncio.wait 的最佳实践
tqdm 概览
正如官方网站所描述,tqdm 是一个用于显示循环进度条的工具。它使用简单,高度可定制,并且资源占用极低。
一种典型的用法是将一个可迭代对象传递给 tqdm 构造函数,然后你将得到一个如下所示的进度条:
或者你可以手动遍历并更新进度条的进度,当文件被读取时:
使用 tqdm 指示读取大数据集的进度。图片由作者提供
将 tqdm 与 asyncio 集成
总体来说,tqdm 使用起来非常简单。然而,GitHub 上关于将 tqdm 与 asyncio 集成的信息还不够多。因此,我深入挖掘了源代码,以查看 tqdm 是否支持 asyncio。
幸运的是,tqdm 的最新版本提供了 tqdm.asyncio
包,该包提供了 tqdm_asyncio
类。
tqdm_asyncio
类有两个相关的方法。一个是 tqdm_asyncio.as_completed
。从源代码中可以看出,它是 asyncio.as_completed
的一个封装:
@classmethod
def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.as_completed`.
"""
if total is None:
total = len(fs)
kwargs = {}
if version_info[:2] < (3, 10):
kwargs['loop'] = loop
yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
total=total, **tqdm_kwargs)
另一个是 tqdm_asyncio.gather
,从源代码中可以看出,它基于对 tqdm_asyncio.as_completed
的实现,该实现模拟了 asyncio.gather
的功能:
@classmethod
async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.gather`.
"""
async def wrap_awaitable(i, f):
return i, await f
ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
total=total, **tqdm_kwargs)]
return [i for _, i in sorted(res)]
接下来,我将描述这两个 API 的用法。在我们开始之前,我们还需要做一些准备工作。在这里,我写了一个简单的方法,用于模拟一个具有随机睡眠时间的并发任务:
紧接着,我们将创建 2000 个并发任务,然后使用 tqdm_asyncio.gather
替代熟悉的 asyncio.gather
方法,看看进度条是否正常工作:
tqdm_asyncio.gather
的效果。图片由作者提供
噠噠!我终于知道我的任务完成了。相当酷。
或者让我们用 tqdm_asyncio.as_completed
替换 tqdm_asyncio.gather
并再试一次:
tqdm_asyncio.as_completed 也运行良好。图像来源:作者
很好,它仍然运行良好。
高级技巧和窍门
一些常见的配置项
tqdm 有一套丰富的 配置项,这里列出一些常见的配置项。
desc
。你可以配置一个 desc 参数,在进度条前显示标题,这在区分多个任务组时很有用。
desc 配置项的作用。图像来源:作者
ncols
。如果默认的进度条过短,可以通过这个参数将其延长。
使用 ncols 更改进度条的宽度。图像来源:作者
colour
。PyCharm 的 CLI 默认将进度条显示为红色,这仍然有些刺眼,因此你可以使用这个参数将进度条更改为其他颜色。但截至写作本文时,我仍未找到将文本更改为白色的方法。
使用 color 来更改进度条的颜色。图像来源:作者
bar_format
。这个选项允许你灵活控制进度条显示的内容和格式。例如,如果你想在顶部显示 ETA。
使用 bar_format 自定义进度条的内容。图像来源:作者
异常处理
从源代码中可以看出,tqdm 通过 tqdm_asyncio.as_completed
方法实现了 gather
方法。因此,我们不能通过使用 return_exceptions
参数来跳过异常捕获。
这很遗憾。但我们仍然可以在 tqdm_asyncio.as_completed
中通过 try…exception
处理异常:
异常处理。图像来源:作者
现实世界中的应用案例
许多 asyncio 的代码示例使用 asyncio.sleep
来模拟 IO 绑定的情况,但这不幸地过于简化了现实情况。我们应该使用实际案例来解释如何在 asyncio 中使用 tqdm。
然而,由于篇幅原因,我们无法在本章中使用实际案例。在下一章中,我们将演示 tqdm 进度条如何在实际使用 asyncio 实现 map-reduce 程序处理大型文件的示例中工作。
## 在 Python 中结合多进程和 asyncio 提升性能
使用实际案例来解释代码实现
towardsdatascience.com
结论
在 asyncio 代码中使用 tqdm 指示进度有很多好处:
-
我们可以在调用者的进度条中显示进度,而无需干涉业务代码。
-
所有工作都可以在主进程中完成,无需担心线程安全和性能问题。
-
图形化展示总是比枯燥的文字描述要生动得多。
-
所有这些只需一行代码。
我还尝试过其他进度条库,比如 alive-progress,它在展示效果上要酷得多,但 alive-progress 不支持 asyncio。
如果设置得当,tqdm 也可以产生一些很酷的效果,但由于时间原因我没有深入研究,所以欢迎进一步讨论并留下评论。你可能会帮助到更多感兴趣的读者。
通过 加入 Medium,你将可以无限制地访问我所有的帖子以及成千上万其他作者的文章。这只需花费你一杯咖啡的钱,但对我来说是极大的鼓励。
本文最初发布在: www.dataleadsfuture.com/using-tqdm-with-asyncio-in-python/
使用 pykrige 和 matplotlib 进行地质变化的空间可视化
探索井测量数据中的空间地质变化
· 发布于 Towards Data Science ·阅读时长 7 分钟·2023 年 6 月 13 日
--
挪威大陆架上声波压缩缓慢度测量的空间变化。图像由作者提供。
在处理地质和岩石物理数据时,我们通常希望了解这些数据在我们研究的区域或场地上的变化。我们可以通过网格化实际测量值,并推测尚未通过钻孔探索的其他区域的数据来实现这一点。
一种进行这种外推的方法是克里金法,这是一种以南非矿业工程师 Danie G. Krige 命名的地统计学程序。克里金法的核心思想在于其估计技术:利用观察数据之间的空间相关性来预测未测量位置的值。
通过衡量变量在距离上的变化,这种方法建立了一种统计关系,可以用于预测区域内的值,将分散的数据点转化为连贯的空间地图。
在本教程中,我们将探讨一个名为pykrige的 Python 库。该库设计用于二维和三维克里金计算,并且使用井测数据非常简便。
导入库和数据
首先,我们需要导入我们将要用到的库。对于这篇文章,我们将需要以下库:
-
pandas — 用于读取我们以
csv
格式存储的数据 -
matplotlib 用于创建我们的可视化
-
pykrige 用于执行克里金计算
-
numpy 用于一些数值计算
import pandas as pd
import matplotlib.pyplot as plt
from pykrige import OrdinaryKriging
import numpy as np
一旦我们导入了库,现在我们可以导入我们的数据。
在本教程中,我们将使用从 Xeek 和 Force 2020 机器学习竞赛中获得的用于预测井日志测量的岩性数据集。该数据集的详细信息可以在本文底部找到。
这个竞赛数据集的子集包含 65 个井位置及 Balder 组的平均声学压缩慢度测量值。
要读取数据,我们可以使用 pandas 的 read_csv()
函数,并传入数据文件的位置。在此示例中,我们使用相对于 Jupyter Notebook 的路径,但如果文件位于其他位置,我们也可以使用绝对路径。
df = pd.read_csv('Data/Xeek Force 2020/Xeek_2020_Balder_DTC_AVG.csv')
df
当我们查看数据框时,我们将看到我们有 65 口井,包含 Balder 组顶部的位置(X_LOC 和 Y_LOC 为网格坐标,LAT 和 LON 为纬度和经度)。我们还记录了遇到该地层的真实垂直深度(TVDSS)以及声学压缩慢度的均值(DTC)。
包含我们选择的井位置数据和挪威北海 Balder 组的 DTC — 声学压缩慢度值的数据框。图片来源于作者。
可视化井的空间位置
现在我们的数据已成功加载到数据框中,我们可以可视化数据,以了解井的位置。为此,我们将使用 matplotlib 的散点图,并传入经度和纬度列。
plt.scatter(df['Longitude'], df['Latitude'], c=df['DTC'])
当我们运行上述代码时,我们会得到以下图表。
基本的 matplotlib 图形展示了我们在挪威北海区域的井位置和 DTC 值。图片来源于作者。
我们可以看到上面的图形非常基本,没有颜色条或坐标轴标签。
让我们稍微修改图表,添加这些特性。
cm = plt.cm.get_cmap('viridis')
plt.figure(figsize=(10,10))
scatter = plt.scatter(df['LON'], df['LAT'], c=df['DTC_MEAN'], cmap=cm, s=50)
plt.colorbar(scatter)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.show()
当我们运行上述代码时,我们会得到以下图形,这为我们提供了更多有关数据的信息。我们可以使用颜色条来估算点值。
在添加颜色条和标签后的挪威北海选择井的 matplotlib 散点图。图片来源于作者。
应用克里金法
为了更好地理解数据点及 DTC 测量在 Balder 组区域的变化,我们可以使用克里金法和数据点来填补测量值之间的空白。
为此,我们需要从 pykrige 库中创建一个 OrdinaryKriging
对象。
在此对象中,我们传入了 x 和 y 的位置数据,以及我们要映射到 z 参数的数据。
我们还需要选择要使用的变差函数模型。在此案例中,我们将使用指数模型。有关模型类型的更多细节,请参见文档。
由于我们使用纬度和经度作为 x 和 y 坐标,我们可以将 coordinates_type 参数更改为 geographic
OK = OrdinaryKriging(x=df['LON'],
y=df['LAT'],
z=df['DTC_MEAN'],
variogram_model='exponential',
verbose=True, enable_plotting=True,
coordinates_type='geographic')
当我们运行上述代码时,我们会返回以下模型总结和半变差函数。
来自 pykrige 的模型总结。图片由作者提供。
下面是返回的参数的简要总结:
-
颗粒:颗粒是变差函数的 y 截距,表示零距离处的方差,通常是由于测量误差或非常小尺度的变化造成的。
-
完全天花板:天花板是变差函数达到的最大方差并开始趋于平稳的点,这发生在点之间距离非常远时。
-
范围:范围是变差函数达到天花板的距离,意味着在此距离之外,进一步分离点不会增加方差。
-
部分天花板:部分天花板是天花板和颗粒之间的差异,表示数据中空间结构的方差量。
这可以让我们理解模型根据生成的线条和点的形状对数据的适用性。
显示克里金结果
要开始显示我们的数据,我们需要创建一个数据网格。
为此,我们首先创建纬度和经度的数组,介于我们定义的坐标之间。在这种情况下,我们希望图表从北纬 57.5 度扩展到北纬 62 度,从东经 1.5 度扩展到东经 4.5 度。
使用 np.arange
将允许我们以规则间隔创建这些数组。
grid_lat = np.arange(57.5, 62, 0.01, dtype='float64')
grid_long = np.arange(1.5, 4.5, 0.01,dtype='float64')
现在我们有了 X 和 Y 坐标,我们可以创建我们的值网格。为此,我们调用 OK.execute
,并传入我们的纬度和经度数组。
zstar, ss = OK.execute('grid', grid_long, grid_lat)
这将返回两个数组。我们的数据网格(zstar)和与之相关的不确定性(ss)
接下来,我们可以使用我们的数据数组,并使用 matplotlib 的 imshow
绘制它。
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10,10))
image = ax.imshow(zstar, extent=(1.5, 4.5, 57.5, 62), origin='lower')
ax.set_xlabel('Longitude', fontsize=14, fontweight='bold')
ax.set_ylabel('Latitude', fontsize=14, fontweight='bold')
scatter = ax.scatter(x=df['LON'], y=df['LAT'], color='black')
colorbar = fig.colorbar(image)
colorbar.set_label('DTC (us/ft)', fontsize=14, fontweight='bold')
plt.show()
当我们运行此代码时,我们得到以下地图,显示了在我们 65 口井中 Balder 地层的声学压缩慢度的变化。
使用 pykrige 生成的声学压缩慢度 (DTC) 数据网格。图片由作者提供。
我们可以看到,在北纬 59 至 60 度之间,我们有更快的岩石,而在东北和西南地区,我们有更慢的岩石。
为了解释这一点,我们需要了解每个井的位置的形成深度。这将帮助我们确定差异是否与埋藏和压实或其他地质过程有关。
我们将在未来的文章中看到如何做到这一点。
可视化克里金不确定性
在查看此类数据时,一个关键点是理解与克里金相关的不确定性。
我们可以通过重复使用相同的绘图代码来实现这一点,而不是传入 zstar
,我们可以将其替换为之前创建的 ss
变量。
fig, ax = plt.subplots(figsize=(10,10))
image = ax.imshow(ss, extent=(1.5, 4.5, 57.5, 62), origin='lower')
ax.set_xlabel('Longitude', fontsize=14, fontweight='bold')
ax.set_ylabel('Latitude', fontsize=14, fontweight='bold')
scatter = ax.scatter(x=df['LON'], y=df['LAT'], color='black')
colorbar = fig.colorbar(image)
colorbar.set_label('DTC (us/ft)', fontsize=14, fontweight='bold')
plt.show()
通过以下图表,我们能够看到不确定性高或低的区域。
由 pykrige 生成的声波压缩慢度(DTC)不确定性数据网格。图像由作者提供。
在我们井覆盖较少的区域,我们的不确定性会更高,而在我们有多个井的区域,我们的不确定性会更低。
总结
在本教程中,我们已经看到如何对井日志测量值(DTC)进行平均处理,并将其映射到整个区域。这使我们能够了解地理区域内数据的趋势。
然而,在查看这些数据时,我们必须记住,我们看到的是一个 2D 表面,而不是我们在地下遇到的更复杂的 3D 结构。因此,测量的变化可能是由深度变化造成的。
使用的数据集
本文使用的数据集是由 Xeek 和 FORCE 2020 (Bormann 等,2020) 组织的机器学习比赛使用的训练数据集的一个子集。它在挪威政府发布的 NOLD 2.0 许可证下发布,详细信息可以在这里找到:挪威开放政府数据许可证(NLOD)2.0。完整数据集可以在这里访问。
数据集的完整参考资料是:
Bormann, Peter, Aursand, Peder, Dilib, Fahad, Manral, Surrender, & Dischington, Peter. (2020). FORCE 2020 Well well log and lithofacies dataset for machine learning competition [数据集]. Zenodo. doi.org/10.5281/zenodo.4351156
感谢阅读。在你离开之前,你应该一定要订阅我的内容,并将我的文章发送到你的邮箱中。 你可以在这里完成订阅!
其次,你可以通过订阅会员来获得完整的 Medium 体验,并支持成千上万的其他作家和我。它每月只需$5,并且你可以完全访问所有精彩的 Medium 文章,还可以有机会通过你的写作赚钱。
如果你通过 我的链接注册,你将通过你的一部分费用直接支持我,这不会增加你的费用。如果你这样做,非常感谢你的支持。
利用 PyArrow 改进 pandas 和 Dask 工作流
现在最大限度地利用 PyArrow 在 pandas 和 Dask 中的支持
·
关注 发表在 Towards Data Science · 13 min 阅读 · 2023 年 6 月 6 日
--
所有图片均由作者创建
介绍
本文探讨了我们现在可以在何处使用 PyArrow 来改进我们的 pandas 和 Dask 工作流。对 PyArrow 数据类型的总体支持在 pandas 2.0 中添加到 pandas 和 Dask 中。这解决了两个库用户长期以来的一些痛点。pandas 用户常常向我抱怨 pandas 不支持任意数据类型中的缺失值,或非标准数据类型支持不够好。Dask 用户特别烦恼的问题是处理大数据集时内存不足。与 NumPy 对象列相比,PyArrow 支持的字符串列可以减少多达 70% 的内存消耗,因此有潜力缓解这个问题,并提供巨大的性能提升。
pandas 和 Dask 对 PyArrow 数据类型的支持仍然相对较新。建议在至少 pandas 2.1 发布之前,谨慎选择 PyArrow 的 dtype_backend
。并非所有 API 部分都已优化。然而,在某些工作流中,你应该能获得显著的改进。本文将介绍几个我建议立即切换到 PyArrow 的示例,因为它已经提供了巨大的好处。
Dask 本身可以通过 PyArrow 数据类型在多方面受益。我们将探讨 PyArrow 支持的字符串如何轻松缓解 Dask 集群内存不足的问题,以及我们如何通过利用 PyArrow 改善性能。
我是 pandas 核心团队的一员,并且在实现和改进 pandas 中的 PyArrow 支持方面参与了大量工作。我最近加入了 Coiled,在那里我负责 Dask。我的任务之一是改进 Dask 中的 PyArrow 集成。
PyArrow 支持的总体概述
PyArrow 数据类型最初是在 pandas 1.5 中引入的。该实现是实验性的,我不建议在 pandas 1.5.x 上使用它。对这些数据类型的支持仍然相对较新。pandas 2.0 提供了巨大的改进,包括使选择 PyArrow 支持的 DataFrames 变得容易。我们仍在努力在各处提供适当的支持,因此在至少 pandas 2.1 发布之前应谨慎使用。这两个项目都在不断努力改进对 Dask 和 pandas 的支持。
我们鼓励用户尝试使用它们!这将帮助我们更好地了解仍然缺乏支持或不够快的地方。反馈有助于我们改进支持,并将显著减少创建流畅用户体验所需的时间。
数据集
我们将使用包含所有 Uber 和 Lyft 乘车记录的纽约市出租车数据集。它具有一些有趣的属性,如价格、小费、司机薪酬等。数据集可以在这里找到(参见 服务条款),并存储在 parquet 文件中。在分析 Dask 查询时,我们将使用一个公开的 S3 存储桶以简化我们的查询:s3://coiled-datasets/uber-lyft-tlc/
。我们将使用 2022 年 12 月的数据集进行 pandas 查询,因为这是在我的机器(24GB 内存)中舒适地适配的最大数据集。我们必须避免过度使用 RAM,因为这可能会在分析性能时引入副作用。
我们还将研究 read_csv
的性能。我们将使用可以在这里找到的 芝加哥犯罪 数据集。
Dask 集群
设置 Dask 集群有各种不同的选项,参见 Dask 文档 获取非详尽的部署选项列表。我将使用 Coiled 在 AWS 上通过 30 台机器创建一个集群。
import coiled
cluster = coiled.Cluster(
n_workers=30,
name="dask-performance-comparisons",
region="us-east-2", # this is the region of our dataset
worker_vm_type="m6i.large",
)
Coiled 已连接到我的 AWS 账户。它在我的账户中创建集群并为我管理所有资源。30 台机器足以舒适地操作我们的数据集。我们将研究如何通过一些小的修改将所需的工作节点数减少到 15 台。
pandas StringDtype
由 PyArrow 支持
我们从 pandas 1.0 中最初引入的一个功能开始。将 pandas 或 Dask 中的 dtype 设置为 string
会返回一个具有 StringDtype
的对象。这个功能相对成熟,应该能提供平滑的用户体验。
历史上,pandas 通过 dtype 为 object
的 NumPy 数组来表示字符串数据。NumPy 对象数据存储为指向内存中实际数据的指针数组。这使得遍历包含字符串的数组非常缓慢。pandas 1.0 最初引入了所谓的 StringDtype
,允许对字符串进行更简单和一致的操作。这个 dtype 仍然由 Python 字符串支持,因此性能也不佳。它提供了字符串数据的清晰抽象。
pandas 1.3 最终引入了一种创建高效字符串 dtype 的增强功能。该数据类型由 PyArrow 数组支持。PyArrow 提供了一种数据结构,能够进行高效且节省内存的字符串操作。从那时起,用户可以使用在内存中连续的字符串 dtype,因此非常快速。该 dtype 可以通过 string[pyarrow]
请求。或者,我们可以通过指定 string
作为 dtype 并设置来请求它:
pd.options.mode.string_storage = "pyarrow"
由于 Dask 构建在 pandas 之上,这里也可以使用这种字符串数据类型。除此之外,Dask 还提供了一个方便的选项,可以自动将所有字符串数据转换为 string[pyarrow]
。
dask.config.set({"dataframe.convert-string": True})
这是避免字符串列使用 NumPy 对象数据类型的方便方法。此外,它还具有创建原生 PyArrow 数组的优势,适用于处理 Arrow 对象的 I/O 方法。在提供显著性能提升的同时,PyArrow 字符串消耗的内存显著减少。与 NumPy 对象相比,使用 PyArrow 字符串的平均 Dask DataFrame 的内存消耗约为原始内存的 33-50%。这解决了 Dask 用户在处理大型数据集时内存不足的最大痛点。该选项启用了 Dask 测试套件中的全局测试。这确保了 PyArrow 支持的字符串足够成熟,以提供流畅的用户体验。
让我们看几个典型的字符串操作。我们将从一些 pandas 示例开始,然后切换到在 Dask 集群上执行的操作。
我们将使用 df.convert_dtypes
将我们的对象列转换为 PyArrow 字符串数组。还有更高效的方法来获取 pandas 中的 PyArrow 数据类型,我们会在后面探讨。我们将使用 2022 年 12 月的 Uber-Lyft 数据集,这个文件在我的机器上可以舒适地放入内存中。
import pandas as pd
pd.options.mode.string_storage = "pyarrow"
df = pd.read_parquet(
"fhvhv_tripdata_2022-10.parquet",
columns=[
"tips",
"hvfhs_license_num",
"driver_pay",
"base_passenger_fare",
"dispatching_base_num",
],
)
df = df.convert_dtypes(
convert_boolean=False,
convert_floating=False,
convert_integer=False,
)
在这个示例中,我们的 DataFrame 对所有非字符串列都有 NumPy 数据类型。让我们开始筛选所有由 Uber 操作的乘车记录。
df[df["hvfhs_license_num"] == "HV0003"]
这个操作创建了一个 True/False 值的掩码,指定 Uber 是否进行了乘车。这不使用任何特殊的字符串方法,但等式比较会交给 PyArrow。接下来,我们将使用 pandas 实现的 String 访问器,它允许你对每个元素执行各种字符串操作。我们想要找到所有从以 "B028"
开头的基地发出的乘车记录。
df[df["dispatching_base_num"].str.startswith("B028")]
startswith
会遍历我们的数组,检查每个字符串是否以指定的子字符串开头。PyArrow 的优势显而易见。数据在内存中是连续的,这意味着我们可以高效地进行遍历。此外,这些数组还具有一个第二个数组,其中包含指向每个字符串首个内存地址的指针,这使得计算起始序列更快。
最后,我们查看一个 GroupBy
操作,它对 PyArrow 字符串列进行分组。分组的计算也可以交给 PyArrow,这比在 NumPy 对象数组上因子化要高效得多。
df.groupby(
["hvfhs_license_num", "dispatching_base_num"]
).mean(numeric_only=True)
让我们看看这些操作与字符串列由 NumPy 对象数据类型表示的 DataFrames 的对比。
结果或多或少符合我们的预期。基于字符串的比较在 PyArrow 字符串上执行时显著更快。大多数字符串访问器应该能提供巨大的性能提升。另一个有趣的观察是内存使用,与 NumPy 对象 dtype 相比,减少了大约 50%。我们将进一步用 Dask 详细研究这一点。
Dask 镜像了 pandas API,并将大多数操作委派给 pandas。因此,我们可以使用相同的 API 来访问 PyArrow 字符串。上述选项是全局请求这些选项的便利方式,我们将在这里使用:
dask.config.set({"dataframe.convert-string": True})
在开发过程中,这种选项的最大好处之一是能够轻松地在 Dask 中全局测试 PyArrow 字符串,以确保一切顺利。我们将利用 Uber-Lyft 数据集进行探索。该数据集在我们的集群上占用大约 240GB 的内存。我们的初始集群有 30 台机器,足以舒适地进行计算。
import dask
import dask.dataframe as dd
from distributed import wait
dask.config.set({"dataframe.convert-string": True})
df = dd.read_parquet(
"s3://coiled-datasets/uber-lyft-tlc/",
storage_options={"anon": True},
)
df = df.persist()
wait(df) # Wait till the computation is finished
我们将数据保存在内存中,以避免 I/O 性能对我们性能测量的影响。我们的数据现在已存储在内存中,这使得访问速度很快。我们将执行类似于 pandas 计算的操作。主要目标之一是展示 pandas 的好处如何转化为在 Dask 分布式环境中的计算。
第一个观察结果是,使用 PyArrow 支持的字符串列的数据框仅消耗 130GB 内存,仅为使用 NumPy 对象列时的一半。我们的数据框中只有几个字符串列,这意味着在切换到 PyArrow 字符串时,字符串列的内存节省实际上比约 50% 更高。因此,我们将在对 PyArrow 字符串列执行操作时将集群大小减少到 15 个工作节点。
cluster.scale(15)
我们通过对数据框进行后续筛选来衡量掩码操作和一个字符串访问器的性能。
df = df[df["hvfhs_license_num"] == "HV0003"]
df = df[df["dispatching_base_num"].str.startswith("B028")]
df = df.persist()
wait(df)
我们可以看到,可以使用与之前示例中相同的方法。这使得从 pandas 迁移到 Dask 相对容易。
此外,我们将再次对数据执行 GroupBy
操作。这在分布式环境中要困难得多,因此结果更具趣味性。之前的操作在大型集群上相对容易并行,而 GroupBy
则更具挑战性。
df = df.groupby(
["hvfhs_license_num", "dispatching_base_num"]
).mean(numeric_only=True)
df = df.persist()
wait(df)
我们得到了 2 和 3 倍的显著改进。这一点尤其引人注目,因为我们将集群的大小从 30 台机器减少到 15 台,成本降低了 50%。随后,我们还将计算资源减少了 2 倍,这使得我们的性能提升更加令人印象深刻。因此,性能分别提高了 4 倍和 6 倍。我们可以在较小的集群上执行相同的计算,这样既节省了成本,也在总体上更高效,并且仍然能够获得性能提升。
总结一下,我们看到 PyArrow 字符串列在与 DataFrame 中的 NumPy 对象列比较时有了巨大的改进。切换到 PyArrow 字符串是一个相对较小的改变,但可能极大地提升依赖于字符串数据的平均工作流程的性能和效率。这些改进在 pandas 和 Dask 中都同样明显!
I/O 方法中的 Engine 关键字
我们现在将看看 pandas 和 Dask 中的 I/O 函数。一些函数有自定义实现,如 read_csv
,而其他函数则委派给另一个库,如 read_excel
委派给 openpyxl
。其中一些函数新增了 engine
关键字,使我们能够委派给 PyArrow
。PyArrow 解析器默认是多线程的,因此可以提供显著的性能提升。
pd.read_csv("Crimes_-_2001_to_Present.csv", engine="pyarrow")
这个配置将返回与其他引擎相同的结果。唯一的区别是使用了 PyArrow 来读取数据。read_json
也提供了相同的选项。添加 PyArrow-engines 是为了提供更快的数据读取方式。改进的速度只是一个优势。PyArrow 解析器返回的数据是一个 PyArrow Table。PyArrow Table 提供了内置功能来转换为 pandas DataFrame
。根据数据的不同,这可能需要在转换为 NumPy(字符串、带缺失值的整数等)时进行拷贝,从而带来不必要的减速。这就是 PyArrow dtype_backend
的作用。它在 pandas 中实现为一个 ArrowExtensionArray
类,背后是一个 PyArrow ChunkedArray。直接的结果是,从 PyArrow Table 转换为 pandas 非常便宜,因为它不需要任何拷贝。
pd.read_csv(
"Crimes_-_2001_to_Present.csv",
engine="pyarrow",
dtype_backend="pyarrow",
)
这将返回一个由 PyArrow 数组支持的 DataFrame
。pandas 并未在所有方面都得到优化,因此这可能会在后续操作中造成减速。如果工作负载特别重 I/O,这可能是值得的。让我们看一下直接比较:
我们可以看到,与 C-engine 相比,PyArrow-engine 和 PyArrow dtypes 提供了 15 倍的加速。
相同的优势适用于 Dask。Dask 封装了 pandas 的 csv 读取器,因此,它免费获得了相同的功能。
对 Dask 的比较要复杂一些。首先,我的示例从本地机器读取数据,而我们的 Dask 示例将从 S3 存储桶中读取数据。网络速度将是一个相关因素。此外,分布式计算也有一些开销,我们需要考虑到这一点。
我们这里只关注速度,所以我们将从一个公共 S3 存储桶中读取一些时间序列数据。
import dask.dataframe as dd
from distributed import wait
df = dd.read_csv(
"s3://coiled-datasets/timeseries/20-years/csv/",
storage_options={"anon": True},
engine="pyarrow",
parse_dates=["timestamp"],
)
df = df.persist()
wait(df)
我们将为 engine="c"
、engine="pyarrow"
以及额外的 engine="pyarrow"
配置 dtype_backend="pyarrow"
执行这个代码片段。让我们来看看一些性能比较。两个示例都在集群上的 30 台机器上执行。
PyArrow 引擎运行速度约为 C 引擎的两倍。两种实现使用了相同数量的机器。使用 PyArrow dtype_backend
时内存使用量减少了 50%。如果仅将对象列转换为 PyArrow 字符串,也可以实现相同的减少,这能在后续操作中提供更好的体验。
我们已经看到 Arrow 引擎提供了显著的速度提升,超过了自定义 C 实现。它们尚未支持自定义实现的所有功能,但如果你的使用场景与支持的选项兼容,你应该能免费获得显著的速度提升。
使用 PyArrow dtype_backend
的情况稍微复杂一些。并非所有 API 区域都经过优化。如果你在 I/O 函数之外花费大量时间处理数据,那么这可能无法满足你的需求。如果你的工作流程在读取数据时花费大量时间,它会加速你的处理。
PyArrow 原生 I/O 读取器中的 dtype_backend
其他一些 I/O 方法也有一个 engine 关键字。read_parquet
是最常见的例子。不过这里的情况稍有不同。这些 I/O 方法默认已经使用 PyArrow 引擎。因此解析尽可能高效。另一个潜在的性能提升是使用dtype_backend
关键字。通常,PyArrow 会返回一个 PyArrow 表,然后转换为 pandas DataFrame。PyArrow 的数据类型会转换为其 NumPy 等效类型。设置dtype_backend="pyarrow"
可以避免这种转换。这能显著提高性能并节省大量内存。
让我们看看 pandas 性能的比较。我们从 2022 年 12 月读取了 Uber-Lyft 出租车数据。
pd.read_parquet("fhvhv_tripdata_2022-10.parquet")
我们读取了带有和不带有dtype_backend="pyarrow"
的数据。
我们可以很容易地看到,转换所花费的时间最多,读取 Parquet 文件后完成的转换。如果避免转换为 NumPy 数据类型,函数运行速度提高了三倍。
Dask 为read_parquet
提供了一个专门的实现,相比于 pandas 实现,具有一些针对分布式工作负载的优势。共同点是这两个函数都调度到 PyArrow 来读取 parquet 文件。两者都共同点是数据在成功读取文件后会转换为 NumPy 数据类型。我们正在读取整个 Uber-Lyft 数据集,这在我们的集群上消耗了约 240GB 的内存。
import dask.dataframe as dd
from distributed import wait
df = dd.read_parquet(
"s3://coiled-datasets/uber-lyft-tlc/",
storage_options={"anon": True},
)
df = df.persist()
wait(df)
我们在 3 种不同的配置下读取了数据集。首先是使用默认的 NumPy 数据类型,然后打开 PyArrow 字符串选项:
dask.config.set({"dataframe.convert-string": True})
最后是dtype_backend="pyarrow"
。让我们看看这在性能方面意味着什么:
类似于我们的 pandas 示例,我们可以看到转换为 NumPy dtypes 会占用我们大量的运行时间。PyArrow dtypes 提供了显著的性能提升。两个 PyArrow 配置所使用的内存是 NumPy dtypes 的一半。
PyArrow 字符串比一般的 PyArrow dtype_backend
更成熟。根据我们得到的性能图表,当使用 PyArrow 字符串和 NumPy dtypes 对于所有其他 dtypes 时,我们得到的性能提升大致相同。如果某个工作流程在 PyArrow dtypes 上的表现还不够好,我建议仅启用 PyArrow 字符串。
结论
我们已经看到如何在 pandas 和 Dask 中利用 PyArrow。PyArrow 支持的字符串列有潜力以积极的方式影响大多数工作流程,并为 pandas 2.0 提供顺畅的用户体验。Dask 提供了一个方便的选项,可以在可能的情况下全局避免使用 NumPy 对象 dtype,这使得选择 PyArrow 支持的字符串变得更加容易。PyArrow 在其他可用领域也提供了巨大的加速。PyArrow 的 dtype_backend
仍然相当新,但有潜力显著减少 I/O 时间。探索它是否能解决性能瓶颈是非常值得的。许多工作正在进行中,以改进对一般 PyArrow dtypes 的支持,并有潜力在不久的将来加速平均工作流程。
pandas 目前有一个提议,从 pandas 3.0 开始默认将字符串推断为 PyArrow 支持的字符串。此外,这还包括更多领域,其中更多地依赖 PyArrow 是非常有意义的(例如:Decimals、结构化数据等)。您可以在这里阅读提议的详细信息。
感谢阅读。欢迎在评论中分享您对两个库中 PyArrow 支持的想法和反馈。我将会写后续的文章,专注于这个话题以及 pandas 一般内容。如果您喜欢阅读关于 pandas 和 Dask 的更多内容,可以关注我在 Medium 上的更新。
V-Net,U-Net 在图像分割中的“大哥”
原文:
towardsdatascience.com/v-net-u-nets-big-brother-in-image-segmentation-906e393968f7
欢迎阅读这本关于 V-Net 的指南,它是著名的 U-Net 的“亲戚”,专用于 3D 图像分割。你将对它了如指掌!
·发表在Towards Data Science ·阅读时间 8 分钟·2023 年 7 月 28 日
--
欢迎踏上探索深度学习架构的激动人心的旅程!你可能已经对U-Net有所了解,它是计算机视觉领域的一个重大突破,显著重塑了图像分割的格局。
今天,让我们将焦点转向 U-Net 的“大哥”——V-Net。
由研究人员 Fausto Milletari、Nassir Navab 和 Seyed-Ahmad Ahmadi 发表的论文“VNet: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation”介绍了一种突破性的 3D 图像分析方法。
这篇文章将带你了解这篇开创性的论文,阐明其独特的贡献和架构进展。无论你是经验丰富的数据科学家、正在成长的 AI 爱好者,还是对最新科技趋势感兴趣的人,这里都有你需要的内容!
关于 U-Net 的简要提醒
在深入 V-Net 的核心之前,让我们先欣赏一下它的架构灵感——U-Net。如果这是你第一次接触 U-Net,别担心,我有一个关于 U-Net 架构的快速简单教程来帮助你。它如此简洁,你最多五分钟就能掌握这个概念!
下面是对 U-Net 的简要回顾:
U Net 架构,见于U Net 文章
U-Net 以其对称结构而闻名,呈现“U”字形。这种架构由两个不同的路径组成:
-
收缩路径(左): 在这里,我们逐渐减少图像的分辨率,同时增加过滤器的数量。
-
扩展路径(右): 这个路径作为收缩路径的镜像。我们逐渐减少过滤器的数量,同时增加分辨率,直到与原始图像大小对齐。
U-Net 的美在于其创新使用的“残差连接”或“跳跃连接”。这些连接收缩路径和扩展路径中的对应层,使网络能够保留在收缩过程中通常丢失的高分辨率细节。
残差连接,来自 U Net 论文
为什么这很重要?因为它简化了反向传播过程中的梯度流动,特别是在早期层中。实质上,我们绕过了梯度消失的风险——一个常见的问题,梯度接近零,阻碍了学习过程:
图片来源:作者
现在,带着对 U-Net 的理解,让我们进入 V-Net 的世界。V-Net 的核心与 U-Net 相似,采用类似的编码器-解码器理念。但正如你很快会发现的,它具有一套独特的特征,使其与其兄弟 U-Net 有所区别。
V-Net 架构,来自 VNet 论文
V-Net 与 U-Net 有何不同?
让我们深入探讨吧!
区别 1:3D 卷积代替 2D 卷积
第一个区别显而易见。虽然 U-Net 是为 2D 图像分割量身定制的,但医学图像通常需要 3D 视角(例如体积脑扫描、CT 扫描等)。
这就是 V-Net 发挥作用的地方。V-Net 中的“V”代表“体积”,这种维度的变化要求将 2D 卷积替换为 3D 卷积。
区别 2:激活函数,PreLU 代替 ReLU
深度学习领域已经爱上了 ReLU 函数,因为它的简单性和计算效率。与 sigmoid 或 tanh 等其他函数相比,ReLU 是“非饱和”的,这意味着它减少了梯度消失的问题。
(左)ReLU,(中)LeakyReLU 和(右)PReLU,来自 PReLU 论文
但 ReLU 并不完美。它因一种被称为“Dying ReLU 问题”的现象而臭名昭著,其中许多神经元总是输出零,成为“死神经元”。为了解决这个问题,引入了 LeakyReLU,它在零的左侧有一个小但非零的斜率。
更进一步推理,V-Net 利用了参数化 ReLU(PReLU)。与其硬编码 LeakyReLU 的斜率,不如让网络来学习它?
毕竟,这是深度学习的核心哲学,我们希望尽可能少地引入归纳偏差,让模型自己学习一切,前提是我们有足够的数据。
区别 3:基于 Dice 分数的不同损失函数
现在,我们来到 V-Net 可能最具影响力的贡献——损失函数的改变。与 U-Net 的交叉熵损失函数不同,V-Net 使用 Dice 损失函数。
交叉熵函数,来自作者
但这个函数的主要问题是它对不平衡的类别处理不好。而这个问题在医学图像中非常常见,因为大多数时候背景的存在远多于感兴趣区域。
例如考虑这张图片:
背景无处不在,来自作者
结果是,一些模型可能变得“懒惰”,在任何地方都预测背景,因为它们仍然会得到一个较小的损失。
所以,V-Net 使用一种对这个问题更有效的损失函数:Dice 系数。
它更好的原因在于它将预测区域和真实值之间的重叠度量为比例,因此考虑了类别的大小。
尽管背景几乎无处不在,Dice 分数测量预测与真实值之间的重叠,因此即使类别占据主导地位,我们仍然可以得到一个介于 0 和 1 之间的数字。
Dice 系数,来自 VNet 论文
我认为这可能是本文的主要贡献,因为从 2D 卷积到 3D 卷积是处理 3D 图像的一个非常自然的想法。然而,这种损失函数已在图像分割任务中被广泛采用。
实际上,混合方法通常被证明是有效的,结合交叉熵损失和 Dice 损失,以利用两者的优点。
V-Net 的性能
因此,我们已经探讨了 V-Net 的独特方面,但你可能在想,“所有这些理论很棒,但 V-Net 实际上能否有效?”好吧,让我们对 V-Net 进行测试!
作者在 PROMISE12 数据集上评估了 V-Net 的性能。
PROMISE12 数据集是为 MICCAI 2012 前列腺分割挑战赛提供的。
V-Net 在 50 张磁共振(MR)图像上进行训练,这个数量并不多!
VNet 在 PROMISE 2012 挑战数据集上的分割,来自 VNet 论文
PROMISE 2012 挑战数据集上的定量指标,来自 VNet 论文
正如我们所见,即使标签很少,V-Net 也能够生成良好的定性分割,并获得非常好的 Dice 分数。
V-Net 的主要限制
确实,V-Net 在图像分割领域,特别是医学成像中设立了新的基准。然而,每项创新都有成长空间。在这里,我们将讨论 V-Net 可以改进的一些显著领域:
局限性 1:模型大小
从 2D 转到 3D 带来了显著的内存消耗增加。这种增加的连锁反应是多方面的:
-
该模型需要大量的内存空间。
-
它严重限制了批量大小(因为将多个 3D 张量加载到 GPU 内存中变得具有挑战性)。
-
医学成像数据稀缺且标注成本高,使得拥有如此多参数的模型更难以拟合。
局限性 2:不使用无监督学习或自监督学习
- V-Net 完全在监督学习的背景下运作,忽视了无监督学习的潜力。在未标注扫描大大超过标注扫描的领域中,融入无监督学习可能会带来突破性的改变。
局限性 3:没有不确定性估计
- V-Net 不估计不确定性,这意味着它无法评估自身预测的信心。这是贝叶斯深度学习闪耀的领域。(请参阅这篇贝叶斯深度学习简明介绍)。
局限性 4:缺乏鲁棒性
- 卷积神经网络(CNNs)传统上在泛化方面表现不佳。它们对于对比度变化、多模态分布或不同分辨率等变化不够鲁棒。这是 V-Net 还可以改进的另一个领域。
结论
V-Net,作为 U-Net 的较少知名但强大的对手,已经彻底改变了计算机视觉,尤其是图像分割。它从 2D 转向 3D 图像,并引入了现在广泛使用的 Dice 系数,为该领域设立了新的标准。
尽管存在局限性,V-Net 应该是开始进行 3D 图像分割任务的首选模型。为了进一步改进,探索无监督学习和整合注意力机制似乎是有前景的方向。
感谢阅读!在你离开之前:
- 查看我在 Github 上的AI 教程汇编
[## GitHub - FrancoisPorcher/awesome-ai-tutorials: 最佳 AI 教程汇编,助你成为…
最佳 AI 教程汇编,助你成为数据科学大师!- GitHub …
你应该在你的收件箱中获取我的文章。 点击此处订阅。
如果你想访问 Medium 上的优质文章,只需每月$5 的会员订阅。如果你通过 我的链接进行注册,你可以在没有额外费用的情况下用你的一部分费用支持我。
如果你觉得这篇文章有见解且有帮助,请考虑关注我并点赞,以获取更多深入内容!你的支持帮助我继续制作有助于我们集体理解的内容。
参考文献
-
Milletari, F., Navab, N., & Ahmadi, S. A. (2016). V-Net:用于体积医学图像分割的全卷积神经网络。在 2016 年第四届国际 3D 视觉会议(3DV)上(第 565–571 页)。IEEE。
-
Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net:用于生物医学图像分割的卷积网络。在国际医学图像计算与计算机辅助干预会议(第 234–241 页)。Springer, Cham。
改善咖啡脱气的真空腔
咖啡数据科学
再探脱气
·
关注 发布于 Towards Data Science ·5 分钟阅读·2023 年 6 月 27 日
--
家用烘焙已经是我一段时间以来的热情所在,但我在等待咖啡休息时总是很不耐烦。我曾发现等待的最佳时间是 3 到 5 周,这让我感到挑战。所以我开发了一种 使用水或湿度更快脱气的技术。这将最佳使用时间缩短到了 1 到 2 周,但咖啡在 3 到 4 周时常常会变味,而没有这种处理的情况下则是 6 到 7 周。
引入真空腔
我最初购买了一个真空室,以查看是否可以在没有水的情况下更快地脱气。然而,气体的内部压力很高(根据 Samo Smrke 的说法接近 6 巴),将外部压力从 1 巴减少到 0 巴仅会改变气体内部与外部压力之间的差值,从 5 巴变为 6 巴。
所有图片由作者提供
让我们将两者合并:
-
烘焙咖啡
-
通过在豆子中添加和混合水分,增加 4%的湿度
-
静置 12 到 24 小时以便吸收
-
放入真空室
主要挑战在于脱气会导致真空室压力增加,所以我每天需要开启两次以保持良好的真空状态。我希望水分引起的氧化作用能被真空减小。
首先,我观察了咖啡在真空罐与真空室中的重量损失。真空罐配有手动泵,但其真空效果不如真空室。果然,真空室比罐子减少了更多的重量,我怀疑重量损失的主要成分是二氧化碳。通常咖啡中有 6 mg/g 的二氧化碳,因此 100 克咖啡中应有 0.6 克二氧化碳。
设备/技术
意式咖啡机: 优质意式咖啡机, 泵与倒废配置
咖啡磨豆机: Niche Zero
咖啡: 家庭烘焙咖啡, 中度烘焙(第一次裂纹 + 1 分钟)
冲泡准备: Staccato 压粉
预注水: 长时间,大约 25 秒
滤篮: 20 Wafo Spirit
其他设备: Acaia Pyxis 称, DiFluid R2 TDS 仪
性能指标
我使用了两组指标来评估技术之间的差异:最终评分和咖啡提取。
最终评分是 7 个指标(锐度、丰富度、糖浆感、甜度、酸度、苦味和回味)的评分卡平均值。这些评分当然是主观的,但它们经过了对我口味的校准,并帮助我提高了制作咖啡的水平。评分存在一定的变动。我希望每个指标的评分都能保持一致,但有时细微差别很难把握。
总溶解固体 (TDS)是通过折射仪测量的,这个数字结合了咖啡萃取的输出重量和输入重量,用于确定杯中提取的咖啡百分比,称为提取率 (EY)。
强度半径 (IR)定义为 TDS 与 EY 的控制图上从原点到半径的距离,因此 IR = sqrt(TDS² + EY²)。这个指标有助于在输出产量或冲泡比例之间规范化样本表现。
数据
我在 4 次烘焙中抽取了 23 组配对样本,并发现大多数情况下口味有所改善。
编辑:发布后,我在查看数据时发现一些样本对没有对齐。我旨在使样本拍摄时间相近,以便烘焙时间相似。这一变化没有影响口味评分,但由于样本对 DE 与 Kim 的错位,丢失了一个样本对。这一变化确实影响了 EY 和 IR 结果,我已相应修改了文本。
我在 EY 和 TDS 上看到了一些改善,但从图表上看不太明显。
细分味道成分时,发现有积极效果。酸味和苦味没有看到统计上的显著影响。
从更高层次来看,TDS/EY/IR 在统计上没有显著变化。如果收集更多数据,这可能会改变,我对此实验感兴趣,想尝试用热预浸的方法。
除了酸味和苦味之外,所有的味道指标都在统计上有显著改善。
我继续使用这种方法进行脱气,我并不建议他人使用这种技术,除非他们真的想要突破咖啡的界限。我怀疑烘焙师可以有效地利用这种方法来处理他们在吧台上使用的咖啡。像所有建议一样,改善口味的脱气技术也有改进的空间。
如果你愿意,可以关注我在Twitter、YouTube和Instagram上的账号,我会发布不同机器上的浓缩咖啡镜头和咖啡相关的内容。你还可以在LinkedIn找到我。你还可以关注我在Medium上的账号,并订阅。
我的进一步阅读:
使用 SQL 验证字符串是否为 HTML
是否可以仅通过 SQL 检查字符串是否包含有效的 HTML?
·
关注 发表在 Towards Data Science ·10 min read·2023 年 1 月 19 日
--
图片由 Valery Sysoev 提供,来源于 Unsplash
检查一个字符串是否为有效的 HTML 是一个非常复杂的任务,并不是简单就能完成的。事实上,编写一个通用的 HTML 字符串验证器是一项相当复杂的任务,这里不打算深入探讨。如果你想在 PostgreSQL 中验证一个包含有效 XML/HTML 的字符串,你应该使用 XML 数据类型 或 XML 函数 库。
先前的文章
-
使用 SQL 验证平衡括号
问题陈述
我们希望编写一个 SQL 查询,当输入一个字符串时,能够返回 TRUE 或 FALSE,基于提供的字符串是否为有效的 HTML。
也就是说,让我们深入探讨验证字符串是否是有效 HTML 的受限问题。为了本文的目的,我们如何定义有效的 HTML?
-
应包含平衡的开闭命名标签。例如, 应与其后的 配对。
-
唯一允许的未配对标签是
和
。 -
我们假设没有如 <a/b> 这样的格式错误的标签。
-
我们假设文本本身不包含 < 和 > 字符,并且它们被正确地转义为 < 和 >。
-
HTML 标签内部没有属性。例如, 是有效的,但 是无效的。
额外背景
XML/HTML 不能使用正则表达式解析,因为 HTML 不是一种正则语言。
我们希望能够验证以下 HTML 文档。
有效的文档
<html>
<head>
</head>
<html>
<html>
<body>
<table>
<tr><td>Hello</td></tr>
<tr>
<td>
<table><tr><td>World</td></tr></table>
</td>
</tr>
</table>
</body>
</html>
<html>
<body>
<div>Hello</div>
</body>
</html>
无效的文档
[1] 在下面的文档中, 标签丢失。
<html>
<head>
</html>
[2] 在下面的文档中, 标签出现在 标签之前。
<html>
<head>
</html>
</head>
两种解决方案的预处理
首先,我们将输入字符串分词,提取文档字符串中出现的开标签和闭标签。我们使用下面的方法将 HTML 标签提取到一个单独的表中,每行代表文档中的一个 HTML 标签。
SELECT
REGEXP_MATCHES(page, '</?[^>]+>', 'g') AS html_tag
FROM html_string
这个表格随后再次处理,以便为每行分配以下内容:
-
row_num:行号。
-
root_tag:如果标签是开标签,它与 html_tag 相同,否则它是去除 / 字符后的标签,以便闭标签的对应开标签在此列中存在。
-
delta:根据这是开 HTML 标签还是闭 HTML 标签,值为 +1 或 -1。
SELECT
ROW_NUMBER() OVER () AS row_num,
html_tag[1],
REPLACE(html_tag[1], '/', '') AS root_tag,
CASE WHEN html_tag[1] LIKE '%/%' THEN -1 ELSE +1 END AS delta
FROM page_as_rows
然后,我们移除所有未配对的标签(即不应有开闭配对的标签)。对于本文,有且仅有两个这样的标签,即
和
。
unpaired_tags(tag_name) AS (
-- Lets track tags that don't need a close tag and then
-- use this set to eliminate such tags from our list of
-- tags since they tags won't contribute to making the
-- input string invalid HTML.
VALUES('<br>'), ('<br/>')
),
only_paired_tags AS (
-- Use the set "unpaired_tags" to keep only tags that
-- have an open and close pair.
SELECT
*
FROM tags_numbered
WHERE root_tag NOT IN (SELECT tag_name from unpaired_tags)
),
第一个解决方案:不正确的解决方案
第一个解决方案看似正确,但未能正确识别我们上面提到的第二个无效文档为无效 HTML。
这个解决方案的主要思路是查看每个输入标签的前缀,并为每种遇到的标签类型维护一个单独的计数器。我们可以通过对 root_tag 列进行分组来实现这一点。
-
我们计算每个根标签的累计和。如果累计和变为负数(对于任何标签前缀),则意味着在开标签之前存在一个闭标签。
-
我们检查最后前缀中标签的运行总和(基本上是整个文档)。如果这个运行总和对任何标签的值不为 0,则文档无效。我们在上面已经涵盖了负面案例。如果这个总和为正,则表示我们在某处有一个未关闭的打开标签。
WITH page_as_rows AS (
-- This returns a single row for every HTML open or close
-- tag in the original string. We assume an HTML tag to be
-- a string enclosed in angle brackets with an optional /
-- to indicate a close tag.
SELECT
REGEXP_MATCHES(page, '</?[^>]+>', 'g') AS html_tag
FROM html_string
),
tags_numbered AS (
-- Let's number the rows so that we know what order the tags
-- are present in the original input string (document). We
-- also replace the close tag with the open tag (by removing
-- the / character) so that we can later match up an open and
-- a close tag. We keep track of whether a tag is an open or
-- close tag by assigning the value +1 or -1 (in a column
-- named delta) along with each row. This will later help
-- us determine if any prefix of the input string is a
-- valid intermediate state for an HTML document.
SELECT
ROW_NUMBER() OVER () AS row_num,
html_tag[1],
REPLACE(html_tag[1], '/', '') AS root_tag,
CASE WHEN html_tag[1] LIKE '%/%' THEN -1 ELSE +1 END AS delta
FROM page_as_rows
),
unpaired_tags(tag_name) AS (
-- Lets track tags that don't need a close tag and then
-- use this set to eliminate such tags from our list of
-- tags since they tags won't contribute to making the
-- input string invalid HTML.
VALUES('<br>'), ('<br/>')
),
only_paired_tags AS (
-- Use the set "unpaired_tags" to keep only tags that
-- have an open and close pair.
SELECT
*
FROM tags_numbered
WHERE root_tag NOT IN (SELECT tag_name from unpaired_tags)
),
self_joined AS (
-- This the main logic. We generate a prefix of the document
-- ending at every tag. We generate O(n²) such prefixes and
-- use that to check if that prefix is a valid prefix of any
-- HTML document.
SELECT
rhs.row_num, lhs.html_tag, lhs.root_tag, lhs.delta
FROM only_paired_tags lhs INNER JOIN only_paired_tags rhs
ON lhs.row_num <= rhs.row_num
ORDER BY rhs.row_num ASC, lhs.row_num ASC
),
grouped AS (
-- Do check validity, we need to first determine:
-- 1\. If we have seen a close tag before its corresponding
-- open tag. If this happens, then nthe running sum
-- for one of the tag elements will become negative (<0).
-- 2\. If the final row of aggregated tag running sums has
-- all zeros. If there is a value that is greater than 0
-- it indicates that we have an open tag without a
-- corresponding close tag. While this is a valid
-- "intermediate" state to be in, it is NOT a valid
-- "final" state to be in.
SELECT
row_num, root_tag, SUM(delta) AS delta_sum
FROM self_joined
GROUP BY 1, 2
),
min_max_for_prefix AS (
-- Compute the MIN and MAX delta sums across all tags
-- for every prefix of the input string.
SELECT
row_num,
MIN(delta_sum) AS min_delta_sum,
MAX(delta_sum) AS max_delta_sum
FROM grouped
GROUP BY 1
ORDER BY row_num ASC
),
overall_min_max AS (
-- Then check if any of the running sums are negative OR
-- if the running sum for the entire document is a value
-- other than 0.
SELECT
MIN(min_delta_sum) AS overall_min,
MAX(
CASE WHEN row_num = (
SELECT MAX(row_num) FROM min_max_for_prefix
) THEN max_delta_sum ELSE NULL END
) AS last_row_max
FROM min_max_for_prefix
)
SELECT * FROM overall_min_max;
此解决方案未能使第二个无效的 HTML 文档失效,因为我们未能跟踪关闭标签相对于文档中其他标签的出现顺序。虽然我们正确检查了关闭标签是否在其对应的打开标签之后出现,但我们没有检查在此关闭标签和其对应的打开标签之间是否有其他未关闭的打开标签。
例如,在下面的示例中,第 2 行的 标签没有关闭,而封闭的 标签(第 1 行)已被第 3 行的 标签关闭。
1 <html>
2 <head>
3 </html>
4 </head>
运行时复杂度:此解决方案的运行时复杂度为O(n²),其中n是文档中的标签数量。这是因为我们将每个标签与之前的每个其他标签连接。这是整个解决方案中的主要成本。
第二种解决方案:由内而外的解决方案
“由内而外的解决方案”从最内层的匹配 HTML 标签对开始处理字符串。此解决方案依赖于在有效的 HTML 文档中,至少有一个匹配的打开/关闭标签对相邻出现。
如果我们移除这个匹配对,那么我们可以找到并消除下一个匹配对,直到没有更多标签(在有效的 HTML 文档中)。我们知道在一个包含 2N 个标签的文档中,我们最多会执行 N 次匹配和消除过程来达到一个空的标签列表。如果在 N 轮匹配和消除之后,我们仍然有一些标签剩下,这表明这是一个无效的 HTML 文档。
示例-1:对于下面的输入,
<html>
<head>
</head>
<body>
<table>
<tr>
<td>
</td>
</tr>
<tr>
<td>
</td>
</tr>
</table>
</body>
</html>
这是递归执行的样子,下面的动画展示了打开/关闭标签对的匹配和消除顺序。
第二种解决方案处理有效的 HTML 字符串(作者提供的图片)
示例-2:对于下面的输入,
<html>
<body>
<table>
<tr>
</tr>
</body>
</table>
</html>
这是递归执行的样子,下面的动画展示了打开/关闭标签对的匹配和消除顺序。由于这是无效的输入,处理在没有匹配的相邻标签对可以处理时停止。
第二种解决方案处理无效的 HTML 字符串(作者提供的图片)
-- Actual solution
WITH RECURSIVE page_as_rows AS (
-- This returns a single row for every HTML open or close
-- tag in the original string. We assume an HTML tag to be
-- a string enclosed in angle brackets with an optional /
-- to indicate a close tag.
SELECT
REGEXP_MATCHES(page, '</?[^>]+>', 'g') AS html_tag
FROM html_string
),
tags_numbered AS (
-- Let's number the rows so that we know what order the tags
-- are present in the original input string (document). We
-- also replace the close tag with the open tag (by removing
-- the / character) so that we can later match up an open and
-- a close tag. We keep track of whether a tag is an open or
-- close tag by assigning the value +1 or -1 (in a column
-- named delta) along with each row. This will later help
-- us determine if any prefix of the input string is a
-- valid intermediate state for an HTML document.
SELECT
ROW_NUMBER() OVER () AS row_num,
html_tag[1] AS html_tag,
REPLACE(html_tag[1], '/', '') AS root_tag,
CASE WHEN html_tag[1] LIKE '%/%' THEN -1 ELSE +1 END AS delta
FROM page_as_rows
),
unpaired_tags(tag_name) AS (
-- Lets track tags that don't need a close tag and then
-- use this set to eliminate such tags from our list of
-- tags since they tags won't contribute to making the
-- input string invalid HTML.
VALUES('<br>'), ('<br/>')
),
only_paired_tags AS (
-- Use the set "unpaired_tags" to keep only tags that
-- have an open and close pair.
SELECT
*
FROM tags_numbered
WHERE root_tag NOT IN (SELECT tag_name from unpaired_tags)
),
max_num_rounds AS (
-- The maximum number of recursive rounds of paired tag
-- elimination that we want to run.
SELECT
COUNT(1) / 2 + 2 AS max_rounds
FROM only_paired_tags
),
-- This recursive query matches and eliminates paired tags in an
-- "inside out" manner by locating adjacent matching paired tags
-- and removing them.
eliminate_paired_tags AS (
SELECT
row_num,
html_tag,
root_tag,
delta,
(SELECT max_rounds FROM max_num_rounds) AS iters
FROM only_paired_tags
UNION ALL
(
WITH lead_lag AS (
SELECT
row_num,
html_tag,
root_tag,
delta,
iters - 1 AS iters,
LAG(root_tag) OVER(ORDER BY row_num ASC) AS lag_root_tag,
LAG(delta) OVER(ORDER BY row_num ASC) AS lag_delta,
LEAD(root_tag) OVER(ORDER BY row_num ASC) AS lead_root_tag,
LEAD(delta) OVER(ORDER BY row_num ASC) AS lead_delta
FROM eliminate_paired_tags
WHERE iters - 1 > -1
),
truncated AS (
SELECT
row_num,
html_tag,
root_tag,
delta,
iters
FROM lead_lag
-- Eliminate correctly paired tags. We eliminate a pair of rows
-- independently.
--
-- 1\. If a row contains an open tag and is followed by a
-- corresponding close tag, we can eliminate the row with
-- this open tag.
--
-- 2\. If a row contains an close tag and is preceeded by a
-- corresponding open tag, we can eliminate the row with
-- this close tag.
WHERE NOT (
(
root_tag = lag_root_tag AND delta = -1 AND lag_delta = +1
) OR (
root_tag = lead_root_tag AND delta = +1 AND lead_delta = -1
)
)
)
SELECT * FROM truncated
)
),
from_last_round AS (
SELECT * FROM eliminate_paired_tags
WHERE iters = 0
)
SELECT
CASE
WHEN COUNT(1) > 0
THEN 'Invalid HTML'
ELSE 'Valid HTML'
END
AS validation_status
FROM from_last_round;
运行时复杂度:此解决方案的运行时复杂度为O(n²),其中n是文档中的标签数量。我们分析了有效和无效输入的成本:
-
有效输入:在最坏的情况下,我们将消除一对标签,并且需要 O(n) 轮来消除文档中的所有 n/2 对标签。在每轮中,我们需要保留并重新处理所有剩余的行。我们从 n 行数据开始,在最坏的情况下每个递归步骤减少 2 行。因此,行数的减少为 n, n-2, n-4, n-6, …, 0(n/2 次),总和为 O(n²)。
-
无效输入:在无效输入的情况下,我们将无法匹配一对开/关标签,并且在整个 n 次递归轮中不会消除任何行。因此,每轮中的行数看起来是 n, n, n, …, n(n 次),总和为 O(n²)。
SQL Fiddle
该帖中所有解决方案的 SQL Fiddle 链接可以在 这里 找到。
结论
我们看到了一种检查字符串是否为有效 HTML 的方法。该方法利用递归 CTE 迭代地减少每一步的问题集大小。
SQL 中的递归 CTE 是强大的工具,如果创意使用,可以解决各种问题。然而,递归 CTE 并不很节省空间。传统的命令式编程语言允许你进行原地更新,而递归 CTE 需要你在每一步都复制数据。
使用 SQL 验证平衡括号
使用 SQL 检查一个包含开括号和闭括号的字符串的格式是否正确
·
关注 发表在 Towards Data Science ·5 分钟阅读·2023 年 1 月 4 日
--
验证一个字符串是否包含平衡的括号是一个实际问题,来源于各种实际场景中的字符串/表达式解析/验证。本文将探讨如何仅使用声明式 SQL 来验证一个只包含开括号‘(’和闭括号‘)’的字符串。
照片由 Elena Mozhvilo 提供,来源于 Unsplash
之前的文章: SQL 中数组的最长递增子序列
问题陈述
给定一个只包含开括号和闭括号的字符串,你能确定这些括号是否是平衡的吗?平衡的意思是:
-
每个开括号 ‘(’ 后面必须有一个匹配的闭括号 ‘)’。
-
一个闭括号 ‘)’ 总是与一个出现在它之前的开括号 ‘(’ 配对。
以下是有效的平衡括号字符串示例:
-
((()()))
-
()()()()
-
(()()(()()))
以下是无效的平衡括号示例:
-
((()( — 这里,第一个和最后一个开括号没有匹配的闭括号
-
()()()()) — 这里,最后一个闭括号没有匹配任何其他未匹配的开括号
问题陈述的在线覆盖
输入表模式
输入表有 2 列:
-
idx:问题索引。即第一个待检查的字符串将有 idx = 1,第二个待检查的字符串将有 idx = 2,依此类推。
-
括号:一个只包含开括号和闭括号的字符串。这个字符串需要检查是否结构正确。
CREATE TABLE inputs(idx SERIAL PRIMARY KEY, parens TEXT NOT NULL);
INSERT INTO inputs(parens) VALUES
('((()()))'), ('((()('), ('()()()())'), ('()()()()'), ('(()()(()()))');
SELECT * FROM inputs;
平衡括号问题的输入表(作者提供的图片)
解决方案:维护一个运行计数器:O(n)
在命令式编程语言中,这个问题通常通过维护一个栈来解决,每当在输入字符串中遇到一个 ‘(’ 时,就将其压入栈中。因此,栈中只包含 ‘(’ 字符。每当在输入中看到 ‘)’ 时,就将其与栈顶的 ‘(’ 匹配,并将 ‘)’ 字符弹出。
由于我们只有一种类型的开括号(和闭括号),我们可以去除栈,只维护一个计数器,用于计数到目前为止看到的未匹配的 ‘(’ 字符。每当遇到一个 ‘(’ 字符时,我们就增加计数器,每当遇到一个 ‘)’ 字符时,我们就减少计数器。
-
负计数器值:如果计数器的值达到负值 (< 0),则表示有未匹配的闭括号。
-
最终计数器值:在处理完输入字符串中的所有字符后,如果计数器的值 != 0,则表示有问题。正值表示输入中有未匹配的 ‘(’ 字符,而负值已经在上面考虑过。
这个解决方案的 SQL 代码如下:
WITH as_split AS (
-- First split each input string such that every character is on its
-- own row. We ensure that we tag each input row with the original
-- index of the input string from which it came so that we know which
-- problem it is a part of.
SELECT
idx,
UNNEST(
STRING_TO_ARRAY(
parens, NULL
)
) AS ch
FROM inputs
),
with_annotation AS (
-- Annotate characters from each problem (unique index) with its
-- position using ROW_NUMBER(). Also annotate an open paren with
-- +1 and a close paren with a -1 number, so that we can maintain
-- a running sum of these counters later.
SELECT
idx,
ch,
ROW_NUMBER() OVER(PARTITION BY idx) AS row_num,
CASE WHEN ch = '(' THEN +1 ELSE -1 END AS ctr
FROM as_split
),
with_running_sum AS (
-- Find the running sum for characters in each problem. Note that we are
-- solving all the problems at once (think "batch API") instead of feeding
-- feeding each problem once into the solution machine.
SELECT
idx,
ch,
ctr,
row_num,
SUM(ctr) OVER(PARTITION BY idx ORDER BY row_num ASC) AS running_sum,
MAX(row_num) OVER(PARTITION BY idx) AS max_row_num
FROM with_annotation
),
with_result AS (
-- The result is valid only if we never hit a negative running sum
-- (which indicates that we have an extra close paren than a corresponding
-- open paren) and if we end with a running sum of 0\. If we end with a
-- running sum > 0 then we have an additional open paren that is not
-- matched with a corresponding close paren.
SELECT
idx,
CASE WHEN MIN(running_sum) < 0 THEN TRUE ELSE FALSE END AS has_negative,
CASE WHEN SUM(
CASE WHEN row_num = max_row_num THEN running_sum ELSE 0 END
) = 0 THEN TRUE ELSE FALSE END AS ends_with_zero_sum
FROM with_running_sum
GROUP BY 1
)
SELECT
lhs.idx,
rhs.parens,
CASE
WHEN has_negative OR NOT ends_with_zero_sum THEN FALSE
ELSE TRUE END
AS is_valid_parens
FROM with_result lhs INNER JOIN inputs rhs
ON lhs.idx = rhs.idx
ORDER BY lhs.idx ASC;
在上述解决方案中,使用了相当多的中间表来帮助可读性,并将各种处理步骤分开。
这是我们第一次在 SQL 中使用UNNEST关键字。这也是我第一次编写一个批量处理多个输入并一次性解决的解决方案。我利用了idx字段,它表示输入字符串的索引。所有中间表都使用idx字段来分离不同问题的解决方案。
O(n)解决方案的结果(作者提供的图片)
估计成本: 在一个具有 5 个不同输入行的表上,该查询的估计成本为45k。这个成本大部分似乎来自于使用窗口聚合函数。
虽然我标记了运行时间为O(n),但这取决于数据库引擎如何内部执行查询。例如,如果引擎发现使用ROW_NUMBER()分配的row_num列结果具有该列值严格递增,并且数据库能够在 CTE 表中保持这一行顺序,那么它可以避免在遇到ORDER BY子句时在窗口函数执行中进行排序。
SUM(ctr) OVER(PARTITION BY idx ORDER BY row_num ASC) AS running_sum,
上述ORDER BY在OVER()子句中是至关重要的,以确保我们得到的是一个运行总和,而不是整个分区的总体总和。
SQL Fiddle
这篇文章中 SQL Fiddle 解决方案的链接可以在这里找到。
扩展和练习
-
如果我们有多种括号类型,比如方括号[],花括号{}。你会如何演变解决方案来适应这一变化?
-
如果组成开括号和闭括号对的字符串不是固定的,而是在另一个维度表中指定的,你会如何演变解决方案来适应这一变化?
我们将看到(在后续文章中)这些扩展如何解决 SQL 从业者遇到的实际问题。
VALL-E — 语音合成的未来?
微软研究院新文本到语音模型的论文解析
·
关注 发表在 Towards Data Science ·15 min read·2023 年 4 月 14 日
--
DALL-E 2:一个接收文本输入的一侧,另一侧输出声波的唱片机。数字艺术。
大家好,
在本文中,我们将深入探讨微软研究院开发的一种新兴的文本到语音模型,名为 VALL-E。介绍这项工作的论文已于 2023 年 1 月 5 日发布,从那时起,它在网上获得了广泛关注。值得注意的是,撰写本文时,没有发布预训练模型,目前唯一的选择是自行训练该模型。
尽管本文提出的想法新颖有趣,值得深入探讨,不管我是否能立即用它克隆我的声音。
本文将按以下结构组织:
-
第一部分 — 语音合成介绍,基本概念
-
第二部分 — VALL-E: 作为语言模型的语音合成
-
第三部分 — Encodec: VALL-E 背后的工作马
-
第四部分 — VALL-E 的问题定义与训练
-
第五部分 — 一些编码
-
第六部分 — 结论与未来展望
第一部分 — 语音合成介绍,基本概念
语音合成技术并不新鲜,自 1939 年贝尔实验室推出的第一个电子语音合成器“Voder”以来,这项技术就已经存在,并且当时需要手动操作。此后,该领域取得了令人难以置信的发展,直到 2017 年左右,主流技术仍是连接语音合成。这项技术基于连接预录制的语音片段以生成可理解的语音。尽管这种技术可以产生逼真的效果,但其缺点显而易见——它无法生成预录数据库中不存在的新声音,也无法生成不同语调或情感的语音。
快进到深度学习时代。如今,语音合成的主流策略总结在图 1 中。让我们来看看它的不同部分。
图 1. 一个模型神经语音合成流程图。作者提供的图像。
-
首先,我们有一个音素化器,将文本转化为音素。音素是词语发音的文本表示(例如——“tomato”这个词在美式和英式口音中有不同的音素),这种表示帮助下游模型获得更好的结果。
-
之后,我们会有一个声学模型将这些音素转化为梅尔频谱图,这是一种在时间-频率域中表示音频的方式。频谱图是通过对原始音频波形的重叠时间窗口应用短时傅里叶变换(STFT)获得的(这里是一个关于梅尔频谱图的出色解释——
medium.com/analytics-vidhya/understanding-the-mel-spectrogram-fca2afa2ce53
)。当然,在这种情况下,频谱图是由统计模型创建的,因为实时语音合成中不存在输入音频。最近的模型架构示例包括 Tacotron2、DeepVoice 3 和 TransformerTTS。 -
最终阶段是将 Mel 频谱图转换为波形。波形通常以 24/48 kHz 的采样率进行采样,每个样本被数字化为 16 位数值。这些数字表示每时刻的气压量,即我们最终听到的声音。为什么我们不能简单地将频谱图确定性地转换为波形?因为这需要在时间域上进行大量的上采样,这需要创建频谱图中不存在的信息,同时频谱图不包含相位信息(只有频率)。因此,就像将音素转换为 Mel 频谱图一样,在这里我们也需要一个统计模型来将频谱图转换为波形,这些模型称为 Vocoders。Vocoders 的例子包括 WaveNet、WaveRNN 和 MelGAN。
此外,还有最近的模型如 VITS 和 YourTTS,它们使用端到端模型从文本输入生成波形。另一个这样的端到端系统的例子是 Deepmind 的一篇论文《End-to-End Adversarial Text-to-Speech》(Yannic Kilcher 在此处进行了极好的解释 — www.youtube.com/watch?v=WTB2p4bqtXU
)。在这篇论文中,他们采用了类似 GAN 的训练过程来生成真实的语音声波。他们还需要解决对齐问题,即生成样本中的单词发音与真实样本中的单词发音在时间上对齐的程度。这个问题不会“自动解决”,需要在模型架构中进行明确处理。
这些端到端 TTS 模型的主要缺点是其极其复杂性。文本和语音是如此不同的模态,这要求复杂的模型显式地解决对齐、说话者身份和语言等问题,使得这些模型高度复杂。VALL-E 的魅力在于,它将生成语言模型的相对简单性巧妙地应用于语音生成领域。对于像我这样刚接触 TTS 和语音领域,并且有一定 NLP 经验的人来说,它提供了一个很好的切入点。
这简短的概述无法充分体现 TTS(文本转语音)这一广阔领域的深度,研究和理解它可能需要一生的时间(我鼓励你深入探讨)。然而,我们今天要讨论的是 VALL-E,所以请允许我直接切入主题。
第二部分 — VALL-E: 语言模型的文本转语音
与其他文本到语音系统一样,VALL-E 的输入是音素化的文本,输出是相应的声音波形。此外,VALL-E 采用了提示机制,其中 3 秒的音频样本作为额外输入提供给模型。这允许生成与给定音频提示条件相关的输入文本的语音发声——实际上,这意味着能够进行零-shot 语音生成,即从训练数据中未见过的声音生成语音。VALL-E 的高层结构见图 2。
图 2. VALL-E 的高层结构。图片来自原始论文[1]。
让我们理解一下这个管道中的过程。首先,我们对文本进行音素转换,这是一种标准程序,我们已经了解过,不需要任何学习机制。为了让模型处理这些音素,我们有一个音素嵌入层,该层将音素词汇表中的索引向量作为输入,并输出对应于输入索引的嵌入矩阵。
3 秒的音频提示(输出语音的条件)被送入音频编解码器编码器。在 VALL-E 中,他们使用了一个预训练的音频编码器——Encodec(由 Facebook Research 开发——arxiv.org/abs/2210.13438
)。Encodec 以语音波形为输入,通过递归向量量化(RVQ)和编码-解码神经架构输出其压缩的离散表示。我们将在本文的第三部分深入探讨 Encodec,但现在可以假设它通过将音频信号分割成固定时间窗口并为每个窗口分配来自已知音频嵌入词汇表的表示,从而输出音频信号的离散表示(概念上,与词嵌入非常相似)。
一旦模型接收到这两个输入,它就可以像自回归语言模型一样输出下一个离散音频表示。由于音频表示来自 Encodec 学习的固定词汇表,我们可以简单地将其视为从固定词汇表中预测下一个单词(在我们的例子中是声音表示)。在这些声音表示被预测之后,它们通过 Encodec 模型的解码器部分被转换回原始的波形表示。
在图 3 中,我们将 VALL-E 的管道与传统的神经 TTS 管道进行比较。我们看到主要的区别在于音频的中间表示。在 VALL-E 中,他们放弃了 Mel 谱图,使用了由 Encodec 模型创建的表示。然而值得注意的是,Encodec 在底层也使用了谱图表示,因此它在这个架构中仍然有所使用,尽管不那么显著。
图 3. VALL-E 流水线与传统神经 TTS 流水线的对比。图片由作者提供。
在 VALL-E 论文的结果部分,他们展示了在 LibriSpeech 数据集上超越了之前的最先进的零样本 TTS 模型 YourTTS,在几个指标上,包括基于人类的评估(如相似性均值选项评分(SMOS))和基于算法的评估(如词错误率(WER))。在一个有趣的消融研究中,他们表明,音素提示有助于生成内容(通过降低 WER),而音频提示有助于说话人相似性(通过提高说话人相似性指标)。
我们现在将深入探讨 Encodec 模型,该模型负责将音频转换为离散标记并再转换回来,并且是本文中使用语言模型方法进行音频生成的基础。
第三部分 — Encodec:VALL-E 背后的主力
在图 4 中,我们可以看到 Encodec 架构。它是一个编码器-解码器架构,通过重建任务学习音频信号的压缩表示。让我们详细了解其不同部分,以理解其内部运作。
图 4. Encodec 架构。图片取自原始论文 [2]。
在最左侧,我们有原始波形,其采样频率为 24/48 kHz,每个样本由 16 位(65536 选项)表示。原始信号传递给编码器,编码器包括用于下采样的 1D 卷积操作和用于序列建模的两层 LSTM。编码器的输出是 75/150 潜在时间步(与原始的 24/48K 比较!),深度维度为 128。
解码器仅仅是编码器的镜像版本,使用转置卷积来对潜在空间进行上采样并构建音频波形(这是对转置卷积的一个很好的解释 towardsdatascience.com/what-is-transposed-convolutional-layer-40e5e6e31c11
)。
有趣的部分当然是量化器。Encodec 如何对声音的连续域进行量化?使用一种叫做残差向量量化(RVQ)的技术,该技术包括将输入向量投影到给定大小代码本中的最近条目。让我们详细解析这句话。首先,什么是代码本?
对于 VALL-E 的情况,代码本是一个大小为 1024 的向量字典,其中每个条目表示一个大小为 128 的向量。我们在向量量化中的目标是将某个向量映射到代码本中最接近的向量(例如通过欧几里得距离),之后可以通过该向量在代码本中的索引来表示它(假设每个人都能访问代码本)。当然,这样我们会丢失大量的信息。如果代码本中的向量都不能准确地类似于我们的向量怎么办?这就是 RVQ 中的“残差”!
在图 5 中,我展示了如何使用残差向量量化对向量进行量化。在这个例子中,我们有 3 个代码本。输入向量与第一个代码本中的每一个向量进行比较,并分配给最接近的一个(C1,1)。然后,计算 C1,1 与输入之间的残差,我们尝试将残差匹配到下一个代码本,以此类推,直到我们遍历完所有的代码本。最终的 RVQ 表示是每个代码本中匹配的索引(在我们的例子中是 1、3、2)。这种编码方法非常高效。如果我们有 8 个代码本,每个包含 1024 个条目——我们可以用仅仅 1024*8=8192 个数字表示 1024⁸=1.2e+24 种不同的向量!当然,发送方和接收方必须持有相同的代码本,这种量化方法才能有效。如果你想深入了解 RVQ,比如代码本是如何训练的,我推荐阅读另一篇 Encodec 所基于的论文《SoundStream》——arxiv.org/abs/2107.03312
(是的,这确实是一个深奥的领域)。
图 5. 残差向量量化示例。图片由作者提供。
回到图 4 中的 Encodec 流程,我们来注意与其训练过程相关的 3 个额外细节:
-
梅尔频谱图是从输入音频和生成音频中创建的。这些频谱图进行比较,比较信号用于作为损失来指导模型训练。
-
使用了几个鉴别器来比较原始波形和合成波形的短时傅里叶变换(STFT)。这种 GAN 损失与梅尔频谱图比较给出的信号不同,并且对 Encodec 有用。
-
量化器包含用于进一步压缩音频信号的变压器。这不是VALL-E 中预测下一个语音标记的变压器,尽管可能会造成混淆。要进一步理解 Encodec 中的变压器,我建议阅读 Aleksa Gordic 的论文或观看视频——
www.youtube.com/watch?v=mV7bhf6b2Hs
。
让我们总结一下我们目前所知道的。VALL-E 是一个文本到语音模型,其操作模式类似于语言模型,预测给定提示的下一个离散音频标记,该提示由音素化文本和音频输入组成。这些离散标记由另一个名为 Encodec 的模型学习(Encodec 本身基于 SoundStream),它使用编码器-解码器架构和残差向量量化将音频转换为离散代码。
第四部分——VALL-E 的问题表述和训练
VALL-E 包含两个变压器模型,用于处理输入数据(音素化文本和音频)——一个自回归(AR)变压器仅关注过去的数据,而一个非自回归(NAR)变压器关注所有时间点。让我们看看原因。
VALL-E 使用了八个不同的编码本作为 Encodec 模型的一部分,其中每个编码本包含 1024 个条目。第一个量化器(编码本)的编码由 AR 模型根据方程 1 进行处理。让我们首先澄清一些术语:
-
C 代表生成的输出 — 作为离散音频编码
-
C~ 是 3 秒的输入声学提示
-
x 是作为音素序列的输入文本
-
C:,₁ 代表来自第一个量化器/编码本的数据
因此,方程 1 显示,第一个量化器的输出基于输入数据以及第一个量化器的前时间步的输出(就像自回归语言模型一样)。
方程 1。自回归模型 — 应用于第一个量化器。图片来源于原始论文 [1]。
在方程 2 中,我们可以看到量化器 2 到 8 的编码生成过程。与之前的情况不同,这里每个量化器的输出都基于前一个量化器的所有时间步(例如,在计算量化器 #7 的编码时,模型依赖于量化器 1 到 6 生成的数据)。与自回归模型不同,这种方式允许在单个量化器中并行生成所有时间步,因为它仅依赖于前一个量化器的编码,而不是当前量化器的前时间步。作者强调了这一点,因为在需要实时生成语音的文本转语音模型中,快速推理尤为重要。
方程 1。非自回归模型 — 应用于第二到第八个量化器。图片来源于原始论文 [1]。
方程 1 和 2 在图 6 中以视觉方式呈现,图中展示了 AR 和 NAR 模型,并突出了它们之间的区别。我们可以看到,自回归变换器仅用于预测 C:,₁,即第一个量化器的标记。在此过程中,它关注于之前生成的标记。NAR 变换器关注于之前的量化器,而不是之前的时间步(当前量化器的前标记在 NAR 模型中不可用)。
图 6。VALL-E 中的自回归(AR)和非自回归(NAR)模型。图片来源于原始论文 [1]。
VALL-E 已经在 LibriLight 数据集上的 60K 小时音频中进行了训练,数据集包含 7000 个不同的说话者(比之前的最先进技术多出 100 倍以上的数据)。该数据集仅包含音频,因此用于标记的自动语音识别模型被使用。Encodec 模型作为预训练模型使用,VALL-E 并未对其进行微调。
在训练中,从 LibriLight 中随机抽取了 10-20 秒的样本。对于声学提示,从相同的语句中提取了另外 3 秒。他们使用了 16 个 Tesla V-100 GPU 来训练模型,相较于大型最先进语言模型,这一配置相当简朴!
我们了解了流程和数据,现在让我们尝试在 GitHub 中使用非官方的 Pytorch 实现的 VALL-E。
第五部分 — 一些编码
VALL-E 在 GitHub 上没有官方实现,因此在我的实验中,我将依赖于已发布的非官方版本 — github.com/enhuiz/vall-e
。此外,还没有发布模型检查点,因此你必须从头开始训练。
还有一个 Google Colab 笔记本可供参考一个简单的训练示例 — colab.research.google.com/drive/1wEze0kQ0gt9B3bQmmbtbSXCoCTpq5vg-?usp=sharing
。在这个例子中,他们在一个“hello world”的单句上过拟合了模型,并展示了模型能够重现这个单句。我对两件事感兴趣:
-
我想用自己的声音复制他们的“hello world”实验,只是为了确认流水线正常工作
-
我想复制 James Skelton 在 Paperspace 上做的实验 —
blog.paperspace.com/training-vall-e-from-scratch-on-your-own-voice-samples/
,他在自己的少量录音上训练了一个模型,并成功用它复制了他的声音(在他已经录制的东西上)
为什么要进行有限的实验?因为从头开始训练这个模型需要很多资源,而我目前没有,另外我认为迟早会发布预训练模型。
那么我是怎么成功的呢?我设法复制了“hello world”的实验,但不幸的是,我没能复制 Paperspace 的实验 — 我只得到了一个产生混乱声音的模型,它模糊地提醒了我的声音。这可能是因为资源不足(我在 Google Colab 实例上进行训练,它在 12 小时后会超时)。但我仍然想和你一起讨论整个过程。我的 VALL-E 笔记本版本在这里 — colab.research.google.com/drive/1NNOsvfiOfGeV-BBgGkwf0pyGAwAgx3Gi#scrollTo=SbWtNBVg_Tfd
。
一旦在 Colab 笔记本中运行以下行 —
!git clone --recurse-submodules https://github.com/enhuiz/vall-e.git
你会在文件浏览器中看到一个名为vall-e
的目录。路径content/vall-e/data/test
包含了“hello world”实验的数据。请注意,它包含两个文件,因为只有一个文件会出问题。要复制这个实验,只需使用!rm content/vall-e/data/test/*
删除数据目录中的文件,录制自己说“Hello world”,并将其保存为两个带有不同名称的.wav 文件。将.wav 文件放入包含两个文本文件(文本文件的名称应与带有.normalized.txt
后缀的.wav 文件相同)的数据目录中。
接着,你将运行这两个单元格:
!python -m vall_e.emb.qnt data/test
!python -m vall_e.emb.g2p data/test
第一个单元将对你自己的数据运行 Encodec 模型并进行量化,就像我们之前讨论的那样。第二个单元将把文本“hello world”转换为音素。
之后,处理过的数据已经准备好,你可以运行执行训练过程的单元。NAR 和 AR 模型有独立的训练(记住,如前所述,NAR 模型训练依赖于 AR 模型,但 AR 模型仅使用和生成第一个量化器数据,因此与 NAR 模型独立)。
!python -m vall_e.train yaml=config/test/ar.yml
!python -m vall_e.train yaml=config/test/nar.yml
模型训练完成后,你将运行此单元:
!mkdir -p zoo
!python -m vall_e.export zoo/ar.pt yaml=config/test/ar.yml
!python -m vall_e.export zoo/nar.pt yaml=config/test/nar.yml
这将最新的模型检查点(已自动创建)保存到名为zoo
的目录中。
最终,你将使用以下内容进行模型推理:
!python -m vall_e 'hello world' /content/vall-e/data/test/hello_world.wav toy.wav --ar-ckpt zoo/ar.pt --nar-ckpt zoo/nar.pt
这将使用文本提示“Hello world”以及相同的语音提示来运行模型。它将生成的样本保存为toy.wav
,然后你可以使用以下内容来收听:
from IPython.display import Audio
Audio('toy.wav')
就这样!你创建了自己的 VALL-E“Hello world”。除非你拥有大量计算资源,否则最好等待一个预训练的模型,以进一步使用这个模型。
第六部分 — 结论与未来展望
在这篇文章中,我们了解了 VALL-E,这是一种由微软研究院开发的新型文本到语音架构。VALL-E 以类似语言模型的方式生成音频,这使其与最近的先进方法有所区别,后者通常是端到端的或遵循文本->谱图->波形创建流程。
我们还讨论了 Encodec 模型,它执行音频量化,并在 VALL-E 的训练中用作预训练模型。Encodec 本身很吸引人,能够使用残差向量量化创建超压缩的音频表示。VALL-E 的创建者利用了这一特性,并在此量化之上构建了一个生成的“语言”模型。
最终,我们看到了一些代码,并用自己的声音重复了来自非官方代码的“hello world”实验。该论文的官方代码尚未发布,模型检查点也未发布。看到和使用一个预训练的 VALL-E 模型会很有趣,我相信它迟早会出现。不过,这次是一次有趣的学习旅程。
下次见!
Elad
参考文献
[1] arxiv.org/abs/2301.02111
— VALL-E 论文(神经编解码语言模型是零-shot 文本到语音合成器)
[2] arxiv.org/abs/2301.02111
— Encodec 论文(高保真神经音频压缩)
[3] wiki.aalto.fi/display/ITSP/Concatenative+speech+synthesis
— 拼接语音合成的解释
[4] www.youtube.com/watch?v=aLBedWj-5CQ&t=1s
— 深入了解语音合成的聚会(HuggingFace)
[5] www.youtube.com/watch?v=MA8PCvmr8B0
— 推动神经文本到语音的前沿(微软研究院)
[6] www.youtube.com/watch?v=G9k-2mYl6Vo&t=5593s
— John Tan Chong Min 关于 VALL-E 的优秀视频
[7] www.youtube.com/watch?v=mV7bhf6b2Hs
— Aleksa Gordic 关于 Encodec 的优秀视频
深度学习中的梯度爆炸与消失问题
原文:
towardsdatascience.com/vanishing-exploding-gradient-problem-neural-networks-101-c8f48ec6a80b
如何确保你的神经网络不会“死掉”或“爆炸”
·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 12 月 8 日
--
www.flaticon.com/free-icons/neural-network
。标题=”神经网络图标。” 神经网络图标由 Paul J. 创建 — Flaticon。
什么是梯度消失与梯度爆炸?
在我之前的一篇文章中,我们解释了神经网络如何通过反向传播算法进行学习。主要思路是我们从输出层开始,将误差“传播”到输入层,在这个过程中根据损失函数更新权重。如果你对此不熟悉,我强烈建议你查看那篇文章:
解释神经网络如何通过手动和使用 PyTorch 的代码“训练”和“学习”数据中的模式
[towardsdatascience.com
权重通过其相对于损失函数的偏导数进行更新。问题在于,当我们接近网络的下层时,这些梯度变得越来越小。这导致下层的权重在训练网络时几乎没有变化。这被称为梯度消失问题。
相反,梯度也可能在各层中持续增大。这就是梯度爆炸问题,它主要是递归神经网络中的一个问题。
然而,2010 年由Xavier Glorot和Yoshua Bengio发布的论文诊断了导致梯度出现问题的几个原因。主要的罪魁祸首是sigmoid 激活函数和权重的初始化方式(通常来自标准正态分布)。这种组合导致了层间方差的变化,直到它们在 sigmoid 函数的极端边缘饱和。
以下是 sigmoid 函数的数学方程和图示。请注意,在其极端情况下,梯度变为零。因此,在这些饱和点没有“学习”发生。
Sigmoid 函数。方程由作者使用 LaTeX 编写。
Sigmoid 函数。由作者在 Python 中生成的图示。
我们现在将介绍一些可以减少训练过程中梯度消失或爆炸的技术。
如果你想了解更多关于激活函数及其优缺点的信息,请查看我之前的帖子:
解释神经网络为何能学习(几乎)任何和一切
towardsdatascience.com
Glorot/Xavier 初始化
幸运的是,论文的作者提出了中和上述问题的方法。他们建议了一种新的权重初始化方法,命名为 Glorot 初始化,以作者命名,确保层间的方差保持不变。
上述论文中包含了完整的数学细节,但他们提出的初始化策略是:
正态分布
对于正态分布X ~ N(0, 𝜎²),权重将如下初始化:
Glorot 正态分布初始化的均值和方差。方程由作者使用 LaTeX 编写。
如果n_in = n_out,则我们有LeCun 初始化以计算机科学家Yann LeCun命名。这个初始化方法是在 Glorot 的论文发表前十年,即 1990 年代提出的。
LeCun 初始化。方程由作者使用 LaTeX 编写。
均匀分布
对于均匀分布X ~ U(-a, a):
Glorot 均匀分布初始化的均值和方差。方程由作者在 LaTeX 中编写。
使用这些初始化设置我们的神经网络可以更快地收敛,因为在训练开始时权重不会过小或过大。
以上表达式仅适用于 sigmoid 和tanh激活函数。例如,对于ReLU(修正线性单元)激活函数,正常分布方差需要使用1/n_in进行初始化。
有兴趣的读者可以在这里找到完整的激活函数及其对应的初始化列表。
更好的激活函数
ReLU
对于大多数行业标准的神经网络,sigmoid 激活函数已被大量弃用,取而代之的是 ReLU,因为它对大正值不会饱和(它也更具计算效率):
ReLU 函数。方程由作者在 LaTeX 中编写。
ReLU 激活函数。图由作者在 Python 中生成。
然而,ReLU 并不完美,且受到dying ReLU 问题的影响。 这是指神经元开始“死亡”,因为它们只输出零,因为它们的输入加权和总是负的。这导致零梯度,从而网络停止“学习”任何东西。这对较大的学习率特别糟糕。
ReLU 可以很容易地在PyTorch中应用,如下所示:
import torch
import torch.nn as nn
from torch.nn.functional import relu
class Model(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.input_layer = nn.Linear(in_features=10, out_features=20)
self.output_layer = nn.Linear(in_features=20, out_features=1)
def forward(self, x):
x = relu(self.input_layer(x))
x = self.output_layer(x)
return x
Leaky ReLU
你可以通过使用经典 ReLU 函数的变体来解决这个问题,例如‘Leaky’ ReLU,其中负输入不是零,而是具有一些浅坡度的梯度α:
Leaky ReLU 函数。方程由作者在 LaTeX 中编写。
Leaky ReLU,注意负 x 值的小梯度。图由作者在 Python 中生成。
Leaky ReLU 通常超越经典 ReLU,因为它减少了这种“死亡神经元”问题的可能性,从而实现更稳健的学习。通常发现更大的“泄漏”更好,但要适度。
还有其他变体也常常改善基本 ReLU 的性能:
-
随机 Leaky ReLU (RReLU):在训练过程中,超参数α被随机初始化。
-
参数化 Leaky ReLU (PReLU): 在训练过程中超参数 α 会被学习。
你可以在 PyTorch 中应用 Leaky ReLU,如下所示:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.input_layer = nn.Linear(in_features=10, out_features=20)
self.output_layer = nn.Linear(in_features=20, out_features=1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)
def forward(self, x):
x = self.leaky_relu(self.input_layer(x))
x = self.output_layer(x)
return x
指数线性单元
我们将考虑的最后一个激活函数是 指数线性单元(ELU)。
ELU 函数。作者使用 LaTeX 表示的方程。
ELU,注意负值的指数值。图像由作者在 Python 中生成。
ELU 和 ReLU 之间的主要区别是:
-
对于负值,ELU 函数不为零,因此缓解了普通 ReLU 的神经元死亡问题。
-
它的梯度对于负输入是非零的。
-
由于指数计算,ELU 的计算速度比 ReLU 及其变体要慢。
-
ELU 通常比 ReLU 性能更好
PyTorch 中的 ELU:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.input_layer = nn.Linear(in_features=10, out_features=20)
self.output_layer = nn.Linear(in_features=20, out_features=1)
self.elu = nn.ELU()
def forward(self, x):
x = self.elu(self.input_layer(x))
x = self.output_layer(x)
return x
选择哪个?
有这么多激活函数,很难知道选择哪个。一般规则是 ELU > Leaky ReLU > ReLU > tanh > sigmoid。然而,计算速度可能会影响你的模型,所以你可能需要重新考虑选择哪个激活函数。最好考虑一下你的问题并尝试几种,看看哪个最适合。
用于生成这些激活函数图的代码可以在我的 GitHub 上找到:
[## Medium-Articles/Data Science Basics/activation_functions.py at main · egorhowell/Medium-Articles
我在我的 Medium 博客/文章中使用的代码。通过创建一个账户为 egorhowell/Medium-Articles 开发做贡献……
批量归一化
使用 Glorot 初始化和 ReLU(及其变体)激活函数有助于减少算法开始时梯度消失/爆炸的可能性,但在训练过程中没有帮助。
防止训练过程中梯度消失/爆炸的一种方法是 批量归一化(BN)。这个过程在激活函数应用之前或之后零中心化和重新归一化输出或输入。然后,它会对结果进行平移和缩放,使每一层都有自己“学习”的均值和方差。
这类似于我们在训练前对特征进行缩放和归一化的原因。它确保所有特征处于同一水平,并且较大值的特征不会淹没较小值的特征。批量归一化是在每一层的输出之后应用这一过程,因为它们是下一层的输入!
关于为什么我们需要归一化特征的详细解释在这里。
算法如下所示:
批量归一化算法。方程由作者在 LaTeX 中生成。
其中:
-
x_i 是批量归一化层的输入。
-
μ_B 是批次的均值。
-
σ²_B 是批次的方差。
-
ϵ 是为数值稳定性(避免除以零)添加的小常数。
-
γ 是在训练过程中学习的缩放参数。
-
β 是在训练过程中学习的偏移参数。
我们可以通过它使其更清晰。第一部分只是计算网络中每层输入或输出的均值和方差。这些输出随后使用均值和方差进行归一化。最后一部分是使用 γ 超参数进行缩放,使用 β 超参数进行偏移。这些超参数由网络学习,这使得 BN 如此强大。每一层都会有其自定义的变换!
另一件重要的事情是,在训练期间,BN 会跟踪指数移动平均(EMA)的均值和方差。这在预测时使用,因为你不能真的将 BN 应用于单个预测行,这在推理过程中发生。
你可以在 PyTorch 中按如下方式应用 BN:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.input_layer = nn.Linear(20, 10)
self.bn1 = nn.BatchNorm1d(10)
self.hidden_layer = nn.Linear(10, 5)
self.bn2 = nn.BatchNorm1d(5)
self.output_layer = nn.Linear(5, 1)
def forward(self, x):
x = self.input_layer(x)
x = self.bn1(x)
x = nn.ReLU()(x)
x = self.hidden_layer(x)
x = self.bn2(x)
x = nn.ReLU()(x)
x = self.output_layer(x)
return x
如你所见,我们在所有隐藏层之间插入一个批量归一化层。
批量归一化已被证明能改善深度神经网络的训练,并减少梯度消失问题的影响。然而,每个 epoch 的训练速度较慢,因为我们需要将隐藏层的输出通过批量归一化层,这增加了网络中的总参数数量。
批量归一化也可以作为正则化器!
梯度裁剪
减少梯度消失/爆炸几率的最终方法是将梯度裁剪到某个最大阈值。例如,我们可以将梯度裁剪到最大值 1。这样,任何大于 1 的梯度都会被“裁剪”到 1。这种技术通常应用于递归神经网络(RNN),因为很难对 RNN 应用批量归一化(batch norm)。
这是一个如何在 PyTorch 中应用梯度裁剪的示例:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# Apply gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
这里 max_norm
是梯度被裁剪的阈值。
总结与进一步思考
消失和爆炸梯度的发生是由于神经网络层之间的方差变化以及由于反向传播时的乘法效应导致梯度减少。在这篇文章中,我们讨论了三种减少这种效应的方法:更好的激活函数、批量归一化和梯度裁剪。在我看来,批量归一化与 ReLU 激活函数结合使用可能是最好的选择。批量归一化通过规范化和缩放输入来确保每层之间的方差保持恒定。
另一个信息!
我有一个免费的通讯,Dishing the Data,每周分享成为更好的数据科学家的技巧。没有“空洞的内容”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作见解。
[## Dishing The Data | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读《Dishing The Data》,由 Egor Howell 发表的 Substack 出版物…
联系我!
参考资料与进一步阅读
-
关于消失梯度研究的论文: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
-
批量归一化的精彩视觉解释
随机森林中的变量重要性
传统方法与新发展
·
关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 11 月 3 日
--
(分布式)随机森林的特点。本文中:产生变量重要性的能力。来源:作者。
随机森林及其推广(特别是广义随机森林(GRF)和分布随机森林(DRF))是强大且易于使用的机器学习方法,任何数据科学家都不应缺少它们。这些方法不仅在大量数据集上表现稳健而无需调优,还能轻松处理缺失值,甚至提供置信区间。本文重点讨论它们能够提供的另一个特性:特征重要性的概念。具体来说,我们关注:
-
传统的随机森林(RF)用于预测给定 p 个预测变量时变量 Y 的条件期望 X。
-
分布随机森林,用于预测给定 p 个预测变量 X 时 d-变量 Y 的整个条件分布。
不幸的是,与许多现代机器学习方法一样,这些森林缺乏可解释性。即,涉及的操作过多,似乎无法确定预测变量与 Y 之间的实际功能关系。解决此问题的一种常见方法是定义变量重要性度量(VIMP),至少可以帮助决定哪些预测变量是重要的。通常,这有两个不同的目标:
(1) 寻找少量具有最大准确度的变量,
(2) 检测和排名所有影响变量,以便进一步探索。
(1) 和 (2) 之间的差异很重要,只要 X 中的元素之间存在依赖关系(几乎总是如此)。例如,如果两个变量与 Y 高度相关且彼此相关,则可以在目标 (1) 中移除其中一个输入,而不会影响准确度,因为这两个变量传达了相同的信息。然而,对于目标 (2),两个变量都应包含在内,因为这两个变量在实际应用中可能对领域专家有不同的含义。
今天我们专注于(1),并试图找到较少数量的预测因子,这些因子展示了或多或少相同的预测准确性。例如,在下面的工资示例中,我们能够将预测因子的数量从 79 减少到大约 20,仅有少量准确性损失。这些最重要的预测因子包含了如年龄和教育等众所周知影响工资的变量。此外,还有许多关于(2)的精彩文章,使用 Shapley 值,例如这篇或这篇。也有非常近期且令人兴奋的学术文献讨论如何高效计算与随机森林结合的 Shapley 值。这些是第二篇文章的材料。
今天我们关注的两个度量实际上是更通用的变量重要性度量,可以用于任何方法,基于我们将要讨论的丢弃重学原则。然而,我们在此仅专注于树基方法。此外,我们不会详细解释这些方法,而是专注于它们的应用及为何新版本优于传统版本。
随机森林变量重要性度量概述。均值减少纯度(MDI)和均值减少准确性(MDA)都是由 Breiman 提出的。然而,由于其经验性,本方法仍存在若干问题,最近由 Sobol-MDA 解决。来源:作者
起始
RF 的变量重要性度量实际上与 RF 本身一样古老。均值减少准确性(MDA)这一方法最早由 Breiman 在其开创性的随机森林论文中提出[1]。其原理很简单:对于每一个维度 j=1,…,p,将完整预测的准确性与当 X_j 被随机置换时的预测准确性进行比较。这一想法是打破 X_j 和 Y 之间的关系,并将 X_j 不参与预测 Y 的准确性与它可能有用的情况进行比较。
MDA 在 R 和 Python 中有各种不同的版本实现:
MDA 的不同版本,在不同包中实现。来源:[3]中的表 1
不幸的是,以这种方式排列变量 X_j 不仅打破了它与 Y 的关系,也打破了它与 X 中其他变量的关系。如果 X_j 与所有其他变量独立,这不是问题,但一旦存在依赖关系,就会成为问题。因此,[3] 证明了只要 X 中存在依赖关系,MDA 就会收敛到一些荒谬的结果。特别是,MDA 可能会给予一个对预测 Y 并不重要的变量 X_j 高重要性,但这个变量与另一个对预测 Y 实际上很重要的变量,比如 X_l,高度相关(如下例所示)。同时,它可能无法检测到实际相关的变量,这在 [3, Section 2.1] 的大量文献中得到了证明。从直觉上讲,我们希望测量的是如果不包括 X_j 模型的性能,而我们测量的是使用排列的 X_j 变量的模型性能。
第二种传统的准确性度量是 均值减少不纯度(MDI),它是对所有分裂在给定协变量上的节点的不纯度加权减少的总和,平均到森林中的所有树。不幸的是,MDI 从一开始就定义不明确(不清楚它应该度量什么),且几篇论文强调了这种方法的实际问题(例如 [5])。因此,我们不会详细讨论 MDI,因为 MDA 通常是首选。
现代发展 I: Sobol-MDA
长时间以来,我认为这些稍显非正式的度量是我们能做到的最好方法。一篇改变了这一点的论文最近才发表。在这篇论文中,作者理论上证明了上述流行度量实际上存在很大缺陷,并未测量我们想要测量的内容。因此,第一个问题可能是:我们实际上想要测量什么?一个潜在的答案是:Sobol 指数(最早在计算机科学文献中提出):
让我们来详细了解一下。首先,tau(X)=E[ Y | X] 是我们希望估计的条件期望函数。由于它是随机 X 的函数,因此这是一个随机变量。现在 X^{(-j)} 是移除协变量 j 的 p-1 向量。因此,ST^{(j)} 是移除第 j 个输出变量后解释方差的减少量。
以上是更传统的度量方式。然而,对我来说,写作:
更直观。在这里,d 是两个随机向量之间的距离,对于上述的 ST{(j)}*,这个距离就是通常的欧几里得距离。因此,*ST 的上半部分只是测量了我们想要的(tau(X)) 和没有变量 j 时得到的之间的平均平方距离。
问题变成了如何有效地估计这个距离。事实证明,直观的“丢弃和重新学习”原则就足够了:简单地使用随机森林估计 tau(X),然后丢弃 X_j 并重新拟合随机森林以获得 tau(X^{(-j)}) 的估计值,就可以得到一致的估计量:
其中 tau_n(X_i) 是对测试点 X_i 使用所有 p 个预测变量的随机森林估计值,同样 tau_n(X_i^{(-j)}) 是仅使用 p-1 个预测变量的重新拟合森林。
然而,这意味着森林需要重新拟合 p 次,当 p 很大时效率不高!因此,[3] 中的作者开发了他们所称的 Sobol-MDA。与每次重新拟合森林不同,该方法仅拟合一次森林。然后,将测试点丢弃到同一森林中,结果预测“投影”以形成公式 (1) 中的度量。也就是说,X_j 上的分裂会被简单忽略(记住目标是获得没有 X_j 的估计)。作者能够表明,使用这种投影方法计算上述公式 (1) 也能得到一致的估计量!这确实是一个很棒的想法,使得该算法即使在高维情况下也能适用。
投影方法的示意图。左侧是 RF 对二维空间的划分。右侧的投影方法忽略了 X² 中的分裂,从而在进行预测时将其移除。可以看出,点 X 被投影到右侧的 X^{(-j)} 上。来源:[3] 中的图 1
该方法在 R 语言中实现于 soboldMDA 包中,基于非常快速的 ranger 包。
现代发展 II:基于 MMD 的敏感性指数
通过使用距离 d 的公式,产生了一个自然的问题:是否可以使用不同的距离来获得更复杂问题的变量重要性度量。其中一个近期的例子是使用 MMD 距离作为 d:
MMD 距离是一个很棒的工具,它允许通过使用核 k(如高斯核)相对轻松地构建分布之间的距离:
目前我将细节留给后续文章。最重要的结论是I^{(j)}考虑了比条件期望更一般的目标。它将变量X_j视为重要,只要它以任何方式影响Y的分布。可能是X_j仅改变方差或分位数,而条件均值Y保持不变(见下面的例子)。在这种情况下,Sobol-MDA 将不会识别X_j为重要,但 MMD 方法会。这并不一定意味着它更好,它只是另一种工具:如果你关注预测条件期望,ST{(j)}*是合适的度量。然而,如果你关注预测分布的其他方面,尤其是分位数,*I会更适合。再次强调,I^{(j)}可以通过使用丢弃重学原则(每次去掉变量\(j\)时重新拟合 DRF),或者使用与 Sobol-MDA 相同的投影方法来一致地估计。基于丢弃重学的实现附在本文末尾。我们在这里将这种方法称为MMD-MDA。
模拟数据
现在我们用一个简单的模拟例子来说明这两种现代度量方法:我们首先从Gitlab下载并安装 Sobol-MDA 包,然后加载这个例子所需的所有包:
library(kernlab)
library(drf)
library(Matrix)
library(DescTools)
library(mice)
library(sobolMDA)
source("compute_drf_vimp.R") ##Contents of this file can be found below
source("evaluation.R") ##Contents of this file can be found below
然后我们从这个简单的例子中进行模拟:我们将X_1, X_2, X_4, …, X_10独立地在(-1,1)之间均匀分布,并通过取X_3=X_1 + 均匀误差来创建X_1和X_3之间的依赖关系。然后我们模拟Y为
##Simulate Data that experiences both a mean as well as sd shift
# Simulate from X
x1 <- runif(n,-1,1)
x2 <- runif(n,-1,1)
X0 <- matrix(runif(7*n,-1,1), nrow=n, ncol=7)
x3 <- x1+ runif(n,-1,1)
X <- cbind(x1,x2, x3, X0)
# Simulate dependent variable Y
Y <- as.matrix(rnorm(n,mean = 0.8*(x1 > 0), sd = 1 + 1*(x2 > 0)))
colnames(X)<-paste0("X", 1:10)
head(cbind(Y,X))
然后我们分析 Sobol-MDA 方法以估计给定X的Y的条件期望:
## Variable importance for conditional Expectation Estimation
XY <- as.data.frame(cbind(Xfull, Y))
colnames(XY) <- c(paste('X', 1:(ncol(XY)-1), sep=''), 'Y')
num.trees <- 500
forest <- sobolMDA::ranger(Y ~., data = XY, num.trees = num.trees, importance = 'sobolMDA')
sobolMDA <- forest$variable.importance
names(sobolMDA) <- colnames(X)
sort(sobolMDA, decreasing = T)
X1 X8 X7 X6 X5 X9
0.062220958 0.021946135 0.016818860 0.016777223 -0.001290326 -0.001540919
X3 X10 X4 X2
-0.001578540 -0.007400854 -0.008299478 -0.020334150
如所见,它正确识别出X_1是最重要的变量,而其他变量则被同等排序(不)重要。这是合理的,因为Y的条件期望仅由X_1改变。关键是,尽管存在X_1和X_3之间的依赖关系,该度量仍然能够做到这一点。因此,我们在这个例子中成功实现了上述目标(1)。另一方面,我们也可以看看传统的 MDA:
forest <- sobolMDA::ranger(Y ~., data = XY, num.trees = num.trees, importance = 'permutation')
MDA <- forest$variable.importance
names(MDA) <- colnames(X)
sort(MDA, decreasing = T)
X1 X3 X6 X7 X8 X2
0.464516976 0.118147061 0.063969310 0.032741521 0.029004312 -0.004494380
X4 X9 X10 X5
-0.009977733 -0.011030996 -0.014281844 -0.018062544
在这种情况下,尽管它正确识别了X_1作为最重要的变量,但它还将X_3排在第二位,其值似乎明显高于其他变量。这尽管X_3与X_2, X_4,…, X_10一样不重要!
但如果我们更普遍地关注预测Y的分布,例如估计分位数呢?在这种情况下,我们需要一个能够识别X_2对Y条件方差影响的度量方法。在这里,MMD 变量重要性度量方法发挥作用:
MMDVimp <- compute_drf_vimp(X=X,Y=Y)
sort(MMDVimp, decreasing = T)
X2 X1 X10 X6 X8 X3
0.683315006 0.318517259 0.014066410 0.009904518 0.006859128 0.005529749
X7 X9 X4 X5
0.003476256 0.003290550 0.002417677 0.002036174
再次,该度量能够正确识别重要的内容:X_1 和 X_2 是两个最重要的变量。尽管 X_1 和 X_3 之间存在依赖关系,它仍然能够做到这一点。有趣的是,它还给予 X_2 的方差变化比 X_1 的期望变化更高的重要性。
实际数据
最后,我展示了一个实际数据应用来演示变量重要性度量。注意,使用 DRF,我们甚至可以查看多变量 Y,但为了简化,我们专注于单变量设置,并考虑来自美国人口普查局的 2018 年美国社区调查的工资数据。在第一篇 DRF 论文 中,我们从美国人口普查局的 2018 年美国社区调查中获得了约 100 万名全职员工的数据,从中提取了工资信息及所有可能与工资相关的协变量。这些数据非常适合用来试验像 DRF 这样的方法(实际上我们只会使用其中的一小部分进行分析)。我们加载的数据可以在 这里 找到。
# Load data (https://github.com/lorismichel/drf/blob/master/applications/wage_data/data/datasets/wage_benchmark.Rdata)
load("wage_benchmark.Rdata")
##Define the training data
n<-1000
Xtrain<-X[1:n,]
Ytrain<-Y[1:n,]
Xtrain<-cbind(Xtrain,Ytrain[,"male"])
colnames(Xtrain)[ncol(Xtrain)]<-"male"
Ytrain<-Ytrain[,1, drop=F]
##Define the test data
ntest<-2000
Xtest<-X[(n+1):(n+ntest),]
Ytest<-Y[(n+1):(n+ntest),]
Xtest<-cbind(Xtest,Ytest[,"male"])
colnames(Xtest)[ncol(Xtest)]<-"male"
Ytest<-Ytest[,1, drop=F]
我们现在计算两个变量重要性度量(这会花费一些时间,因为仅实现了丢弃再学习方法用于 DRF):
# Calculate variable importance for both measures
# 1\. Sobol-MDA
XY <- as.data.frame(cbind(Xtrain, Ytrain))
colnames(XY) <- c(paste('X', 1:(ncol(XY)-1), sep=''), 'Y')
num.trees <- 500
forest <- sobolMDA::ranger(Y ~., data = XY, num.trees = num.trees, importance = 'sobolMDA')
SobolMDA <- forest$variable.importance
names(SobolMDA) <- colnames(Xtrain)
# 2\. MMD-MDA
MMDVimp <- compute_drf_vimp(X=Xtrain,Y=Ytrain,silent=T)
print("Top 10 most important variables for conditional Expectation estimation")
sort(SobolMDA, decreasing = T)[1:10]
print("Top 5 most important variables for conditional Distribution estimation")
sort(MMDVimp, decreasing = T)[1:10]
Sobol-MDA:
education_level age male
0.073506769 0.027079349 0.013722756
occupation_11 occupation_43 industry_54
0.013550320 0.010025332 0.007744589
industry_44 occupation_23 occupation_15
0.006657918 0.005772662 0.004610835
marital_never married
0.004545964
MMD-MDA:
education_level age male
0.420316085 0.109212519 0.027356393
occupation_43 occupation_11 marital_never married
0.016861954 0.014122583 0.003449910
occupation_29 marital_married industry_81
0.002272629 0.002085207 0.001152210
industry_72
0.000984725
在这种情况下,这两种变量重要性度量在哪些变量重要上有相当一致的意见。虽然这不是因果分析,但变量 “年龄”、“教育水平”和“性别” 确实被这两种度量视为非常重要的,这一点也很令人欣慰。
为了获得一小组预测变量,现在可以对 j=1,…p-1 进行操作,
(I) 移除最不重要的变量
(II) 在测试集上计算损失(例如均方误差)
(III) 重新计算剩余变量的变量重要性
(IV) 重复直到满足某个停止标准
比如,如果损失增加超过 5%,可以停止。为了简化本文,我仅使用上面保存的相同变量重要性值,即“SobolMDA”和“MMDVimp”。也就是说,我忽略了步骤(III),只考虑(I)、(II)和(IV)。当估计目标是完整的条件分布时,步骤(II)也不是完全清楚的。我们使用称为 MMD 损失的度量,这在我们的论文中有更详细的描述 ([4])。这种损失考虑了我们在预测分布时所犯的错误。对于条件均值,我们简单地使用均方误差。这在下面找到的函数“evalall”中完成:
# Remove variables one-by-one accoring to the importance values saved in SobolMDA
# and MMDVimp.
evallistSobol<-evalall(SobolMDA, X=Xtrain ,Y=Ytrain ,Xtest, Ytest, metrics=c("MSE"), num.trees )
evallistMMD<-evalall(MMDVimp, X=Xtrain ,Y=Ytrain ,Xtest, Ytest, metrics=c("MMD"), num.trees )
plot(evallistSobol$evalMSE, type="l", lwd=2, cex=0.8, col="darkgreen", main="MSE loss" , xlab="Number of Variables removed", ylab="Values")
plot(evallistMMD$evalMMD, type="l", lwd=2, cex=0.8, col="darkgreen", main="MMD loss" , xlab="Number of Variables removed", ylab="Values")
这导致以下两张图片:
注意到这两者都有些许波动线条,这首先是由于我没有重新计算重要性测量,例如,省略了步骤(III),其次是森林的随机性。除此之外,图形很好地展示了每次移除变量后错误如何逐渐增加。这种增加对于最不重要的变量开始时很慢,然后对于最重要的变量加速,正如预期的那样。特别是,如果移除 50 个最不重要的变量,损失几乎保持不变!实际上,在这两种情况下,移除约 70 个变量而损失增加不超过 6%是可能的。不过,需要注意的是,许多预测变量是单独编码的分类变量的一部分,因此在移除预测变量时需要小心,因为它们对应于一个分类变量的水平。然而,在实际应用中,这可能仍然是可取的。
结论
在本文中,我们探讨了随机森林中变量重要性的现代方法,目标是获得一个小的预测变量或协变量集合,既考虑条件期望,也更一般地考虑条件分布。我们在工资数据示例中看到,这可以显著减少预测变量,同时保持几乎相同的准确度。
如上所述,所呈现的测量方法并不严格限制于随机森林,但原则上可以更一般地使用。然而,森林允许优雅的投影方法,使得可以计算所有变量的重要性测量,而无需每次都重新拟合森林(!)这在[3]和[4]中都有描述。
文献
[1] Breiman, L. (2001). 随机森林。机器学习, 45(1):5–32。
[2] Breiman, L. (2003a). 设置、使用和理解随机森林 v3.1。技术报告,UC Berkeley,统计系。
[3] Bénard, C., Da Veiga, S., 和 Scornet, E. (2022). 随机森林的均值减少准确度:不一致性及通过 Sobol-MDA 的实际解决方案。Biometrika, 109(4):881–900。
[4] Clément Bénard, Jeffrey Näf, 和 Julie Josse. 基于 MMD 的分布随机森林变量重要性,2023。
[5] Strobl, C., Boulesteix, A.-L., Zeileis, A., 和 Hothorn, T. (2007). 随机森林变量重要性测量中的偏差:说明、来源和解决方案。BMC Bioinformatics, 8:25。
附录:代码
#### Contents of compute_drf_vimp.R ######
#' Variable importance for Distributional Random Forests
#'
#' @param X Matrix with input training data.
#' @param Y Matrix with output training data.
#' @param X_test Matrix with input testing data. If NULL, out-of-bag estimates are used.
#' @param num.trees Number of trees to fit DRF. Default value is 500 trees.
#' @param silent If FALSE, print variable iteration number, otherwise nothing is print. Default is FALSE.
#'
#' @return The list of importance values for all input variables.
#' @export
#'
#' @examples
compute_drf_vimp <- function(X, Y, X_test = NULL, num.trees = 500, silent = FALSE){
# fit initial DRF
bandwidth_Y <- drf:::medianHeuristic(Y)
k_Y <- rbfdot(sigma = bandwidth_Y)
K <- kernelMatrix(k_Y, Y, Y)
DRF <- drf(X, Y, num.trees = num.trees)
wall <- predict(DRF, X_test)$weights
# compute normalization constant
wbar <- colMeans(wall)
wall_wbar <- sweep(wall, 2, wbar, "-")
I0 <- as.numeric(sum(diag(wall_wbar %*% K %*% t(wall_wbar))))
# compute drf importance dropping variables one by one
I <- sapply(1:ncol(X), function(j) {
if (!silent){print(paste0('Running importance for variable X', j, '...'))}
DRFj <- drf(X = X[, -j, drop=F], Y = Y, num.trees = num.trees)
DRFpredj <- predict(DRFj, X_test[, -j])
wj <- DRFpredj$weights
Ij <- sum(diag((wj - wall) %*% K %*% t(wj - wall)))/I0
return(Ij)
})
# compute retraining bias
DRF0 <- drf(X = X, Y = Y, num.trees = num.trees)
DRFpred0 = predict(DRF0, X_test)
w0 <- DRFpred0$weights
vimp0 <- sum(diag((w0 - wall) %*% K %*% t(w0 - wall)))/I0
# compute final importance (remove bias & truncate negative values)
vimp <- sapply(I - vimp0, function(x){max(0,x)})
names(vimp)<-colnames(X)
return(vimp)
}
#### Contents of evaluation.R ######
compute_mmd_loss <- function(Y_train, Y_test, weights){
# Y_train <- scale(Y_train)
# Y_test <- scale(Y_test)
bandwidth_Y <- (1/drf:::medianHeuristic(Y_train))²
k_Y <- rbfdot(sigma = bandwidth_Y)
K_train <- matrix(kernelMatrix(k_Y, Y_train, Y_train), ncol = nrow(Y_train))
K_cross <- matrix(kernelMatrix(k_Y, Y_test, Y_train), ncol = nrow(Y_train))
weights <- matrix(weights, ncol = ncol(weights))
t1 <- diag(weights%*%K_train%*%t(weights))
t2 <- diag(K_cross%*%t(weights))
mmd_loss <- mean(t1) - 2*mean(t2)
mmd_loss
}
evalall <- function(Vimp, X ,Y ,Xtest, Ytest, metrics=c("MMD","MSE"), num.trees ){
if (ncol(Ytest) > 1 & "MSE" %in% metrics){
metrics <- metrics[!( metrics %in% "MSE") ]
}
# Sort for increasing importance, such that the least important variables are removed first
Vimp<-sort(Vimp)
if ( is.null(names(Vimp)) ){
stop("Need names for later")
}
evalMMD<-matrix(0, nrow=ncol(X))
evalMSE<-matrix(0, nrow=ncol(X))
###Idea: Create a function that takes a variable importance measure and does this loop!!
for (j in 1:ncol(X)){
if (j==1){
if ("MMD" %in% metrics){
DRFred<- drf(X=X,Y=Y)
weights<- predict(DRFred, newdata=Xtest)$weights
evalMMD[j]<-compute_mmd_loss(Y_train=Y, Y_test=Ytest, weights)
}
if ("MSE" %in% metrics){
XY <- as.data.frame(cbind(X, Y))
colnames(XY) <- c(paste('X', 1:(ncol(XY)-1), sep=''), 'Y')
RFfull <- sobolMDA::ranger(Y ~., data = XY, num.trees = num.trees)
XtestRF<-Xtest
colnames(XtestRF) <- paste('X', 1:ncol(XtestRF), sep='')
predRF<-predict(RFfull, data=XtestRF)
evalMSE[j] <- mean((Ytest - predRF$predictions)²)
}
}else{
if ("MMD" %in% metrics){
DRFred<- drf(X=X[,!(colnames(X) %in% names(Vimp[1:(j-1)])), drop=F],Y=Y)
weights<- predict(DRFred, newdata=Xtest[,!(colnames(Xtest) %in% names(Vimp[1:(j-1)])), drop=F])$weights
evalMMD[j]<-compute_mmd_loss(Y_train=Y, Y_test=Ytest, weights)
}
if ("MSE" %in% metrics){
XY <- as.data.frame(cbind(X[,!(colnames(X) %in% names(Vimp[1:(j-1)])), drop=F], Y))
colnames(XY) <- c(paste('X', 1:(ncol(XY)-1), sep=''), 'Y')
RFfull <- sobolMDA::ranger(Y ~., data = XY, num.trees = num.trees)
XtestRF<-Xtest[,!(colnames(Xtest) %in% names(Vimp[1:(j-1)])), drop=F]
colnames(XtestRF) <- paste('X', 1:ncol(XtestRF), sep='')
predRF<-predict(RFfull, data=XtestRF)
evalMSE[j] <- mean((Ytest - predRF$predictions)²)
# DRFall <- drf(X=X[,!(colnames(X) %in% names(Vimp[1:(j-1)])), drop=F], Y=Y, num.trees=num.trees)
# quantpredictall<-predict(DRFall, newdata=Xtest[,!(colnames(Xtest) %in% names(Vimp[1:(j-1)])), drop=F], functional="quantile",quantiles=c(0.5))
# evalMAD[j] <- mean(sapply(1:nrow(Xtest), function(j) abs(Ytest[j] - quantpredictall$quantile[,,"q=0.5"][j]) ))
}
}
}
return(list(Vimp=Vimp, evalMMD=evalMMD, evalMSE=evalMSE ))
}
实验中的方差减少 — 第一部分:直觉
方差减少的直觉及其在随机实验中的重要性。
·
关注 发表在 Towards Data Science ·7 min read·2023 年 2 月 2 日
--
这是一个由两篇文章组成的系列的第一部分,我们将在其中深入探讨实验中的方差减少。在本文中,我们将讨论为何方差减少是必要的,并建立其机制的直觉。在第二部分中,我们将评估该领域的最新方法:MLRATE,并将其与其他成熟方法如 CUPED 进行比较。
让我们先讨论一下为何在实验中需要减少方差。看看,当涉及因果推断时,我们的结论可能由于两种不同的原因而错误:系统性偏差和随机变异。系统性偏差主要是由于自我选择或未测量的混杂因素造成的坏结果,而随机变异发生是因为我们获得的数据仅是我们尝试研究的总体的一个样本。随机实验允许我们消除系统性偏差,但它们并非免疫于基于随机变异的错误。
均值差估计器
让我们考虑一个例子,我们的营销团队想要弄清楚发送电子邮件促销的影响。作为营销数据科学团队,我们决定通过从客户基础中随机选择 2000 个客户来进行实验。我们的响应变量是电子邮件发送后的两周内的支出金额,对于每个客户,我们让硬币投掷决定该客户是否接收到促销。
由于治疗是随机分配的,我们没有系统性偏差,而实验中处理组(T)和对照组(C)之间的简单均值差(DIM)是总体平均治疗效应(ATE)的无偏估计:
关键点在于,这个估计量在平均情况下给我们真实效果,但对于我们最终分析的某个样本,它可能会远离真实值。
在以下代码中,我们首先生成一个 2000 个客户的样本,随机选择其中一半进行治疗,该治疗对客户支出的效果为$5。然后,我们使用线性回归对这个样本应用 DIM 估计器,希望能恢复真实效果。
def dgp(n=2000, p=10):
Xmat = np.random.multivariate_normal(np.zeros(p), np.eye(p), size=n).astype('float32')
T = np.random.binomial(1, 0.5, n).astype('int8')
col_list = ['X' + str(x) for x in range(1,(p+1))]
df = pd.DataFrame(Xmat, columns = col_list)
# functional form of the covariates
B = 225 + 50*df['X1'] + 5*df['X2'] + 20*(df['X3']-0.5) + 10*df['X4'] + 5*df['X5']
# constant ate
tau = 5
Y = (B + tau*T + np.random.normal(0,25,n)).astype('float32')
df['T'] = T
df['Y'] = Y
return df
data = dgp()
ols = smf.ols('Y ~ T', data = data).fit(cov_type='HC1',use_t=True)
ols.summary().tables[1]
不幸的是,这个特定样本的 DIM 估计值远离真实效果并且是负值。它还伴随着一个相当宽的置信区间,这使我们无法得出任何结论。
另一方面,如果我们有机会在每次不同样本的情况下平行重复这个实验成千上万次,我们会看到估计值的密度会在真实效果附近达到峰值。为了验证这一点,让我们创建一个模拟,生成 2000 个客户的样本,并在每次调用时运行前述实验,共 10000 次。
def experiment(**kwargs):
dct = {}
n = kwargs['n']
p = kwargs['p']
df = dgp(n,p)
#Difference-in-means
mu_treated = np.mean(df.query('T==1')['Y'])
mu_control = np.mean(df.query('T==0')['Y'])
dct['DIM'] = mu_treated - mu_control
return dct
def plot_experiment(results):
results_long = pd.melt(results, value_vars=results.columns.tolist() )
mu = 5
p = (ggplot(results_long, aes(x='value') ) +
geom_density(size=1, color='salmon')+
geom_vline(xintercept=mu, colour='black', linetype='dashed' ) +
annotate("text", x=mu, y=.1, label="True Mean", size=15)+
labs(color='Method') +
xlab('Estimate') +
theme(figure_size=(10, 8))
)
return p
%%time
tqdm._instances.clear()
sim = 10000
results = Parallel(n_jobs=8)(delayed(experiment)(n=2000, p=10)\
for _ in tqdm(range(sim)) )
results_df = pd.DataFrame(results)
plot = plot_experiment(results_df)
果然,我们的 10000 个估计值的平均值恢复了真实效果$5。
作者的图 1
结果的变异
为了理解为什么我们在实际存在处理效果时未能揭示该效果,我们需要考虑到底是什么决定了客户支出。如果我们的电子邮件促销与收入、与业务的关系长度、最近性、购买频率和以前购买的价值等其他因素相比,对客户支出的影响有限,那么看到支出变异性更多地由其他因素解释就不足为奇了。
事实上,将客户支出与处理指标绘图揭示了支出从大约$4 到超过$400 的变异情况。由于我们知道促销对支出的影响很小,因此在所有变异中检测这一点变得相当具有挑战性。因此,从对照组到处理组的回归线在下面有一个轻微的负斜率。
作者图 2
增加样本量是让 DIM 估计器更容易检测真实效果的一种方法。如果我们能接触到 10000 名客户而不是 2000 名客户,我们可以获得更接近真实效果的估计,如下所示。
data = dgp(n=10000, p=10)
ols = smf.ols('Y ~ T', data = data).fit(cov_type='HC1',use_t=True)
ols.summary().tables[1]
回归调整
现在,如果我们在时间或其他资源上受到限制,不能增加样本量,但仍然希望能够检测到小效果怎么办?这引出了减少结果方差的想法,这与增加样本量类似,如果效果确实存在,它使检测效果变得更容易。
实现这一目标最简单的方法之一是使用回归调整来控制所有其他决定支出的因素。请注意,我们为此示例生成的样本有 10 个协变量,其中 5 个直接与结果相关。如前所述,我们可以将其视为以下几个因素:收入、与业务的关系长度、最近性、购买频率和以前购买的价值。将它们作为回归中的控制变量意味着在观察处理效果时保持这些因素不变。这意味着如果我们观察在实验时收入水平和购买行为相似的客户,结果的方差会更小,我们将能够捕捉到如下所示的小处理效果:
data = dgp(n=2000, p=10)
ols = smf.ols('Y ~ T+' + ('+').join(data.columns.tolist()[0:10]),
data = data).fit(cov_type='HC1',use_t=True)
ols.summary().tables[1]
为了帮助直观理解,让我们首先将协变量对处理和结果的影响部分剔除,然后绘制结果残差。
model_t = smf.ols('T ~ ' + ('+').join(data.columns.tolist()[0:10]), data=data).fit(cov_type='HC1',use_t=True)
model_y = smf.ols('Y ~ ' + ('+').join(data.columns.tolist()[0:10]), data=data).fit(cov_type='HC1',use_t=True)
residuals = pd.DataFrame(dict(res_y=model_y.resid, res_t=model_t.resid))
p1=(ggplot(residuals, aes(x='res_t', y='res_y'))+
geom_point(color='c') +
ylab("Spending Amount ($) Residual") +
xlab("Received Coupon Residual") +
geom_smooth(method='lm',se=False, color="salmon")+
theme(figure_size=(10, 8)) +
theme(axis_text_x = element_text(angle = 0, hjust = 1))
)
p1
我们看到两个组中结果残差的剩余变异性都大大减少,并且从对照组残差到处理组残差的回归线现在有了一个正斜率。
作者图 3
将原始支出金额的方差与支出金额残差的方差进行比较,确认我们能够实现超过 80%的方差减少。
print("Spending Amount Variance:", round(np.var(data["Y"]),2))
print("Spending Amount Residual Variance:", round(np.var(residuals["res_y"]),2))
由于回归机制,你如果感兴趣可以在这篇文章中了解更多,我们可以通过对支出金额的残差进行回归,得到与对处理的残差回归相同的结果:
model_res = smf.ols('res_y ~ res_t', data=residuals).fit()
model_res.summary().tables[1]
结论
对我们的实验分析应用简单的回归调整,可以大大减少结果指标的方差,这在我们尝试检测实验中的小效果时尤为有用。在本系列的第二部分,我们将深入探讨各种方法在这个领域的表现。具体而言,我们将运行具有不同复杂度的数据生成过程的模拟,并应用 MLRATE - 机器学习回归调整处理效应估计器,这一领域的最新发明,以及其他方法如 CUPED 和回归调整。敬请关注。
代码
此分析的代码可以在我的 github仓库中找到。
感谢阅读!欢迎评论/建议。
实验中的方差减少 —— 第二部分:协变量调整方法
深入探讨 MLRATE —— 机器学习回归调整治疗效应估计器,并与其他方法进行比较
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 2 月 7 日
--
图片由 Sam Moghadam Khamseh 提供,来源于 Unsplash
这是系列文章中的第二篇,我们讨论了实验中的方差减少。在第一篇文章中,我们讨论了为什么在实验中减少结果度量的方差是必要的,并展示了简单的回归调整如何带来显著的好处,同时对这一主题建立了直觉。
实验中的方差减少 — 第一部分:直觉
方差减少的直觉及其在随机实验中重要的原因。
towardsdatascience.com
在这篇文章中,我们将分析几种已建立的协变量调整方法的方差减少性能。具体而言,我们将对数据生成过程中的不同复杂度运行模拟,并将以下方法应用于每个实验数据:
1. 回归调整(OLS_adj)
2. 带交互项的回归调整(OLS_int)
3. 使用实验前数据的控制实验(CUPED)
4. 差异中的差异(DID)
5. 机器学习回归调整处理效应估计器(MLRATE)
我们想要找到的平均处理效应(ATE)是处理(用 1 表示)和控制(用 0 表示)之间的结果预期差异,Y:
思考条件平均处理效应(CATE)也很有用,它是在描述协变量X的单位子集上的 ATE:
对整个协变量空间的 CATE 取平均,得到 ATE:
由于我们有一个随机分配的实验,每种方法都是 ATE 的无偏估计。然而,每种方法实现的方差减少水平可能非常不同,具体取决于基础数据生成过程。
首先讨论每种方法的机制。
均值差异(DIM)
DIM 是 ATE 的一个简单而一致的估计,并且将在我们的分析中作为基线,因为它不涉及任何协变量调整:
回归调整(OLS_adj)
这是治疗指示符* T 的系数估计,如果单位i处于治疗中则为 1,否则为 0,它来自包含协变量X*的结果回归,线性且加性,并假设所有单位的处理效应是恒定的:
要了解为何τ的 OLS 估计量是 ATE 的一个一致估计量,我们考虑处理组和对照组的回归函数分别计算,然后取差异:
因此:
带交互项的回归调整(OLS_int)
这是治疗指示符* T 的系数估计,它来自包含协变量X以及T*与去均值协变量之间交互作用的结果回归。
与 OLS_adj 相反,在这里我们不假设效应在所有单位之间是恒定的,而是允许处理效应随着协变量而变化,尽管是以线性和加法的方式。
τ 的 OLS 估计量再次是 ATE 的一致且渐近正态的估计量。要了解原因,我们再次分别取处理组和对照组的回归函数,然后取差值:
所以,我们有:
使用预实验数据的对照实验 (CUPED)
CUPED (Deng et al., 2013) 的基础思想是使用与结果 Y 高度相关但与处理 T 无关的预实验协变量 X。结果的预实验值 Y 是一个自然的候选者,因为它满足这些标准。在我们的数据中,如果可以访问这样的协变量,我们将按如下方式应用 CUPED:
- 获取 theta:
theta = cov(Y,X)/var(X)
2. 为每个单位 i 创建一个转化后的结果 Y_cuped:
Y_cuped = Y - theta*(X - mean(X))
3. 从以下估计 τ:
这样 Y 的方差被减少了 1-Corr(X, Y):
Var(Y_cuped) = Var(Y)(1-Corr(Y,X))
在我们的模拟中,我们设计了 X1,使其满足上述标准。
差分中的差分 (DID)
在协变量调整的背景下,DID 估计量不应与面板数据应用中的 DID 估计量混淆。在这个背景下,DID 是通过首先训练一个机器学习模型 g(X) 来预测 Y 从 X,然后计算 Y − g(X) 的处理组和对照组平均值之间的差异(Yongyi et al., 2021)。
DID 的动机在于,由于 g(X) 和 T 是独立的,结果估计量与 DIM 估计量具有相同的期望:
然而,如果 g(X) 是 Y 的一个良好预测变量,那么 Var(Y) 将超过 Var(Y − g(X)),基于 Y − g(X) 的平均值的 DID 估计将比基于 Y 的平均值的 DIM 估计具有更低的方差。
此外,使用机器学习使我们能够以数据驱动的方式捕捉结果与协变量之间的复杂相互作用,而不依赖于 OLS_adj 和 OLS_int 中的函数形式假设。
一个关键点是使用交叉拟合。简单来说,我们将数据分为两部分,用一部分建立 g(X),用另一部分获取预测结果,并通过交换这两个部分重复这个过程。我们最终得到每个观测值的预测结果,这些预测结果是由仅在其他观测值上训练的模型生成的。在这个分析中,我们对 DID 和 MLRATE 都应用了交叉拟合,因为这两者在最终估计中都使用了机器学习预测,但原始论文仅对 MLRATE 应用了交叉拟合(Yongyi et al., 2021)。
DID 的计算方法如下:
-
通过交叉拟合训练 g(X) 并从 X 预测 Y。
-
为每个单位 i 创建 Y_res:
Y_res = Y - g(X)
3. 从以下估计 τ:
机器学习回归调整处理效应估计量(MLRATE)
DID 和 MLRATE 之间的主要区别在于,DID 直接从结果Y中减去机器学习预测值g(X),而 MLRATE 则在后续的线性回归步骤中包括预测值以及 T 与去均值预测值之间的交互项(Yongyi et al., 2021)。
MLRATE 的步骤如下:
-
通过交叉拟合训练g(X)并从X预测Y。
-
获取τ估计值:
研究表明,将预测值作为回归量可以保证估计器对差预测的稳健性,并且 MLRATE 的渐近方差不大于 DIM 估计器。
def mlpredict(dfml,p):
XGB_reg = XGBRegressor(learning_rate = 0.1,
max_depth = 6,
n_estimators = 500,
reg_lambda = 1)
X_mat = dfml.columns.tolist()[0:p]
kfold = StratifiedKFold(n_splits=2, shuffle=True, random_state=1)
ix = []
Yhat = []
for train_index, test_index in kfold.split(dfml, dfml["T"]):
df_train = dfml.iloc[train_index].reset_index()
df_test = dfml.iloc[test_index].reset_index()
X_train = df_train[X_mat].copy()
y_train = df_train['Y'].copy()
X_test = df_test[X_mat].copy()
XGB_reg.fit(X_train, y_train)
Y_hat = XGB_reg.predict(X_test)
ix.extend(list(test_index))
Yhat.extend(list(Y_hat))
df_ml = pd.DataFrame({'ix':ix,'Yhat':Yhat}).sort_values(by='ix').reset_index(drop=True)
df_ml[['Y','T']] = dfml[['Y','T']]
df_ml['Ytilde'] = df_ml['Yhat'] - np.mean(df_ml['Yhat'])
df_ml['Yres'] = df_ml['Y'] - df_ml['Yhat']
df_ml = df_ml.drop('ix', axis=1)
return df_ml
数据生成过程 (DGP)
为了比较这 5 种协变量调整方法在不同复杂度下减少方差的效果,我们有一个 DGP,包含N = 2000个独立同分布观测值和 10 个协变量,分布为N(0,I(10×10))。处理指示符为Ti ∼ Bernoulli(0.5),误差项分布为N(0,25²)。处理与协变量和误差项独立,误差项本身与协变量独立。结果Yi和处理效应函数τ(Xi)取决于以下函数形式:
- 协变量的线性效应与恒定处理效应
2. 协变量的线性效应与变化的处理效应
3. 协变量的非线性效应与恒定处理效应
4. 协变量的非线性效应与变化的处理效应
def dgp(n=2000, p=10, linear=True, constant=True):
Xmat = np.random.multivariate_normal(np.zeros(p), np.eye(p), size=n).astype('float32')
T = np.random.binomial(1, 0.5, n).astype('int8')
col_list = ['X' + str(x) for x in range(1,(p+1))]
df = pd.DataFrame(Xmat, columns = col_list)
# functional form of the covariates
if linear:
B = 225 + 50*df['X1'] + 5*df['X2'] + 20*(df['X3']-0.5) + 10*df['X4'] + 5*df['X5']
else:
B = 225 + 50*df['X1'] + 5*np.sin(np.pi*df['X1']*df['X2'] ) + 10*(df['X3']-0.5)**2 + 10*df['X4']**2 + 5*df['X5']**3
# constant ate or non-constant
tau = 5 if constant else 5*df['X1'] + 5*np.log(1 + np.exp(df['X2']))
Y = (B + tau*T + np.random.normal(0,25,n)).astype('float32')
df['T'] = T
df['Y'] = Y
return df
我们在每种函数形式下生成 1000 个数据集,并将每种协变量调整方法应用于这些数据。
def experiment(**kwargs):
dct = {}
n = kwargs['n']
p = kwargs['p']
linear = kwargs['linear']
constant = kwargs['constant']
df = dgp(n,p,linear,constant)
#1\. Difference-in-means
mu_treated = np.mean(df.query('T==1')['Y'])
mu_control = np.mean(df.query('T==0')['Y'])
dct['DIM'] = mu_treated - mu_control
#2\. OLS adjusted
if kwargs['OLS_adj']:
ols_adj = smf.ols('Y ~' + ('+').join(df.columns.tolist()[0:(p+1)]),
data = df).fit(cov_type='HC1',use_t=True)
dct['OLS_adj'] = ols_adj.params['T']
#3\. OLS interacted
if kwargs['OLS_int']:
df = df.assign(**({c+'tilde': (df[c] - df[c].mean()) for c in df.columns.tolist()[0:p]}))
ols_int = smf.ols('Y ~' + ('+').join(df.columns.tolist()[0:(p+1)]) + '+' + 'T:('+('+').join(df.columns.tolist()[p+2:])+')',
data = df).fit(cov_type='HC1',use_t=True)
dct['OLS_int'] = ols_int.params['T']
#4\. CUPED
if kwargs['CUPED']:
theta = smf.ols('Y ~ X1',data = df).fit(cov_type='HC1',use_t=True).params['X1']
df['Y_res'] = df['Y'] - theta*(df['X1'] - np.mean(df['X1']))
cuped = smf.ols('Y_res ~ T', data=df).fit(cov_type='HC1',use_t=True)
dct['CUPED'] = cuped.params['T']
pred_df = mlpredict(df,p)
#5\. Difference-in-differences
if kwargs['DID']:
mu2_treated = np.mean(pred_df.query('T==1')['Yres'])
mu2_control = np.mean(pred_df.query('T==0')['Yres'])
dct['DID'] = mu2_treated - mu2_control
#6\. MLRATE
if kwargs['MLRATE']:
mlrate = smf.ols('Y ~ T + Yhat + T:Ytilde',
data = pred_df).fit(cov_type='HC1',use_t=True)
dct['MLRATE'] = mlrate.params['T']
return dct
最后,我们打印估计值、它们的标准误差和 95%置信区间,并绘制分布图。
def plot_experiment(results, constant = True ):
results_long = pd.melt(results, value_vars=results.columns.tolist() )
print(round(results_long.groupby('variable').agg(mean=("value", "mean"), std=("value", "std"))
.reset_index().sort_values(by='std', ascending=False).reset_index(drop=True)
.assign(CI_lower= lambda x: x['mean'] - x['std']*1.96,
CI_upper= lambda x: x['mean'] + x['std']*1.96,),3)
)
mu = 5 if constant else 4
p = (ggplot(results_long, aes(x='value',color='variable') ) +
geom_density(size=1 )+
scale_color_manual(values = ['black', 'blue', 'green', 'c','red', 'salmon', 'magenta' ]) +
geom_vline(xintercept=mu, colour='black', linetype='dashed' ) +
annotate("text", x=mu, y=.1, label="True Mean", size=15)+
labs(color='Method') +
xlab('Estimate') +
theme(figure_size=(10, 8))
)
print(p)
发现
仿真结果显示在下图中,主要发现如下:
-
每种协变量调整方法在标准误差上都优于 DIM 估计器,无论 DGP 是什么,然而,减少的程度取决于 DGP 的复杂性。
-
当 DGP 包括协变量的线性效应时,无论是恒定处理效应还是变化处理效应,OLS_adj 和 OLS_int 估计器的标准误差最小。
-
当 DGP 包括协变量的非线性效应时,基于机器学习的估计器 DID 和 MLRATE 在恒定和变化的处理效应下的标准误差最小。
-
由于协变量的非线性和变化的处理效应在实际应用中更为常见,因此我们得出结论,基于机器学习的调整方法在实验中减少方差方面优于其他方法。
-
最终,CUPED 在所有场景下的表现都不如其他方法。这主要是因为 CUPED 的原始版本使用了单一协变量进行调整,我们也是如此实现的。可以扩展 CUPED 以处理多个协变量。
%%time
tqdm._instances.clear()
sim = 1000
results1 = Parallel(n_jobs=8)(delayed(experiment)(n=2000, p=10, linear=True, constant=True,
OLS_adj=True, OLS_int=True, CUPED=True, DID=True, MLRATE=True)\
for _ in tqdm(range(sim)) )
results_df1 = pd.DataFrame(results1)
plot_experiment(results_df1)
图 1:协变量的线性效应与恒定处理效应
results2 = Parallel(n_jobs=8)(delayed(experiment)(n=2000, p=10,linear=True, constant=False,
OLS_adj=True, OLS_int=True, CUPED=True, DID=True, MLRATE=True)\
for _ in tqdm(range(sim)) )
results_df2 = pd.DataFrame(results2)
plot_experiment(results_df2, False)
图 2:协变量的线性效应与变化的处理效应
results3 = Parallel(n_jobs=8)(delayed(experiment)(n=2000, p=10,linear=False, constant=True,
OLS_adj=True, OLS_int=True, CUPED=True, DID=True, MLRATE=True)\
for _ in tqdm(range(sim)) )
results_df3 = pd.DataFrame(results3)
plot_experiment(results_df3)
图 3:协变量的非线性效应与恒定的处理效应
results4 = Parallel(n_jobs=8)(delayed(experiment)(n=2000, p=10,linear=False, constant=False,
OLS_adj=True, OLS_int=True, CUPED=True, DID=True, MLRATE=True)\
for _ in tqdm(range(sim)) )
results_df4 = pd.DataFrame(results4)
plot_experiment(results_df4, False)
图 4:协变量的非线性效应与变化的处理效应
结论
这标志着关于实验中方差减少的两篇文章系列的结束。在第一部分中,我们对方差减少在实验中的重要性建立了直觉,而在第二部分中我们分析了不同的协变量调整方法。我们发现所有这些方法都比简单的 DIM 估计器得到更紧的置信区间。因此,协变量调整应该成为分析实验时的标准实践。至于应用哪种方法,我们已经看到基于机器学习的估计器表现最佳,尤其是当 DGP 中存在复杂的相互作用时,这无疑更真实地反映了现实世界。
代码
原始笔记本可以在我的 github 仓库中找到。
感谢阅读!希望你觉得这值得你花时间。
我致力于为从事因果推断及市场数据科学的实践者撰写高质量且有用的文章。
如果你对这些领域感兴趣,可以考虑关注我,也欢迎分享你的评论/建议。
参考文献
[1] W. Lin,关于回归调整实验数据的无神论笔记:重新审视 Freedman 的批评。 (2013),《应用统计年鉴》。
[2] A. Deng, Y. Xu, R. Kohavi, T. Walker,通过利用实验前数据提高在线对照实验的敏感性 (2013),WSDM。
[3] G. Yongyi, C. Dominic, M. Konutgan, W. Li, C. Schoener, M. Goldman, 在线实验中的方差减少的机器学习(2021),NeurIPS。
重要性采样的方差减少
原文:
towardsdatascience.com/variance-reduction-with-importance-sampling-4e5ca4b1c5a7
数学解释和 Python 实现
·发布于Towards Data Science ·阅读时长 6 分钟 ·2023 年 1 月 23 日
--
图片来源:Edge2Edge Media 在Unsplash
在之前的文章中,我介绍了不同的数值采样技术,其中之一是重要性采样。在那篇文章中,我们使用了这一技术来允许从复杂的分布中采样,而从这些分布中采样在其他情况下是不切实际的。然而,重要性采样还有另一个常见用途,即方差降低:通过选择合适的建议分布,我们可以降低估计量的方差——这就是我们在这里要讨论的内容。
重要性采样回顾
假设我们不仅仅想计算随机变量 X
的期望 E[X]
,而是该变量函数的期望 f[X]
。在连续情况下,这可以计算为:
我们可以通过数值近似,即蒙特卡洛方法,来近似这个期望,通过从分布 p
中采样 n
个随机值,然后计算样本均值:
重要性采样背后的想法是使用一个简单的重新表述技巧,将期望写为
— 给出 f(x)p(x)/q(x)
在分布 q
下的期望!通过这种方式,我们可以通过从 q
中采样来计算样本均值:
方差减少
标准蒙特卡洛估计量的方差为:
改编后的重要性采样估计器的方差为:
因此,作为第一步,我们确实观察到方差的差异,这意味着我们很有可能找到减少方差的方法。实际上,选择q
可以将方差减少到 0:
(将这个术语插入上述方程中,并设想f(x)p(x)
互相抵消——留下Var[E[f(X)]]=0
。)
自然地,我们不知道E[f(X)]
,毕竟我们进行采样的原因是为了找到f
的期望值。
然而,我们可以将E[f(X)]
视为某种归一化常数,并从中获得一个重要的见解:我们应该构造q
,使其在f(x)p(x)
较高的地方具有高密度。有了这些,我们进入一个实际例子并应用这一学习。
实际例子
为了演示,我们选择一个尖锐的函数f
和一个与其重叠不太好的概率分布p
。为简化起见,我们将两者都设置为正态分布,例如f = N(5, 1)
和p = N(9, 2)
:
图片由作者提供
我希望选择两个正态分布不会让读者感到困惑,因此让我们重申一下我们要做的事情:我们想计算E[f(X)]
,其中X
是遵循分布p
的随机变量——也就是说,我们想计算p
下f
的均值。注意,这个均值不是通常与正态分布相关的均值(这是 x 轴上的一个值,即分布的众数),而是我们现在关注的p
下 y 值的均值:在这个例子中是~0.36——这是一个鲜为人知且使用较少的值。
为了在数值上近似这一点,如上所述,我们现在将从分布p
中抽样值x
,并计算f(x)
的经验均值。
从直觉上可以理解为什么从这个分布中采样是不好的,希望前面的部分能够对此有更好的解释:对于从p
中抽样的大多数值,f
将接近 0——但对于一些抽样值,f
将非常大——因此我们得到一个大的方差。
因此,按照上述介绍的思路,我们现在提出了一个新的分布q = N(5.8, 1)
,它满足推导出的标准,即在f(x)p(x)
较高的区域其密度也较高:
图片由作者提供
注意,找到这个函数并不是很简单,现实世界中还有许多更复杂的情况。我们必须尽量满足标准,同时还要注意满足重要性采样要求,即p
覆盖q
等。对于这个例子,我实际上绘制了p(x)f(x)
,然后选择了最接近的q
。
Python 实现
让我们用 Python 进行编码。首先,我们引入必要的函数和分布,为了方便使用 functools.partials
来获得一个表示具有固定均值/标准差的正态分布的函数:
MEAN_F, STD_F = 5, 1
MEAN_P, STD_P = 9, 2
MEAN_Q, STD_Q = 5.8, 1
def normal_dist(
mean: float, standard_deviation: float, x: np.ndarray
) -> np.ndarray:
return (
1
/ (standard_deviation * np.sqrt(2 * np.pi))
* np.exp(-0.5 * ((x - mean) / standard_deviation) ** 2)
)
f = partial(normal_dist, MEAN_F, STD_F)
p = partial(normal_dist, MEAN_P, STD_P)
q = partial(normal_dist, MEAN_Q, STD_Q)
然后,我们生成上述图表以供参考:
x = np.linspace(0, 15, 100)
plt.plot(x, f(x), "b-", label="f")
plt.plot(x, p(x), "r-", label="p")
plt.plot(x, q(x), "y-", label="q")
plt.legend()
plt.show()
最后,我们进入(重要性)采样部分。首先,我们计算 E[f(X)]
的直接蒙特卡罗估计器。我们从 p
中生成随机样本 x
,并计算 f(x)
的均值:
x_p = np.random.normal(loc=MEAN_P, scale=STD_P, size=NUM_SAMPLES)
y_p = f(x_p)
现在我们应用重要性采样,即从 q
中抽样并通过重要性权重进行修正:
x_q = np.random.normal(loc=MEAN_Q, scale=STD_Q, size=NUM_SAMPLES)
y_q = f(x_q) * p(x_q) / q(x_q)
综合起来:
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
NUM_SAMPLES = 1000000
MEAN_F, STD_F = 5, 1
MEAN_P, STD_P = 9, 2
MEAN_Q, STD_Q = 5.8, 1
def normal_dist(
mean: float, standard_deviation: float, x: np.ndarray
) -> np.ndarray:
return (
1
/ (standard_deviation * np.sqrt(2 * np.pi))
* np.exp(-0.5 * ((x - mean) / standard_deviation) ** 2)
)
f = partial(normal_dist, MEAN_F, STD_F)
p = partial(normal_dist, MEAN_P, STD_P)
q = partial(normal_dist, MEAN_Q, STD_Q)
x = np.linspace(0, 15, 100)
plt.plot(x, f(x), "b-", label="f")
plt.plot(x, p(x), "r-", label="p")
plt.plot(x, q(x), "y-", label="q")
plt.legend()
plt.show()
x_p = np.random.normal(loc=MEAN_P, scale=STD_P, size=NUM_SAMPLES)
y_p = f(x_p)
x_q = np.random.normal(loc=MEAN_Q, scale=STD_Q, size=NUM_SAMPLES)
y_q = f(x_q) * p(x_q) / q(x_q)
print(
f"Original mean / variance: {np.mean(y_p):.6f} / {np.var(y_p):.6f}"
)
print(
f"Importance sampling mean / variance: {np.mean(y_q):.6f} / {np.var(y_q):.6f}"
)
输出结果将类似于:
原始均值 / 方差: 0.036139 / 0.007696 重要性采样均值 / 方差: 0.036015 / 0.000027
因此,我们仍然获得了正确的均值,但方差减少了约 100 倍!
结论
重要性采样是一个巧妙的重构技巧,使我们能够通过从不同的提案分布中采样来计算期望值和其他矩。它不仅允许从复杂的、否则难以采样的分布中采样,还改变了结果估计器的方差。在这篇文章中,我们展示了如何利用这一点来减少方差。特别地,我们证明并展示了在 p(x)f(x)
(原始分布和问题函数的乘积)高的区域选择具有高概率的提案分布可以获得最佳结果。
感谢阅读!
本文是关于采样系列的第二部分。你可以在这里找到其他部分:
-
第一部分: 采样方法介绍
-
第三部分: 马尔可夫链蒙特卡罗(MCMC)方法介绍
使用 Gumbel Softmax 的离散分布变分自编码器(VAE)
理论与 PyTorch 实现
·发表于 Towards Data Science ·17 分钟阅读·2023 年 8 月 9 日
--
unsplash.com/photos/sbVu5zitZt0
由于这篇文章将会很详尽,我将为读者提供一个索引以便更好地导航:
-
介绍
-
变分自编码器(VAEs)简要介绍
-
Kullback–Leibler (KL) 散度
-
VAE 损失
-
重参数化技巧
-
从分类分布中采样与 Gumbel-Max 技巧
-
实现
介绍
生成模型如今变得非常流行,这要归功于它们能够通过学习和捕捉训练数据的基础概率分布来生成具有固有变异性的全新样本。
我们可以识别出两大主要的生成模型家族:生成对抗网络(GANs)、变分自编码器(VAEs)和扩散模型。在这篇文章中,我们将深入探讨 VAEs,特别是关注具有分类潜在空间的 VAEs。
变分自编码器(VAEs)简要介绍
变分自编码器(VAEs)是一种用于无监督机器学习的深度神经网络。它们属于自编码器家族,自编码器是设计用于通过压缩然后重构数据来学习高效数据表示的神经网络。
VAEs 的主要思想是学习潜在空间中的数据概率分布。这个潜在空间是输入数据的低维表示,其中每个点对应于一个特定的数据样本。例如,给定一个维度为 3 的潜在空间中的向量,我们可以认为第一个维度表示眼睛的形状,第二个维度表示胡须的多少,第三个维度表示生成的人的脸上的肤色。
VAEs 具有两个关键组件:
-
编码器:编码器网络接收输入数据,并将其映射到潜在空间中一个概率分布的参数(通常是高斯分布)。编码器不是直接在潜在空间中产生一个单一的点,而是输出分布的均值和方差。
输出一个分布而不是潜在空间中的单个点作为正则化,这样当我们在潜在空间中选择一个随机点时,解码这个数据点后我们总能得到一个有意义的图像。
-
解码器:解码器网络从潜在空间中采样,并将其重建回原始数据空间。它使用类似于编码器的过程但相反,将潜在表示转换回数据空间。
让我们来说明这个过程:
VAE 编码器-解码器图示,图片由作者提供 (1)
其中 x 是输入图像,z 是潜在空间中的一个采样向量,μ 和 σ 是潜在空间参数,其中μ是均值向量,σ是标准差向量。最后,x’ 是从潜在变量重建的图像。
我们希望这个潜在空间具备两个特性:
-
潜在空间中接近的点应输出相似的图片。
-
从潜在空间中采样的任何点都应产生与训练数据相似的东西,即,如果我们训练的是人的面孔,它不应产生任何有 3 只眼睛或 4 只耳朵的面孔。
要实现第一个目标,我们需要让编码器将相似的图片映射到接近的潜在空间参数,然后解码器将它们映射回看起来相似的图片——这通过图像重建损失来实现。为了实现第二个目标,我们需要添加一个正则化项。这个正则化项是编码器返回的参数与均值为 0、方差为 1 的标准高斯分布——N(0,1)之间的 Kullback–Leibler(KL)散度。通过保持潜在空间接近 N(0,1),我们确保编码器不会为每个样本产生相距过远的分布(通过使均值非常不同和方差非常小),这会导致过拟合。如果发生这种情况,在潜在空间中采样一个远离任何训练点的值将无法产生有意义的图像。
Kullback–Leibler (KL) 散度
KL 散度,简称 Kullback-Leibler 散度,是衡量一个概率分布与另一个分布的不同程度的指标。给定两个概率分布 P(X)和 Q(X),其中 X 是随机变量,KL(Q || P)表示从 Q 到 P 的 KL 散度,是一个非负值,表示使用 Q 来近似 P 时信息的丧失程度。它不是对称度量,即 KL(Q || P)通常不同于 KL(P || Q)。连续和离散变量的公式如下:
KL 散度,离散情况 (2)
KL 散度,连续情况 (3)
但是这个公式的直觉是什么?它是如何推导出来的?
假设我们有一个包含 从分布 P(x) 中抽样得到的观察数据 — {x1, x2, …, xn} 的数据集,我们想要比较这些观察数据在真实分布 P(x) 和近似分布 Q(x) 下的生成可能性。在概率分布下观察整个数据集的可能性可以通过每个观察数据的个体概率的乘积来计算:
-
在 P(x) 下的数据似然:L_P = P(x1) * P(x2) * … * P(xn)
-
在 Q(x) 下的数据似然:L_Q = Q(x1) * Q(x2) * … * Q(xn)
通过比较比率 L_P / L_Q,我们可以比较它们的相似度。如果比率接近 1,则近似分布与真实分布相似;而如果这个比率很高,意味着根据近似分布从真实分布中抽样的序列的可能性显著较低,则两个分布不同。显然,这个比率不能小于 1,因为数据是从真实分布 P(x) 中抽样的。
对这个比率两边取对数,我们得到:
比率 L_P / L_Q 的对数 (4)
现在,如果我们对数据集上真实分布 P(x) 的对数进行期望计算,我们得到期望对数似然比:
比率 L_P / L_Q 的期望对数 (5)
这不过是 KL 散度!作为额外的内容,让我们深入了解 KL 散度如何与交叉熵相关联。细心的读者可能已经认识到:
公式中的 Σ P(x) * log(P(x)) 是 P(x) 的熵的负值,而 Σ P(x) * log(Q(x)) 是 P(x) 和 Q(x) 之间的交叉熵。所以,我们有:
KL 散度作为交叉熵和熵之间的差异 (6)
现在,真实数据分布 P(x) 的熵是一个不依赖于近似分布 Q(x) 的常数。因此,最小化期望对数似然比 E[log(L_P / L_Q)] 等同于最小化真实分布 P(x) 和近似分布 Q(x) 之间的交叉熵 H(P, Q)。
VAE 损失
在“变分自编码器(VAEs)简介”部分,我们提供了关于如何优化 VAEs 的一些直觉,并且潜在空间应该满足 2 个属性,以在从潜在空间抽样 任何 随机数据点时生成有意义的图像,这由重构损失和 KL 散度正则化强制执行。在本节中,我们将深入探讨这两个方面的数学。
给定一些从潜在变量 z 生成的训练数据 x = {x1, x2, …, xn},我们的目标是最大化这些数据的似然,以训练我们的变分自编码器模型。数据的似然由以下公式给出:
数据似然 (7)
我们将潜在变量积分出去,因为它是不可观察的。
现在,p(x|z)可以通过解码器网络轻松计算,而p(z)被假定为高斯分布。然而,我们面临一个大问题——在有限的时间内计算这个积分实际上是不可能的,因为我们需要在所有潜在空间上进行积分。因此,我们使用贝叶斯规则以不同的方式计算我们的p(x):
p(x)的贝叶斯规则 (8)
现在,p(z|x)是难以处理的。p(z∣x)的难处理性源于我们需要对每个数据点x的所有可能值z计算p(z∣x)的积分。形式上,这个积分可以表示为:
p(z|x)的贝叶斯规则 (9)
由于这种难处理性,在 VAE 中,我们 resort 使用一个近似分布(在我们情况下是高斯分布)q(z∣x),这更容易处理且计算上可行。这个近似分布是通过编码器网络学习的:
p(z|x)的近似分布 (10)
现在我们已经准备好所有元素,可以用解码器网络计算的p(x|z)来近似p(x),以及由编码器q近似的p(z|x)。对方程 9 的两边应用对数并进行一些代数变换,我们得到:
p(x)的对数概率 (11)
现在,对两边应用期望算子:
p(x)的对数概率的期望 (12)
这等于:
p(x)的对数概率的期望 — 不同形式 (13)
在上图中,第一个项是重建项,即我们的模型从潜在变量重建训练数据x的效果。第二个项是z的先验——N(0,1)与来自编码器的样本之间的 KL 散度。第三个项是编码器和解码器后验之间的 KL 散度,这是难以处理的。如果我们忽略最后一项,我们得到数据似然的下界,因为 KL 总是≥0,这称为证据下界(ELBO)。因此,我们最终得到:
证据下界(ELBO) (14)
因此,在训练 VAE 时,我们尝试最大化 ELBO,这等同于最大化我们数据的概率。
重参数化技巧
让我们先了解重参数化技巧,因为理解这一点对于理解 Gumbel-Softmax 使用类似的东西至关重要。
正如我们在第一部分中所看到的,编码器输出正态分布的均值和方差参数,然后我们从具有这些参数的正态变量中抽取一个随机向量,并通过解码器传递这个潜在向量以重建初始图像。为了最小化重建损失并使网络学习,我们需要从这个重建损失中进行反向传播,但存在一个问题——潜在变量 Z,即从高斯中抽样的变量,是不可微分的。想一想——你如何对一个样本进行求导?因此,我们不能使用反向传播。解决方案是使用重新参数化技巧。
为了使随机变量Z可微分,我们需要将其分为一个可微分的确定性部分和一个不可微分的随机部分。任何来自随机正态分布的样本 Z ~ N(μ, σ) 可以写成:
Z = μ + N(0,1) = σ = μ + ε σ 其中 ε ~ N(0,1)
所以μ和σ是确定的,我们可以对其进行反向传播,而ε是随机部分,我们不能对其进行反向传播。因此,我们可以对μ和σ进行求导:
随机变量 Z 对均值和标准差的导数 (15)
…以学习潜在空间中正态分布的均值和标准差。
从分类分布中进行抽样 & Gumbel-Max 技巧
如果我们希望将潜在空间建模为分类分布,而不是具有连续潜在分布的情况,会怎么样?你会问,为什么有人要这样做?好吧,离散表示在许多情况下是有用的,例如在强化学习问题中采样离散动作、生成离散文本标记等等。
那么我们如何从分类分布中进行抽样并学习其参数,使其可微分?我们可以重复使用重新参数化技巧的想法,将其调整到这个问题上!
首先,让我们尝试理解如何从分类分布中进行抽样。假设我们有以下概率向量:
theta = [0.05, 0.25, 0.7] 代表以下类别——[红色, 蓝色, 白色]。为了进行抽样,我们需要一个随机源,通常使用 0 到 1 之间的均匀分布。请记住,在均匀分布中,0 到 1 之间的抽样是同样可能的。因此,我们从均匀分布中抽样,并将其转换为分类分布,我们可以根据我们的概率theta进行切片。我们定义一个累计和向量 theta_cum = [0.05, 0.3, 1],它代表下面的图。
给定来自均匀分布的样本,例如 0.31,我们选择累计概率超过生成随机数的类别。
argmax(theta_cum ≥ U(0,1)) = argmax([False, True, True]) 这对应于示例中的“蓝色”,因为 argmax 选择第一个对应于True的索引。
累积概率分类分布,图像作者提供 (16)
现在,我们可以用另一种方式从分类分布中采样 — 不是使用均匀分布,而是使用定义为的 Gumbel 分布:
Gumbel 分布 (17)
假设我们有一个 (log) 概率向量,如之前所示
theta = [log(alpha1), log(alpha2), log(alpha3)], 这些是我们希望通过反向传播估计的参数。为了使用反向传播,我们复现了重新参数化技巧部分中所做的 — 拥有一个确定性部分,即作为我们参数的类别对数概率和一个由随机标准 Gumbel 噪声给出的随机部分。
要使用 Gumbel 从分类分布中采样,我们可以按以下步骤操作:
argmax([log(alpha1) + G1, log(alpha2) + G2, log(alpha3) + G3])
其中 theta 是确定性部分,Gumbel 噪声是随机部分。我们可以通过这两个部分的和进行传播。然而,argmax 不是一个可微的 函数。因此,我们用具有温度 τ 的 Softmax 替代它,以使一切可微。于是,类别的概率 yi 变成:
使用 Gumbel-Softmax 分布采样 (18)
低 τ 会使 Softmax 更接近 argmax,而较高的 τ 会使其更接近均匀分布。实际上,当我们将温度降低到如 1e-05 这样的非常低的值时,概率几乎像选择 argmax,即我们基本上是从离散分布中采样。
实现
我们以 MNIST 数据集为例 (许可:公共领域 / 来源:yann.lecun.com/exdb/mnist/
,也可以在 torchvision.datasets 中找到),目标是学习一个生成模型,假设图像是二值的。潜在变量大小假设为 20,包含 10 个分类变量(10 个数字)。先验是一个包含 10 个类别的分类分布,均匀先验概率为 1/10。
1. 首先实现 Gumbel softmax 函数 gumbel_softmax
。如前所述,这由每个类别的对数概率(logits)之和加上 Gumbel 分布给出的随机性构成。在 3 个类别的情况下,我们有:
softmax([log(alpha1) + G1, log(alpha2) + G2, log(alpha3) + G3]) 使用 Softmax 替代 argmax 以实现可微性。
def sample_gumbel(shape, eps=1e-20):
# sample from a uniform distribution
U = torch.rand(shape)
if is_cuda:
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
y = gumbel_softmax_sample(logits, temperature)
if not hard:
return y.view(-1, latent_dim * categorical_dim)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
# skip the gradient of y_hard
y_hard = (y_hard - y).detach() + y
return y_hard.view(-1, latent_dim * categorical_dim)
附注:
我们可以注意到 gambel_softmax
函数中的一个小技巧——如果参数 hard
为 True,我们使用 argmax 而不是 softmax。在评估时,我们通常使用 argmax(这是我们在 model.sample_img
中所做的),而在训练期间,我们使用 softmax,因为 argmax 操作是不可微分的。然而,这不是必需的,我们也可以在训练期间使用 argmax,通过 跳过 y_hard
在 gumbel_softmax
函数中的梯度并对 softmax y
进行微分。一个简短的示例会有所阐明:
skip_d = False
a = torch.Tensor([1])
a.requires_grad = True
b = torch.Tensor([2])
b.requires_grad = True
c = 2 * (a + b)
if skip_d:
d = c ** 2
d = (d - c).detach() + c
else:
d = c ** 2
f = d * 4
f.retain_grad()
d.retain_grad()
c.retain_grad()
loss = f * 3
loss.backward()
print(loss)
print(a.grad, b.grad, c.grad, d.grad, f.grad)
# Loss value: tensor([432.])
# (tensor([288.]), tensor([288.]), tensor([144.]), tensor([12.]), tensor([3.]))
# Running the same with skip_d = True we get:
# tensor([432.])
# (tensor([24.]), tensor([24.]), tensor([12.]), tensor([12.]), tensor([3.]))
当 skip_d = False 时,我们有:
dl/df = 3
dl/dd = dl/df * df/dd = (3) * (4) = 12
dl/dc = dl/df * df/dd * dd/dc = (3) * (4) * (2 * c) = 144
dl/da = dl/df * df/dd * dd/dc * dc/da = (3) * (4) * (2 * c) * (2) = 288
dl/db = dl/df * df/dd * dd/dc * dc/db = (3) * (4) * (2 * c) * (2) = 288
当 skip_d = True: dl/df = 3
dl/dd = dl/df * df/dd = (3) * (4) = 12
dl/dc = dl/df * df/dd = (3) * (4) = 12
从现在开始,我们跳过 dd/dc,即我们将梯度 dl/dc = dl/dd。
dl/da = dl/df * df/dd * dc/da = (3) * (4) * (2) = 24
dl/db = dl/df * df/dd * dc/db = (3) * (4) * (2) = 24
在上述示例中,loss 的值是相同的,但梯度却不同。在我们的模型中,值不会相同,因为当 hard=True
时我们将 latent_z
设置为 y_hard
,而当 hard=False
时设置为 softmax y
,但 y
的反向传播梯度在两种情况下都是相同的。
2. 现在让我们定义我们的 VAE 模型。编码器将图像映射到分类变量的对数概率,由 3 个线性层和 ReLU 非线性层组成。解码器将潜在空间向量映射回图像空间,由 3 个线性层、2 个 ReLU 非线性层和最后一个 sigmoid 非线性层组成。Sigmoid 直接输出概率,这很方便,因为我们将 MNIST 图像(每个像素)建模为 Bernoulli 变量。
class VAE_model(nn.Module):
def __init__(self):
super(VAE_model, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
self.fc5 = nn.Linear(256, 512)
self.fc6 = nn.Linear(512, 784)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def encode(self, x):
h1 = self.relu(self.fc1(x))
h2 = self.relu(self.fc2(h1))
return self.relu(self.fc3(h2))
def decode(self, z):
h4 = self.relu(self.fc4(z))
h5 = self.relu(self.fc5(h4))
return self.sigmoid(self.fc6(h5))
在前向函数中,我们首先通过编码器计算 logits,使用 Gumbel Softmax:
logits_z = self.encode(data.view(-1,
logits_z = logits_z.view(-1, latent_dim, categorical_dim)
latent_z = gumbel_softmax(logits_z, temp)
latent_z = latent_z.view(-1, latent_dim * categorical_dim)
然后,我们解码它们,给出每个像素的 Bernoulli 概率。我们可以从中采样,以生成具有概率参数的图像:
probs_x = self.decode(latent_z)
# we assumed distribution of the data is Bernoulli
dist_x = torch.distributions.Bernoulli(probs=probs_x, validate_args=False)
接下来,让我们计算 ELBO 损失
EBLO 损失 (19)
对于第一项(重建损失),我们需要计算在我们估计的模型下真实数据的对数似然,这告诉我们真实图像在我们模型下的可能性。我们之前从解码器计算了 dist_x
,这就是我们用来估计该概率的:
# reconstruction loss - log probabilities of the data
rec_loss = dist_x.log_prob(data.view(-1, 784)).sum(dim=-1)
然后,我们计算由 KL 散度给出的正则化,该散度是由 10 类别的分类分布与均匀先验概率 1/10 之间的差异和潜在空间的分类参数给出的:
# KL divergence loss
KL = (posterior_distrib.probs * (logits_z_log - prior_distrib.probs.log())).view(-1, latent_dim * categorical_dim).sum(dim=-1)
包括训练函数和绘图工具在内的完整代码如下:
torch.manual_seed(0)
batch_size = 100
temperature = 1.0
seed = 0
log_interval = 10
hard = False
is_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
if is_cuda:
torch.cuda.manual_seed(seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data/MNIST', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)
def sample_gumbel(shape, eps=1e-20):
# sample from a uniform distribution
U = torch.rand(shape)
if is_cuda:
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
y = gumbel_softmax_sample(logits, temperature)
if not hard:
return y.view(-1, latent_dim * categorical_dim)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
# skip the gradient of y_hard
y_hard = (y_hard - y).detach() + y
return y_hard.view(-1, latent_dim * categorical_dim)
class VAE_model(nn.Module):
def __init__(self):
super(VAE_model, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
self.fc5 = nn.Linear(256, 512)
self.fc6 = nn.Linear(512, 784)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def sample_img(self, img, temp, random=True):
# evaluation
with torch.no_grad():
logits_z = self.encode(img.view(-1, 784))
logits_z = logits_z.view(-1, latent_dim, categorical_dim)
if random:
latent_z = gumbel_softmax(logits_z, temp, True)
else:
latent_z = logits_z.view(-1, latent_dim * categorical_dim)
logits_x = self.decode(latent_z)
# probs instead of logits because we have sigmoid activation
# in the decoder
dist_x = torch.distributions.Bernoulli(probs=logits_x)
sampled_img = dist_x.sample()
return sampled_img
def encode(self, x):
h1 = self.relu(self.fc1(x))
h2 = self.relu(self.fc2(h1))
return self.relu(self.fc3(h2))
def decode(self, z):
h4 = self.relu(self.fc4(z))
h5 = self.relu(self.fc5(h4))
return self.sigmoid(self.fc6(h5))
def forward(self, data, temp, hard):
logits_z = self.encode(data.view(-1, 784))
logits_z = logits_z.view(-1, latent_dim, categorical_dim)
# estimated posterior probabiity coefficients
probs_z = F.softmax(logits_z, dim=-1)
posterior_distrib = torch.distributions.Categorical(probs=probs_z)
# categorical prior
probs_prior = torch.ones_like(logits_z)/categorical_dim
prior_distrib = torch.distributions.Categorical(probs=probs_prior)
latent_z = gumbel_softmax(logits_z, temp)
latent_z = latent_z.view(-1, latent_dim * categorical_dim)
probs_x = self.decode(latent_z)
# we assumed distribution of the data is Bernoulli
dist_x = torch.distributions.Bernoulli(probs=probs_x, validate_args=False)
# Losses
# reconstruction loss - log probabilities of the data
rec_loss = dist_x.log_prob(data.view(-1, 784)).sum(dim=-1)
logits_z_log = F.log_softmax(logits_z, dim=-1)
# KL divergence loss
KL = (posterior_distrib.probs * (logits_z_log - prior_distrib.probs.log())).view(-1, latent_dim * categorical_dim).sum(dim=-1)
elbo = rec_loss - KL
loss = -elbo.mean()
return loss
def train(epoch, model, optimizer):
model.train()
train_loss = 0
temp = temperature
for batch_idx, (data, _) in enumerate(train_loader):
if is_cuda:
data = data.cuda()
optimizer.zero_grad()
loss = model(data, temp, hard)
loss.backward()
train_loss += loss.item() * len(data)
optimizer.step()
if batch_idx % 100 == 1:
temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100\. * batch_idx / len(train_loader),
loss.item()))
print("Temperature : ", temp)
sampled = model.sample_img(data[0].view(-1, 28*28), temp).view(28, 28).detach().cpu()
fig, axs = plt.subplots(1, 2, figsize=(6,4))
fig.suptitle('Reconstructed vs Real')
axs[0].imshow(sampled.reshape(28,28))
axs[0].axis('off')
axs[1].imshow(data[0].reshape(28,28).detach().cpu())
axs[1].axis('off')
plt.show()
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
### Train
temp_min = 0.5
ANNEAL_RATE = 0.00003
latent_dim = 20
categorical_dim = 10
my_model = VAE_model()
my_model.to('cuda:0')
optimizer = optim.Adam(my_model.parameters(), lr=1e-3)
for epoch in range(3):
train(epoch, my_model, optimizer)
在训练开始时,我们有较高的损失和糟糕的重建效果:
重建与真实,训练的开始,作者图像(20)
在训练接近尾声时,我们得到了相当好的重建效果和显著降低的损失。显然,我们可以继续训练更长时间,以获得更好的重建效果。
重建与真实,训练的结束,作者图像(21)
结论
在这篇文章中,我们发现 VAE 也可以用分类潜在空间来建模。当我们想在强化学习问题中采样离散动作或生成文本的离散标记时,这非常有用。在尝试对 argmax 操作进行微分以选择分类变量时,我们遇到了一个问题,因为 argmax 是不可微分的,但 thanks to Gumbel Softmax 和重新参数化技巧的启发,解决了这个问题。
[## 使用我的推荐链接加入 Medium — Alexey Kravets
作为 Medium 的会员,你的部分会员费会分配给你阅读的作者,并且你可以完全访问所有故事…
medium.com](https://medium.com/@alexml0123/membership?source=post_page-----b3f749b3417e--------------------------------)
参考资料
[1] jhui.github.io/2017/03/06/Variational-autoencoders/
[2] blog.evjang.com/2016/11/tutorial-categorical-variational.html
[3] www.youtube.com/watch?v=Q3HU2vEhD5Y&list=PL5-TkQAfAZFbzxjBHtzdVCWE0Zbhomg7r&index=19
[4] arxiv.org/pdf/1611.01144.pdf
[5] github.com/shaabhishek/gumbel-softmax-pytorch
变分推断:基础知识
·
关注 发表在 数据科学前沿 ·9 分钟阅读·2023 年 6 月 16 日
--
我们生活在量化的时代。然而,严格的量化比说起来容易做起来难。在生物等复杂系统中,数据的收集可能既困难又昂贵。而在医疗和金融等高风险应用中,考虑不确定性至关重要。变分推断——一种处于人工智能研究前沿的方法——是一种解决这些问题的方式。
本教程将介绍变分推断的基础知识:何时、为何以及如何使用变分推断。
变分推断何时有用?
变分推断在以下三个密切相关的用例中非常有吸引力:
1. 如果你拥有少量数据(即观察值较少),
2. 如果你关心不确定性,
3. 用于生成建模。
我们将在我们的实例中讨论每种用例。
1. 变分推断与少量数据
图 1:变分推断允许你在领域知识与样本信息之间进行权衡。图像由作者提供。
有时候,数据收集是昂贵的。例如,DNA 或 RNA 测量每次观察可能会花费几千欧元。在这种情况下,你可以用领域知识代替额外的样本进行硬编码。变分推断可以帮助你在收集更多样本时系统性地“减少”领域知识,并更多地依赖于数据(见图 1)。
2. 不确定性的变分推断
对于安全关键型应用,如金融和医疗保健,确定性很重要。不确定性可以影响模型的所有方面,最明显的是预测输出。模型的参数(例如权重和偏置)则不那么明显。你可以将参数赋予一个分布,使其变得模糊,而不是通常的数字数组——权重和偏置。变分推断允许你推断出合理值的范围。
3. 用于生成建模的变分推断
生成模型提供了数据生成的完整规范。例如,如何生成猫或狗的图像。通常,有一个潜在表示 z 具有语义意义(例如,z 描述了一只暹罗猫)。通过一系列(非线性)变换和采样步骤,z 被转换为实际的图像 x(例如,暹罗猫的像素值)。变分推断是一种推断和从潜在语义空间 z 进行采样的方法。一个著名的例子是 变分自编码器。
变分推断是什么?
从本质上讲,变分推断是一种贝叶斯方法[1]。从贝叶斯的角度来看,你仍然让机器像往常一样从数据中学习。不同的是,你给模型一个提示(先验),并允许解(后验)变得更加模糊。更具体地说,假设你有一个训练集 X = [x₁, x₂,..,x**ₘ]ᵗ,共有 m 个样本。我们使用贝叶斯定理:
p(Θ|X) = p(X|Θ)p(Θ) /p(X),
推断一个范围——一个分布——的解决方案Θ。这与传统的机器学习方法形成对比,后者通过最小化损失 ℒ(Θ, X) = ln p(X|Θ)来寻找一个特定的解决方案Θ。贝叶斯推断的核心在于找到一种方法来确定 p(Θ|X): 给定训练集 X 的参数 Θ 的后验分布。一般来说,这是一个困难的问题。实际上,有两种方法用于求解 p(Θ|X*): (i) 使用模拟(马尔科夫链蒙特卡罗)或 (ii) 通过优化。
变分推断涉及选项 (ii)。
证据下界(ELBO)
图 2: 变分推断的示意图。 我们寻找一个接近 p(Θ|X) 的分布 q(Θ)。图像由作者提供。
变分推断的核心思想是寻找一个分布 q(Θ),作为 p(Θ|X) 的替代(代理)。然后我们尝试通过改变 Φ 的值,使 q[Θ|Φ(X)] 看起来类似于 p(Θ|X)(见图 2)。这通过最大化证据下界(ELBO)来完成:
ℒ(Φ) = E[ln p(X,Θ) — ln q(Θ|Φ)],
其中期望 E[·] 是对 q(Θ|Φ) 进行的。 (注意 Φ 隐式依赖于数据集 X,但为了方便书写,我们将忽略这一显式依赖。)
对 ℒ 进行基于梯度的优化时,乍一看,我们必须在对 Φ 求导时小心,因为 E[·] 对 q(Θ|Φ) 的依赖。幸运的是,像 JAX 这样的自动梯度包支持重参数化技巧 [2],允许你直接从随机样本(例如伽马分布的样本)中进行求导,而无需依赖高方差的黑箱变分方法 [3]。简而言之:使用一批 [Θ₁, Θ₂,..] ~ q(Θ|Φ) 来估计 ∇ℒ(Φ),然后让你的自动梯度包处理细节。
从头开始进行变分推断
图 3: 来自 scikit-learn 的手写“零”的示例图像。图像由作者提供。
为了巩固我们的理解,让我们从头开始使用 JAX 实现变分推断。在这个例子中,你将对来自 scikit-learn 的手写数字进行生成模型训练。你可以按照 Colab notebook 进行操作。
为了简单起见,我们将只分析数字“零”。
from sklearn import datasets
digits = datasets.load_digits()
is_zero = digits.target == 0
X_train = digits.images[is_zero]
# Flatten image grid to a vector.
n_pixels = 64 # 8-by-8.
X_train = X_train.reshape((-1, n_pixels))
每张图像是一个 8x8 的离散像素值数组,范围从 0 到 16。由于像素是计数数据,我们使用 泊松分布 和伽马 先验 来对像素 x 进行建模,其中 Θ 是速率参数。速率 Θ 决定了像素的平均强度。因此,联合分布 为:
p(x,Θ) = 泊松(x|Θ)伽马(Θ|a, b),
其中 a 和 b 是 伽马分布 的形状参数和速率参数。
图 4: 使用数字“零”的领域知识作为先验。图像由作者提供。
先验——在这种情况下是 Gamma(Θ|a, b)——是你注入领域知识的地方(用例 1)。例如,你可能对“平均”的数字零是什么样子有一些想法(见图 4)。你可以使用这些 a priori 信息来指导你选择 a 和 b。为了使用图 4 作为先验信息——我们称之为 x₀——并将其重要性作为两个例子来加权,然后设置 a = 2x₀; b = 2。
用 Python 写出来的代码如下:
import jax.numpy as jnp
import jax.scipy as jsp
# Hyperparameters of the model.
a = 2\. * x_domain_knowledge
b = 2.
def log_joint(θ):
log_likelihood = jnp.sum(jsp.stats.gamma.logpdf(θ, a, scale=1./b))
log_likelihood += jnp.sum(jsp.stats.poisson.logpmf(X_train, θ))
return log_likelihood
请注意,我们使用了 JAX 实现的 numpy 和 scipy,以便我们可以进行求导。
接下来,我们需要选择一个替代分布 q(Θ|Φ)。提醒一下,我们的目标是改变 Φ 使得替代分布 q(Θ|Φ) 匹配 p(Θ|X)。因此,q(Θ) 的选择决定了近似的水平(我们在上下文允许的地方省略对 Φ 的依赖)。为了说明问题,我们选择一个由 gamma 分布组成的变分分布:
q(Θ|Φ) = Gamma(Θ|α,β),
其中我们使用了简写 Φ = {α,β}。
接下来,为了实现证据下界 ℒ(Φ) = E[ln p(X,Θ) — ln q(Θ|Φ)],首先写下期望括号内的项:
@partial(vmap, in_axes=(0, None, None))
def evidence_lower_bound(θ_i, alpha, inv_beta):
elbo = log_joint(θ_i) - jnp.sum(jsp.stats.gamma.logpdf(θ_i, alpha, scale=inv_beta))
return elbo
在这里,我们使用了 JAX 的 vmap 来矢量化函数,以便我们可以在批量 [Θ₁, Θ₂,..,Θ₁₂₈]ᵗ 上运行它。
为了完成 ℒ(Φ) 的实现,我们对变分分布 Θᵢ ~ q(Θ) 的实现进行平均。
def loss(Φ: dict, key):
"""Stochastic estimate of evidence lower bound."""
alpha = jnp.exp(Φ['log_alpha'])
inv_beta = jnp.exp(-Φ['log_beta'])
# Sample a batch from variational distribution q.
batch_size = 128
batch_shape = [batch_size, n_pixels]
θ_samples = random.gamma(key, alpha , shape=batch_shape) * inv_beta
# Compute Monte Carlo estimate of evidence lower bound.
elbo_loss = jnp.mean(evidence_lower_bound(θ_samples, alpha, inv_beta))
# Turn elbo into a loss.
return -elbo_loss
关于这些参数,有几个要点需要注意:
-
我们将 Φ 打包为一个字典(或者技术上说是一个 pytree),其中包含 ln(α), 和 ln(β)。这个技巧确保了 α>0 和 β>0——这是 gamma 分布施加的一个要求——在优化过程中。
-
loss 是 ELBO 的随机估计。在 JAX 中,我们每次采样时都需要一个新的伪随机数生成器(PRNG)key。在这种情况下,我们使用 key 采样 [Θ₁, Θ₂,..,Θ₁₂₈]ᵗ。
这完成了模型 p(x,Θ)、变分分布 q(Θ) 和损失 ℒ(Φ) 的规范说明。
模型训练
接下来,我们通过改变 Φ = {α,β} 来最小化损失 ℒ(Φ),使 q(Θ|Φ) 匹配后验 p(Θ|X)。怎么做?使用传统的梯度下降法!为了方便,我们使用了 Optax 中的 Adam 优化器,并用先验 α = a 和 β = b 初始化参数 [记住,先验是 Gamma(Θ|a, b) 并且编码了我们的领域知识]。
# Initialise parameters using prior.
Φ = {
'log_alpha': jnp.log(a),
'log_beta': jnp.full(fill_value=jnp.log(b), shape=[n_pixels]),
}
loss_val_grad = jit(jax.value_and_grad(loss))
optimiser = optax.adam(learning_rate=0.2)
opt_state = optimiser.init(Φ)
在这里,我们使用value_and_grad同时评估 ELBO 及其导数。这对于监控收敛非常方便!然后我们使用jit) 即时编译结果函数,使其更加高效。
最终,我们将训练模型 5000 步。由于损失是随机的,对于每次评估,我们需要为其提供一个伪随机数生成器(PRNG)密钥。我们通过分配 5000 个密钥来实现这一点,使用随机拆分。
n_iter = 5_000
keys = random.split(random.PRNGKey(42), num=n_iter)
for i, key in enumerate(keys):
elbo, grads = loss_val_grad(Φ, key)
updates, opt_state = optimiser.update(grads, opt_state)
Φ = optax.apply_updates(Φ, updates)
恭喜!你已经成功地使用变分推断训练了第一个模型!
你可以通过这里在 Colab 上访问包含完整代码的笔记本。
结果
图 5:变分分布与精确后验分布的比较。图片来源:作者。
让我们退一步欣赏一下我们所构建的(图 5)。对于每个像素,替代的 q(Θ) 描述了关于平均像素强度的 uncertainty(用例 2)。特别地,我们选择的 q(Θ) 捕捉了两个互补的元素:
-
典型的像素强度。
-
图像间强度的变化程度(变异性)。
结果表明,我们选择的联合分布 p(x,Θ) 有一个精确的解:
p(Θ|X) = Gamma(Θ|a + Σxᵢ, m + b),
其中 m 是训练集中样本的数量 X。在这里,我们可以明确看到领域知识——以 a 和 b 形式体现——在我们收集更多样本 xᵢ 时如何被调节。
我们可以轻松比较学到的形状 α 和速率 β 与真实值 a + Σxᵢ 和 m + b。在图 5 中,我们比较了两个特定像素的分布——q(Θ|Φ) 与 p(Θ|X) —。结果令人惊叹,完美匹配!
附加:生成合成图像
图 6:使用变分推断生成的合成图像。图片来源:作者。
变分推断非常适合生成建模(用例 3)。有了替代后验 q(Θ),生成新的合成图像是简单的。两步过程是:
- 样本像素强度 Θ ~ q(Θ).
# Extract parameters of q.
alpha = jnp.exp(Φ['log_alpha'])
inv_beta = jnp.exp(-Φ['log_beta'])
# 1) Generate pixel-level intensities for 10 images.
key_θ, key_x = random.split(key)
m_new_images = 10
new_batch_shape = [m_new_images, n_pixels]
θ_samples = random.gamma(key_θ, alpha , shape=new_batch_shape) * inv_beta
- 使用 x ~ 泊松(x|Θ) 采样图像。
# 2) Sample image from intensities.
X_synthetic = random.poisson(key_x, θ_samples)
你可以在图 6 中看到结果。请注意,“零”字符的锐度稍逊于预期。这是我们建模假设的一部分:我们将像素建模为相互独立而非相关。要考虑像素相关性,你可以扩展模型以聚类像素强度:这称为泊松分解[4]。
摘要
在本教程中,我们介绍了变分推断的基础知识,并将其应用于一个玩具示例:学习手写数字零。得益于自动求导,从头实现变分推断只需几行 Python 代码。
变分推断在数据较少的情况下特别强大。我们展示了如何融合和权衡领域知识与数据中的信息。推断的替代分布 q(Θ) 提供了模型参数的“模糊”表示,而不是一个固定值。如果你处于一个不确定性重要的高风险应用中,这种方法是理想的!最后,我们展示了生成模型。只要你能从 q(Θ) 中采样,生成合成样本就很容易。
总结来说,通过利用变分推断的力量,我们可以解决复杂问题,使我们能够做出明智的决策、量化不确定性,并最终释放数据科学的真正潜力。
致谢
我想感谢 Dorien Neijzen 和 Martin Banchero 的校对。
参考文献:
[1] Blei, David M., Alp Kucukelbir, 和 Jon D. McAuliffe. “变分推断:统计学家的综述.” 美国统计协会杂志 112.518 (2017): 859–877.
[2] Figurnov, Mikhail, Shakir Mohamed, 和 Andriy Mnih. “隐式重新参数化梯度.” 《神经信息处理系统进展》 31 (2018).
[3] Ranganath, Rajesh, Sean Gerrish, 和 David Blei. “黑箱变分推断.” 人工智能与统计学. PMLR, 2014.
[4] Gopalan, Prem, Jake M. Hofman, 和 David M. Blei. “可扩展推荐与泊松分解.” arXiv 预印本 arXiv:1311.1704 (2013).
机器学习中的各种部署类型
学习各种部署策略,以成功构建端到端的机器学习管道
·
关注 发表在 Towards Data Science · 6 分钟阅读 · 2023 年 1 月 6 日
--
图片由 Jiawei Zhao 提供,来源于 Unsplash
机器学习有很大的范围和需求,尤其是在最新的自动驾驶行业中,驾驶员通过 AI 的帮助获得辅助。此外,还有其他行业受益,如制药行业,它们开始使用 AI 来开发有趣的产品,这些产品本质上用于预测性医疗。其他行业还包括电子商务,在这些行业中,最相关的产品被推荐给用户,提高了客户购买产品的倾向。
通常,关于机器学习的能力以及它们如何在大量任务中取得最先进结果以实现高精度的讨论很多。然而,最少讨论的话题是如何在实时中进行部署,以及在生产阶段进行持续监控和评估。这是许多在线机器学习和深度学习课程中被忽视的关键因素。一个机器学习模型只有在我们能够将其作为应用提供给最终用户时才算优秀。
查看所有依赖机器学习的不同类型行业会让很多人倾向于这个领域,并使公司取得成功。有大量的在线课程突出机器学习的关键领域,如特征工程、数据准备、模型构建、超参数调整等等。然而,这些课程中缺少一个重要的元素:部署。
照片由Mediamodifier提供,来源于Unsplash。
在这篇文章中,我们将详细了解各种部署策略,这些策略对于希望通过建立AI 能力来给团队留下深刻印象的人来说是至关重要的。现在,让我们详细探讨一下机器学习的部署策略。
批量推理
现在你已经训练并进行了机器学习模型的超参数调整,是时候将最佳模型投入生产。批量推理是一种部署策略,其中机器学习模型以实时方式部署,只接受数据批次。采用这种策略,模型通常能够离线工作或处理周期性任务,例如生成报告或预测。
批量推理可以在我们希望对客户对各种产品的情感进行分类的场景中非常有用。换句话说,他们可能会给出评论,如果我们想了解客户对产品的整体情感,批量推理可以是一种很好的机器学习模型部署策略。
实时推理
这是一种将机器学习模型在接收数据时进行实时运行的部署方式。因此,它们会准备好以某种格式接收数据并提供实时预测,以便可以采取相应的行动。此外,根据项目和团队目标,可能还需要满足实时推理的要求,如低延迟系统或更高的预测准确性。
实时推理的一个经典例子是在进行交易时检测欺诈活动的可能性。机器学习模型最初会使用包含欺诈和非欺诈交易的数据进行训练。在选择出最佳模型后,它会进行实时推理,以便客户能够了解是否发生了欺诈活动。
本地部署
在产品实时部署之前,机器学习团队通常需要进行高安全性措施和数据合规检查。在这种情况下,数据和生产中的机器学习代码的重要性更高。
本地部署涉及在组织设施内的物理设备或服务器上部署机器学习模型。因此,它可以提供对数据和模型的高安全性和控制。
本地部署在预测性维护中可能非常有用,其中使用机器学习模型来确定各种制造设备的故障可能性。我们不依赖互联网提供实时预测,而是使用我们自己的一组服务器和机器,这些设备能够提供机器学习预测所需的计算能力。每当模型预测制造材料存在各种缺陷时,人们可以通过更换这些产品来采取行动。
云部署
这是一种将我们的机器学习服务提供到云中的部署方式。因此,我们利用集群设备的计算资源和内存。因此,我们应能根据用户执行的各种机器学习操作的流量来扩展我们的应用程序。
云部署在我们不确定训练和部署模型所需资源数量时可能会很有用。此外,这些服务仅在用户使用我们的预测时才会初始化。
使用云部署的一个流行示例是预测客户在特定服务集上流失的可能性。如果我们建立了一个服务,让订阅者使用,我们将基于一组预测特征预测客户是否会离开服务。由于我们无法完全了解可能注册服务和同时离开服务的客户总数,因此在云中部署训练模型是一种好的方法,因为这会根据流量需求简化扩展。
移动部署
这是一种在移动设备(如智能手机和平板电脑)上部署机器学习模型的部署方式。这种类型的部署示例包括个人助理、图像识别和语言翻译应用程序。
由于我们在资源受限的环境中部署模型,这与在服务器上部署的环境不同,因此必须在最终形成 ML 产品之前考虑硬件因素。机器学习应用可能非常有用,并且可以具有合理的准确性。然而,如果模型在硬件资源较少的情况下无法生成预测,那么它可能不适合用于移动应用程序。
在尝试实时在移动设备上部署这些产品时,必须考虑低延迟要求、偏差-方差权衡以及其他因素。
边缘部署
这是一种在边缘设备(如物联网 (IoT))上部署机器学习模型的方式。这些设备位于网络的边缘,并依赖于稳定的互联网连接来进行预测。尽管如此,也有一些物联网设备不需要互联网连接,而是拥有能够生成预测的硬件。
在尝试使用这种类型的部署时,必须考虑一些要求。重要的考虑因素包括处理能力、内存容量和连接性。这些因素对用于机器学习的物联网设备的性能有很大影响。
另一个重要的考虑因素是优化模型以适应边缘部署,并考虑在云端运行模型的可行性。这可能涉及减少需要在边缘设备上使用的 ML 模型的复杂性或大小。因此,部署类型将取决于用于提供机器学习能力的设备类型。
结论
总的来说,我们已经探讨了很多机器学习和深度学习模型的部署选项。在许多在线课程中,重点强调了机器学习模型及其内部工作原理。这些课程很好地突出了一些关于这些模型的细微差别,这些差别可以在测试之前进行深入解读。然而,也应该对部署方面给予足够重视,因为由于众多考虑因素,实时部署这些模型可能具有挑战性。
阅读完这篇文章后,希望你对机器学习的常见部署选项有了一些了解。谢谢。
以下是你可以联系我或查看我工作的方式。
GitHub: suhasmaddali (Suhas Maddali ) (github.com)
YouTube: www.youtube.com/channel/UCymdyoyJBC_i7QVfbrIs-4Q
LinkedIn: (1) Suhas Maddali, Northeastern University, Data Science | LinkedIn
Medium: Suhas Maddali — Medium
机器学习中的向量表示
原文:
towardsdatascience.com/vector-representations-for-machine-learning-5047c50aaeff
数据科学家如何将现实世界的对象转换为数值表示以开发机器学习模型
·发表于 Towards Data Science ·8 分钟阅读·2023 年 4 月 25 日
--
机器学习工程师利用世界的数值表示来构建和训练预测算法。
在监督学习的背景下,这些表示方式使计算机能够学习它们与目标变量之间的关系。
假设一个向量仅仅是一个数字列表。
X = [1, 2, 3, 4, 5]
这个列表与目标变量 y
相关。
X = [1, 2, 3, 4, 5]; y = 1
机器学习模型学习特征与目标之间的关系,并输出预测——在这种情况下,是一种将一个类别标识为数字 1 的分类。
在这篇文章中,我将写关于如何使用向量以数值格式表示复杂概念。
其理论基础是机器学习模型无法从以非数值格式提供的观察数据中学习。
文本、图像、声音和其他输入观察数据必须首先转化为适合学习的数值格式。
有多种技术可以将现象转化为向量,这取决于我们处理的数据类型。
-
我们将从介绍One-Hot 编码的概念开始,这是一种用于将单词表示为数值向量的技术。
-
接下来,我们将探讨这种技术的局限性,并介绍嵌入的概念,这是一种可以将单词、图像、声音等表示为比 One-Hot 编码所需的成千上万个类别更小的数值向量的技术。
-
我们还将提到TF-IDF 和词袋模型,它们在文本向量化中至关重要。
我们如何将现象编码为向量?
我们将以文本为例继续讨论。这个例子非常明显,因为正如我们所猜测的那样,机器学习模型不能直接使用文本进行学习。我们需要首先将每个字符或词汇转换为数字。
假设我们想创建词汇的数值表示
-
国王
-
皇后
-
王子
-
公主
对这些词汇进行编码的最简单方法是依次为每个词汇分配一个数字。
作者提供的图片
词汇已正确转换为数值格式,按照映射关系
map = {
"King": 1,
"Queen": 2,
"Prince": 3,
"Princess": 4
}
但有一个问题。如果我们将这些数据提供给任何预测模型,它会为王子和公主分配更高的数学值,使它们比国王和皇后更重要。
显然,这将为模型提供错误的信息,导致模型学习错误的关系。我们需要使我们的数值表示更为精确。
One-Hot Encoding
为了解决上述数值表示问题,可以使用One-Hot Encoding技术。
在这种情况下,每个词将由一个数值向量表示,向量的大小等于需要表示的词总数。该向量的所有值都为零,只有一个值表示特定的词。
例如,对于“国王”、“皇后”、“王子”和“公主”这四个词,每个词将由一个包含四个元素的数组表示,其中值“1”在对应于该词的位置,所有其他位置的值为“0”。
该技术解决了在数值表示中为词汇分配更高数学值的问题,这些词汇在数量上并不比其他词汇更重要。
作者提供的图片
现在,我们的模型为数据集中每个词(在此案例中仅包含 4 个词)具有“平衡”的向量表示。
但是……如果我们的词汇表由成千上万甚至数百万个词汇组成呢?考虑到意大利词典中大约有 270,000 个词汇,应用 One-Hot Encoding 至少会有问题。
执行这种编码所需的计算资源将是相当可观的,最终的表示将是“仅仅”平衡的:没有关于词汇之间关系的信息。
嵌入
为了克服 One-Hot Encoding 的局限性,可以使用称为嵌入的技术。这允许将词汇表示为可控大小的数值向量,相较于 One-Hot Encoding 所需的成千上万的类别。
其想法是创建一个数值表示词汇的方式,这种表示方式考虑了词汇之间的语义关系。
实际上,每个词汇都表示为一个实数向量,其中每个维度表示词义的不同方面。
理解嵌入很简单:相关的单词应该在向量空间中靠得很近,而不相关的单词则应该相距较远。
让我们尝试创建一个图表,捕捉前面提到的一些单词特征。
作者提供的图像
我们可以看到王子和公主的词语彼此多么接近,就像国王和女王一样。
假设性别变量只能取两个值,M 和 F(我们使用 0 和 1),而年龄变量只能取三个值[年轻、中年、老年](我们使用 0、1、2),我们可以看到嵌入如何表示这些关系。
作者提供的图像
这种表示通过使用性别和年龄维度,成功捕捉了个体的贵族地位。
在 X 轴上移动,我们可以观察到这两个贵族如何在表示性别差异的维度上等距(0:男性,1:女性)。而在 Y 轴上移动,我们可以观察到年龄是如何通过嵌入距离 Y 轴的方式来表示的。
这样,词嵌入可以作为机器学习模型的输入,使得复杂的概念能以数字格式更准确地表示。
在这个例子中我们只有两个维度。实际上,神经网络是通过特定任务来训练的,以在多个维度上找到这些表示。
为了更好地理解,像 GPT-3 这样的模型使用了超过 12,000 维。
行业的一个里程碑
嵌入表示不仅可以用于单词,还可以用于表示图像、声音等。
向量表示的使用在当今的机器学习中至关重要。深度学习领域的各种创新和技术源于向量化的概念。
像 GPT-3.5 这样的模型通过结合向量表示、经过充分研究的优化算法和大量计算资源而产生。
理论上,这种方法没有限制。
更多的数据 → 更高质量的向量 → 使用这些向量的模型进行更好的训练。
嵌入的局限性
尽管嵌入是一种非常有用的技术,用于以数字格式表示复杂概念,但它们也有局限性。
特别重要的是要强调,嵌入是从训练数据中构建的,因此可能会受到数据中任何偏见的影响。
如前所述,嵌入的质量取决于训练数据的质量。如果训练数据不代表模型将使用的领域,嵌入可能无法捕捉概念之间所有的语义关系。
此外,嵌入可能需要大量的内存来存储,特别是当维度数量很大时。这对于需要在资源受限的设备上运行的机器学习模型,尤其是移动设备,可能特别有问题。
其他文本表示方式
由于文本是我们周围最常见的数据格式(只需想到互联网上大量的文本数据),一些文本向量化技术是常见且众所周知的。
其中之一是TF-IDF 转换,这是一种文本向量化技术,根据词在文档中的频率及其在语料库中的总体频率,为每个词分配一个权重。
这样,文档中出现频率较高但在语料库中出现频率较低的词汇将比那些在各处频繁出现的词汇具有更高的权重。这种技术在自然语言处理领域用于文本分析中被广泛使用。
我邀请有兴趣的读者通过阅读以下文章来了解更多关于 TF-IDF 模型的知识
## Text Clustering with TF-IDF in Python
文本聚类的简单管道解释。完整示例和代码
TF-IDF 基于词袋模型,该模型将文档表示为一个无序的词汇集合,忽略句子结构和词序。
通过这种方式,词袋模型可以用来表示任何文档为一组数值,其中每个数值代表一个词在文档中的频率。当然,这并不能充分表示词汇之间的关系,这一点由词嵌入(embeddings)提供。
结论
在这篇文章中,我们已经看到如何使用向量以数字格式表示复杂的概念。
对数据科学家来说,以向量化的方式思考是重要的。像这样的疑问
-
我如何将这个刺激转换成一个数字?
-
神经网络如何解释这些数据?
-
我如何改善这个表示?
这些问题至关重要,能够充分回答这些问题的团队将创造出更好的系统。
数据科学家从向量的角度看待世界。
如果你想支持我的内容创作活动,请随意通过下面的推荐链接加入 Medium 的会员计划。我将获得你投资的一部分,你将能够无缝访问 Medium 上丰富的数据科学及更多领域的文章。
## Join Medium with my referral link - Andrea D'Agostino
阅读 Andrea D'Agostino(以及 Medium 上的其他数千位作家的)每个故事。你的会员费直接……
推荐阅读
对于有兴趣的人,这里是我推荐的每个机器学习相关主题的书单。这些书籍在我看来是必读的,并且对我的职业生涯产生了深远的影响。
免责声明:这些是亚马逊的附属链接。我将从亚马逊那里获得少量佣金作为推荐费。您的体验不会改变,您不会被收取额外费用,但这将帮助我扩大业务并制作更多有关人工智能的内容。
-
机器学习简介: 自信的数据技能:掌握数据处理的基本知识,提升你的职业生涯 由 Kirill Eremenko 编著
-
Sklearn / PyTorch: 用 PyTorch 和 Scikit-Learn 进行机器学习:使用 Python 开发机器学习和深度学习模型 由 Sebastian Raschka 编著
-
Sklearn / TensorFlow: 动手实践机器学习:使用 Scikit-Learn、Keras 和 TensorFlow 由 Aurelien Géron 编著
-
自然语言处理: 文本作为数据:机器学习与社会科学的新框架 由 Justin Grimmer 编著
-
数据可视化: 用数据讲故事:商务专业人士的数据可视化指南 由 Cole Knaflic 编著
有用的链接(由我编写)
-
学习如何在 Python 中进行顶级探索性数据分析: Python 中的探索性数据分析 — 分步过程
-
学习 PyTorch 的基础知识: PyTorch 入门:从训练循环到预测
-
学习 TensorFlow 的基础知识: 入门 TensorFlow 2.0 — 深度学习简介
-
使用 TF-IDF 在 Python 中进行文本聚类: 使用 TF-IDF 在 Python 中进行文本聚类
向量搜索并不是你所需的一切
原文:
towardsdatascience.com/vector-search-is-not-all-you-need-ecd0f16ad65e
·发表于数据科学前沿 ·阅读时间 6 分钟·2023 年 9 月 18 日
--
人工智能软件被用来提升本文文本的语法、流畅性和可读性。
引言
检索增强生成(RAG)已经革新了开放领域问答,使系统能够对各种查询生成类似人类的响应。RAG 的核心是一个检索模块,它扫描大量语料库以找到相关的上下文段落,这些段落随后由神经生成模块(通常是预训练的语言模型,如 GPT-3)处理,以制定最终答案。
尽管这种方法非常有效,但也不是没有局限性。
最关键的组成部分之一,即对嵌入段落的向量搜索,存在固有的限制,这可能会妨碍系统以细致的方式进行推理。这在需要跨多个文档进行复杂的多跳推理时尤为明显。
向量搜索指的是使用数据的向量表示来搜索信息。这涉及两个关键步骤:
- 将数据编码为向量
首先,被搜索的数据会被编码为数值向量表示。对于像段落或文档这样的文本数据,这通过诸如 BERT 或 RoBERTa 的嵌入模型来完成。这些模型将文本转换为表示语义意义的连续数字的稠密向量。图像、音频和其他格式也可以使用适当的深度学习模型编码为向量。
2. 使用向量相似性搜索
一旦数据被编码为向量,搜索涉及到找到与搜索查询的向量表示相似的向量。这依赖于距离度量,如余弦相似度,以量化两个向量的接近程度并对结果进行排序。距离最小(相似度最高)的向量被返回为最相关的搜索结果。
向量搜索的主要优势在于能够搜索语义相似性,而不仅仅是字面上的关键词匹配。向量表示捕捉了概念意义,从而能够识别出更相关但语言上不同的结果。这使得搜索质量比传统的关键词匹配更高。
然而,将数据转换为向量并在高维语义空间中进行搜索也存在局限性。平衡向量搜索的权衡是一个活跃的研究领域。
在本文中,我们将剖析向量搜索的局限性,探讨它为何难以捕捉文档之间的多样关系和复杂的相互联系。我们还将深入研究如知识图谱提示等替代技术,这些技术有望克服这些不足之处。
随着我们的生活中越来越多地整合 AI 工具,了解当前 AI 工具的优缺点是至关重要的。本文旨在提供对向量搜索在增强大语言模型推理能力方面的优缺点的全面视角。
问题与答案之间的语义差距
在向量搜索中,输入问题和语料库中的段落都被编码为密集的向量表示。通过找到与问题向量具有最高语义相似性的段落来检索相关的上下文。
然而,问题往往与它们寻求的实际答案存在间接关系。
“法国的首都是什么?”的向量可能不一定与陈述“巴黎是法国人口最多的城市”的段落具有高度相似性。
这种语义差距意味着包含答案的段落可能会被忽视。
嵌入无法捕捉问题与答案之间的推理联系。
段落粒度的重要性
在向量搜索系统中,段落通常由单一的嵌入向量表示。这些段落的粒度可以有所不同。
如果段落非常大,例如整个文档,它可能包含多个概念。段落的某些部分可能相关,而其他部分则不相关。
但由于单个向量代表整个段落,因此无法区分相关部分和无关部分。整个段落可能与问题向量只有微弱的相似性。
相反,使用句子级别的块可以帮助隔离概念。但这会增加索引中的向量数量,增加计算开销。
选择段落大小时存在精度与可处理性之间的固有权衡。
对复杂推理的挑战
有些问题需要综合多个文档中的事实。
例如,“酿酒的最早历史记录是什么?” 可能需要从不同来源拼凑日期。
向量搜索对于这种多跳推理能力不足。每个段落独立地对问题进行评分。没有机制可以共同分析或连接不同结果中的信息。
随着问题变得越来越复杂,简单的相似性搜索达到了极限。系统难以从不同的段落中收集和上下文化事实。
黑箱模型工作原理
在标准向量搜索流程中,如何选择初始检索的段落是不透明的。排名取决于语义相似性模型的内部工作。
这种缺乏透明性使得结果难以解释、验证和改进,也限制了在业务关键应用中的部署。
为了增加监督,排名算法应提供一些可解释性,以说明为什么某些段落被认为是相关的。
建模多样关系
标准向量搜索的核心限制在于其单一关注语义相似性。
然而,现实世界的推理需要对内容之间的多样关系进行建模。
多文档问答(MD-QA)涉及回答需要综合多篇信息的问题…
知识图谱通过明确编码各种连接到互联图结构中来克服这一点。具体而言:
-
主题关系 — 如果段落共享稀有或关键关键词,则这些段落会被链接。这捕捉到讨论主题的相似性。
-
语义关系 — 段落嵌入被比较以连接那些语义上接近的段落,即使它们不共享相同的术语。
这超越了表面层次的主题匹配。
- 结构关系 — 段落与它们出现的特定部分、页面或文档相连。
这编码了上下文层次结构。
- 时间关系 — 讨论时间顺序事件的段落按时间顺序链接在一起。
这代表了事件的流动。
- 实体关系 — 在引用相同现实世界实体的段落之间添加了指代链接。
这允许以实体为中心的推理。
通过结合这些超越语义相似性的多样信号,知识图谱提示(KGP)提供了一个更丰富的推理基础,用于关于互联信息的推理。
结构关系
相比之下,标准向量搜索没有这些结构关系的概念。段落被视为原子,没有任何周围的上下文。
知识图谱对结构关系的建模值得进一步讨论。通过将段落链接到它们出现的特定文档或部分,信息的上下文层次被编码。
这使得可以明确推理某一事实所在的部分、其来源的文档以及发布的网站。
对层次文档结构的编码为确定重要性、有效性和相关性提供了有用的归纳偏差,这在跨段落推理时尤为重要。
时间关系
在孤立的向量搜索中完全不存在这种归纳偏差。向量相似性评分没有考虑时间动态。检索到的段落是断裂的快照,缺乏叙事流。
KGP 中时间关系的明确建模也带来了显著的优势。根据所描述事件的时间顺序排列段落,使得对展开的叙事和时间线进行推理成为可能。
知识图谱通过根据相对时间链接事件来克服这一局限性。这解锁了更丰富的推理能力。
实体关系
在标准的向量搜索中,这些实体链接没有直接建模。有关实体的宝贵知识在段落嵌入中丢失。
知识图谱连接实体引用的能力是一项强大的资产。链接讨论相同现实世界实体、概念或人物的段落,可以围绕这些共享元素进行重点推理。
KGP 保留了这一信号,使得可以以实体为中心探索知识图谱。这在跨文档聚合关于特定实体的事实时提供了结构性优势。
结论
向量搜索基于语义相似性实现了高效的近似匹配。然而,在 RAG 系统的检索步骤中,单独使用时存在明显的局限性。
采用结合向量搜索与基于图谱的知识表示、多步推理模块和透明排名算法的混合方法可以帮助克服这些弱点。
一如既往,没有单一的解决方案——利用多样化的技术工具包是实现现实世界问答系统的强大检索的关键。
此图像是使用 AI 图像生成模型创建的。
来源:
labelbox.com/blog/how-vector-similarity-search-works/
www.elastic.co/what-is/vector-search
medium.com/@PolonioliAI/limitations-of-vectors-and-neural-search-4d81fd64482f
medium.com/vector-database/frustrated-with-new-data-our-vector-database-can-help-e5c430b29be7
www.singlestore.com/blog/why-your-vector-database-should-not-be-a-vector-database/
clickhouse.com/blog/vector-search-clickhouse-p1
www.searchenginejournal.com/semantic-search-with-vectors/467574/
www.usenix.org/system/files/osdi23-zhang-qianxi_1.pdf
www.infoworld.com/article/3651360/solving-complex-problems-with-vector-databases.html
people.eecs.berkeley.edu/~matei/papers/2020/sigir_colbert.pdf
blog.futuresmart.ai/gpt-4-semantic-search-and-vector-databases-revolutionizing-question-answering
blog.vespa.ai/constrained-approximate-nearest-neighbor-search/
www.pinecone.io/learn/vector-search-filtering/
向量化:是什么以及它是如何工作的?
图片由Mariana Beltrán提供,来自Unsplash
O(n)比 O(1)更快,缓存行,Pandas 2.0 和列的持续增长
·
关注 发布在 Towards Data Science ·10 min read·Apr 13, 2023
--
这是本文的第二次迭代。在完成第一次迭代后,我让它静置一段时间进行编辑,因为标题看起来不太好——一个关于向量化的 13 分钟长篇,其中包含与数据库理论和历史趋势的松散联系。
在等待重新草拟时,我发现了几个关于新版本 Pandas 2.0 的性能比较——尤其是与 Polars 的比较。此时我必须坦白——Pandas 对我而言是起点,我甚至还没有 pip install
Polars 进行测试。我总是犹豫在尚未得到广泛支持的新工具和流行工具之间进行替换,直到:
-
现有工具开始失败(我在 SQL 聚合后使用的数据不够大)。
-
有一些其他明显有力的证据支持采用。
然而,新版 Pandas 2.0 性能相对较差确实让我产生了疑问——如果 Polars 在内存操作中如此快速,它是如何实现的?
Polars 的作者写了我原来的文章(但更好)。
向量化。Polars 之所以快速,是因为原作者设计了整个系统以向量化为核心。在这篇‘hello world’文章中,Polars 作者 Ritchie Vink 通过清晰简洁的语言和简单的视觉效果解释了 Polars 如何实现其目标——因为它不仅仅是以向量化为理念进行构建,而是完全围绕这些原则构建的。
这篇文章的其余部分并非仅仅是重述那些内容,而是通过回顾一些想法和历史背景,阐明我们如何将基于列(或数组)的计算变得更加‘主流’,以及这如何开始渗透到现代数据科学工具包中。
“我不在乎你的花哨数据结构是什么,但我知道数组会胜过它。”
上述内容出自这次讲座由 Scott Meyers 主讲,他将上述引言归于一位支持数组的算法交易公司 CTO。这个想法在 Polars 文章中也有所提及,但概念相同——在实际应用中,有时你需要抛弃基本的时间复杂度分析,因为 O(n) 算法可以超越 O(1) 算法。
我来自非计算机科学背景,但大学(特别是美国大学)提供的大量在线材料使我能够学习一些基础的‘算法’和‘数据结构’课程。根据我所见,联合目标可以(可能总结得很糟糕)如下:
-
使用逻辑来提出一个涉及最少步骤(算法)的过程。
-
组织你的数据,以便选择具有最小步骤的算法(数据结构)。
理解两者核心概念的重要性在于众多公司用大量的Leetcode动态规划或二叉搜索树等与日常无关的题目来考察他们的候选人。一个典型的例子是——在查找东西时使用哈希映射而不是数组。这是因为:
-
哈希映射的查找是 O(1)
-
(未排序的)数组是 O(n),因为你可能需要遍历整个数组
为什么上述情况并不总是成立?因为“O(1)算法比 O(n)算法更快”是不完整的。真正的说法应该是“O(1)算法比 O(n)算法更快,前提是起始点相同*”。
在纽约驾驶法拉利
现代 CPU 被描述为类似于在纽约驾驶一辆法拉利。毋庸置疑,这个比喻显然来自于一个认为开车唯一目的就是从 A 点到 B 点的计算机科学家,但这一点依然成立。如果你只是不断地停车和启动(尽管停车和启动非常非常快),为什么需要这样一辆快车呢?
将‘车’替换为‘处理器’,我们就看到了现代 CPU。这是因为处理器速度与内存速度之间的相对速度改进(尽管近年来速度提升变得较慢)。这一点在 Herb Sutter 的C++导向讲座中得到了很好的阐述(如果你只想了解概述,可以从 12:00 开始观看约 20 分钟),下面的图表也很好地展示了这一点:
新的向量时代的机会和选择——ResearchGate 上的科学图
简单来说,我们有一个“摩尔定律引发的”处理器操作速度和将数据送到处理器的速度之间的差距。如果处理器大部分时间都在空闲等待新数据进行操作,那么就没有必要拥有如此快速的处理器。
我们如何规避这个问题?缓存、缓存和更多缓存
下面是许多人对计算机架构的标准印象(跳到 Herb Sutter 讲座的 22:30,观看这个有趣的例子):
作者提供的图片
换句话说——我们有:
-
我们的 CPU 执行任务
-
我们的内存(或 RAM)‘快速’访问
-
我们的磁盘访问速度慢
如上图所示,内存可能相对于从硬盘检索数据是快的——但相对于我们当前处理器的数量和速度,检索速度确实不快。为了解决这个问题,硬件开发者决定将内存放置到CPU(或‘芯片’)上。这些被称为缓存,每个处理核心都有多个。 每个核心包括:
-
L1 缓存:这是最快的,并分为指令缓存(存储你代码转化成的指令)和数据缓存(存储你操作的变量)
-
L2 缓存:比 L1 大,但较慢
然后,机器上的所有核心共享一个 L3 缓存——同样比 L1 和 L2 大,但再次较慢。
下图基于来自 Intel 白皮书的放大照片(Intel 架构基础知识(v.1, Jan. 2014)),展示了 Intel i7 处理器的 4 核心 CPU 布局——L1 和 L2 缓存在每个‘核心’部分内:
图片由作者提供——原图见此文档:Intel 架构基础知识(v.1, Jan. 2014)
但这同样适用于更新的 PC。我在 Macbook Air M2 上写这篇文章, 在维基百科上快速查阅表明 M2 芯片包含 L1、L2 和一个共享的‘最后一级缓存’(L3)。
那么,从缓存中检索数据比从内存中快多少呢?
对此——一图胜千言——或者更具体地说,是一个动画。来自游戏软件优化公司 Overbyte 的以下链接展示了相对性能的极端差异:www.overbyte.com.au/misc/Lesson3/CacheFun.html
。
速度在机器之间有所不同,但——以标准的‘时钟周期’单位来衡量,我们大致有:
-
L1:~1–3x 时钟周期
-
L2:~10x 时钟周期
-
L3:~40x 时钟周期
-
主内存:~100–300x 时钟周期
换句话说:如果我们的数据在 L1 内存中,我们的计算机会比从主内存中提取数据快 30–300 倍。主内存指的是 RAM,而非硬盘。当你的 CPU 在 L1 缓存中找不到数据时,它会搜索 L2,然后是 L3,最后到主内存每次这些搜索失败都会标记为‘缓存未命中’。 缓存未命中越少,你的代码执行越快。
那么,我如何将数据加载到缓存中?缓存行
你的 CPU 为你完成了这一切。基于你的代码,这些代码被翻译成指令集(最低级的命令——即使汇编也被汇编成指令),你的 CPU 将数据加载到缓存中。然后它对数据进行操作并将结果存储在缓存中。
特别是,如果你定义以下内容:
x = 1
y = 2
x + y
你的 CPU 不仅仅是将整数 x
和 y
加载到缓存中——它实际上加载的是‘缓存行’。 缓存行 是 CPU 处理数据的最底层单位。一个缓存行将包含你需要的数据,但也包括内存中周围的其他数据,这些数据构成了整个缓存行——通常是一个 64 字节的连续内存块。
不仅如此,CPU 还被设计来进行聪明的数据优化提取。为什么?因为这很慢——所以我们越早将数据加载到缓存中进行操作,就越早解决数据存储便宜而处理器快速之间的瓶颈。
-
处理器对数据进行操作的速度
-
将数据提供给处理器所需的时间——在缓存中
为了做到这一点,CPU 实现了类似于预取的功能——这意味着识别程序内存访问中的模式,并预测你将会使用哪些内存。
总结
可能最好的总结方式是回答(最后)我们最初的问题——什么是向量化?
-
向量化是充分利用你超快处理器的手段
-
它通过将数据组织成可预测的连续块(数组)来完成这一点,这些数据块可以一起加载(缓存行)并进行预取
-
这可以防止你的 CPU 在等待从 L2/L3 缓存或更糟——主内存中加载数据时无所事事
但为什么是现在?如果这看起来如此明显,为什么这不是一直以来的基础?
发生了什么变化?相对的历史进展
之前的计算机科学家并不是愚蠢的(正好相反)而错过了这个显而易见的想法。相反,我们现在正在开发最适合的工具包:
-
我们当前的硬件格局
-
我们当前的使用案例
由于历史进展(上面的摩尔定律图表),我们正处于这样的情况:
-
数据存储便宜而处理器迅速
-
我们有广泛的数据分析使用案例来推动决策制定(无论是自动化还是非自动化)
结果是,由于相对的进步,我们现在面临一个瓶颈——将所有这些数据输入到我们超快的处理器中。而且为了清楚起见,这个问题并不新鲜。如上图所示,这已经是一个多年来逐渐扩大的问题,但我们应用了 L1/L2 缓存的权宜之计;因为这意味着大多数人不必担心这个问题。
但是现在差距 如此 巨大,数据规模也在不断增长,以至于问题需要在更接近源头的地方解决。换句话说,如果你想让你的数据以闪电般的速度被操作,那么不要把工作推给 CPU 设计师,而是自己动手,并将数据存储在数组中。
Kdb+、Dremel 和 BigQuery,PyArrow
一般来说,当我们处理数据时,通常有两种方式:
-
在内存中
-
在磁盘上
对我来说,这仅仅意味着数据是否一般存储在一个数据 库 或者数据 框架 中。Polars,以及最近的 Pandas 2.0,未必做了完全新的事情,而是更多地重新设计了我们内存中数据的表示方式,使其更接近于在磁盘上通常的存储方式。
为什么?因为我们在如何以可以快速过滤和汇总的数据存储方式上取得了显著进展——那么为什么不在内存中以能够利用这些进展的方式来表示数据呢?我们完全可以采纳像 Kdb+ 这样的技术驱动思想,并将其实现于内存数据存储方式中。
迈向一致的方法
Polars 基于 PyArrow——这是 Apache Arrow 内存列式数据格式 的 Python 实现。新版本的 Pandas 2.0 也是如此。PyArrow 在从磁盘加载 Apache Parquet 格式的数据时表现特别好。Parquet 基本上是原始 Dremel 论文 中描述的列格式。Dremel 是支撑 Google Big Query 的分析引擎。
要点是:这些都是相互关联的概念,最新的内存数据科学工具并不一定是一个彻底的变化,而是朝着一个日益一致的理念迈进,这个理念支撑着我们的数据分析工具。我们将数据存储在数组/列中——既在磁盘上 也 在内存中。
为什么?因为这是考虑到处理硬件、数据存储成本和内存数据检索速度之间的相对速度改进后,做事的最佳方式。
结论:Wes McKinney 的简历
这可能看起来是一个奇怪的结论话题,但通过 Pandas 创始人的职业生涯可以看出上述向语言无关、一致的数据操作和建模方法的轨迹。他最初在 AQR Capital Management 使用电子表格处理数据,然后创建并推广了 Pandas(利用 NumPy 的向量友好型 ndarray),现在他 深度参与 Apache Arrow。
他的职业生涯与现代“数据科学堆栈”(至少在 Python 中)朝着更加面向列的内存数据表示方式的演变紧密相连。看起来数组的推广势不可挡,并且我个人认为这种趋势不会放缓。请原谅这句双关语。
使用 JAX 向量化和并行化 RL 环境:以光速进行 Q 学习⚡
在这篇文章中,我们学习如何向量化一个 RL 环境,并在 CPU 上并行训练 30 个 Q 学习代理,每秒进行 180 万次迭代。
·
关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 10 月 15 日
--
图片来源于 Google DeepMind 的 Unsplash
在前面的故事中,我们介绍了时序差分学习,特别是Q 学习,并将其应用于 GridWorld 的背景中。
在动态网格世界中比较无模型(Q-learning)和有模型(Dyna-Q 和 Dyna-Q+)的 TD 方法。
[towardsdatascience.com
虽然这个实现用于展示这些算法在性能和探索机制上的差异,但速度非常慢。
实际上,环境和代理主要使用Numpy编写,这在强化学习中并非标准,尽管它使代码易于理解和调试。
在这篇文章中,我们将看到如何通过向量化环境和无缝并行化数十个代理的训练来扩展强化学习实验。特别地,本文涵盖了:
-
JAX 基础和强化学习的有用功能
-
向量化环境及其高速原因
-
在 JAX 中实现环境、策略和 Q-learning 代理
-
单代理训练
-
如何并行化代理训练,以及这有多简单!
本文中展示的所有代码均可在 GitHub上找到:
[## GitHub - RPegoud/jym: JAX 实现的 RL 算法和向量化环境
JAX 实现的 RL 算法和向量化环境 - GitHub - RPegoud/jym: JAX 实现的 RL…
github.com](https://github.com/RPegoud/jym?source=post_page-----49d07373adf5--------------------------------)
JAX 基础
JAX 是 Google 开发的另一种Python 深度学习框架,被 DeepMind 等公司广泛使用。
“JAX 是Autograd(自动微分)和XLA(加速线性代数,TensorFlow 编译器)的结合,旨在实现高性能数值计算。” — 官方文档
与大多数 Python 开发人员习惯的不同,JAX 不采用面向对象编程(OOP)范式,而是采用函数式编程(FP)[1]。
简而言之,它依赖于纯函数(确定性和无副作用)以及不可变的数据结构(不是在原地修改数据,而是创建具有所需修改的新数据结构)作为主要构建块。因此,FP 鼓励一种更具函数性和数学性的编程方法,使其非常适合于数值计算和机器学习任务。
让我们通过查看 Q 更新函数的伪代码来说明这两种范式之间的差异:
-
面向对象 方法依赖于一个包含各种 状态变量(如 Q 值)的 类实例。更新函数被定义为一个类方法,它 更新实例的 内部状态。
-
函数式编程 方法依赖于 纯函数。实际上,这个 Q 更新是 确定性的,因为 Q 值作为参数传递。因此,对这个函数的任何调用只要 输入相同 就会产生 相同的输出,而类方法的输出可能依赖于实例的内部状态。此外,数据结构 如数组在 全局范围 内被 定义 和 修改。
在 面向对象编程 和 函数式编程 中实现 Q 更新(作者制作)
因此,JAX 提供了各种 函数装饰器,在 RL 的上下文中尤为有用:
- vmap (向量化映射):允许作用于单个样本的函数应用于一个 批次。例如,如果 env.step() 是一个在单个环境中执行一步的函数,那么 vmap(env.step)() 是一个在 多个环境 中执行一步的函数。换句话说,vmap 为函数添加了一个 批次维度。
使用 vmap 向量化的 step 函数示例(作者制作)
-
jit (即时编译):允许 JAX 执行 “JAX Python 函数的即时编译” 使其 兼容 XLA。本质上,使用 jit 允许我们 编译函数 并提供 显著的速度提升(以在首次编译函数时的一些额外开销为代价)。
-
pmap (并行映射):类似于 vmap,pmap 实现了简便的并行化。然而,它不是为函数添加批次维度,而是复制函数并在 多个 XLA 设备 上执行它。注意:应用 pmap 时,jit 也会被 自动 应用。
使用 pmap 并行化的 step 函数示例(作者制作)
既然我们已经掌握了 JAX 的基础知识,我们将探讨如何通过向量化环境获得巨大的速度提升。
向量化环境:
首先,什么是向量化环境,它解决了什么问题?
在大多数情况下,RL 实验由于 CPU-GPU 数据传输 而 变慢。深度学习 RL 算法如 近端策略优化(PPO)使用神经网络来近似策略。
像深度学习中的常规做法一样,神经网络在训练和推理时使用GPU。然而,在大多数情况下,环境运行在CPU上(即使在使用多个环境并行的情况下也是如此)。
这意味着,通过策略(神经网络)选择动作并从环境中接收观察和奖励的常规 RL 循环需要不断的来回交换,这影响了性能。
此外,使用诸如 PyTorch 的框架而不进行“jitting”可能会导致一些开销,因为 GPU 可能需要等待 Python 将观察和奖励从 CPU 发送回来。
通常的 RL 批量训练设置在PyTorch中(由作者制作)
另一方面,JAX 使我们能够轻松地在 GPU 上运行批量环境,消除由 GPU-CPU 数据传输引起的摩擦。
此外,随着 jit 将我们的 JAX 代码编译为 XLA,执行不再(或至少减少)受到 Python 低效的影响。
RL 批量训练设置在JAX中(由作者制作)
有关元学习 RL 研究的更多细节和令人兴奋的应用,我强烈推荐Chris Lu的这篇博客文章。
环境、代理和策略实现:
让我们查看 RL 实验中不同部分的实现。以下是我们需要的基本函数的高级概述:
简单 RL 设置所需的类方法(由作者制作)
环境
该实现遵循Nikolaj Goodger在其关于在 JAX 中编写环境的精彩文章中提供的方案。
如何以 1.25 亿步/秒运行 CartPole
让我们从环境及其方法的高级视图开始。这是实现 JAX 环境的一般计划:
让我们更详细地查看类方法(作为提醒,函数以“_”开头的是 私有的 ,不应在类的作用域之外调用):
-
_get_obs:此方法将环境状态转换为代理的观察。在部分可观察或随机环境中,应用于状态的处理函数将在这里。
-
_reset:由于我们将并行运行多个代理,因此我们需要一个方法来在完成一个回合后进行单独的重置。
-
_reset_if_done:此方法将在每一步调用,并在“done”标志设置为 True 时触发 _reset。
-
reset:此方法在实验开始时被调用,以获取每个代理的初始状态以及相关的随机密钥。
-
步骤:给定一个状态和一个动作,环境返回一个观察(新状态)、一个奖励和更新后的“done”标志。
实际上,GridWorld 环境的通用实现如下:
请注意,如前所述,所有类方法都遵循函数式编程范式。实际上,我们从未更新类实例的内部状态。此外,类属性都是常量,在实例化后不会被修改。
让我们更仔细地看一下:
-
init: 在我们的 GridWorld 中,可用的动作是[0, 1, 2, 3]。这些动作通过自我移动转化为二维数组,并在步骤函数中添加到状态中。
-
_get_obs: 我们的环境是确定性和完全可观察的,因此代理直接接收到状态,而不是处理后的观察。
-
_reset_if_done: 参数env_state对应于(state, key)元组,其中 key 是一个jax.random.PRNGKey. 如果done标志被设置为 True,该函数会返回初始状态,然而,我们不能在 JAX jitted 函数中使用传统的 Python 控制流。使用jax.lax.cond,我们实际上得到一个等效的表达式:
def cond(condition, true_fun, false_fun, operand):
if condition: # if done flag == True
return true_fun(operand) # return self._reset(key)
else:
return false_fun(operand) # return env_state
- step: 我们将动作转换为移动,并将其添加到当前状态中(jax.numpy.clip确保代理保持在网格内)。然后我们更新env_state元组,然后检查环境是否需要重置。由于步骤函数在训练中频繁使用,对其进行 jitting 可以显著提高性能。@partial(jit, static_argnums=(0, )装饰器表示该类方法的“self”参数应被视为静态。换句话说,类属性是常量,在对步骤函数的连续调用中不会改变。
Q-Learning 代理
Q-learning 代理由update函数定义,以及一个静态的学习率和折扣因子。
再次强调,当对更新函数进行 jitting 时,我们将“self”参数传递为静态。同时,请注意,q_values矩阵是就地修改的,使用set(),其值未作为类属性存储。
Epsilon-Greedy 策略
最后,本实验中使用的策略是标准的epsilon-greedy 策略。一个重要细节是它使用随机平局,这意味着如果最大 Q 值不是唯一的,动作将从最大 Q 值中进行均匀采样(使用 argmax 会始终返回具有最大 Q 值的第一个动作)。如果 Q 值被初始化为零矩阵,这一点尤其重要,因为动作 0(向右移动)将始终被选择。
否则,策略可以通过这段代码总结:
action = lax.cond(
explore, # if p < epsilon
_random_action_fn, # select a random action given the key
_greedy_action_fn, # select the greedy action w.r.t Q-values
operand=subkey, # use subkey as an argument for the above funcs
)
return action, subkey
注意,当我们在 JAX 中使用key时(例如这里我们采样了一个随机浮点数并使用了 random.choice),通常的做法是之后拆分 key(即“转到新的随机状态”,更多细节见这里)。
单代理训练循环:
现在我们有了所有必要的组件,让我们训练一个单一的代理。
这是一个Pythonic的训练循环,正如你所见,我们基本上是使用策略选择一个动作,在环境中执行一步,并更新 Q 值,直到一个回合结束。然后我们重复这个过程N回合。正如我们稍后会看到的,这种训练代理的方式相当低效,但它以一种可读的方式总结了算法的关键步骤:
在单个 CPU 上,我们在 11 秒内完成了 10,000 个回合,以每秒 881 个回合和 21,680 步的速度。
100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]
Total Number of steps: 238 488
Number of steps per second: 21 680
现在,让我们使用 JAX 语法重复相同的训练循环。以下是rollout函数的高级描述:
使用JAX 语法的训练 rollout 函数(作者制作)
总结一下,rollout 函数:
-
初始化 观察值、奖励和完成标志为空数组,维度等于时间步的数量,使用jax.numpy.zeros. Q 值被初始化为一个形状为[timesteps+1, grid_dimension_x, grid_dimension_y, n_actions]的空矩阵。
-
调用env.reset()函数来获取初始状态
-
使用jax.lax.fori_loop()函数调用fori_body()函数N次,其中N是timestep参数
-
fori_body()函数的行为类似于之前的 Python 循环。在选择一个动作、执行一步并计算 Q 更新后,我们在原地更新 obs、rewards、done 和 q_values 数组(Q 更新目标是时间步t+1*)。
这种额外的复杂性导致了85 倍加速,我们现在以大约183 万步每秒的速度训练我们的代理。请注意,这里训练是在单个 CPU上进行的,因为环境较为简单。
然而,端到端的向量化应用于 复杂环境和受益于多 GPU 的算法时效果更佳(Chris Lu 的文章报告了 CleanRL PyTorch PPO 实现与 JAX 复现之间惊人的4000 倍加速)。
100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]
Total Number of steps: 1 000 000
Number of steps per second: 1 837 563
在训练我们的代理后,我们绘制了 GridWorld 中每个单元格(即状态)的最大 Q 值,并观察到它已经有效地学会了从初始状态(右下角)到目标(左上角)的路径。
GridWorld 中每个单元格的最大 Q 值的热图表示(作者制作)
并行代理训练循环:
如承诺的那样,现在我们已经编写了训练 单个代理 所需的函数,剩下的工作就是在批处理环境中训练 多个代理,几乎没有其他工作!
由于 vmap 的帮助,我们可以快速将之前的函数转换为处理数据批次。我们只需指定预期的输入和输出形状,例如对于 env.step:
-
in_axes = ((0,0), 0) 表示输入形状,由 env_state 元组(维度 (0, 0))和一个 observation(维度 0)组成。
-
out_axes = ((0, 0), 0, 0, 0) 表示输出形状,输出为 ((env_state), obs, reward, done)。
-
现在,我们可以在一个 array 的 env_states 和 actions 上调用 v_step,并接收一个处理后的 array,其中包含 env_states、observations、rewards 和 done flags。
-
注意,我们还对所有批处理函数进行了 jit 优化以提高性能(可以说,对 env.reset() 进行 jit 优化是多余的,因为它在我们的训练函数中只调用一次)。
我们必须做的最后一个调整是 为我们的数组添加批处理维度,以考虑每个代理的数据。
通过这样做,我们获得了一个函数,可以在 并行 训练 多个代理,与单个代理函数相比,只需最小的调整:
使用这个版本的训练函数,我们得到了类似的性能:
100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]
Total Number of steps: 100 000 * 30 = 3 000 000
Number of steps per second: 49 036 * 30 = 1 471 080
就这些了!感谢你读到这里,希望这篇文章为你提供了有关在 JAX 中实现矢量化环境的有用介绍。
如果你喜欢这篇文章,请考虑 分享 这篇文章并 收藏 我的 GitHub 仓库,谢谢你的支持! 🙏
[## GitHub - RPegoud/jym: JAX 实现的 RL 算法和矢量化环境
JAX 实现的 RL 算法和矢量化环境 - GitHub - RPegoud/jym: JAX 实现的 RL...
最后,对于那些希望深入了解的人,这里有一个 有用的资源 列表,帮助我入门 JAX 并撰写这篇文章:
精心策划的 JAX 文章和资源汇总:
[1] Coderized, (函数式编程) 最纯粹的编码风格,几乎不可能出错, YouTube
[2] Aleksa Gordić, 从零到英雄的 JAX YouTube 播放列表 (2022), The AI Epiphany
[3] Nikolaj Goodger, 用 JAX 编写 RL 环境 (2021)
[4] Chris Lu, 通过 PureJaxRL 实现 4000 倍加速和元进化发现 (2023), 牛津大学, Foerster 人工智能研究实验室
[5] Nicholas Vadivelu,Awesome-JAX (2020),一个 JAX 库、项目和资源的列表
[6] JAX 官方文档,使用 PyTorch 数据加载训练简单神经网络
使用 Modelbit 通过 Git 进行机器学习模型部署的版本控制
开发、部署和跟踪!
·
关注 发表于 Towards Data Science ·7 min read·May 10, 2023
--
介绍
版本控制对所有开发过程至关重要,允许开发者随时间跟踪软件变更(代码、配置、数据等)。
此外,它促进团队成员之间的合作,使他们能够在同一代码库上共同工作,而不会干扰彼此的工作。
在数据团队中,版本控制在部署模型时尤为关键。
它使他们能够准确识别发生了什么变化、何时发生变化以及是谁进行了更改——这是在部署过程中诊断和解决出现的问题或模型在部署后表现不佳时至关重要的信息。
模型版本控制(作者提供的图片)
在这种情况下,基于 git 的功能可以提供快速回滚到先前版本的能力。
因此,在本文中,我将展示如何利用 Git 功能为你的模型部署提供支持。
更具体地,我们将使用 Modelbit 的 git 功能进行部署,并将 GitHub 与 Modelbit 同步以实现协作功能。
让我们开始吧 🚀!
Git 对数据团队的重要性
在深入了解如何操作之前,让我们先建立更多关于基于 git 的版本控制的动机,以及它为何至关重要。
#1)协作
随着数据科学项目越来越大,有效的协作变得越来越重要。
使用版本控制,团队可以在相同的代码库/数据上工作,并改进相同的模型,而不会干扰彼此的工作。
分支模型(作者提供的图片)
此外,还可以轻松跟踪更改,审查彼此的工作,并解决冲突(如有)。
#2)可重复性
可重复性是构建可靠机器学习的关键方面之一。在一个系统上工作而在另一个系统上不起作用反映了不良的可重复性实践。
你可能会好奇,这为何重要?
它确保了结果可以被他人复制和验证,从而提高了你工作的整体可信度。
使用版本控制的可重复性(作者提供的图片)
版本控制使你能够跟踪用于生成特定结果的确切代码版本和配置,从而使将来复制结果变得更加容易。
这对于许多人可以使用的开源项目尤其有用。
#3)持续集成和部署(CI/CD)
CI/CD 使团队能够快速有效地构建、测试和部署代码。
在机器学习中,持续集成(CI)可能涉及自动构建和测试对 ML 模型的更改,一旦这些更改被提交到代码库。
在持续部署(CD)中,目标是将模型的最新更改反映出来,一旦它们通过了测试。因此,它应该无缝更新生产中的更改,使最新版本的模型对最终用户可用。
既然我们知道了版本控制从开发和部署的角度为何重要,那么让我们看看如何利用 git 基于的功能在 Modelbit 的部署阶段。
本地仓库与 Modelbit 集成
Modelbit 完全由 git 驱动。因此,每当你推送模型进行部署时,它会将部署内部维护为 git 仓库。
基于 Git 的部署(作者图片)
由于支持 git,它本地提供了所有版本控制的优点,适用于你的部署、模型和数据集。
更进一步,你可以从本地计算机克隆到远程 git 仓库,并执行所有 git 命令,如 git pull
、git push
或进行分支等。
连接到 Modelbit git 仓库
要访问 Modelbit git 仓库,你需要添加一个 SSH 密钥,以将你的本地计算机连接到 Modelbit。
打开终端并运行以下命令:
ssh-keygen -t rsa -b 4096 -C "My SSH key"
这将创建一个 SSH 密钥。要查看它,请运行以下命令:
cat ~/.ssh/id_rsa.pub
上述命令取自官方 GitHub 文档。
现在,复制 cat
命令的完整输出,并转到 Modelbit 仪表盘中的 Git 设置。点击“添加密钥”,然后粘贴上面获得的输出。如下所示:
添加 SSH 密钥(作者图片)
完成了!
现在我们已连接到 Modelbit 的远程 git 仓库。
部署模型
让我们从 Jupyter Notebook 推送一个模型用于部署。我不会详细说明,因为我已经在之前的博客中讲过这个。
## 直接从 Jupyter Notebook 部署机器学习模型
用一行代码部署机器学习模型
towardsdatascience.com
简而言之,你应该训练一个模型,定义一个预测函数,并将这个函数对象推送用于部署,如下所示:
## Train Model
from sklearn.linear_model import LinearRegression
model = LinearRegression().fit(x, y)
## Define Prediction function
def Linear_Model(input_x):
if isinstance(input_x, (int, float)): ## check input type
return model.predict([[input_x]])[0] ## prediction
else:
return None
## Deploy it
import modelbit
mb = modelbit.login() ## authenticate the notebook here.
mb.deploy(Linear_Model)
一旦我们部署了一个模型,我们会在 Modelbit 仪表盘中看到以下内容:
部署仪表盘(作者图片)
克隆 Modelbit 仓库
让我们克隆到这个仓库以查看其内容。在终端中运行以下命令。
modelbit clone my_linear_model
这将克隆到 Modelbit 的 git 仓库,并创建一个名为 my_linear_model
的文件夹。
克隆部署仓库(作者图片)
一旦你运行命令,复制获得的链接以进行认证。
如上所示,克隆创建了一个新的本地仓库,数据集、部署和端点在 Modelbit 的远程 git 仓库的主分支中。
当前的仓库结构如下:
my_linear_models
├── bin
├── datasets
├── endpoints
└── deployments
└── Linear_Model
├── source.py ## source code
└── data
└── model.pkl ## model pickle
将更改推送到 Modelbit
现在我们已经克隆到远程仓库,我们可以在本地进行任何更改并推送它们。
让我们将一个虚拟 CSV 文件添加到 Linear_Model
文件夹,提交到本地仓库,并推送到远程仓库。
my_linear_models
├── bin
├── datasets
├── endpoints
└── deployments
└── Linear_Model
├── source.py ## source code
├── dummy_data.csv ## added locally
└── data
└── model.pkl ## model pickle
让我们将 CSV 文件添加到暂存区:
git add deployments/Linear_Model/git dummy_data.csv
接下来,让我们将其提交到本地仓库:
git commit -m "Add dummy data csv"
最后,让我们推送它:
git push
这样,虚拟 CSV 文件已被提交到远程 Modelbit git 仓库。
注意:我们将 CSV 添加到
Linear_Model
文件夹而不是datasets
文件夹是有原因的。datasets
文件夹仅支持通过 SQL 查询的数据集。这些查询的结果在运行时可用于运行部署。目前还不支持其他自定义数据集。
分支
如果您希望在远程 Modelbit 仓库中创建和工作于一个单独的分支,这也是可能的。
按如下方式从仪表盘创建新分支:
分支远程仓库(图片由作者提供)
接下来,假设我们希望在这个分支上本地改进我们的模型。在你的笔记本中,你可以按如下方式切换到这个新分支:
## notebook.ipynb
mb.switch_branch("another_branch")
现在,从笔记本中进行的所有新部署(以及其他提交,如果有的话)将推送到远程 Modelbit git 仓库的another_branch
分支。
同步 GitHub
远程 Modelbit 仓库可以与您的个人 GitHub 仓库自动同步。
这对于在 Modelbit 部署上执行基于 GitHub 的代码审查、CI/CD 和 Pull Request 工作流特别有用。
#1)创建一个新的 GitHub 仓库
下面,我在 GitHub 上创建了一个空仓库。
新的 GitHub 仓库(图片由作者提供)
接下来,我们应该授予 Modelbit 对这个仓库的写权限。
#2 复制 GitHub 仓库的 SSH URL
在Code
— SSH
下,复制 URL。
仓库 SSH URL(图片由作者提供)
#3)在 Modelbit 中添加 Git 远程
在仪表盘中,转到Git Settings
,Add Git Remote
,粘贴复制的仓库 URL,然后Connect Remote
。
将 Git 远程添加到 Modelbit(图片由作者提供)
#4)授予 Modelbit 写权限
从上述同步面板中,复制部署密钥:
部署密钥(图片由作者提供)
现在转到 GitHub 仓库的Settings
,Deploy keys
,Add deploy key
。粘贴密钥,给它一个标题,授予写权限,然后点击Add key
。
在 GitHub 仓库中添加部署密钥(图片由作者提供)
完成!GitHub 仓库已自动更新:
GitHub 中的部署代码(图片由作者提供)
现在,远程 Modelbit git 仓库与您的 GitHub 仓库已同步,您可以将其用于各种协作工作。
结论
至此,我们的博客结束了。
在这篇文章中,我们学习了 Git 功能对数据团队的重要性,以及如何通过 Modelbit 使用 git 来支持模型部署。
接下来,我们查看了如何将 Modelbit 内部创建的远程 git 仓库连接到个人 GitHub 仓库。
话虽如此,Modelbit 仍处于开发初期,目前可能还不是其他服务(如 Heroku)的终极替代品。
然而,根据我在使用 Modelbit 和 Heroku 的经验,我认为 Modelbit 的部署过程更为简化,不论是对经验丰富的用户还是新手都更为友好。
我期待看到他们的后续发展!
感谢阅读!
实践中的版本控制:数据、机器学习模型和代码
原文:
towardsdatascience.com/version-controlling-in-practice-data-ml-model-and-code-e13c518067dc
MLOps 中版本管理的逐步指南
·发表于数据科学前沿 ·13 分钟阅读·2023 年 12 月 2 日
--
照片由 Christopher Gower 提供,来源于 Unsplash
版本控制是一个至关重要的实践!没有它,你的项目可能会变得杂乱无章,使得回滚到任何期望的点变得困难。你可能会丧失重要的模型配置、权重、来自长期训练的实验结果,甚至整个项目本身。你也可能会在代码出现问题时与队友产生分歧和冲突,从而阻碍有效的合作。在本文中,我们通过一个实际的例子来探讨版本控制的重要性,例子中使用了一些该领域最常见的工具。本文的完整代码库可以在相关仓库中访问。
目录:
· 1. 介绍
· 2. 工具
· 3. 设置你的项目
∘ 3.1. 项目文件夹
∘ 3.2. 项目环境
· 4. 代码版本管理
· 5. 数据版本管理
· 6. 模型版本管理
· 结论
1. 介绍
版本控制是记录文件或文件集随时间变化的实践,使用版本控制系统,以便我们可以在以后回忆特定版本。在 MLOps 中,版本控制是我认为在开始机器学习项目时需要考虑的首要原则之一。为了确保我们充分利用所有好处,版本控制应应用于不同的机器学习工作流步骤,包括数据、机器学习模型(ML 模型)和代码。
为什么要进行版本控制? 使用版本控制来管理代码、数据和模型可以实现可重复性(这是另一个重要的 MLOps 原则),通过允许在任何时间点重新创建项目的特定状态;跟踪和监控变化,通过建立系统化的方法来捕捉、记录和管理开发生命周期中的变化;协作,通过跟踪不同贡献者所做的更改,并高效地合并这些更改,以及其他许多重要的好处,如错误恢复和可追溯性。
版本控制用例? 让我们考虑一个具体的场景,在手写数字分类项目中,我们将用作本文的示例。
-
代码。 假设我们引入了优化以提高速度。然而,部署后,用户报告了预测中出现了意外的不准确。得益于项目的强大代码版本控制实践,我们可以迅速识别出与错误相关的提交,并在解决错误、修复并重新集成到主项目版本之前,暂时回滚部署。
-
数据。 假设我们决定扩展数据集以增强模型的泛化能力。然而,在扩展的数据集用于训练后,观察到模型性能出现了意外的变化。因此,我们审查版本控制历史,确定可能导致问题的具体扩展技术,并迅速回滚到数据集的先前版本。然后,我们共同优化数据扩展方法,确保只有经过验证的更改才会重新集成到主项目版本中。
-
机器学习模型。 假设现在,我们开始优化模型架构以提高准确性。我们实施了一个卷积神经网络(CNN)以改进特征提取,并将其集成到主项目中。然而,在部署过程中,出现了细微的差异,影响了实时预测。因此,我们回滚到之前更稳定的模型版本。接着,我们共同解决问题,进行彻底测试,然后将优化后的模型重新集成到主项目版本中。
尽管这篇文章专注于如何在项目中使用版本控制,但它也是我 MLOps 文章系列的一部分。此外,通过阅读我之前和之后的教程,你将能够从工作流程到模型部署和跟踪,创建自己的端到端 MLOps 项目。
如果你对 MLOps 感兴趣,可以查看我的文章:
-
教程 1:MLOps 的关键开始:探索其基本组件
-
教程 2:面向初学者的 MLOps 工作流程介绍
-
教程 6: 实践中的测试:代码、数据和 ML 模型
-
教程 7: 实践中的追踪:代码、数据和 ML 模型
2. 工具
在进行机器学习项目或任何计算机科学项目时,在开始编程之前,需要选择合适的工具。工具的选择取决于项目需求、团队专业技能、数据量和成本等不同因素。
在本文中,选择了以下工具:
-
Python 作为编程语言,结合了丰富的生态系统、社区支持、学习的便捷性、多功能性、集成能力、广泛的库、数据科学工具、可扩展性以及行业采用,这些因素共同促成了它在机器学习项目领域的突出地位。
-
Git 用于代码版本控制。Git,全称为全球信息跟踪器,是一种开源的分布式版本控制系统(DVCS),广泛用于软件开发中以跟踪项目开发过程中源代码的更改。它是一个关键工具,能够有效管理代码更改、无缝协作并保持可靠的版本历史。它已成为行业标准,被全球开发者用于各种规模的项目。
-
DVC 用于数据版本控制。DVC,全称为数据版本控制,是一种开源版本控制系统,广泛用于数据管理。它旨在管理大规模数据集、使项目可重现以及促进更好的协作。它在 Git 仓库之上运行,具有类似的感觉和流程。DVC 的一个关键特性是数据版本控制:它允许将数据集与代码分开进行版本控制。因此,数据可以被跟踪、共享,并在不同版本之间轻松切换。
-
MLflow 用于模型版本控制。它是一个开源平台,旨在管理端到端的机器学习生命周期并促进 ML 从业者之间的协作。其与流行库的兼容性和强大的社区支持使其成为以统一且可扩展的方式管理完整机器学习生命周期的有吸引力的选择。
3. 设置你的项目
在开始之前,确保你的系统上已安装 Git 和 DVC。如果尚未安装,你可以从官方 Git 网站和官方 DVC 网站下载并安装,或者如果你使用的是 Ubuntu,可以直接执行以下命令:
sudo apt install git-all # to install git
pip install dvc # to install DVD (do not install it for now!)
然而,强烈建议在安装 DVD 之前创建一个虚拟环境;因此,我们将在创建虚拟环境后几分钟内进行安装。此外,请注意:
DVC 并不替代或包含 Git。你必须在系统中安装
git
,以启用数据版本控制和快速实验等重要功能(推荐)。 [1]
3.1. 项目文件夹
让我们开始设置项目文件夹!为此,有几种方法,包括:
-
从头创建文件夹: 这是最直接的方法,但需要手动添加标准文件并随后构建项目。我不推荐在中大型项目中使用这种方法。
-
导入现有模板: 这是进行简单维护、易于协作以及确保良好透明性、可重复性和可重用性的最佳选择。本文中,我们将使用以下项目结构,用于通过这个 Github 模板或这个Cookiecutter MLOps 仓库创建的机器学习项目,但可以随意探索其他模板。如果你渴望深入了解 ML 项目结构,我邀请你阅读我专门讨论这一主题的文章:考虑 MLOps 的机器学习项目结构。
使用这个 Github 模板或这个Cookiecutter MLOps 仓库的项目结构
- 克隆/分叉现有项目: 当处理现有项目时,这通常是最佳选择。它支持协作和代码重用。对于本文,可以随意克隆或分叉我的仓库,以便轻松重用提供的代码。要克隆项目,请使用:
# Clone repository:
git clone git@github.com:Chim-SO/hand-written-digits-classification.git
使用 Github 模板或克隆 Github 仓库需要对 Github 有一定的了解。不过,请放心!你仍然可以跟随本教程,我将为你提供必要的命令和解释。
3.2. 项目环境
另一个重要步骤是设置虚拟环境,这是软件开发中的最佳实践,可以增强项目隔离、依赖管理、可重复性、协作和整体项目整洁性。
- 首先创建一个名为
handwritten-digits-classification-env
的虚拟环境并激活它:
python -m venv venv/handwritten-digits-classification-env
source venv/handwritten-digits-classification-env/bin/activate
-
在那之后,通常在使用 GPU 时,我们需要更新环境以适应适当的 Cuda 版本(有关详细信息,请参见这篇文章)。然而,为了使本教程简单易懂,并且项目要求较简单,因为数据和模型都不大,所以不需要 GPU。
-
最后,我们通过执行以下命令来安装需求和 DVC:
pip install -r requirements.txt
pip install dvc
4. 代码版本控制
在设置好仓库后,我们现在准备开始版本控制!在本教程中,我们采用一种简单的功能分支工作流。这种工作流涉及为每个新功能创建一个专门的分支,而不是直接更改主分支。然后,我们使用变基/合并方法将功能分支无缝集成到主分支中。
- 我们首先列出所有分支,并检查当前所在的分支,通常用星号(
*
)标记:
git branch # List local branches
* master
git branch -r # List remote branches
remotes/origin/HEAD -> origin/master
origin/master
git branch -a # List all local and remote branches
* master
remotes/origin/HEAD -> origin/master
remotes/origin/master
这里我只有一个分支,即主分支,并且它是当前分支。
- 如果你还不在主分支上,请使用以下命令切换到它:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration but you can skip it now
- 我们首先创建一个名为
feature/data
的分支,在其中添加所有与数据处理相关的代码:
git branch feature/data # to create a branch
git checkout feature/data # to switch to the created branch
# or use the combined creation and switch command
git checkout -b feature/data
- 在添加了所有必要的代码后,我们通过使用
merge
命令将代码导入主分支,该命令将命名分支中的更改并入当前工作分支:
git checkout master # switch to the main directory
git merge feature/data # apply changes to master
- 类似地,我们创建了另一个名为
feature/model
的分支,在其中添加了所有与模型创建、训练和验证相关的代码,并将其合并到主分支中:
# Model branch creation:
git checkout master # switch to the main directory
git checkout -b feature/model
# Development ...
# Merge branch
git checkout master # switch to the main directory
git merge feature/model # apply changes to master
此时,我们可以说我们创建了代码的简单第一个版本!现在是时候标记这个特定点,添加一个标签,如下所示:
git tag -a v1.0 -m "Version 1.0"
整个工作流描述如下:
每个圆圈代表一个提交,可以使用以下命令显示:
git log --pretty=format:"%h - %an, %ar : %s"
回到我们的代码问题示例:
- 假设在部署后出现了问题,我们决定暂时将部署回滚到之前的版本:
git revert <commit-hash> # Revert the merge commit
撤销操作通过创建一个新的提交来撤销指定提交所做的修改,但我们可能需要解决在此过程中出现的任何冲突,类似于常规合并中发生的情况。
- 通过使用提交历史,我们识别出模型分支中的一个特定优化可能导致了这个问题。因此,创建了一个名为
hotfix/inference-bug
的热修复分支来解决这个错误:
git checkout -b hotfix/inference-bug
- 然后我们对代码进行必要的修正并提交更改:
git commit -m "Fix bug in digit classification during inference"
- 错误修复经过彻底测试,并且如果我们真的在团队中工作,还会为代码审查打开一个新的拉取请求,最终热修复被合并到主分支中:
git checkout main
git merge hotfix/inference-bug
修复了错误的代码重新部署到生产环境。
5. 数据版本控制
现在代码准备好了,我们可以将数据集下载到其第一个版本,然后将其转换为csv
格式。
-
首先,我们需要确保存储数据的文件夹未被 git 忽略。这通过检查
.gitignore
文件并移除/注释掉排除数据文件夹的行来完成。如果你使用的是我提供的模板,请注释掉第 79 行。 -
现在,我们开始创建一个分支
feature/data-csv
,在项目文件夹内初始化 DVC 项目,最后将创建的文件添加到 Git 中:
# Branch creation:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration but you can skip it now
git checkout -b feature/data-dvc
# DVC initialisation:
dvc init
# Add to Git the created files:
git commit -m "chore: Initialize DVC."
- 然后,我们下载我们的数据集,将其添加到 DVC 中,并将新的 DVC 文件添加到 git 中:
# Download data
python src/data/ingestion.py -r data/raw
# Add data to dvc
dvc add data/raw/test_images.gz data/raw/test_labels.gz data/raw/train_images.gz data/raw/train_labels.gz
#Add dvc files to git and commit
git add data/raw/.gitignore data/raw/test_images.gz.dvc data/raw/test_labels.gz.dvc data/raw/train_images.gz.dvc data/raw/train_labels.gz.dvc
git commit -m "Add raw data"
将文件添加到 dvc 将生成存储在新文件中的元数据,文件扩展名为.dvc
。另外,请注意,即使数据文件夹由 git 跟踪,但一旦我们将其添加到 DVC 中,它会创建.gitignore
并将数据路径添加到其中,因此它将被忽略。
- 我们还将其转换为
csv
格式,并将生成的文件data/preprocessed/train.csv
和data/preprocessed/test.csv
以与之前相同的方式添加到 DVC 中:
# transform data:
python src/data/build_features.py -r data/raw/ -p data/processed/
# Add to dvc:
dvc add data/processed/train.csv data/processed/test.csv
#Add dvc files to git and commit
git add data/processed/.gitignore data/processed/test.csv.dvc data/processed/train.csv.dvc
git commit -m "Add processed data"
目前,数据已下载并创建。下一步是合并到主分支并添加 git 标签:
# Apply changes:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration, you can skip it now
git merge feature/data-dvc # apply changes to master
# Tag this point:
git tag -a v1.1 -m "Data collected and processed"
回到我们的数据问题示例:
- 假设我们对处理过的数据进行了离线增强,并将其添加到 DVC 中:
# Add to dvc after update:
dvc add data/processed/train.csv data/processed/test.csv
git add data/processed/test.csv.dvc data/processed/train.csv.dvc
git commit -m "Data augmentation offline"
- 然而,训练后,模型表现不佳,因此我们决定重新使用之前的版本,如下所示:
git checkout data/processed/test.csv.dvc data/processed/train.csv.dvc
dvc checkout data/processed/test.csv.dvc data/processed/train.csv.dvc
6. 模型版本控制
如前所述,我们使用 MLflow 来跟踪和管理我们的模型。由于目前我们将在本地工作,我们启动一个本地的 MLflow Tracking Server:
mlflow server --host 127.0.0.1 --port 8080
- 创建一个分支,我们在其中训练并保存模型:
# Branch creation:
git checkout master # switch to the main directory
git checkout -b feature/model-dvc
- 现在,我们初始化 MLflow 运行上下文以启动一个运行,训练模型,然后使用 MLflow 保存模型:
# Create model:
model = create_model(x_train[0].shape)
# Log parameters:
loss = 'categorical_crossentropy'
metric = 'accuracy'
# Train:
model.compile(loss=loss, optimizer='adam', metrics=[metric])
history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, verbose=1,
validation_data=(x_val, y_val))
# ....
# Set tracking server uri for logging
mlflow.set_tracking_uri(config['mlflow']['tracking_uri'])
# Create an MLflow Experiment
mlflow.set_experiment(config['mlflow']['experiment_name'])
# Start an MLflow run
with mlflow.start_run():
# Save model:
signature = infer_signature(x_train, y_train)
mlflow.tensorflow.log_model(model, output_path, signature=signature)
#Log other metrics and parametrics:
# Next tutorial.
- 合并到主分支并添加 git 标签:
# Apply changes:
git checkout master # switch to the main directory
git merge feature/model-mlflow # apply changes to master
# Tag this point:
git tag -a v1.2 -m "Model versioning mlflow"
- 使用以下命令训练模型:
python -m src.models.cnn.train -c configs/cnn.yaml
其中configs/cnn.yaml
文件包含一些配置参数,如批量大小和训练轮次。
- 我们可以在 MLflow UI 中查看运行结果,只需在浏览器中导航到之前的 URL。点击实验名称
cnn
以列出其相关的运行,然后点击为运行生成的随机名称:
- 通过点击运行名称,显示 RUN 页面,其中显示了执行的详细信息:
当你使用 MLflow 保存模型时,它会创建一个包含以下内容的目录结构:
-
data
文件夹包含序列化的模型参数文件。 -
MLmodel
文件,包含有关模型的元数据,如框架、模型签名及其他属性。 -
conda.yaml
、python_env.yaml
和requirements.txt
文件,这些文件帮助重建相同的环境以便加载模型。
它还提供了对模型架构的深入理解,并展示了如何执行预测,支持 Spark DataFrame 和 Pandas DataFrame 两种灵活的选择。MLflow 的另一个显著特点是能够保留生成模型的提交 ID。此外,它引入了一个简单的模型注册选项,这是一个将在后续文章中探讨的话题。
结论
我们来到了本文的结尾。在这篇文章中,我们通过一个实际示例学习了如何对机器学习项目中的三个元素:代码、数据和机器学习模型进行版本控制。版本控制是 MLOps 的一个基本原则,它能实现细致的跟踪、无缝的协作以及机器学习工作流程的强大可复现性。本文的完整代码库可在 相关的仓库 中访问。
感谢阅读本文。您可以在我的 GitHub 个人主页 中找到我提供的所有不同教程的示例。如果您喜欢我的教程,请通过关注我和订阅来支持我。这样,您将收到我新文章的通知。如果您有任何问题或建议,请随时留言。
初学者友好的 MLOps 介绍
towardsdatascience.com
参考文献
图片来源
本文中所有未在说明中提及来源的图片和图表均由作者提供。
VIM 教程 — 像专业人士一样编辑文本
原文:
towardsdatascience.com/vim-tutorial-edit-text-like-a-pro-ea13e45010f5
易于遵循的 Vim 文本编辑器教程
·发布于 Towards Data Science ·12 分钟阅读·2023 年 3 月 6 日
--
图片由 Pereanu Sebastian 在 Unsplash 上提供
Vim 是一个功能强大且高度可定制的文本编辑器,自 1991 年以来一直存在。它以其模式编辑界面而闻名——这意味着它允许用户在不同的模式之间切换,用于编辑、导航和选择文本。Vim 被程序员和其他需要快速高效方式编辑和操作文本文件的高级用户广泛使用。使用 Vim 的主要优势之一是其速度和效率。一旦你掌握了命令和界面,你可以用几个击键执行复杂的编辑任务。
在这篇文章中,我们将介绍 Vim 中最常用的命令。我在本节之后包含了一个速查表,供将来参考。速查表之后是命令的详细教程。我建议首先阅读详细教程部分,然后在未来使用时,你可以简单地参考速查表。
开始吧!
(除非另有说明,否则所有图片均由作者提供。)
[## 通过我的推荐链接加入 Medium - Farzad Mahmoodinobar
阅读 Farzad (以及 Medium 上其他作者) 的每一个故事。你的会员费直接支持 Farzad 和其他人……
medium.com](https://medium.com/@fmnobar/membership?source=post_page-----ea13e45010f5--------------------------------)
重要 Vim 命令概览表
下表总结了一些你需要了解的重要 Vim 命令。每个命令在下一节中有更详细的描述。
Vim 命令速查表
Vim 命令 — 详细教程与实现
在本节中,我们将详细介绍一些最常用的 Vim 命令。我们还将把每个命令应用于文本文件以查看结果。我们将使用命令行界面(CLI)来运行 Vim 命令(如果你需要复习 CLI,可以访问这个 CLI 教程)。如果你准备好了,打开 CLI,开始吧!
1. “vim"
— 创建文件
首先,我们将创建一个名为 demo.txt
的文件,以便可以在该文本文件中应用命令,如下所示:
vim demo.txt
这个命令使用 Vim 在 CLI 中创建一个文本文件。一旦运行此命令,你将在 CLI 中看到一个类似于这张图片的空白页面:
如你所见,文件是空的,我们只能在左下角看到文件名。让我们看看如何在文本文件中添加一些文本。
2. “i” — 插入
i
命令代表“insert”,用于进入 Vim 的插入模式。当你处于插入模式时,可以像在任何其他文本编辑器中一样输入和编辑文本。要使用此命令,只需在普通模式下按 i
。
继续按 i
,然后输入以下内容:
Topic: VIM Tutorial
Author: Farzad Nobar
If you find this helpful, follow me on Medium at: https://medium.com/@fmnobar
确保不要按其他任何键。以下是此时我 CLI 的样子:
我们可以在该图片中看到文本,并且在最底部我们可以看到我们处于 — INSERT —
模式。为了退出该模式,只需按下键盘上的 Escape 键(确保在进入下一步之前执行此操作)。
到目前为止,我们已经在文件中添加了这些文本,但尚未保存更改。接下来我们来看看这个问题。
3. “:w” — 写入
:w
命令代表“write”,用于在 Vim 中保存对文件的更改。要使用此命令,只需在普通模式下输入 :w
并按下回车键(或 Enter)。这将保存对文件所做的任何更改,以下是更改后我的 CLI 的样子:
专业提示: 请注意,“Write”命令是在
:
之后输入的。在 Vim 中,我们通过输入命令:
进入“命令”模式。换句话说,我们首先通过输入:
进入命令模式,然后输入w
,这就成为了“Write”命令。我们可以随时通过按下 Escape 键退出命令模式。我们将在接下来的部分中更多地使用命令模式。
现在我们已经保存了更改,接下来我们可以退出(即关闭)文件。
4. ":q"
— 退出
:q
命令代表“quit”,用于退出 Vim。要使用此命令,只需在普通模式下输入 :q
(然后按下回车/Enter)。如果文件中有未保存的更改,Vim 会提示你在退出之前保存这些更改。
5. “:wq” — 写入并退出
:wq
命令代表“写入并退出”,用于将更改保存到文件并退出 Vim(而不是像之前的示例那样分开进行)。要使用此命令,在普通模式下键入:wq
。这将保存你对文件所做的任何更改并退出 Vim。
6. “h”、“j”、“k”、“l”、“w”、“b” — 文本文件中的导航
这些命令用于在 Vim 中移动光标。h
将光标向左移动,j
向下移动,k
向上移动,l
向右移动。w
将光标移动到下一个单词的开头,b
将光标移动到当前单词的开头。这些命令在普通模式下使用。
为了练习这些,请使用以下命令打开我们创建的文件:
vim demo.txt
现在你已经在 Vim 中打开了文件,你可以使用导航键进行移动——试试看吧!
技巧提示: 请注意,打开现有文件和创建新文件的命令是相同的。如果文件已存在,Vim 将打开现有文件。如果文件不存在,Vim 将创建一个新文件。
7. “x” — 删除
x
命令用于删除光标下的字符。要使用此命令,只需将光标放在你要删除的字符上,并在普通模式下按x
即可。
让我们用这个来删除“Topic”中的“ic”字母。我看到的结果如下:
但如果我们不小心删除了这些呢?让我们看看下一步如何撤销这个操作。
8. “u” — 撤销
u
命令用于在 Vim 中撤销更改。要使用此命令,只需在普通模式下按u
即可。
我按了u
两次,删除操作被撤销,如下所示:
9. “Ctrl-r” — 重做
Ctrl-r
命令用于在 Vim 中重做更改。要使用此命令,只需在普通模式下按Ctrl-r
即可。
假设我们又改变了主意,实际上我们是要删除那两个字母。我按了Ctrl-r
两次,在下面的截图中我们可以看到字母再次被删除:
好的。我们先暂时保留这些更改。接下来我们来谈谈如何使用 Vim 搜索术语。
10. “dd” — 删除行
dd
命令用于在 Vim 中删除整行。要使用此命令,只需在普通模式下按dd
即可。
11. “/search_term” — 搜索
/
命令用于在 Vim 中搜索术语。要使用此命令,键入/
后跟你要搜索的术语,然后按回车。Vim 将突出显示文件中术语的下一个出现位置。
试试看吧。我使用/medium
搜索了“medium”,结果如下:
正如我们在图片中看到的,光标跳到了文本最后一行的“medium”一词的开头。
现在让我们使用 :set hlsearch
高亮显示搜索结果,现在搜索的词汇(例如本例中的“medium”)被高亮显示,如下图所示:
我们可以使用 :set nohlsearch
关闭此高亮搜索。
12. “😒/search_term/replace_term/g” — 替换
:s/search_term/replace_term/g
命令用于在 Vim 中替换所有出现的词汇。要使用此命令,在普通模式下键入 :s/search_term/replace_term/g
并按回车键。Vim 将用替换词汇替换所有出现的搜索词汇。
我通过使用以下命令将“Tutorial”替换为“guideline”进行尝试::s/Tutorial/Guideline/g
以下是结果:
接下来,让我们讨论一下如何完成文本复制并粘贴的操作。
13. “yy” — 剪切
yy
命令用于在 Vim 中复制一行文本。要使用此命令,只需将光标放在要复制的行上,并在普通模式下按 yy
。
14. “p” — 粘贴
p
命令用于粘贴已被剪切或删除的文本。要使用此命令,只需在普通模式下按 p
键即可。
让我们尝试剪切和粘贴命令。首先在文件的开头按 yy
,然后按 p
,看看会发生什么。请注意,按下 yy
后实际上没有明显的变化,但实际情况是在后台进行的。一旦按下 p
,你将看到结果。我碰巧按了两次 p
,以下是结果:
15. ":set number"
— 显示行号
:set number
命令用于在 Vim 中显示行号。要使用此命令,在普通模式下键入 :set number
。
这个命令不言而喻。以下是我输入命令后 CLI 的样子:
16. “:set nonumber” — 隐藏行号
:set nonumber
命令用于在 Vim 中隐藏行号。要使用此命令,在普通模式下键入 :set nonumber
。
以下是隐藏行号的结果:
17. “:set paste” — 无自动缩进的粘贴
:set paste
命令用于粘贴文本而不进行自动缩进。要使用此命令,在普通模式下键入 :set paste
,然后粘贴文本,最后键入 :set nopaste
退出粘贴模式。此命令的一个应用场景是粘贴对缩进敏感的文本输入,例如编程脚本(例如 Python 中的 for 循环等)。
18. “:set spell” — 拼写检查
:set spell
命令用于启用 Vim 的拼写检查。要使用此命令,在普通模式下键入 :set spell
。Vim 会高亮显示任何拼写错误的单词。
让我们看看 Vim 是否在我们的文件中识别出了什么:
正如预期的那样,我的名字被突出显示,Vim 未将其识别为正确拼写的词——我不会把这个放在心上!
19. “:set nospell” — 禁用拼写检查
:set nospell
命令用于在 Vim 中禁用拼写检查。要使用此命令,请在正常模式下输入:set nospell
。
20. “:w filename” — 写入文件
:w filename
命令用于将当前文件的内容写入一个具有指定文件名的新文件。要使用此命令,请在正常模式下输入:w filename
。
试试这个。我将首先使用以下命令将文件内容保存到名为demo2.txt
的新文件中::w demo2.txt
。然后我关闭了 Vim(使用:q!
),接着查看文件(使用 CLI 中的ls
命令),看到新文件demo2.txt
现在被添加到了那里。让我们用 Vim 打开它,看看它的样子:
21. “:q!” — 不保存退出
:q!
命令用于退出 Vim 而不保存对文件的任何更改。要使用此命令,请在正常模式下输入:q!
。如果你在之前的例子中注意到,我使用:q!
退出了 Vim,因此我们期望更改没有被保存。让我们打开demo.txt
并进行验证。结果如下:
正如预期的那样,修改没有被保存,我最终得到了原始文件。
22. “:e filename” — 打开一个文件
:e filename
命令用于在 Vim 中打开一个文件。要使用此命令,请在正常模式下输入:e filename
。
现在我们已经打开了demo.txt
文件,接下来使用以下命令打开第二个名为demo2.txt
的文件,结果如下:
正如预期的那样,第二个文件已被打开。
23. “:set tabstop=4” — 设置制表符宽度
:set tabstop=4
命令用于将 Vim 中的制表符宽度设置为 4 个空格。要使用此命令,请在正常模式下输入:set tabstop=4
。
试试这个。我将首先将 tabstop 设置为 4,进入插入模式,然后在文件第二行的开头添加一个制表符,结果如下:
现在让我们将 tabstop 大小更改为 8,看看它如何变化:
请注意,现有的制表符在文件中自动应用了更改,制表符大小从 4 增加到 8,如上图所示。
24. “:set expandtab” — 使用空格代替制表符
:set expandtab
命令用于在 Vim 中将制表符替换为空格。要使用此命令,请在正常模式下输入:set expandtab
。这个命令一目了然,我们可以继续下一个命令。
25. “:set syntax=python” — 语法高亮
:set syntax=python
命令用于在 Vim 中启用 Python 代码的语法高亮显示。要使用此命令,请在正常模式下输入:set syntax=python
。
我很想包括这一点,因为它与我的工作有关,并且我每天都在使用 Python。为了测试这一点,让我们采取以下步骤:
1. 使用 Vim 打开一个新的 Python 文件,方法是:vim demo.py
2. 粘贴以下命令(从我关于多变量分析的帖子中选择)
# Import libraries
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
%matplotlib inline
# Show all columns/rows of the dataframe
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
# Read the data
df = pd.read_csv('auto-cleaned.csv')
# Return top 5 rows of the dataframe
df.head()
以下是我看到的结果:
3. 然后我通过使用:set syntax=python
将语法设置为 Python
4. 然后我通过使用:syntax on
来开启语法高亮
以下是结果,它符合预期的 Python 语法:
正如预期的那样,Python 语法得到了识别和颜色编码。
结论
在这篇文章中,我们介绍了 Vim,这是一个在命令行界面(CLI)中使用的多功能文本编辑器,被程序员和高级用户所利用。然后我们逐步讲解了一些在 Vim 中最常用的命令及其示例。我希望通过本教程,你对使用 Vim 感到更加舒适,并开始将其整合到你日常的文本编辑活动中。
感谢阅读!
如果你觉得这篇文章对你有帮助,请在 Medium 上关注我并订阅以接收我最新的帖子!
视觉基础的重复计数在实际应用中的探索
对视觉基础重复计数的不同方法进行回顾
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 2 月 21 日
--
src: 图片由 @paragdmehta 提供,展示了一个重复的模式。
在这篇文章中,我尝试解释我对不同视觉基础的重复计数技术的探索,并讨论它们的优缺点。具体而言,我重点介绍了计算机视觉在重复计数中应用的五种主要方式。
穿戴式传感器在重复和组计数中非常受欢迎。由于这些传感器价格昂贵,而且在大多数情况下仅限于跟踪特定的身体部位,近年来,越来越多的关注转向了使用视觉基础的方法进行重复计数。
从活动监控、运动和游戏中的无数应用,到帮助了解生物事件(心跳、脉搏计数等)发生的次数,重复计数是学术界和工业界积极解决的问题。
关键词: 重复计数、计算机视觉、姿态估计。
目录
-
RepNet: 无类别的重复计数应用
-
基于规则的姿态估计运动重复计数
-
基于信号处理的运动重复计数
-
GymCam
-
基于深度学习的光流方法进行重复计数
我们在博客中讨论的大多数技术不会是通用的,而是专门针对特定问题(例如锻炼)。此外,若要深入了解该技术,请参阅提供的参考文献。
1. RepNet: 无类别的重复计数应用
论文: Counting Out Time: 无类别的视频重复计数 [1]
图:RepNet 体系结构。源:arxiv.org/pdf/2006.15418.pdf
在重复计数方面最突出的工作之一是 RepNet,这是一个端到端的深度学习模型,可以准确预测广泛的重复运动计数。
RepNet 模型接收视频流作为输入,并预测两个输出:
每帧周期长度: 对于每一个重复动作的帧,我们希望知道该动作的周期长度(以时间单位表示)。
每帧周期性: 表示当前帧是否为重复的一部分的分数。
RepNet 模型的几个关键亮点包括:时间自相似矩阵(TSM):
图:RepNet 利用时间自相似矩阵(TSM)。源:arxiv.org/pdf/2006.15418.pdf
-
TSM 是这种重复计数技术的亮点。它是 RepNet 体系结构的信息瓶颈。该矩阵通过计算两个嵌入之间的成对相似性函数来帮助将帧彼此关联起来。
-
还可以通过(使用启发式方法)从这些 TSM 推断重复次数,这使得 RepNet 模型的预测具有解释性。
-
多样化的真实世界重复视频确保这些 TSM 足够多样化,因此 RepNet 除了重复计数之外还有一系列应用。
这种重复计数方法最令人印象深刻的一点是它是无类别的(通用的),适用于广泛的重复动作。RepNet 模型是流行的 Transformers 在计算机视觉中应用的经典实例。
源:arxiv.org/pdf/2006.15418.pdf
然而,该模型存在一定限制,因为输入视频中的帧数必须有限。这可以归因于 TSM 的大小等于输入帧的数量。
该模型相当庞大且复杂,因此在移动应用程序或任何生产环境中部署这一模型将是相当具有挑战性的,并可能存在延迟问题。
2. 基于规则的姿态估计重复计数
博客: 使用计算机视觉赢得互动 [2]
这是行业中最常见的想法。许多健康和健身初创公司一直在努力构建准确、轻量化的最先进姿态估计模型,这些模型可以用来准确计算运动中的重复次数,并提供姿势纠正反馈等。
涉及的主要步骤:
-
给定一个特定的练习,你首先需要为该练习中的状态制定定义(规则)。一个练习中可以有多个状态。例如,深蹲练习可以分为两个状态,即下半部状态和上半部状态。在运动过程中,做练习的人将从一个状态转移到另一个状态。这些状态规则可以看作是运动过程中激活区域的表示。
-
例如,对于深蹲,这些规则可以是(th指阈值):
down: (left_knee_hip_dist_y < th1 and right_knee_hip_dist_y < th2)
up: (left_knee_hip_dist_y > th3 and right_knee_hip_dist_y > th4)
- 在推理过程中,我们首先实时计算模型中姿态关键点的度量(角度、归一化的距离),并检查某个特定规则是否被激活,然后使用标志进行重复计数。
来源:使用计算机视觉赢得互动。图片由作者提供。
该方法的主要优点之一是重复计数快速且准确,延迟非常低。然而,一些主要的缺点包括:
-
这不是通用的重复计数。
-
姿态估计模型对背景噪音非常敏感,因此对重复计数也有影响。
-
可扩展性问题: 手动编写规则是一个耗时的过程。我们还需要测试不同角度、方向、姿势等的规则变体。想象一下为语料库中的数百种练习编写规则。
3. 使用信号处理思想进行重复计数
博客: 使用信号处理思想构建运动重复计数器 [3]
目标: 使用零交叉和峰值检测等信号处理思想来制作一个运动重复计数器。
这种方法与基于规则的重复计数非常相似,只是需要手动编写不同状态的规则。这种方法通过推断参考线(可以视为状态边界)来半自动化状态计算方法,使用训练师的视频来进行特定运动/练习,然后使用该参考线来计数任何视频中的该练习的重复次数。
在这里,我们将运动视为关键点度量的波形集合。这些度量包括不同身体关键点之间的角度和距离,关键点是通过姿态估计模型(Tensorflow 的 Movenet 姿态估计模型)计算的。
主要步骤:
-
我们首先使用训练师参考视频(作为输入)计算关键点之间的度量(距离和角度)。这些度量在时间上表示为信号。
-
我们过滤掉所有静态信号,并创建非静态信号的组合信号。然后,我们使用汇总信号的均值计算参考线。
-
在推断过程中,我们首先在测试用户输入视频上计算度量,并实时计算整体组合信号。
-
我们创建一个固定大小的移动窗口,并检查整体信号(来自 3)与参考线(来自 2)的交点。这个交点表明重复动作已经完成。
使用零交叉思想的结果。图片由作者提供。 源
这种方法快速、易于实施且相当准确。然而,主要的缺点包括以下几点:
-
重复计数是 独占性的 和 非通用的。
-
对背景噪音高度敏感。
-
缩放问题:需要使用参考视频计算零交叉线来进行任何活动(同时确保视频没有噪音)。
4. GymCam
论文: 在不受约束的场景中检测、识别和跟踪同时进行的运动[4]
GymCam 是一个基于视觉的系统,用于自动化的运动重复计数和跟踪。它基于这样的假设:任何在健身房内的重复运动都属于某种运动。在这里,系统的输入是来自摄像头的视频流,输出是多个与运动相关的度量,包括重复计数。
步骤总结
GymCam:主要步骤。图片由作者提供。
-
使用密集的 光流 算法检测视频中的所有潜在运动轨迹。 运动轨迹也可能是非运动活动的结果,例如热身、用户的步态、到处走动等。
-
检测所有 运动 轨迹在场景中。 他们是怎么做到的?首先,他们执行一个特征提取步骤,从任何轨迹的 5 秒窗口中提取手工特征。他们使用一个MLP 基于的二分类器模型,该模型输入特征并输出该输入轨迹(特征)是否与运动相关的概率。
-
聚类 运动轨迹在空间和时间上的练习。 聚类后,通过将所有属于给定聚类的轨迹合并生成一个平均运动轨迹。请注意,聚类的数量是预定义的。这些平均轨迹随后用于计数和跟踪练习的重复次数。
-
重复计数 和 运动识别: 平均轨迹然后被转换为特征向量,这些特征向量被输入到 MLP 回归模型和 MLP 分类模型中,以分别推断重复计数和运动标签。
从组合轨迹中进行重复计数和练习识别。图片由作者提供。
该系统的一些显著特点是:它是一个端到端 系统,可以在现实世界环境中执行重复计数。光流 识别所有运动,因此即使用户几乎不可见,也足以跟踪运动并执行重复计数。
系统存在的问题:
-
多个用户重叠 在视频中进行运动。因此,很难确定这些用户的确切边界并推断重复次数。
-
噪声敏感: 噪声人类行为,如热身、休息、用户的步态等,可能会表现出周期性,因此可能对重复计数产生不必要的贡献。
-
重复计数不是通用的:该系统仅限于运动重复计数。
5. 基于 DL 的光流方法
博客: 使用深度学习和光流算法的运动计数应用 [5]
另一个有趣的思路是利用视觉解决重复计数问题,即光流方法。
主要步骤
- 使用密集光流算法查找重复活动中视频帧的颜色编码表示。 这里的关键在于不同的重复运动状态将有不同的颜色编码。
密集光流将向下运动编码为绿色,将向上运动编码为紫色。Gif 由作者提供。
关于光流算法的详细信息,请参见opencv文档 这里(以及实现)。
2. 数据集创建:下一步是生成一个带有颜色编码的图像和视频数据集,并将其标注为不同的运动状态(如向上或向下)。
3. 模型训练:下一步涉及训练一个普通的 CNN 模型,以对帧进行多类分类。在测试时,将来自光流的颜色编码帧输入到模型中,模型预测运动状态之一并捕获类别标签。这基本上是一个通过模型进行的颜色匹配问题,因为模型更为健壮。
该方法准确且容易在生产环境中部署。然而,缺点往往超过优点:
-
重复计数是 独占性 和 类别依赖性 的。
-
扩展问题:每次将新运动添加到语料库时,都需要标注数据集并训练模型。
-
方向敏感性:相同的动作在不同的方向上会有不同的颜色编码,从而导致错误的模型预测。这是该方法的主要限制之一。
-
噪声敏感性:背景中的任何轻微噪声都会改变这些颜色编码,从而影响模型的预测。
参考文献
[1]. Dwibedi, Debidatta 和 Aytar, Yusuf 和 Tompson, Jonathan 和 Sermanet, Pierre 和 Zisserman, Andrew. 计时:在实际环境中进行类无关的视频重复计数。IEEE/CVF 计算机视觉与模式识别会议(CVPR)。DOI: doi.org/10.48550/arxiv.1902.09868
[2]. Aakash Agrawal. 利用计算机视觉赢得交互性。 The .fit 方式。
[3]. Aakash Agrawal. 利用信号处理中的想法构建一个运动重复计数器。Towards Data Science。
[4]. Rushil Khurana, Karan Ahuja, Zac Yu, Jennifer Mankoff, Chris Harrison, 和 Mayank Goel. 2019. GymCam:在不受约束的场景中检测、识别和跟踪同时进行的运动。Proc. ACM Interact. Mob. Wearable Ubiquitous Technol. 2, 4, Article 185. DOI: doi.org/10.1145/3287063
[5]. Art Kulakov. 我如何使用深度学习和光流算法创建健身动作计数应用程序。Towards Data Science。
我希望你喜欢探索一些基于视觉的重复计数技术。大多数想法都很容易实现和部署。我希望了解任何阅读这篇博客的人的反馈。我很乐意回答有关上述任何概念的疑问/问题。欢迎反馈。你可以通过 Linkedin 联系我。
感谢阅读!
面向视觉的语义占用预测用于自动驾驶
关于 2023 年上半年学术“占用网络”的综述
·发布在Towards Data Science ·阅读时间 11 分钟·2023 年 5 月 29 日
--
现有的 3D 物体检测方法在自动驾驶中的一个关键痛点是它们通常输出简洁的 3D 边界框,忽略了更精细的几何细节,并且在处理一般的、超出词汇表的物体时表现不佳。这一痛点存在于单目 3D 物体检测和 BEV 多摄像头物体检测中。为了解决这个问题,占用网络作为一种面向视觉的通用障碍检测解决方案,首次在特斯拉在 CVPR 2022 的主题演讲中介绍,并在AI Day 2022中推广。更多细节请参见之前关于可行驶空间的博客系列。
截至 2023 年的行业应用最新趋势
在学术界,与占用网络相对应的感知轨道被称为 语义占用预测(SOP),有时也称为 语义场景完成(SSC),两者之间有一些细微的区别。语义占用预测为场景中的每个体素分配占用状态和语义标签。这是一种通用且足够表达的表示方式,可以描述已知类别但形状不规则或不在已知白名单中的物体。本文将回顾截至 2023 年初的语义占用预测的最新方法。这个领域在学术界受到广泛关注,顶级会议上提交了大量论文,今年的 CVPR 也有一个占用预测挑战。
语义场景完成 vs 语义占用预测
语义场景完成(SSC)的概念最初在 SSCNet 论文(CVPR 2017)中提出,后来由 SemanticKITTI(ICCV 2019)普及,该数据集提供了官方数据集和比赛赛道。最近,出现了一种略有不同的任务语义占用预测(SOP)。SSC 和 SOP 都旨在预测给定空间位置处体素的占用状态和语义类别,但存在一些细微差别。首先,SSC 的输入模式是由 LiDAR 或其他主动深度传感器收集的部分 3D 数据,因此称为 3D 语义场景的“完成”。SOP 使用 2D 图像,可能是多摄像头和多帧的。除此之外,SSC 通常关注静态场景,而 SOP 也可以处理动态物体。总之,SOP 似乎是更通用和更受欢迎的术语,在本文中,我们将交替使用语义占用预测、语义场景完成和占用网络。
MonoScene 的开创性工作首次使用单目图像在 SemanticKITTI 上进行语义占用预测任务。它仍然将任务称为 SSC,这可能是因为 SemanticKITTI 主要包含静态场景。后续研究更倾向于使用 SOP 这一术语,将任务扩展到 NuScenes 和 Waymo 等其他数据集,并处理动态物体。
开创性工作的高级总结
我将首先总结过去一年研究的爆炸性增长,然后跟进各种技术细节的总结。下面是一个总结待审阅工作的整体发展脉络的图示。值得注意的是,该领域仍在快速发展,并且尚未趋向于一个普遍接受的数据集和评估指标。
语义占用预测领域的发展时间线(来源:作者创作)
MonoScene(CVPR 2022),首次尝试视觉输入
MonoScene是首个仅使用RGB 图像作为输入来重建户外场景的工作,相较于以往研究中使用的激光雷达点云。它是一个单摄像头解决方案,专注于前摄像头的 SemanticKITTI 数据集。
MonoScene 的架构(来源:MonoScene)
论文提出了许多想法,但似乎只有一个设计选择至关重要——FLoSP(Feature Line of Sight Projection)。这个想法类似于沿视线传播特征的思想,也被OFT(BMVC 2019)或Lift-Splat-Shoot(ECCV 2020)所采用。其他新颖性如 Context Relation Prior 和直接优化度量的独特损失,根据消融研究看来似乎不那么有用。
VoxFormer (CVPR 2023),显著改进了 monoScene
VoxFormer的关键见解在于 SOP/SSC 必须同时解决两个问题:可见区域的场景重建和遮挡区域的场景幻觉。VoxFormer 提出了一种重建和密集化的方法。在第一次重建阶段,论文利用单目深度方法将 RGB 像素提升为伪 LiDAR 点云,然后将其体素化为初始查询提案。在第二次密集化阶段,这些稀疏的查询通过图像特征增强,并利用自注意力进行标签传播以生成密集预测。VoxFormer 在 SemanticKITTI 上的性能显著优于 MonoScene,且仍为单摄像头解决方案。图像特征增强架构大量借鉴了BEVFormer的变形注意力思想。
VoxFormer 的架构(来源:VoxFormer)
TPVFormer (CVPR 2023),首次多摄像头尝试
TPVFormer是首个将 3D 语义占用预测推广到多摄像头设置的工作,并将 SOP/SSC 的理念从 SemanticKITTI 扩展到 NuScenes。
TPVFormer 的架构(来源:TPVFormer)
TPVFormer 将 BEV 的理念扩展到三个正交轴。这允许在不压制任何轴的情况下建模 3D,避免了立方复杂度。具体来说,TPVFormer 提出了两个步骤的注意力来生成 TPV 特征。首先,它使用图像交叉注意力(ICA)来获取 TPV 特征。这本质上借鉴了BEVFormer的思想,并扩展到其他两个正交方向以形成 TriPlane View 特征。然后,它使用交叉视图混合注意力(CVHA)通过关注其他两个方向来增强每个 TPV 特征。
预测比 TPVFormer 中的监督更密集,但仍然存在间隙和孔洞(来源:TPVFormer)
TPVFormer 使用了来自普通 NuScenes 数据集的稀疏激光雷达点进行监督,没有任何多帧密集化或重建。尽管在训练时进行的监督是稀疏的,但它声称模型能够在推理时预测所有体素的更密集和一致的体积占据。然而,这种密集的预测仍然不如后来的研究,如使用密集 NuScenes 数据集的 SurroundOcc 那样密集。
SurroundOcc(Arxiv 2023/03)和 OpenOccupancy(Arxiv 2023/03),是首次尝试密集标签监督
SurroundOcc认为密集预测需要密集标签。论文成功证明了更密集的标签可以显著提高之前方法(如 TPVFormer)的性能,提升近 3 倍。其最大贡献是生成密集占据真实数据的管道,无需昂贵的人为标注。
SurroundOcc 的 GT 生成管道(来源:SurroundOcc)
密集占据标签的生成涉及两个步骤:多帧数据汇总和密集化。首先,将动态物体和静态场景的多帧激光雷达点分别拼接在一起。累计的数据比单帧测量更密集,但仍然存在许多孔洞,需要进一步密集化。密集化是通过三角网格的泊松表面重建和最近邻(NN)将标签传播到新填充的体素上来完成的。
OpenOccupancy与 SurroundOcc 同时期且在精神上相似。与 SurroundOcc 一样,OpenOccupancy 也使用了一个管道,该管道首先分别汇总动态物体和静态场景的多帧激光雷达测量数据。为了进一步密集化,OpenOccupancy 采用了Augment-and-Purify(AAP)方法,而不是SurroundOcc所采用的泊松重建方法。具体来说,一个基线模型使用汇总的原始标签进行训练,然后用其预测结果与原始标签融合以生成更密集的标签(即“增强”)。这种更密集的标签大约密集 2 倍,并由人工标注者进行手动精炼(即“净化”)。总共投入了 4000 小时的人工时间来精炼 nuScenes 的标签,大约每 20 秒的片段需要 4 小时的人工时间。
SurroundOcc 的架构(来源:SurroundOcc)
CONet 的架构(来源:OpenOccupancy)
与密集标签生成管道的贡献相比,SurroundOcc 和 OpenOccupancy 的网络架构不如创新。SurroundOcc 主要基于BEVFormer,通过粗到细的步骤来增强 3D 特征。OpenOccupancy 提出了 CONet(级联占用网络),其方法类似于Lift-Splat-Shoot,将 2D 特征提升到 3D,然后通过级联方案增强 3D 特征。
Occ3D(Arxiv 2023/04),首次尝试遮挡推理
Occ3D还提出了一种生成密集占用标签的管道,包括点云聚合、点标注和遮挡处理。这是首篇明确处理密集标签的可见性和遮挡推理的论文。可见性和遮挡推理对于 SOP 模型的车载部署至关重要。在训练过程中,必须对遮挡和可见性进行特殊处理,以避免对不可观察场景的过度幻想导致假阳性。
值得注意的是,激光雷达的可见性与相机的可见性不同。激光雷达的可见性描述了完整性,因为有些体素在多帧数据聚合后仍不可观察。这在整个序列中是一致的。同时,相机的可见性关注于检测的可能性,而且在每个时间戳上不同。评估只在激光雷达和相机视图中的“可见”体素上进行。
在密集标签的准备过程中,Occ3D 仅依赖于多帧数据聚合,没有像 SurroundOcc 和 OpenOccupancy 那样的第二阶段密集化。作者声称,对于 Waymo 数据集,标签在没有密集化的情况下已经相当密集。对于 nuScenes,尽管点云聚合后标注仍有孔洞,但 Poisson 重建导致结果不准确,因此没有执行密集化步骤。也许 OpenOccupancy 的 Augment-and-Purify 方法在这种情况下更为实用。
CTF-Occ 在 Occ3D 中的架构(来源:Occ3D)
Occ3D 还提出了一种神经网络架构:粗到细占用(CTF-Occ)。粗到细的思想与 OpenOccupancy 和 SurroundOcc 中的基本相同。CTF-Occ 提出了增量标记选择以减少计算负担。它还提出了一种隐式解码器,用于输出任何给定点的语义标签,类似于占用网络的思想。
技术细节比较
上述关于语义占用预测的研究总结在下表中,包括网络架构、训练损失、评估指标以及检测范围和分辨率。
最近的语义占用预测论文技术细节的比较截至 2023 年 4 月(来源:作者创建,文本版)
网络架构
大多数研究基于已验证的 BEV 感知最先进方法,如BEVFormer和Lift, Splat, Shoot。架构主要分为两个阶段:2D 到 3D 特征提升和 3D 特征增强。有关更详细的总结,请参见上表。架构似乎已基本趋于一致。最重要的是密集占用注释生成管道,以及训练过程中的密集监督。
以下是SurroundOcc、OpenOccupancy和Occ3D中生成密集占用标签的自动标记管道的总结。
SurroundOcc、OpenOccupancy和Occ3D中密集标签管道的总结(来源:作者创建)
训练损失
语义占用预测任务与语义分割非常相似,SOP 需要对 3D 空间中的每个体素预测一个语义标签,而语义分割需要对每个测量点(无论是图像中的像素,还是激光雷达扫描中的 3D 点)进行预测。语义分割的主要损失函数是交叉熵损失和Lovasz 损失。Lovasz 扩展使得神经网络可以直接优化平均交集-并集(IoU)指标。
可能受到 Lovasz 的启发,monoScene 提出了几种其他损失函数,可以直接优化评估指标。然而,它们似乎较为深奥,并未完全通过消融研究得到支持。
评估指标
主要指标是几何占用预测的 IoU(一个体素是否被占用)和语义分类的 mIoU(平均 IoU)。这些指标可能不适合工业应用。
基于视觉的 SOP 任务需要成熟以满足工业使用并取代激光雷达。虽然 IoU 指标中计算了精度和召回率,但精度在 ADAS(高级驾驶员辅助系统)应用中始终更为重要,以避免虚假刹车,只要我们还有司机在驾驶。
检测范围和分辨率
目前所有跟踪算法在自车周围预测 50 米的范围。体素分辨率从 SemanticKITTI 的 0.2 米到 NuScenes 和 Waymo 数据集的 0.4 米或 0.5 米不等。这是一个良好的起点,但可能仍然不适用于工业应用。
一个更合理的分辨率和范围可能是 50 米范围内的 0.2 米,以及 50 米到 100 米范围内的 0.4 米。
相关任务
有两个与 SOP 相关的任务,即周围深度图和激光雷达语义分割,我们将在下面简要回顾。
周围深度图预测任务(例如 FSM 和 SurroundDepth)扩展了单目深度预测,并利用重叠摄像头视场中的一致性来进一步提升性能。它更多关注测量源,通过为图像中的每个像素赋予深度值(自下而上),而 SOP 则更多关注 BEV 空间中的应用目标(自上而下)。Lift-Splat-Shoot 与 BEVFormer 在 BEV 感知中的类似情况也是如此,前者是自下而上的方法,后者是自上而下的。
激光雷达语义分割专注于为激光雷达扫描中的每个点云分配语义类别标签。现实世界中的 3D 感知本质上是稀疏和不完整的。为了全面的语义理解,单纯解析稀疏测量而忽略未观察到的场景结构是不够的。
主要收获
-
语义占用预测中的神经网络架构似乎已经基本收敛。最重要的是自动标签管道,以生成密集的占用标签和训练过程中的密集监督。
-
当前常用数据集采用的检测范围和体素分辨率对于工业应用来说是不够的。我们需要更大的检测范围(例如 100 米)和更细的分辨率(例如 0.2 米)。
-
目前的评估指标对于工业应用也显得不够。对于 ADAS 应用,精确度比召回率更为重要,以避免频繁的虚假刹车。
-
语义占用预测的未来方向可能包括场景流估计。这将有助于预测未知障碍物的未来轨迹,并在自车轨迹规划过程中进行碰撞避免。
注:本博客文章中的所有图片均由作者创作,或来自公开的学术论文。详情请参见图注。
参考文献
-
SSCNet: 从单幅深度图像进行语义场景补全,CVPR 2017
-
SemanticKITTI: 用于激光雷达序列语义场景理解的数据集,ICCV 2019
-
MonoScene: 单目 3D 语义场景补全,CVPR 2022
-
VoxFormer: 基于相机的稀疏体素变换器用于 3D 语义场景补全,CVPR 2023
-
TPVFormer: 三视角视图用于基于视觉的 3D 语义占用预测,CVPR 2023
-
Occ3D: 自主驾驶的大规模 3D 占用预测基准,Arxiv 2023/04
-
SurroundOcc: 自主驾驶的多摄像头 3D 占用预测,Arxiv 2023/03
-
OpenOccupancy: 周围语义占用感知的大规模基准,Arxiv 2023/03
-
SimpleOccupancy: 自主驾驶中 3D 占用估计的简单尝试,Arxiv 2023/03
-
OccFormer: 基于视觉的 3D 语义占用预测的双路径变换器,Arxiv 2023/04
-
BEVFormer: 通过时空变换器从多摄像头图像学习鸟瞰视图表示,ECCV 2022
-
FSM: 来自多摄像头的完整周围单目深度,ICRA 2021
-
SurroundDepth: 通过周围视图纠缠进行自监督多摄像头深度估计,CoRL 2022
-
Lift, Splat, Shoot (LSS): 通过隐式反投影到 3D 来编码来自任意相机设备的图像,ECCV 2020
-
OFT: 单目 3D 物体检测的正射特征变换,BMVC 2019
使用冻结的大型语言模型进行视觉问答
原文:
towardsdatascience.com/visual-question-answering-with-frozen-large-language-models-353d42791054
与大型语言模型讨论图像,而无需在图像上训练大型语言模型。
·发表于 Towards Data Science ·18 分钟阅读·2023 年 10 月 9 日
--
“弥合模态”,由 MidJourney 制作。除非另有说明,否则所有图像均由作者提供。
在这篇文章中,我们将使用 Q-Former,一种弥合计算机视觉和自然语言模型的技术,来创建一个视觉问答系统。我们将讨论必要的理论,参照 BLIP-2 论文,然后实现一个可以与大型语言模型讨论图像的系统。
我们将构建的内容
谁会觉得这篇文章有用? 对计算机视觉、自然语言处理和多模态建模感兴趣的数据科学家。
这篇文章的难度如何? 中级。如果你没有一些计算机视觉和自然语言处理的经验,你可能会觉得有些困难。
前提条件: 对变换器、嵌入和编码器-解码器有较高的熟悉度。所有这些主题在以下文章中都有涵盖:
Transformers — 直观且全面的解释
探索现代机器学习的浪潮:一步一步拆解变换器
towardsdatascience.com](/transformers-intuitively-and-exhaustively-explained-58a5c5df8dbb?source=post_page-----353d42791054--------------------------------)
视觉语言建模的简要时间线
视觉语言建模真正起步于 2016 年,随着论文 VQA: Visual Question Answering 的发布,该论文正式提出了以下问题类别:
给定一张图片和关于这张图片的自然语言问题,任务是提供一个准确的自然语言回答——VQA: 视觉问答
在 2016 年,当 VQA(视觉问答)被普及时,典型的方法看起来是这样的:
2016 年的一个 VQA 模型使用 LSTM 将问题嵌入向量,使用现有的计算机视觉网络将图像嵌入向量,然后通过一个密集层在正确的输出选择中考虑这两个向量。来自VQA: 视觉问答。
在 VQA 的早期阶段,从头开始训练视觉和语言组件,将输出传递到密集网络,并选择 n 个可能输出中的一个作为响应是合适的。
随着视觉和语言模型变得更强大,视觉问答让位于视觉语言建模(VLM),它通常被认为是对视觉问答的扩展。现代视觉语言模型不仅可以回答“这张图片中有车吗”这样简单的问题,还可以询问图片中是什么类型的车,然后询问这辆车的驾驶情况、这辆车出演过的最受欢迎的电影等。
这是视觉语言建模实际应用的一个例子。这个特定的例子来自BLIP-2 论文,我们将用它作为本帖的参考。
从 VQA 到 VLM 的转变在很大程度上是由于将大型语言模型融入视觉系统,提供了复杂的推理能力和开箱即用的百科知识。
视觉语言建模的难点是,一直以来都是多模态。你需要擅长图像、自然语言,并且需要让它们良好地协同工作。随着视觉和语言模型的不断扩大,用于将它们结合起来进行视觉语言建模的系统变得越来越复杂。
这带来了实际问题。大型语言模型是巨大的,因此更新它们的参数以学习一些新任务是昂贵的(比如,几千到几百万美元的费用)。此外,当在完全新模式的数据上训练模型时,该模型通常会灾难性遗忘;这是一个术语,用于描述模型在调整到新用例时遗忘关键信息的情况。如果你随意将图像编码器和大型语言模型结合起来,你可能会得到一个对图像和文本理解都很差的模型。
BLIP-2 论文提出了 Q-Former 来解决灾难性遗忘问题,并且通过利用现有模型来降低成本。
Q-Former 的简要介绍
如果你想在一个周末从零开始构建一个 VQA 系统,你可以考虑以下方法:
-
通过标题生成器处理你想讨论的图像
-
将用户提出的问题和生成的标题结合起来,使用一些模板创建 LLM 的提示。
-
将该提示传递给 LLM,LLM 将返回最终输出。
一个天真的方法成功的流程图。用户提出了一个能够从生成的标题中回答的问题。
如果你问的是关于图像主题的简单问题,这种方法可能会奏效,但如果你有更晦涩的问题,你可能就会失望了。
一个天真的方法未成功的流程图。用户提出了一个无法从生成的标题中回答的问题。
Q-Former 作为查询 transformer(因此得名)用于根据图像转换用户的查询。其目的是根据用户的提示从图像中提取正确的信息,并将其提供给 LLM。
Q-Former 所做的概念图。它利用提示和图像来构建 LLM 的输入。实际上,Q-Former 并不生成文本,而是生成高维嵌入,但这就是概念的本质。
BLIP-2 架构
在我们真正深入之前,先对其进行高层次的了解。
Q-Former 的稍微准确一点的描述,以及周围的组件。图像编码器将图像嵌入其最重要的部分,文本编码器对用户的提示做同样的处理,而 Q-Former 将它们结合起来,为 LLM 创建输入。
BLIP-2 架构中存在 Q-Former 的组件如下:
-
图像编码器: 一个预训练模型,将图像嵌入到抽象表示中,从而使图像分类等任务变得更容易。实际上,你可以将其看作是提取图像的重要内容。一个流行的例子是 CLIP。
-
文本编码器: 一个预训练模型,将文本嵌入到抽象表示中。这些模型通常将单词视为高维空间中的点,相似的单词会位于该空间中的相似点。一个流行的例子是 Word2Vec。
-
LLM: 一个经过训练的大型语言模型,用于执行通用语言任务。类似于聊天 GPT。
-
Q-Former: 一个将嵌入的图像和嵌入的提示结合成与 LLM 兼容的格式的 transformer 模型。Q-Former 的主要工作是正确地上下文化这两种输入,并以有利于文本生成的方式将其提供给 LLM。
由于 Q-Former 的灵活性,可以在 BLIP-2 中使用不同的编码器和 LLM。我在这篇文章中不会深入探讨这些内容,但我将很快撰写一篇关于 CLIP 图像编码的文章,并且我有一篇关于 LLM 和文本嵌入的文章,可能对不熟悉的人有所帮助:
探索现代机器学习的浪潮:逐步拆解 transformer
Q-Former,简而言之
Q-Former 的高层次概念图
首先,需要对注意力有一个整体理解,因为它构成了 Q-Former 架构的主要部分。我在这篇文章中直观且详尽地讲解了注意力机制,但基本上,注意力机制会生成修改过的输入副本,然后将这些副本混合在一起。
如果我们将文本输入“What Color is the Background”通过自注意力机制处理,那么句子中每个词的向量将与其他每个词的向量结合。这将产生一个抽象矩阵,其中包含输入中所有词的上下文信息。
多头自注意力,简而言之。这一机制在数学上将不同输入(在这个例子中是词语)的向量结合起来,生成一个矩阵,该矩阵编码了整个输入的更深层次的含义。在这篇文章中直观且详尽地解释了这一点。
即使你对注意力机制有所了解,可能仍不清楚为什么自注意力块会被分成两半。实际上,Q-Former 中的两个自注意力块实际上是一个。自注意力机制左侧的输入可以与自注意力机制右侧的输入完全互动,反之亦然。这个划分不是基于模型的工作原理,而是基于模型的训练方式。我们将在下一节中详细讨论,但要点是:由于 Q-Former 的训练方式,自注意力块擅长处理仅图像、仅文本以及同时处理这两者。因此,它有点像两个注意力块,但实际上是一个大的注意力块。
一个概念图展示了 Q-Former 中的自注意力机制如何既隔离文本和图像表示,又促进它们的交互。这是在训练的引导阶段完成的,我们将在下一节中讨论。
图表左下角学习到的标记本质上是模型在第一个自注意力块中使用的学习常量。我们稍后会详细讨论它们,但简而言之,我喜欢从两种方式来理解它们:
-
如果你从文本的自注意力角度考虑它们,它们在决定文本如何最初介绍给图像。
-
如果你从与图像互动的角度考虑它们,它们作为初始化,最终会通过图像进行修改,最终成为模型的提示。
此外,如本节的第一张图片所示,有虚线递归连接将两个前馈网络的输出连接回输入。整个区域被描绘为黄色的是一个 Q-Former 块。这些块堆叠在一起以创建完整的 Q-Former。
这就是所有组件,这可能会让人惊讶。仅仅从组件来看,并不明显为什么 Q-Former 在桥接图像和文本方面表现得特别好。要理解这一点,你需要了解 Q-Former 是如何训练的。
Q-Former 的训练方式
Q-Former 的训练可以分为两个阶段:引导和生成学习预训练。引导阶段可以进一步分为三个子阶段。我们将逐步探讨所有这些。
这些训练阶段的命名可能会有些困惑。什么是“引导”?为什么会有预训练步骤却没有“训练步骤”?我认为这些阶段的命名是以下定义的结果:
-
引导是指使用可能不完全适合最终用例的数据,以将模型从随机初始化状态提升到在相关任务中表现良好的状态的过程。
-
预训练是指使用大量数据将模型调整到一个普遍良好的状态,为最终任务做好准备。
-
微调是指将预训练模型进行调整,并提供少量任务特定数据,以优化其最终建模任务。
BLIP-2 的核心优势之一是零样本性能。BLIP-2 承诺能够在没有针对 VQA 数据集进行微调的情况下,出色地完成诸如视觉问答等任务。它使用带有图片说明的图像数据集(这些说明解释了图像的内容)来进行引导和预训练,但从未实际对 VQA 进行微调。
引导
引导阶段旨在鼓励模型在需要理解文本和图像的各种任务中表现良好。就像我的自监督学习帖子一样,你可以将其视为一种“游戏”,模型学习为视觉问答的最终任务做准备。
引导阶段有三个子阶段。这些将在后续部分中探讨,但总的来说:
-
图像-文本对比学习:模型学习如何将属于一组的图像-字幕对分在一起,并通过 对比学习 将不属于一组的图像-字幕对分开。
-
图像基础的文本生成: 将字幕分成两部分,隐藏的和未隐藏的部分,并尝试根据未隐藏的部分和图像来猜测隐藏的部分。
-
图像-文本匹配: 将 Q-Former 的输出传入 牺牲密集网络,该网络将输出转换为二元分类,然后使用此二元分类来决定一个字幕是否属于某个图像。
图像-文本对比学习
图像-文本对比学习的实际操作。图像侧的所有向量(即 LLM 的输入)都与文本侧的类标记进行比较。在这个例子中,相似度很高,因为图像和文本匹配。我们希望,如果图像和文本不匹配,最大相似度分数会很低。
在这种自举模式中,Q-Former 中的自注意力机制被分为两部分。这是通过应用于注意力层的掩码完成的,称为“单模态自注意力掩码”。这是一个复杂的短语,但表示一个简单的概念:
在自注意力机制中,每次文本侧与图像侧互动时,只需将值设置为零。
这实际上阻止了 Q-Former 图像侧和文本侧之间的所有通信。
这种训练方法还使用了一个特殊的标记,称为“类”标记。这个想法受到 BERT 的启发,这是一个我将来会详细介绍的里程碑。基本上,你有一个任意的标记,让模型知道“嘿,我们现在在进行图像-文本对比学习”。然后,你会忽略文本侧除了类标记之外的任何其他输出,并使用它来计算损失。因此,模型知道当“类”标记出现时,它是特殊的,并会尝试学习如何操控左侧的视觉和右侧的文本,以最大化对比损失的性能。
对比损失,本质上,是将匹配对拉近而将不匹配对拉远的任务。 我有一篇更详细地介绍对比损失的文章,但本质上,对比损失会查看一组图像及其标题,并尝试让模型学习哪些图像属于哪些标题。在我们的案例中,这通过计算两侧向量的相似度并找到最大相似度值来实现。匹配的文本和图像对应该具有较大的相似度分数,而不匹配的对应该具有较小的相似度分数。
通过执行这种自举操作,我们鼓励模型学习对齐图像和文本。换句话说,我们正在训练模型来学习哪些图像与某段文本相关,哪些则不相关。
单模态自注意力的概念图,这是在此训练阶段使用的掩码策略。请注意图像端和文本端都有完全的注意力,但两者之间没有注意力。
图像基础文本生成
“图像基础文本生成”预训练步骤的概念图。在这里,Q-Former 的一个输入文本部分被隐藏,Q-Former 的任务是尝试填补隐藏的文本。
在这种自举模式下,我们要求 Q-Former 完成一个部分隐藏的标题。我们应用了“多模态因果自注意力掩码”,允许 Q-Former 的文本端与图像端交互,但隐藏了部分需要由 Q-Former 预测的标题。我们还将“类别”标记替换为“解码器”标记,以让模型知道它应该执行什么任务。
多模态因果自注意力的概念图;这是在此训练阶段使用的掩码策略。请注意,除了那些模型应输出的标记外,所有标记之间都允许完全的注意力。
图像-文本匹配
“图像-文本匹配”预训练步骤的概念图。在这个阶段没有使用掩码,允许所有文本和图像标记在自注意力机制中进行交互。请注意,输出是错误的,因为图像与文本“A Painting of a Monster Truck”不兼容。
在这种预训练模式下,我们创建了一个临时线性分类器(一个密集网络),并将 Q-Former 的所有输出标记输入其中。这个线性分类器将标记投影为“真实”或“虚假”预测,用于训练模型预测输入文本是否与输入图像匹配。不同的组合对,无论是匹配还是不匹配的组合,都被输入到模型中。
我在这篇文章中讨论了使用密集网络来投影模型输出用于某些预训练任务的概念。本质上,线性分类器用于训练,但在推理时被丢弃,这有助于模型学习文本和图像的一般表示,但有助于防止模型在任务中过于专业化;如此专业化,以至于在将标记传递给 LLM 的实际任务中表现较差。
你可以把 Q-Former 想象成“理解文本和图像”的部分,而临时线性分类器则是“将这种理解转化为是或否的回答”的部分。在这一步之后,我们会丢弃“将这种理解转化为是或否的回答”的部分,保留一般的文本和图像理解。
从自举中得到什么
在上一节中,我们讨论了自举的三个阶段:图像-文本对比学习、图像基础文本生成和图像-文本匹配。通过优化 Q-Former 以完成这些各种任务,Q-Former 被鼓励建立图像和文本的强表示,并建立一个强大的系统来关联这两者。
关于学习到的标记的说明
如前所述,学习到的标记(在 BLIP-2 论文中称为“查询向量”)与图像和文本交互以提取关键信息。为了进一步阐述这一点,我想分享 BLIP-2 论文中关于查询向量的以下引文:
关于查询向量的一般情况:
查询通过自注意力层相互作用,并通过交叉注意力层(插入每隔一个变换器块)与冻结的图像特征互动。查询还可以通过相同的自注意力层与文本交互。
关于自举阶段如何与查询向量相关:
我们的目标是训练 Q-Former,使查询能够学习提取对文本最具信息性的视觉表示。
关于查询向量如何关联文本和图像信息:
由于 Q-Former 的架构不允许冻结的图像编码器和文本标记之间直接交互,因此生成文本所需的信息必须首先由查询提取,然后通过自注意力层传递给文本标记。因此,查询被迫提取能够捕捉文本所有信息的视觉特征。
预训练
现在我们有了一个在文本和图像上具有良好内部表示的 Q-Former,我们可以将其连接到 LLM 上,并使用它来训练 Q-Former。
BLIP-2 的生成预训练图
我们可以将图像的描述分成两部分,一个前缀和一个后缀。我们可以通过整个 BLIP-2 架构传递前缀,并修改 Q-Former 的权重以鼓励 LLM 的输出为后缀。从概念上讲,这使图像和文本的表示在 Q-Former 中与特定 LLM 模型的需求对齐。
理论总结
太好了,现在我们了解了 BLIP-2 架构;组件、Q-Former(其核心组件)如何工作以及如何进行训练。在下一节中,我们将使用预训练的 BLIP-2 进行图像描述、VQA,甚至进行小型图像基础对话。
整个BLIP-2架构
在BLIP-2 论文中定义的 BLIP-2 架构与 Q-Former 的关系。这里我们可以看到 Q-Former 的输入、学习到的查询、三种训练引导策略和支持这些策略的三种掩码。一个微妙的说明:BLIP-2 每隔一个块就将图像暴露给 Q-Former。我邀请你自行思考这样做的理由。对我来说,这感觉像是一种残差连接,鼓励对图像进行反复且越来越复杂的分析。没有图像访问权限的块可以形成复杂的关系,然后在随后的块中与图像进行比较。
使用 Hugging Face 的 Q-Former 进行 VQA
在未来的帖子中,我会从零开始编写和训练一个 Q-Former,但我认为你会同意这篇帖子已经足够长了。现在让我们使用预构建的解决方案来实验 Q-Formers。
完整的笔记本可以在这里查看:
[## MLWritingAndResearch/VQAWithQFormer.ipynb at main · DanielWarfield1/MLWritingAndResearch
用于机器学习写作和研究的笔记本示例 - MLWritingAndResearch/VQAWithQFormer.ipynb at main ·…
由 SalesForce 内的机器学习团队 LAVIS(发布 BLIP-2 论文的团队)慷慨地提供了 Hugging Face 上的端到端预训练解决方案:
"""Downloading the BLIP-2 Architecture
loading as an 8 bit integer to save on GPU memory. This may have some impact on
performance.
"""
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True) # load in int8
如你所见,BLIP-2 包含两个部分;一个处理器和一个模型。首先,让我们探讨处理器。
处理器
在 HuggingFace 提供的示例中,处理器用于在将输入(包括文本和图像)传递给 BLIP-2 之前进行预处理。让我们加载一张图像,生成一些文本,并将其传递给处理器,看看结果如何。
"""Loading and displaying a sample image
"""
import requests
from PIL import Image
url = 'https://github.com/DanielWarfield1/MLWritingAndResearch/blob/main/Assets/Images/pexels-thuany-marcante-1805053.jpg?raw=true'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
print(f'Image dimensions: {image.width}px X {image.height}px')
dsfact = 15
display(image.resize((int(image.width/dsfact), int(image.height/dsfact))))
我们将使用的一个示例图像
将这张图片和一些示例文本传递给处理器后,我们可以看到从处理器获得了一个字典:
"""Exploring the outputs of the processor
"""
processor_result = processor(image, text='a prompt from the user about the image', return_tensors="pt").to("cpu", torch.float16)
processor_result.keys()
从处理器获得的 pixel_values 是将图像缩小到 224 x 224 的变换,颜色值被标准化到适合建模的范围。
"""Understanding resolution and plotting one of the color channels
"""
import matplotlib.pyplot as plt
#printing the processed image shape
print(f'processed image shape: {processor_result["pixel_values"].numpy().shape}')
#extracting one of the color channels from the processed image
print('single color channel:')
processed_im_c0 = processor_result['pixel_values'].numpy()[0,0]
#rendering
plt.imshow(processed_im_c0, interpolation='nearest')
plt.show()
"""Understanding the distribution of values allong each color channel,
both in the processed image and in the original image
"""
bins = 100
#extracting all color channels from the processed image
processed_im_c1 = processor_result['pixel_values'].numpy()[0,1]
processed_im_c2 = processor_result['pixel_values'].numpy()[0,2]
#plotting modified pixel value distributions
plt.figure()
plt.hist([processed_im_c0.flatten(),
processed_im_c1.flatten(),
processed_im_c2.flatten()], bins, stacked=True, density = True)
plt.title('processed image value distribution')
plt.show()
#plotting original pixel value distributions
import numpy as np
image_np = np.array(image)
plt.figure()
plt.hist([image_np[:,:,0].flatten(),
image_np[:,:,1].flatten(),
image_np[:,:,2].flatten()], bins, stacked=True, density = True)
plt.title('raw image value distribution')
plt.show()
处理过的图像的值被缩减到一个合理的范围,并且它们似乎平均接近于零
原始图像的值具有偏倚分布,覆盖了更宽的值范围
从处理器获得的 input_ids 是词片索引。句子的各个部分被分配了单独的索引,这些索引后来被用于词向量嵌入器,然后应用于 BLIP-2。
"""Exploring the input_ids from the processor given a variety of prompts
"""
print('input_ids for "a short prompt":')
sampres = processor(image, text='a short prompt', return_tensors="pt").to("cpu", torch.float16)
print(sampres['input_ids'])
print('input_ids for "a much much much much longer prompt":')
sampres = processor(image, text='a much much much much longer prompt', return_tensors="pt").to("cpu", torch.float16)
print(sampres['input_ids'])
print('input_ids for "alongcompoundword":')
sampres = processor(image, text='alongcompoundword', return_tensors="pt").to("cpu", torch.float16)
print(sampres['input_ids'])
由于我们是在推断模型,因此处理器提供的掩码只是全 1,允许模型看到所有输入值。
"""Understanding the mask from the processor
"""
print('input_ids for "a short prompt":')
sampres = processor(image, text='a short prompt', return_tensors="pt").to("cpu", torch.float16)
print(sampres['input_ids'])
print('mask for "a short prompt":')
print(sampres['attention_mask'])
调用模型
现在我们对处理器的功能有了初步了解,我们可以开始使用它将数据传递给 BLIP-2,并开始生成输出。
图像标题生成: 如果你提供图像而没有文本,BLIP-2 将为图像生成标题。
"""Getting BLIP-2 to describe the image, unprompted
this is done by only passing the image, not the text
"""
inputs = processor(image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
提示图像标题生成: 如果你为标题提供一个前缀,BLIP-2 将尝试完成这个标题。
"""Prompted caption example 1
"""
prompt = "this is a picture of"
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
完成提示“这是一张……的图片”的结果
"""Prompted caption example 2
"""
prompt = "the weather looks"
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
完成提示“天气看起来……”的结果
视觉问答: 通过调用格式化特殊查询的 BLIP-2,可以实现视觉问答,而无需经过视觉问答数据的训练。
prompt = "Question: what season is it? Answer:"
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
基于视觉的对话: 我们可以将提示格式化成类似对话的内容,从而与模型就图像进行对话。
"""Visual Conversation
conversing about the image
"""
#imagine these are generated by a person as a response to output, rather than pre-defined.
questions = [
"What's in this photo?",
"What is vernacular architecture?"
]
#defining the state of the conversation as it progresses, to be passed to the model
conv_state = ''
#asking all questions in order
for question in questions:
#updating the conversational state with the question
conv_state = conv_state+' Question: ' + question + ' Answer: '
#passing the state thus far to the model
inputs = processor(image, text=conv_state, return_tensors="pt").to(device, torch.float16)
#generating a response
generated_ids = model.generate(**inputs, max_new_tokens=40)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
#rendering conversation
print('Question: '+question)
print('Answer: ' + generated_text)
#updating the conversational state with the answer
conv_state = conv_state + generated_text + '\n'
结论
在这篇文章中,我们回顾了多模态图像和语言建模的发展历史;从最初的视觉问答,到现代阶段使用大型语言模型和图像编码器。我们描述了一种将图像暴露给大型语言模型的方法,即 BLIP-2 架构,并描述了其最重要组件 Q-Former 的内部工作原理。接着,我们探索了 BLIP-2 在标题生成、视觉问答和基于视觉的对话中的实际应用。
关注以获取更多信息!
我描述机器学习领域的论文和概念,重点是提供实用和直观的解释。我计划在未来的文章中从零开始实现一个 Q-Former。
归属声明: 本文档中的所有图片均由丹尼尔·沃菲尔德创作,除非另有来源说明。您可以将此帖中的任何图片用于非商业目的,只要您注明此文章,danielwarfield.dev
,或两者皆可。
使用 Python 可视化卫星图像的 RGB 通道
原文:
towardsdatascience.com/visualising-the-rgb-channels-of-satellite-images-with-python-6d541af1f98d
在可视化卫星图像时,如何处理多个光谱波段、大像素值和倾斜的 RGB 通道
·发表于 Towards Data Science ·阅读时长 6 分钟·2023 年 4 月 11 日
--
(来源:SWED)
卫星图像包含大量信息。缺点是它们的可视化并不简单。与普通图像不同,卫星图像可能具有:
-
超过 12 个通道
-
大像素值
-
倾斜的像素值
我们将讨论这些关键考虑因素。然后,我们将这些因素考虑到一个 Python 函数中,这样你在组合 RGB 通道时就能拥有更大的灵活性。具体来说,它允许你调整图像的亮度和色调。代码已给出,完整项目可以在 GitHub 上找到。
导入和数据集
我们的第一个导入是 地理空间数据抽象库(gdal)。这在处理遥感数据时非常有用。我们还引入了更多标准的 Python 包(第 4-5 行)。最后,glob 用于处理文件路径(第 7 行)。
# Imports
from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import glob
我们将可视化来自测试集的图像 Sentinel-2 水边数据集 (SWED)。该数据集包含来自 49 个位置的 98 张海岸线图像。我们将在下方加载这些图像的所有路径。
#Load paths
paths = glob.glob("../data/SWED/test/images/*")
关键考虑因素
光谱波段数量
我们加载第一张图像(第 2 行),并输出其形状(第 3 行)。这给出的值为(12,256,256)。换句话说,每张图像由 12 个通道的 256x256 像素组成。如果这是一个普通图像,我们会期望其维度为(256,256,3)。也就是说,每个像素有 3 个通道——RGB。
#Load first image
img = gdal.Open(paths[0]).ReadAsArray()
img.shape #(12,256,256)
12 个通道是哨兵-2 光谱带。每个光谱带测量的是从地球表面反射的不同波长的电磁辐射。这包括人眼无法看到的红外光。在可视化图像时,我们只对可见光光谱带感兴趣。
对于我们的特定图像,显光带位于位置 3(红色)、2(绿色)和 1(蓝色)。我们在下面选择这些。我们还转置了数组,使其维度为(256,256,3)。
#Get RGB image
rgb = img[[3,2,1]].transpose(1,2,0)
大像素值
为了捕捉尽可能多的细节,卫星图像具有大的动态范围。因此,像素值不会落入典型的[0, 255]范围内。对于我们的图像,打印最小和最大像素值给出的是150 和 8600。
# Pixel range
print(rgb.min(),rgb.max()) #150 8,600
哨兵-2 图像具有最大反射值为 10000。虽然像素值有时会高于此值,但在可视化 RGB 通道时,我们可以忽略这些大值。因此,我们通过将图像除以 10000 并剪裁到 0 和 1 之间来缩放图像。这确保所有像素值都在 0 和 1 之间。
#Scale image
rgb = np.clip(rgb/10000,0,1)
偏向的像素值
现在我们可以使用 matplotlib 显示我们的卫星图像(第 2 行)。你会注意到图像的亮度/对比度较低。这在显示 RGB 通道时是典型的。
#Display RGB image
plt.imshow(rgb)
图 1:低亮度的 RGB 通道(来源:作者)(数据集:SWED)
如前所述,卫星图像的动态范围很大。这是为了捕捉更广泛的像素范围——从非常亮到非常暗。结果是,像素往往偏向较低的值。你可以在下面的直方图中看到这一点。这些直方图显示了上述图像的 RGB 通道的像素频率。
图 2:卫星图像 RGB 通道的偏向像素值(来源:作者)
#Display histograms of pixel intesity for each band
fig, axs = plt.subplots(1,3,figsize=(18,5))
fig.patch.set_facecolor('xkcd:white')
labels = ['Red','Green','Blue']
for i,ax in enumerate(axs):
ax.hist(rgb[:,:,i].flatten(),bins=100)
ax.set_title(labels[i],size=20,fontweight="bold")
ax.set_xlabel("Pixel Value",size=15)
ax.set_ylabel("Frequency",size =15)
获取更好图像的一个简单方法是剪裁每个通道的像素值范围(第 2 行)。我们只取像素值从 0 到 0.3,并将它们缩放回 0 到 1 之间。在图 3中,你可以看到结果图像更亮。
# Clip RGB image to 0.3
rgb = np.clip(rgb,0,0.3)/0.3
plt.imshow(rgb)
图 3:高亮度的 RGB 通道(来源:作者)(数据集:SWED)
整合所有内容
对于快速可视化,上述内容足矣。如果你想对过程有更多控制,可以使用下面的函数。这将缩放图像,选择 RGB 通道,并使用不同的截止值剪裁每个通道。还有一个显示选项,将输出带有所选截止值的 RGB 直方图。
def visualise_rgb(img,clip=[0.3,0.3,0.3],display=True):
"""Visulaise RGB image with given clip values and return image"""
# Scale image
img = np.clip(img/10000,0,1)
# Get RGB channels
rgb = img[[3,2,1]]
#clip rgb values
rgb[0] = np.clip(rgb[0],0,clip[0])/clip[0]
rgb[1] = np.clip(rgb[1],0,clip[1])/clip[1]
rgb[2] = np.clip(rgb[2],0,clip[2])/clip[2]
rgb = rgb.transpose(1,2,0)
if display:
#Display histograms of pixel intesity with given clip values
fig, axs = plt.subplots(1,4,figsize=(22,5))
fig.patch.set_facecolor('xkcd:white')
labels = ['Red','Green','Blue']
for i,ax in enumerate(axs[0:3]):
ax.hist(img[3-i].flatten(),bins=100)
ax.set_title(labels[i],size=20,fontweight="bold")
ax.axvline(clip[i],color="red",linestyle="--")
ax.set_yticks([])
#Display RGB image
axs[3].imshow(rgb)
axs[3].set_title("RGB",size=20,fontweight="bold")
axs[3].set_xticks([])
axs[3].set_yticks([])
return rgb
你可以在下面看到这个功能的实际效果。在这个案例中,我们使用 0.3 的截止值剪切每个通道。图 4 显示了这些截止值的直方图和结果 RGB 图像。这在尝试不同截止值时可能会很有用。
img = gdal.Open(paths[0]).ReadAsArray()
rgb = visualise_rgb(img,[0.3,0.3,0.3])
图 4:RGB 通道可视化功能的输出(来源:作者)(数据集:SWED)
调整亮度
不同的图像会有不同的最佳截止值。实际上,本文的封面图像是使用 3 个不同的值创建的。上述功能使调整这些值变得简单。在图 5中,你可以看到调整截止值如何改变亮度。
rgb_1 = visualise_rgb(img,[0.15,0.15,0.15],display=False)
rgb_2 = visualise_rgb(img,[0.3,0.3,0.3],display=False)
rgb_3 = visualise_rgb(img,[0.45,0.45,0.45],display=False)
图 5:调整卫星图像的亮度(来源:作者)(数据集:SWED)
调整颜色(色调)
它还可以让你对每个颜色通道有更多控制。如图 6所示,如果我们降低某个通道的截止值,该颜色将变得更加突出。这可以帮助你微调可视化效果。
rgb_1 = visualise_rgb(img,[0.2,0.3,0.3],display=False)
rgb_2 = visualise_rgb(img,[0.3,0.3,0.2],display=False)
rgb_3 = visualise_rgb(img,[0.3,0.2,0.3],display=False)
图 6:调整卫星图像的颜色(来源:作者)(数据集:SWED)
在处理遥感数据时,通常重要的是可视化你的图像。这将帮助你建立对问题的直觉。也许更重要的是,你可以创建用于向他人解释工作的美丽图像。上述功能将帮助你创建最引人注目的图像。请记住,你可能需要根据你的卫星图像数据集进行调整。
希望你喜欢这篇文章!你可以通过成为我的 推荐会员 来支持我 😃
[## 通过我的推荐链接加入 Medium — Conor O’Sullivan
作为 Medium 会员,你的会员费的一部分将会分配给你阅读的作者,你将可以访问所有故事…
conorosullyds.medium.com](https://conorosullyds.medium.com/membership?source=post_page-----6d541af1f98d--------------------------------)
| Twitter | YouTube | Newsletter — 免费注册获取 Python SHAP 课程
数据集
Sentinel-2 水边数据集(SWED)(许可:Sentinel 数据法律声明) openmldata.ukho.gov.uk/
参考文献
GISGeography,Sentinel 2 波段及组合(2022),gisgeography.com/sentinel-2-bands-combinations/
Csaba,在 Python 中创建 Sentinel-2 RGB 合成图像(2021),www.satmapper.hu/en/rgb-images/
使用 Plotly Express 在 3D 线图上可视化井路径
原文:
towardsdatascience.com/visualising-well-paths-on-3d-line-plots-with-plotly-express-de73ca73dea3
使用 Plotly Express 的 3D 线图
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 6 月 4 日
--
使用 Plotly Express 进行的 3D 井路径可视化。图片由作者提供。
可视化是我们理解井记录数据和地下数据的关键任务之一。这包括在井记录图、散点图和直方图上查看数据。通过这样做,我们可以对数据有一个可靠的理解。然而,处理 2D 图有时是不够的,我们需要通过 3D 增加额外的维度。
在岩石物理学和地球科学中,3D 可视化的一个优秀应用案例是可视化井的路径。
在石油和天然气勘探的早期阶段,井是垂直钻入地下的。然而,随着技术的进步,行业从垂直钻探转向水平钻探,最终使用地质导向的创新技术钻探复杂的井路径。
井路径几何的这种演变突显了在三维中可视化井路径的重要性。这样做可以让我们更好地理解井如何穿透地质层,并规划未来的井以避免碰撞或问题。
在本文中,我们将探讨如何利用记录井位置的井调查数据,并使用 Plotly Express 的 3D 线图展示这些数据。
数据
在这个简短的教程中,我们将使用来自荷兰北海区域八口井的勘探数据。这些数据来源于提供免费使用的井记录数据的NLOG.nl网站,这些数据是过去几十年中获取的。
这些文件包含以下列:
-
MD:测量深度(米)
-
INC:钻孔倾斜度(度)
-
AZI:钻孔方位角(度)
-
TVD:真实垂直深度(米)
-
XOFFSET:井在 X 方向上的位置(米)
-
YOFFSET:井在 Y 方向的位置(米)
为了简化数据加载,每个文件中的列名都是相同的。
数据集的完整细节可以在文章末尾找到。
此外,本教程中使用的所有井数据都具有相同的来源。我将在未来的文章中讲解如何将 XOFFSET 和 YOFFSET 转换为网格坐标。
设置数据
首先,我们需要导入多个库。对于这篇文章,我们将相对简单地导入pandas — 用于读取我们的csv
数据,和Plotly Express — 用于创建我们的可视化,还有os — 用于读取包含文件的目录。
import pandas as pd
import plotly.express as px
import os
我们可以通过几种方式将数据加载到 Python 和Pandas中。
为了简化操作,我保留了长形式的代码,以帮助那些新手理解发生了什么。
首先,我们需要创建一个空列表来存储我们的文件路径。
# Create empty list to store file paths
survey_data = []
接下来,我们将创建一个变量来存储包含调查数据的 CSV 文件的文件路径位置。
# Set up the file path location. This can be a relative or absolute path
file_path = 'Data/Notebook 43/'
最后,我们可以遍历file_path
目录中的每个文件,并检查是否有扩展名为.csv 的文件。一旦找到这些文件,完整的文件路径将被附加到 file_path 列表中。
# Loop through each file within the file_path location
for f in os.listdir(file_path):
# Check for CSV files
if f.endswith('.csv'):
# Add CSV file location to the survey_data list
survey_data.append(f'{file_path}{f}')
survey_data
当我们查看survey_data
列表时,我们得到以下内容:
['Data/Notebook 43/NLOG - A12-A-01 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-06 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-07 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-02 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-05 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-07-ST1 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-04 - Survey.csv',
'Data/Notebook 43/NLOG - A12-A-03 - Survey.csv']
一旦我们有了文件路径,我们需要使用 pandas 中的.concat
和map
函数将它们合并成一个pandas数据框。
df = pd.concat(map(pd.read_csv, survey_data))
当我们查看数据框(df
)时,我们将得到以下内容。
这是来自八个荷兰部门石油和天然气井的合并井调查数据的数据框。图片由作者提供。
或者,如果你是一个更高级的 Python 用户,并且希望寻找更高效的方法将多个文件直接加载到pandas数据框中,你可以使用以下函数。
如你所见,它更加紧凑。
# Alternatively:
import glob
df = pd.concat(map(pd.read_csv, glob.glob('Data/Notebook 43/*.csv')))
为了确认我们已经加载了所有的井数据,我们可以调用WELL列并查看其中的唯一值。
df['WELL'].unique()
这将返回一个包含 8 个井的数组。
array(['A12-A-01', 'A12-A-06', 'A12-A-07', 'A12-A-02', 'A12-A-05',
'A12-A-07-ST1', 'A12-A-04', 'A12-A-03'], dtype=object)
创建一个交互式的 Plotly Express 3D 折线图
对于这个例子,我们将使用Plotly Express 的 3D 折线图而不是使用 matplotlib 的 3D 图。从我的经验来看,这种方式生成和处理 3D 图似乎更快、更高效、更简便。
要创建我们的3D 折线图(Line Plot),我们首先需要创建一个图形并将其分配给px.line_3d()
。
在此函数中,我们将传入我们的 XOFFSET、YOFFSET、TVD 和 WELL 列。
这将允许我们绘制每个井的 X、Y 和 TVD(真实垂直深度)位置,并通过 WELL 列用颜色区分它们。
fig = px.line_3d(x=df.XOFFSET,
y=df.YOFFSET,
z=df.TVD,
color=df.WELL)
fig.show()
一旦我们运行上述代码,我们会得到如下图形。
Plotly Express 3D 线条图显示了在调整比例和线条之前的井路径情况。图片由作者提供。
从图中可以明显看出,我们需要调整它,以获得更好的效果。
这个图表存在一些问题,我们需要解决:
-
整体图形较小
-
线条宽度非常细
-
井路径颠倒了
我们可以通过一些调整来解决这个问题。
首先,我们将更新图表的布局,并定义一个width
和height
参数。
然后,我们将更新包含 TVD 测量的 z 轴,并将其设置为反向比例。
最后,我们可以通过调用update_traces()
来加粗线条。
fig = px.line_3d(x=df.XOFFSET,
y=df.YOFFSET,
z=df.TVD,
color=df.WELL)
fig.update_layout(width=800,
height=800,
autosize=False,
scene={'zaxis':{'autorange':'reversed'}})
fig.update_traces(line={'width':10})
fig.show()
当我们运行上述代码时,我们会得到如下互动图形。
Plotly Express 3D 线条图显示了在调整比例和线条后井路径的情况。图片由作者提供。
我们可以看到,井路径现在正确定位,并且我们有了一个稍微大的图形进行操作。
以这种方式可视化数据使我们能够看到井的去向,以及是否有任何井的侧钻。如果我们专注于规划未来的井,我们能够在过程早期识别任何潜在问题。
总结
清晰地可视化井路径是理解井位置的绝佳方法。它不仅可以显示井在地下的位置,还能显示相对于其他附近井的位置。
在这个简短的教程中,我们已经了解了如何利用 Python 和 Plotly Express 库,在互动 3D 图上可视化多个井路径。这使我们能够以易用的方式深入了解井的具体位置。
在你的下一个项目中试试看吧。
数据详情
本教程中使用的数据下载自NLOG.nl,这是一个包含整个荷兰北海区域井记录数据的网站。数据可以免费下载和使用。数据许可证的完整细节可以在这里找到,但知识产权部分提供了使用的摘要:
NLOG.NL 不对通过本网站提供的信息(域名、商标权、专利及其他知识产权除外)主张任何权利。用户可以在不事先获得 NLOG.NL 的书面许可或合法授权方同意的情况下,复制、下载、公开、分发或简化本网站提供的信息。用户还可以复制、重复、处理或编辑这些信息和/或布局,前提是标明 NLOG.NL 作为来源。
感谢阅读。在离开之前,你一定要订阅我的内容,将我的文章直接送到你的收件箱。 你可以在这里完成订阅!或者,你也可以 注册我的新闻通讯 以便将额外的内容直接免费送到你的收件箱。
其次,你可以通过注册会员获得完整的 Medium 体验,并支持成千上万的其他作家和我。它只需每月 $5,你将全面访问所有精彩的 Medium 文章,并有机会通过写作赚取收入。
如果你通过 我的链接注册,你将直接用部分费用支持我,而且不会增加你的成本。如果你这样做了,非常感谢你的支持。
Arabica 中的可视化模块加速了文本数据的探索
原文:
towardsdatascience.com/visualization-module-in-arabica-speeds-up-text-data-exploration-47114ad646ce
Arabica 现在提供了 unigram、bigram 和 trigram 词云、热图和折线图,以进一步加速时间序列文本数据分析。
·发表于 Towards Data Science ·阅读时长 6 分钟·2023 年 1 月 9 日
--
图 1. Bigram 词云,作者提供的图像。
介绍
Arabica 是一个用于探索性文本数据分析的 Python 库,专注于从时间序列的角度分析文本。它反映了许多文本数据集作为时间上重复观察的经验现实。时间序列文本数据包括新闻文章标题、研究文章摘要和元数据、产品评论、社交网络通信等。 Arabica 通过提供这些方法简化了这些数据集的探索性分析(EDA):
-
arabica_freq: 描述性 n-gram 分析和时间序列 n-gram 分析,用于基于 n-gram 的文本数据集 EDA
-
cappuccino: 用于数据的可视化探索。
本文介绍了Cappuccino,Arabica 的时间序列文本数据探索性分析可视化模块。请阅读文档和教程这里以获取对 Arabica 的一般介绍。
EDIT Jan 2023: Arabica 已经更新。请查看 文档 以获取参数的完整列表。
2. Cappuccino,探索性文本数据分析的可视化
实现的图表包括词云(unigram、bigram 和 trigram 版本)、热图和折线图。它们帮助发现(1)最频繁的 n-grams,反映其时间序列特征(词云)和(2)n-grams 随时间的发展(热图、折线图)。
这些图表设计用于演示、报告和实证研究。因此,它们具有高分辨率(像素取决于图表中显示的数据范围)。
Cappuccino 依赖于matplotlib、worcloud、和plotnine来创建和展示图表,同时使用cleantext和NTLK停用词语料库进行预处理。Plotnine 将广泛使用的ggplot2库实现到了 Python 中。要求在这里。
方法的参数如下:
def cappuccino(text: str, # Text
time: str, # Time
plot: str = '', # Chart type: 'wordcloud'/'heatmap'/'line'
ngram: int = '', # N-gram size, 1 = unigram, 2 = bigram, 3 = trigram
time_freq: str= '', # Aggregation period: 'Y'/'M'', if no aggregation: 'ungroup'
max_words int = '', # Max number for most frequent n-grams displayed for each period
stopwords: [], # Languages for stop words
skip: [ ], # Remove additional strings
numbers: bool = False, # Remove numbers
punct: bool = False, # Remove punctuation
lower_case: bool = False # Lowercase text
)
3. 描述性 n-gram 可视化
Arabica 中的描述性分析提供了 n-gram 频率计算,但不对特定时期进行汇总。简单来说,首先计算每个文本记录的 n-grams 频率,其次,将这些频率对整个数据集进行汇总,最后,将这些频率可视化到图表中。
词云
让我们以百万新闻标题为例,这些新闻标题按日发布,数据集的时间范围是 2003–2–19 到 2016–09–18。该数据集由澳大利亚广播公司提供,采用CC0: 公共领域许可证。我们将数据子集化为前 50,000 条标题。
首先,使用 pip install arabica
安装 Arabica,然后导入 Cappuccino:
from arabica import cappuccino
使用 pandas
读取数据后,数据如下所示:
图 2. 百万新闻标题数据
我们将文本转换为小写,清理数据中的标点符号和数字,去除英文停用词和其他不需要的字符串(“g”、“br”),并绘制包含 100 个最频繁单词的词云:
cappuccino(text = data['headline'],
time = data['date'],
plot = 'wordcloud',
ngram = 1, # n-gram size, 1 = unigram, 2 = bigram, 3 = trigram
time_freq = 'ungroup', # no period aggregation
max_words = 100, # displays 100 most frequent words
stopwords = ['english'], # remove English stopwords
skip = ['g','br'], # remove additional strings
numbers = True, # remove numbers
punct = True, # remove punctuation
lower_case = True # lowercase text
)
它返回词云:
图 3. 词云,作者图片。
在将 ngram = 2
更改后,我们获得了包含 100 个最频繁二元组的词云(见封面图片)。另外,ngram = 3
显示最频繁的三元组:
图 4. 词云 — 三元组,作者图片。
4. 时间序列 n-gram 可视化
时间序列文本数据通常显示出随时间变化的波动。选举前的政治声明和新冠疫情期间的新闻标题就是很好的例子。为了展示时间段内的 n-grams,Arabica 实现了热图和折线图用于按月和按年展示。
作者图片,来源:Draw.io
热图
一个显示每个月十个最频繁单词的热图,代码如下:
cappuccino(text = data['headline'],
time = data['date'],
plot = 'heatmap',
ngram = 1, # n-gram size, 1 = unigram, 2 = bigram
time_freq = 'M', # monthly aggregation
max_words = 10, # displays 10 most frequent words for each period
stopwords = ['english'], # remove English stopwords
skip = ['g', 'br'], # remove additional strings
numbers = True, # remove numbers
punct = True, # remove punctuation
lower_case = True # lowercase text
)
单元词热图是输出结果:
图 5. 热图 — 单元词,作者图片。
单元组热图让我们首次了解数据的时间变异性。我们可以清晰地识别数据中的重要模式:
最常见的 n-grams:“us”、“警察”、“新”、“人”。
离群值(仅在一个周期内出现的术语):“战争”、“wa”、“雨”、“被杀”、“伊拉克”、“担忧”、“预算”、“巴厘岛”。
我们可以考虑在分析的后期去除离群值。或者,通过更改 ngram = 2
和 max_words = 5
创建一个二元组 热图,显示每个周期中最常见的五个二元组。
图 6. 热图 — 二元组,作者提供的图像。
线图
通过更改 plot = ‘line’
显示带有 n-grams 的线图。通过将 ngram
参数设置为 1 和 max_words = 5
,我们为每个周期中的五个最常见的词创建一个线图:
图 7. 线图 — 单元组,作者提供的图像。
类似地,通过更改 ngram = 2
和 max_words = 3
,二元组线图看起来像这样:
图 8. 线图 — 二元组,作者提供的图像。
最终备注
Cappuccino 在可视化探索具有时间序列特征的文本数据方面大有帮助。只需一行代码,我们就能预处理数据并提供数据集的初步探索性视图。以下是一些建议:
-
可视化的频率也取决于数据中时间维度的长度。在长时间序列中,月度图无法清晰显示数据,而在短时间序列(少于一年)的年频率图则无法提供时间上的变异性。
-
根据项目中的数据集选择合适的可视化形式。对于 n-gram 变异性高的数据集(见图 8),线图不是一个好的选择。在这种情况下,即使每个周期中有许多 n-grams,热图也能更好地显示数据。
我们可以用 Arabica 回答的一些问题是 (1) 特定领域(经济学、生物学等)的概念如何随时间演变,使用研究文章元数据,(2) 在总统竞选期间,哪些关键主题被强调,使用Twitter 推文,(3) 公司应该改进品牌和沟通的哪些部分,使用客户产品评论。
本教程中的完整代码在我的 GitHub。欲了解更多示例,请阅读 文档 和关于 arabica_freq 方法的 教程。
编辑: Arabica 现在有了一个 情感和结构性断裂 分析模块。阅读更多内容并查看这些教程中的实际应用:
-
时间序列文本数据中的情感分析和结构性突破
-
N-gram 和情感分析的客户满意度测量
你喜欢这篇文章吗?你可以邀请我 喝咖啡 并支持我的写作。你也可以订阅我的 电子邮件列表 以获取关于我新文章的通知。谢谢!
照片由 Kanwardeep Kaur 拍摄,来源于 Unsplash
嵌入的可视化
可视化高维数据的方式不止一种。在这里,我们回顾了人工智能的发展历史,以探索这些可视化的演变。
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 5 月 27 日
--
我在 1990 年将我的第一篇关于人工智能的论文提交到一个小型地方会议——“中西部人工智能与认知科学学会”。在那些日子里,人工智能领域完全被“符号”研究所定义。这种方法被称为“古老的人工智能”或 GOFAI(发音为“go fi”,类似“wifi”)。我们这些现在被称为“深度学习”的研究者不得不真的争论我们所研究的内容是否应被视为人工智能。
被排除在人工智能之外是把双刃剑。一方面,我不同意当时定义的人工智能的基本信条。基本假设是“符号”和“符号处理”必须是所有人工智能的基础。因此,我很高兴能在一个甚至不被认为是人工智能的领域工作。另一方面,如果你不将自己的想法包装成至少与人工智能相关的内容,就很难找到愿意倾听你意见的人。
这个小型会议接受了关于“人工智能”和“认知科学”的论文——我视此为对“符号处理”之外的想法的邀请。所以我提交了我的第一篇论文,并且被接受了!论文展示了一种处理自然语言的神经网络方法。我们这个领域的许多人称这种神经网络研究为“连接主义”,但现在这种研究,如前所述,会被标记为“深度学习”(DL)——尽管我的初期研究并不深……只有三层!现代 DL 系统可以由数百层组成。
我的论文在会议上被接受了,我于 1990 年在伊利诺伊州的卡本代尔进行了展示。后来,会议组织者 John Dinsmore 邀请我提交论文的一个版本用于他正在编辑的一本书。我觉得自己无法独立完成论文,于是邀请了两位研究生朋友(Lisa Meeden 和 Jim Marshall)加入我。他们同意了,我们最终在书中完成了一章。这本书的标题是“符号与连接主义范式:弥合差距”。我们的论文很契合书的主题。我们将论文标题为“探索符号/亚符号连续体:RAAM 的案例研究”。令我高兴的是,这本书聚焦于这两种人工智能方法之间的分裂。我认为这个领域至今仍在挣扎于这种分歧。
关于我初期研究的更多内容,我会稍后再说。现在我想谈谈这个领域如何处理“嵌入”的可视化。首先,我们当时并没有称这些向量为“嵌入”。大多数研究使用了“隐层表示”这样的术语。这包括了连接主义系统为解决问题而学习的任何内部表示。按照我们当时的定义,有三种层:“输入”(在这里你插入数据集)、“输出”(在这里你放置期望的输出或“目标”)以及其他所有层——即“隐层”。隐层是网络激活在输入和输出之间流动的地方。隐层激活通常是高维的,并且是网络学习到的“概念”的表示。
像今天一样,视觉化这些高维向量被认为有助于深入理解这些系统的工作原理及其常见故障。在我们书中的章节中,我们使用了三种类型的可视化:
-
所谓的“Hinton 图”
-
聚类图,或树状图
-
投影到二维空间
第一种方法是 Hinton 和 Shallice 在 1991 年 使用的一个新创意。(这就是我们今天所知道的 Geoffrey Hinton。更多内容将在未来的文章中介绍)。这个图表是一个简单的想法,实用性有限。基本想法是,激活、权重或任何类型的数值数据可以用盒子表示:白色盒子(通常表示正数)和黑色盒子(通常表示负数)。此外,盒子的大小表示相对于模拟神经元中的最大值和最小值的值的大小。
这是我们论文中的表示,显示了网络隐藏层中的平均“嵌入”作为对网络呈现的单词的表示:
我们论文中的图 10。
Hinton 图确实有助于可视化数据中的模式。但它们并不真正有助于理解表示之间的关系,也无助于在维度数目大大增加时的理解。现代嵌入可以有数千维。
为了帮助解决这些问题,我们转向第二种方法:聚类图或 树状图。这些图表显示了任意两个模式之间的距离(无论如何定义)作为一个层次树。以下是我们论文中使用欧几里得距离的一个示例:
我们论文中的图 9。
这与 Hinton 图中显示的信息相同,但格式更加实用。在这里,我们可以看到个别模式之间以及整体模式之间的内部关系。请注意,垂直排序无关紧要:分支点的水平位置是图表的有意义方面。
在上面的树状图中,我们手动构建了整体图像,基于程序计算出的树簇。今天,有自动构建这种树和图像的方法。然而,当模式的数量远超过几十个时,图表可能变得难以理解。这是今天由 matplotlib 制作的一个示例。你可以在这里了解更多关于 API 的信息:matplotlib 树状图。
现代树状图,包含大量模式。图像由作者制作。
最后,我们来到了最后一种方法,也就是今天主要使用的方法:投影方法。这种方法使用算法找到一种将嵌入的维度数减少到更容易被人类理解的数量(例如,2 或 3 维)并绘制为散点图的方法。
在 1990 年,投影高维数据到较小维度集的主要方法是主成分分析(简称 PCA)。维度降维是一个活跃的研究领域,目前仍在不断开发新的方法。
目前最常用的降维算法包括:
-
PCA
-
t-SNE
-
UMAP
哪种方法最好?这实际上取决于数据的细节以及你创建维度降维的目标。
PCA 可能是总体上最好的方法,因为它是确定性的,并且允许你从高维空间创建到降维空间的映射。这对于在一个数据集上训练,然后检查测试数据集在学习空间中的投影非常有用。然而,PCA 可能会受到未缩放数据的影响,并可能导致一个“点球”,这对结构模式的洞察较少。
t-SNE,即 t-分布随机邻域嵌入,由 Roweis 和 Hinton(是的,就是那个 Hinton)于 2002 年创建。这是一种学习型投影,可以利用未缩放的数据。然而,t-SNE 的一个缺点是它不会创建映射,而只是一个用于寻找聚类的学习方法。也就是说,与具有Projection.fit()和Projection.transform()方法的其他算法不同,t-SNE 只能进行拟合。(有些实现,如openTSNE,提供了转换映射。然而,openTSNE 似乎与其他算法非常不同,速度较慢,支持度也低于其他形式。)
最后,还有 UMAP,即均匀流形近似与投影。该方法由McInnes 和 Healy 于 2018 年创建。对于许多高维空间,这可能是最佳的折中方案,因为它计算开销相对较小,但能够在降维中保留重要的表示结构。
这是一个将维度降维算法应用于 sklearn 中可用的未缩放乳腺癌数据的示例:
三种投影方法的降维示例,包括 PCA、t-SNE 和 UMAP。图像由作者制作。
你可以自己测试维度降维算法,以找到适合你用例的最佳方法,并使用Kangas DataGrid创建如上所示的图像。
如前所述,维度缩减仍然是一个活跃的研究领域。我完全期待在这一领域看到持续的改进,包括可视化信息在深度学习网络中的流动。以下是我们书籍章节中的一个最终示例,展示了激活在我们模型的表示空间中的流动:
图 7 来自我们的论文。神经网络解码部分中单步的隐藏层激活。
对人工智能、机器学习和数据科学中的想法来源感兴趣吗?请考虑点赞和订阅。告诉我你感兴趣的内容!
纽约市的可视化
原文:
towardsdatascience.com/visualizations-of-new-york-city-f7b4772e9518
使用 Python 和 Plotly 让 NYC 开放数据生动起来
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 8 月 18 日
--
图片由 Fabien BELLANGER 提供,来源于 Unsplash
纽约市的 开放数据平台 是一个令人惊叹的信息来源。城市收集和生成的所有公共数据都 依法 通过该门户提供,并且对公众免费使用。
数据集涵盖了交通、住房和机动车事故等信息,甚至包括中央公园松鼠普查和公园护林员报告的攻击性乌龟遭遇情况。
像这些地理、基础设施和社会学数据集代表了现实世界的过程和事件。即使你与纽约市或城市地区没有联系,或者对它们兴趣不大,它们也为你提供了一个机会,让你处理的数据更接近于你在专业角色中会遇到的数据,而不像 MNIST 或泰坦尼克号乘客那样。更好的是,它们几乎一样容易获取。
我们将演示这些数据集的使用有多么简单,并在此过程中构建一些有趣的可视化。
为了尽可能简洁地保持代码块,这里是本文中所有代码所需的模块:
import folium
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
from scipy.stats import gaussian_kde
import seaborn as sns
from shapely.geometry import Point, shape, box, Polygon
如果你想自己复制任何内容,确保它们已经安装。
建筑物足迹
这是我最喜欢玩的数据集之一。数据包括 NYC 大多数建筑物的足迹多边形、年龄和高度。
我们将从数据提取开始,分开于可视化代码,因为我们使用此数据集制作了几个不同的可视化。
# Pull data
api_endpoint = 'https://data.cityofnewyork.us/resource/qb5r-6dgf.json'
limit = 1000 # Number of rows per request
offset = 0 # Starting offset
data_frames = [] # List to hold chunks of data
# Loop to fetch data iteratively
# while offset <= 100000: # uncomment this and comment while True to fetch a
# sample much faster
while True: # while True will take a long time but gets all the data
url = f"{api_endpoint}?$limit={limit}&$offset={offset}"
chunk = pd.read_json(url)
if chunk.empty:
break # Stop the loop when no more data is returned
data_frames.append(chunk)
offset += limit
# Concatenate all chunks into a single DataFrame
data = pd.concat(data_frames, ignore_index=True)
# Convert the 'the_geom' column from a dictionary to a Shapely geometry object
data['geometry'] = data['the_geom'].apply(lambda x: shape(x))
# Convert the Pandas DataFrame to a GeoDataFrame
gdf = gpd.GeoDataFrame(data, geometry='geometry', crs="EPSG:4326")
# Convert 'MultiPolygon' to representative points for visualization
gdf['centroid'] = gdf['geometry'].centroid.to_crs(epsg=3395).centroid.to_crs(epsg=4326)
# Get rid of columns we don't need anymore
keep_cols = ['cnstrct_yr', 'heightroof', 'geometry', 'centroid']
gdf = gdf[keep_cols]
这为我们提供了一个工作数据集,长得像这样。
几何列是建筑轮廓的多边形,重心列是轮廓中心的单一点,即建筑位置的单一经纬度。现在我们可以深入到有趣的部分了。
新旧建筑的对比
此时,城市已成为旧褐石房、密集的独栋住宅和排屋、方块形公寓楼以及俯瞰河流的玻璃塔楼的混合体。
比如我们想找一个历史悠久的社区居住,可能希望有很多老旧建筑来增加一些历史魅力。在许多美国城市,尤其是在东北部,1930 年代及之前的建筑被称为“战前建筑”或口语中的“prewars”。
为了找到适合我们需求的社区,我们可以使用散点图在地图上标记每栋建筑的重心位置,并使用 KDE 图突出显示战前建筑的高密度。
# Create the bounding box from the provided corner points
bounding_box = box(-74.0294, 40.685, -73.91695, 40.742)
# Filter the GeoDataFrame using the bounding polygon
gdf = gdf[gdf['centroid'].within(bounding_box)]
# Create a new column for building decade
gdf['decade'] = (gdf['cnstrct_yr'] // 10) * 10
# Remove rows where 'cnstrct_yr' is NaN
gdf = gdf[gdf['cnstrct_yr'].notna()]
# Get unique decades
unique_decades = sorted(gdf['decade'].unique())
# Use the Cividis colorscale and split it for our unique decades
colorscale = px.colors.sequential.Cividis
num_decades = len(unique_decades)
colors = [colorscale[i * len(colorscale) // num_decades] for i in range(num_decades)]
color_map = dict(zip(unique_decades, colors))
# Filter the data for buildings built in the 1930s and earlier
old_buildings = gdf[gdf['decade'] <= 1930]
# Create a new figure for better control
fig = go.Figure()
# Add the traces for each decade
for decade, color in color_map.items():
subset = gdf[gdf['decade'] == decade]
# Add the original trace with showlegend set to False
fig.add_trace(go.Scattermapbox(
lat=subset['centroid'].y,
lon=subset['centroid'].x,
mode='markers',
marker=go.scattermapbox.Marker(
size=3,
color=color,
opacity=0.8
),
text=decade,
name=str(int(decade)),
hoverinfo='none',
showlegend=False
))
# Add a dummy trace with larger markers for the legend
# placed outside the visible map
fig.add_trace(go.Scattermapbox(
lat=[90], # Latitude outside the visible map
lon=[0], # Longitude outside the visible map
mode='markers',
marker=go.scattermapbox.Marker(
size=10,
color=color,
opacity=1
),
legendgroup=str(int(decade)),
showlegend=True,
name=str(int(decade)),
hoverinfo='none'
))
# Add heatmap for older buildings
fig.add_trace(go.Densitymapbox(
lat=old_buildings['centroid'].y,
lon=old_buildings['centroid'].x,
radius=4,
colorscale="Greens",
opacity=1,
name="Density of Prewar Buildings",
showlegend=True,
zmax=3,
zmin=0,
showscale=False
))
fig.update_layout(
title='Buildings by Decade with Density Underlay for Prewar Buildings',
autosize=True,
mapbox=dict(
accesstoken=None,
bearing=0,
center=dict(lat=40.71359, lon=-73.97216),
pitch=0,
zoom=12.6,
style='carto-positron'
),
height=800,
width=1200,
legend=dict(tracegroupgap=0)
)
# Display the map
fig.show()
显示战前建筑数量较多的街道和社区。图表由 Plotly 创建。
如果你对纽约市熟悉,这些数据中的一些结果并不令人惊讶,但有些可能会让人感到意外。科布尔希尔和西村的经典褐石房排是显而易见的。但是,我实际上没有预料到格林点和威廉斯堡(布鲁克林北端)有如此多的老建筑。
建筑尺寸
由于数据集中有建筑轮廓和屋顶高度,我们可以计算建筑体积。然后,我们可以通过行政区可视化建筑的平均大小以进行比较。
为了实现这一点,我们将使用建筑轮廓的几何形状和屋顶高度来计算数据中每栋建筑的体积。我们将使用每个轮廓的重心来确定其所在的行政区。
# ...following from data pull code block
# Define borough bounding boxes
# These are very loose bounds and a more thorough analysis should use a higher
# precision polygon.
boroughs = {
"Manhattan": box(-74.02, 40.70, -73.91, 40.88),
"Bronx": box(-73.93, 40.80, -73.79, 40.92),
"Brooklyn": box(-74.05, 40.57, -73.85, 40.74),
"Queens": box(-73.94, 40.54, -73.70, 40.80),
"Staten Island": box(-74.26, 40.50, -74.03, 40.65)
}
# Assign borough to each building based on its centroid
def assign_borough(centroid):
for borough, bbox in boroughs.items():
if bbox.contains(centroid):
return borough
return None
# Assuming the gdf variable already contains your data
gdf['borough'] = gdf['centroid'].apply(assign_borough)
# Calculate building volume using footprint area and height
gdf['volume'] = gdf['geometry'].area * gdf['heightroof']
# Compute average volume by borough
avg_volume_by_borough = gdf.groupby('borough')['volume'].median()
# Create 3D bar shapes using surface plots
def create_3d_bar(x, y, z, dx, dy, dz):
# Define vertices of the bar
x_data = [[x, x, x+dx, x+dx, x], [x, x, x+dx, x+dx, x]]
y_data = [[y, y+dy, y+dy, y, y], [y, y+dy, y+dy, y, y]]
z_data = [[z, z, z, z, z], [z+dz, z+dz, z+dz, z+dz, z+dz]]
return go.Surface(
x=x_data,
y=y_data,
z=z_data,
colorscale=[[0, 'blue'], [1, 'blue']],
showscale=False,
opacity=0.5
)
# Define bar dimensions
dx = 0.4
dy = 0.4
# Create figure
fig = go.Figure()
# Add bars to the figure
for i, borough in enumerate(avg_volume_by_borough.index):
fig.add_trace(create_3d_bar(i, 0, 0, dx, dy, avg_volume_by_borough[borough]))
# Define the layout with adjusted aspect ratio for wider chart area
fig.update_layout(
title='Average Building Volume by Borough',
scene=dict(
xaxis=dict(
title='Borough',
tickvals=list(range(len(avg_volume_by_borough))),
ticktext=avg_volume_by_borough.index
),
yaxis=dict(title='', visible=False),
zaxis=dict(title='Average Building Volume (m³)'),
aspectratio=dict(x=3, y=2, z=1.5) # Adjusting the aspect ratio for wider x-axis
),
margin=dict(t=40, b=40, l=60, r=60)
)
fig.show()
各区建筑体积的体积柱状图。图表由 Plotly 创建。
曼哈顿的建筑更大这一点并不令人惊讶。然而,你可能会预期其他三个区之间存在更大的差异,尤其是布鲁克林的塔楼俯瞰东河与斯塔滕岛明显的纽约市郊外风貌之间的差异。
四个外区都有大量低密度住房。特别是皇后区,其面积巨大,比第二大区布鲁克林大 56%。
我们可以在建筑尺寸分布的箱形图中看到这一点。请注意体积是对数化的,以更好地展示外区与曼哈顿大量高楼之间的差异。
皇后区的小型独栋住宅数量较大,相较于斯塔滕岛尤为明显。
WiFi 热点
纽约市在全市范围内运营 wifi 热点,其位置可通过开放数据门户访问。这些只是市政府运营的热点,所以像星巴克这样的地方不包括在内。
我们可以直接从门户网站读取 JSON 文件并创建位置地图。
# Define the URL
url = "https://data.cityofnewyork.us/resource/yjub-udmw.json"
# Send a GET request
response = requests.get(url)
# Load the response into a JSON
data = response.json()
# Convert the JSON data into a DataFrame
df = pd.DataFrame(data)
# Convert lat and lon columns to float
df['latitude'] = df['latitude'].astype(float)
df['longitude'] = df['longitude'].astype(float)
# Map the borough codes to borough names
borough_dict = {'1': 'Manhattan', '2': 'Bronx', '3': 'Brooklyn', '4': 'Queens', '5': 'Staten Island'}
df['borough'] = df['borough'].map(borough_dict)
# Replace the 'token' with your own Mapbox access token
px.set_mapbox_access_token('token')
fig = px.density_mapbox(
df,
lat='latitude',
lon='longitude',
zoom=10,
mapbox_style="carto-positron",
title="Distribution of WiFi Hotspots in NYC",
radius=6
)
fig.update_layout(
height=800,
width=1200
)
fig.show()
纽约市运营的 wifi 热点的位置。图表由 Plotly 创建。
纽约市的大多数地铁站都有免费 wifi,这在地图上显示得非常清晰。中央公园两侧延伸出的高密度热点区域大多是地铁站。
你还可以通过寻找排列成不自然直线的地点,清楚地看到两条主要的地铁线路通向皇后区和布鲁克林。
松鼠普查
松鼠普查(www.thesquirrelcensus.com/
)是一个科学、设计和讲故事的项目,专注于东部灰松鼠(Sciurus carolinensis)。他们统计松鼠数量并将结果公之于众。
这些数据包含了 3,023 次观察的松鼠数据,包括位置坐标、年龄、主要和次要毛色、高度、活动、通讯以及松鼠与人类之间的互动。
现在让我们提出一个有趣的问题。
中央公园中最吵的松鼠在哪里?
数据中有三个字段表示松鼠发出的不同噪音/叫声:‘kuks’、‘quaas’和‘moans’。我们将任何噪音视为吵闹,并创建一个密度图以显示吵闹的松鼠聚集在哪里。
# Pull data
data_url = 'https://data.cityofnewyork.us/api/views/vfnx-vebw/rows.csv?accessType=DOWNLOAD'
squirrels = pd.read_csv(
data_url,
usecols=['X', 'Y', 'Kuks', 'Quaas', 'Moans']
)
# Create column denoting that the squirrel made any kind of noise
squirrels['noisy'] = squirrels[['Kuks', 'Quaas', 'Moans']].any(axis=1)
# Filter out the quiet squirrels
noisy_squirrels = squirrels[squirrels['noisy']]
# Convert noisy column to integer
noisy_squirrels['noisy'] = noisy_squirrels['noisy'].astype(int)
# Create the density heatmap
fig = px.density_mapbox(
noisy_squirrels, lat='Y', lon='X', z='noisy', radius=50,
center=dict(lat=40.783, lon=-73.969), # Center coordinates for Central Park
zoom=13,
mapbox_style="stamen-terrain",
#mapbox_style="stamen-watercolor",
color_continuous_scale=["white", "orange", "red"],
range_color=[0, 5], # Adjusting the range for color scale
)
# Set the bearing to orient Central Park horizontally
fig.update_layout(
mapbox_bearing=0,
height=700,
width=1000,
title='Density of Noisy Squirrel Observations',
showlegend=False,
coloraxis_showscale=False
)
fig.show()
松鼠发出声音的观察密度图。图表由 Plotly 创建。
看起来上西区的松鼠有最多的发言权。
最后的说明
在 NYC Open Data 上还有大量数据可以深入探索,确保自己查看。大多数大城市也有自己的开放数据计划。值得注意的广泛数据计划包括洛杉矶、多伦多、伦敦和新加坡。
NYC Open Data 是公共领域数据,根据开放数据常见问题 “对开放数据的使用没有限制”,包括商业用途。有关详细信息,请参阅使用条款。
本帖中所有代码均可在GitHub上获得,我已经将 MIT 许可证附加到仓库中,以便你可以自由使用其中的任何内容。