代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理
代码地址:https://github.com/nfmcclure/tensorflow-cookbook
在讲述skip-gram,CBOW,Word2Vec,Doc2Vec模型时需要复用的函数
- 加载数据函数
- 归一化文本函数
- 生成词汇表函数
- 生成单词索引表
- 生成批量数据函数
加载数据函数
# Load the movie review data # Check if data was downloaded, otherwise download it and save for future use def load_movie_data(data_folder_name): pos_file = os.path.join(data_folder_name, 'rt-polarity.pos') neg_file = os.path.join(data_folder_name, 'rt-polarity.neg') # Check if files are already downloaded if os.path.isfile(pos_file): pos_data = [] with open(pos_file, 'r') as temp_pos_file: for row in temp_pos_file: pos_data.append(row) neg_data = [] with open(neg_file, 'r') as temp_neg_file: for row in temp_neg_file: neg_data.append(row) else: # If not downloaded, download and save movie_data_url = 'http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz' stream_data = urllib.request.urlopen(movie_data_url) tmp = io.BytesIO() while True: s = stream_data.read(16384) if not s: break tmp.write(s) stream_data.close() tmp.seek(0) tar_file = tarfile.open(fileobj=tmp, mode="r:gz") pos = tar_file.extractfile('rt-polaritydata/rt-polarity.pos') neg = tar_file.extractfile('rt-polaritydata/rt-polarity.neg') # Save pos/neg reviews pos_data = [] for line in pos: pos_data.append(line.decode('ISO-8859-1').encode('ascii',errors='ignore').decode()) neg_data = [] for line in neg: neg_data.append(line.decode('ISO-8859-1').encode('ascii',errors='ignore').decode()) tar_file.close() # Write to file if not os.path.exists(save_folder_name): os.makedirs(save_folder_name) # Save files with open(pos_file, "w") as pos_file_handler: pos_file_handler.write(''.join(pos_data)) with open(neg_file, "w") as neg_file_handler: neg_file_handler.write(''.join(neg_data)) texts = pos_data + neg_data target = [1]*len(pos_data) + [0]*len(neg_data) return(texts, target)
归一化文本函数
# Normalize text def normalize_text(texts, stops): # Lower case texts = [x.lower() for x in texts] # Remove punctuation texts = [''.join(c for c in x if c not in string.punctuation) for x in texts] # Remove numbers texts = [''.join(c for c in x if c not in '0123456789') for x in texts] # Remove stopwords texts = [' '.join([word for word in x.split() if word not in (stops)]) for x in texts] # Trim extra whitespace texts = [' '.join(x.split()) for x in texts] return(texts)
生成词汇表函数
# Build dictionary of words构建词汇表(单词和单词数对),词频不够的单词(即标记为unknown的单词)标记为RARE def build_dictionary(sentences, vocabulary_size): # Turn sentences (list of strings) into lists of words split_sentences = [s.split() for s in sentences] words = [x for sublist in split_sentences for x in sublist] # Initialize list of [word, word_count] for each word, starting with unknown count = [['RARE', -1]] # Now add most frequent words, limited to the N-most frequent (N=vocabulary size) count.extend(collections.Counter(words).most_common(vocabulary_size-1)) # Now create the dictionary word_dict = {} # For each word, that we want in the dictionary, add it, then make it # the value of the prior dictionary length for word, word_count in count: word_dict[word] = len(word_dict) return(word_dict)
生成单词索引表
# Turn text data into lists of integers from dictionary def text_to_numbers(sentences, word_dict): # Initialize the returned data data = [] for sentence in sentences: sentence_data = [] # For each word, either use selected index or rare word index for word in sentence.split(): if word in word_dict: word_ix = word_dict[word] else: word_ix = 0 sentence_data.append(word_ix) data.append(sentence_data) return(data)
生成批量数据函数
# Generate data randomly (N words behind, target, N words ahead) def generate_batch_data(sentences, batch_size, window_size, method='skip_gram'): # Fill up data batch batch_data = [] label_data = [] while len(batch_data) < batch_size: # select random sentence to start rand_sentence_ix = int(np.random.choice(len(sentences), size=1)) rand_sentence = sentences[rand_sentence_ix] # Generate consecutive windows to look at window_sequences = [rand_sentence[max((ix-window_size),0):(ix+window_size+1)] for ix, x in enumerate(rand_sentence)] # Denote which element of each window is the center word of interest label_indices = [ix if ix<window_size else window_size for ix,x in enumerate(window_sequences)] # Pull out center word of interest for each window and create a tuple for each window if method=='skip_gram': batch_and_labels = [(x[y], x[:y] + x[(y+1):]) for x,y in zip(window_sequences, label_indices)] # Make it in to a big list of tuples (target word, surrounding word) tuple_data = [(x, y_) for x,y in batch_and_labels for y_ in y] batch, labels = [list(x) for x in zip(*tuple_data)] elif method=='cbow': batch_and_labels = [(x[:y] + x[(y+1):], x[y]) for x,y in zip(window_sequences, label_indices)] # Only keep windows with consistent 2*window_size batch_and_labels = [(x,y) for x,y in batch_and_labels if len(x)==2*window_size] batch, labels = [list(x) for x in zip(*batch_and_labels)] elif method=='doc2vec': # For doc2vec we keep LHS window only to predict target word batch_and_labels = [(rand_sentence[i:i+window_size], rand_sentence[i+window_size]) for i in range(0, len(rand_sentence)-window_size)] batch, labels = [list(x) for x in zip(*batch_and_labels)] # Add document index to batch!! Remember that we must extract the last index in batch for the doc-index batch = [x + [rand_sentence_ix] for x in batch] else: raise ValueError('Method {} not implmented yet.'.format(method)) # extract batch and labels batch_data.extend(batch[:batch_size]) label_data.extend(labels[:batch_size]) # Trim batch and label at the end batch_data = batch_data[:batch_size] label_data = label_data[:batch_size] # Convert to numpy array batch_data = np.array(batch_data) label_data = np.transpose(np.array([label_data])) return(batch_data, label_data)
年岁有加并非垂老
理想丢弃方堕暮年