pytorch-Embedding

Embedding

无初始化embedding

import torch.nn as nn
emb=nn.Embedding(num_embeddings, embedding_dim)

加载预训练模型(如glove)

def build_embedding_matrix(word2idx, embed_dim, dat_fname):
    if os.path.exists(dat_fname):
        print('loading embedding_matrix:', dat_fname)
        embedding_matrix = pickle.load(open(dat_fname, 'rb'))
    else:
        print('loading word vectors...')
        embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))  # idx 0 and len(word2idx)+1 are all-zeros
        fname = './glove.twitter.27B/glove.twitter.27B.' + str(embed_dim) + 'd.txt' \
            if embed_dim != 300 else './glove/glove.42B.300d.txt'
        word_vec = _load_word_vec(fname, word2idx=word2idx)
        print('building embedding_matrix:', dat_fname)
        for word, i in word2idx.items(): # 根据word_vec(包括word和vec)创建embedding_matrix(只有vec)
            vec = word_vec.get(word)
            if vec is not None:
                # words not found in embedding index will be all-zeros.
                embedding_matrix[i] = vec
        pickle.dump(embedding_matrix, open(dat_fname, 'wb'))

    return embedding_matrix

def _load_word_vec(path, word2idx=None): # word2idx: index->word
    fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')  # glove
    word_vec = {}
    for line in fin:
        tokens = line.rstrip().split()
        if word2idx is None or tokens[0] in word2idx.keys():
            word_vec[tokens[0]] = np.asarray(tokens[1:], dtype='float32')  #
            # np.asarray :将token[1:]结构数据转化为ndarray
            # tokens[0]应该是个单词?
    return word_vec

Model
emb = nn.Embedding.from_pretrained_embedding(torch.tensor(embedding_matrix(加载好的),dtype=torch.float))

posted @ 2021-05-22 09:34  ArdenWang  阅读(118)  评论(0编辑  收藏  举报