nlp 计数法 PTB数据集
2022-04-05 14:37 jym蒟蒻 阅读(183) 评论(0) 编辑 收藏 举报计数方法应用于PTB数据集
- PTB数据集
- ptb.py
- 使用ptb.py
- 计数方法应用于PTB数据集
内容如下:
一行保存一个句子;将稀有单词替换成特殊字符 < unk > ;将具体的数字替换 成“N”
we 're talking about years ago before anyone heard of asbestos having any questionable properties
there is no asbestos in our products now
neither <unk> nor the researchers who studied the workers were aware of any research on smokers of the kent cigarettes
we have no useful information on whether users are at risk said james a. <unk> of boston 's <unk> cancer institute
dr. <unk> led a team of researchers from the national cancer institute and the medical schools of harvard university and boston university
使用PTB数据集:
由下面这句话,可知用PTB数据集时候,是把所有句子首尾连接了。
words = open(file_path).read().replace('\n', '<eos>').strip().split()
ptb.py起到了下载PTB数据集,把数据集存到文件夹某个位置,然后对数据集进行提取的功能,提取出corpus, word_to_id, id_to_word。
import sys
import os
sys.path.append('..')
try:
import urllib.request
except ImportError:
raise ImportError('Use Python3!')
import pickle
import numpy as np
url_base = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/'
key_file = {
'train':'ptb.train.txt',
'test':'ptb.test.txt',
'valid':'ptb.valid.txt'
}
save_file = {
'train':'ptb.train.npy',
'test':'ptb.test.npy',
'valid':'ptb.valid.npy'
}
vocab_file = 'ptb.vocab.pkl'
dataset_dir = os.path.dirname(os.path.abspath(__file__))
def _download(file_name):
file_path = dataset_dir + '/' + file_name
if os.path.exists(file_path):
return
print('Downloading ' + file_name + ' ... ')
try:
urllib.request.urlretrieve(url_base + file_name, file_path)
except urllib.error.URLError:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
urllib.request.urlretrieve(url_base + file_name, file_path)
print('Done')
def load_vocab():
vocab_path = dataset_dir + '/' + vocab_file
if os.path.exists(vocab_path):
with open(vocab_path, 'rb') as f:
word_to_id, id_to_word = pickle.load(f)
return word_to_id, id_to_word
word_to_id = {}
id_to_word = {}
data_type = 'train'
file_name = key_file[data_type]
file_path = dataset_dir + '/' + file_name
_download(file_name)
words = open(file_path).read().replace('\n', '<eos>').strip().split()
for i, word in enumerate(words):
if word not in word_to_id:
tmp_id = len(word_to_id)
word_to_id[word] = tmp_id
id_to_word[tmp_id] = word
with open(vocab_path, 'wb') as f:
pickle.dump((word_to_id, id_to_word), f)
return word_to_id, id_to_word
def load_data(data_type='train'):
'''
:param data_type: 数据的种类:'train' or 'test' or 'valid (val)'
:return:
'''
if data_type == 'val': data_type = 'valid'
save_path = dataset_dir + '/' + save_file[data_type]
word_to_id, id_to_word = load_vocab()
if os.path.exists(save_path):
corpus = np