学习笔记(24)- plato-训练中文模型
先处理中文语料。参考上篇笔记
1. 准备model_definition_file
文件
官方文档给了例子,
plato/example/config/ludwig/metalWOZ_seq2seq_ludwig.yaml
---
input_features:
-
name: user
type: text
level: word
encoder: rnn
cell_type: lstm
reduce_output: null
output_features:
-
name: system
type: text
level: word
decoder: generator
cell_type: lstm
attention: bahdanau
training:
epochs: 100
2. 开始训练模型
注意模型的保存路径
ludwig train \
--data_csv data/metalwoz.csv \
--model_definition_file plato/example/config/ludwig/metalWOZ_seq2seq_ludwig.yaml \
--output_directory "models/joint_models/"
3. 写类文件,加载模型
模型训练完毕之后,就可以使用了。
那么如何使用呢? 需要写类实现接口。
写一个类,继承Conversational Module,来加载和查询模型。
这个类只需要加载模型,查询并负责输出。
我们需要把输入文本转换为pandas dataframe,从输出捕获预测序列,将他们组织为字符串,并返回。
参考 plato.agent.component.joint_model.metal_woz_seq2seq.py
package: plato.agent.component.joint_model.metal_woz_seq2seq
class: MetalWOZSeq2Seq
文件:
plato/agent/component/joint_model/metal_woz_seq2seq.py
"""
MetalWOZ is an MetalWOZ class that defines an interface to Ludwig models.
"""
class MetalWOZSeq2Seq(ConversationalModule):
……
4. 运行Agent
写一个yaml文件,就可以运行Agent了,
参考plato/example/config/application/metalwoz_generic.yaml
,这是一个seq2seq的例子。
plato run --config metalwoz_text.yaml
plato/example/config/application/metalwoz_text.yaml
5. 测试结果
可以做一些输入和测试,看看效果