MindSpore:【课程作业经验】基于TextCNN文本情感分类

#基于TextCNN文本情感分类

在本次实验中我们用mindspore实现TextCNN的针对aclImdb情感分类任务。

数据加载

在这里基于TextCNN的方法我们需要指定文本句子长度,并且对句子进行处理(留下字母和空格其余符号删去)标定标签:pos:1,neg:0

maxlen =20
sentences =[]
labels=[]
posdirname = "aclImdb\\train\\pos\\"
negdirname = "aclImdb\\train\\neg\\"
file_num =10000
for txtfile in os.listdir(posdirname)[:file_num]:
    newline=""
    with open(posdirname+txtfile,encoding="utf-8") as txt:
        line = txt.read()
        s = ''.join(ch for ch in line if (ch.isalnum()|ch.isspace()))
        sentences.append(s[:maxlen])
        labels.append(1)
        
for txtfile in os.listdir(negdirname)[:file_num]:
    newline=""
    with open(negdirname+txtfile,encoding="utf-8") as txt:
        line = txt.read()
        s = ''.join(ch for ch in line if (ch.isalnum()|ch.isspace()))
        sentences.append(s[:maxlen])
        labels.append(0)

模型构建:

我们按照如下方法构建TEXTCNN卷积网络

class TextCNN(nn.Cell):
    def __init__(self, embedding_size, sequence_length, num_classes, filter_sizes, num_filters, vocab_size):
        super(TextCNN, self).__init__()
        self.num_filters_total = num_filters * len(filter_sizes)
        self.filter_sizes = filter_sizes
        self.sequence_length = sequence_length
        self.W = nn.Embedding(vocab_size, embedding_size)
        self.Weight = nn.Dense(self.num_filters_total, num_classes, has_bias=False)
        self.Bias = Parameter(Tensor(np.ones(num_classes), mindspore.float32), name='bias')
        self.filter_list = nn.CellList()
        for size in filter_sizes:
            seq_cell = nn.SequentialCell([
                nn.Conv2d(1, num_filters, (size, embedding_size), pad_mode='valid'),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(sequence_length - size + 1, 1))
            ])
            self.filter_list.append(seq_cell)
        
        self.concat = ops.Concat(axis=len(filter_sizes))

    def construct(self, X):
        embedded_chars = self.W(X)
        embedded_chars = embedded_chars.expand_dims(1)
        pooled_outputs = []
        for conv in self.filter_list:
            pooled = conv(embedded_chars)
            pooled = pooled.transpose((0, 3, 2, 1))
            pooled_outputs.append(pooled)
            
        h_pool = self.concat((pooled_outputs[0], pooled_outputs[1], pooled_outputs[2]))
        h_pool_flat = h_pool.view(-1, self.num_filters_total)
        model = self.Weight(h_pool_flat) + self.Bias
        return mode
    
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)

参数设定

embedding_size = 2
sequence_length = maxlen
num_classes = 2
filter_sizes = [2, 2, 2]
num_filters = 3

word_list = " ".join(sentences).split(" ")
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict)

model = TextCNN(embedding_size, sequence_length, num_classes, filter_sizes, num_filters, vocab_size)

输入格式转换

我们需要使用如上构造的word_dict进行word2vector转换

inputs=[]
for i in sentences:
    sen=[]
    for n in i.split():
        sen.append(word_dict[n])
    if len(sen)<maxlen:
        sen.extend(0 for _ in range(abs(maxlen-len(sen))))
    inputs.append(sen)
inputs = Tensor(inputs,mindspore.int32)
targets = Tensor([out for out in labels]) 

##模型训练

from mindspore import context
context.set_context(mode=context.GRAPH_MODE)

net_with_criterion = nn.WithLossCell(model, criterion)
train_network = nn.TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()

epoch = 5000
for step in range(epoch):
    loss = train_network(inputs, targets)
    
    if (step + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (step + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))

测试结果

test_text = 'The film lacks style'
tests = [word_dict[n] for n in test_text.split()]
print(tests)
if len(tests)<maxlen:
    tests.extend(0 for _ in range(abs(maxlen-len(tests))))
tests = [np.array(tests)]
test_batch = Tensor(tests, mindspore.int32)

predict = model(test_batch).asnumpy().argmax(1)
if predict[0] == 0:
    print(test_text,"is Negative...")
else:
    print(test_text,"is Postive!!")
    
    
The film lacks style is Negative...
posted @ 2022-08-12 10:43  Skytier  阅读(202)  评论(0编辑  收藏  举报