import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data
2.PTB 数据集
简单来说,Word2Vec 能从语料中学到如何将离散的词映射为连续空间中的向量,并保留其语义上的相似关系。本文使用经典的 PTB 语料库进行训练。PTB (Penn Tree Bank) 是一个常用的小型语料库,它采样自《华尔街日报》的文章,包括训练集、验证集和测试集。在PTB训练集上训练词嵌入模型。
载入数据集
withopen('../../inputs/ptb.train.txt', 'r') as f:
lines = f.readlines() # 该数据集中句子以换行符为分割
raw_dataset = [st.split() for st in lines] # st是sentence的缩写,单词以空格为分割print('# sentences: %d' % len(raw_dataset))
# 对于数据集的前3个句子,打印每个句子的词数和前5个词# 句尾符为 '' ,生僻词全用 '' 表示,数字则被替换成了 'N'for st in raw_dataset[:3]:
print('# tokens:', len(st), st[:5])
counter = collections.Counter([tk for st in raw_dataset for tk in st]) # tk是token的缩写
counter = dict(filter(lambda x: x[1] >= 5, counter.items())) # 只保留在数据集中至少出现5次的词
idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk inenumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
for st in raw_dataset] # raw_dataset中的单词在这一步被转换为对应的idx
num_tokens = sum([len(st) for st in dataset])
'# tokens: %d' % num_tokens
# '# tokens: 887100'
defdiscard(idx):
'''
@params:
idx: 单词的下标
@return: True/False 表示是否丢弃该单词
'''return random.uniform(0, 1) < 1 - math.sqrt(
1e-4 / (counter[idx_to_token[idx]] / num_tokens))
subsampled_dataset = [[tk for tk in st ifnot discard(tk)] for st in dataset]
print('# tokens: %d' % sum([len(st) for st in subsampled_dataset]))
defcompare_counts(token):
return'# %s: before=%d, after=%d' % (token, sum(
[st.count(token_to_idx[token]) for st in dataset]), sum(
[st.count(token_to_idx[token]) for st in subsampled_dataset]))
print(compare_counts('the'))
print(compare_counts('join'))
# tokens: 376269# the: before=50770, after=2204# join: before=45, after=45
提取中心词和背景词
defget_centers_and_contexts(dataset, max_window_size):
'''
@params:
dataset: 数据集为句子的集合,每个句子则为单词的集合,此时单词已经被转换为相应数字下标
max_window_size: 背景词的词窗大小的最大值
@return:
centers: 中心词的集合
contexts: 背景词窗的集合,与中心词对应,每个背景词窗则为背景词的集合
'''
centers, contexts = [], []
for st in dataset:
iflen(st) < 2: # 每个句子至少要有2个词才可能组成一对“中心词-背景词”continue
centers += st
for center_i inrange(len(st)):
window_size = random.randint(1, max_window_size) # 随机选取背景词窗大小
indices = list(range(max(0, center_i - window_size),
min(len(st), center_i + 1 + window_size)))
indices.remove(center_i) # 将中心词排除在背景词之外
contexts.append([st[idx] for idx in indices])
return centers, contexts
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context inzip(*get_centers_and_contexts(tiny_dataset, 2)):
print('center', center, 'has contexts', context)
dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [1, 2, 4, 5]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [4, 5]
center 7 has contexts [8]
center 8 has contexts [7, 9]
center 9 has contexts [8]
deftrain(net, lr, num_epochs):
device = torch.device('cuda'if torch.cuda.is_available() else'cpu')
print("train on", device)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch inrange(num_epochs):
start, l_sum, n = time.time(), 0.0, 0for batch in data_iter:
center, context_negative, mask, label = [d.to(device) for d in batch]
pred = skip_gram(center, context_negative, net[0], net[1])
l = loss(pred.view(label.shape), label, mask).mean() # 一个batch的平均loss
optimizer.zero_grad()
l.backward()
optimizer.step()
l_sum += l.cpu().item()
n += 1print('epoch %d, loss %.2f, time %.2fs'
% (epoch + 1, l_sum / n, time.time() - start))
train(net, 0.01, 5)
# train on cpu# epoch 1, loss 0.61, time 221.30s# epoch 2, loss 0.42, time 227.70s# epoch 3, loss 0.38, time 240.50s# epoch 4, loss 0.36, time 253.79s# epoch 5, loss 0.34, time 238.51s
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!