triton部署基于wenet的流式asr服务
1、docker镜像下载
下载链接:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags
或者直接 docker pull nvcr.io/nvidia/tritonserver:23.01-py3
2、克隆wenet项目
git clone https://github.com/wenet-e2e/wenet.git
3、下载预训练模型,下载链接
https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.en.md
模型文件主要有以下四个文件:
- final.pt 模型参数文件
- global_cmvn 归一化单数文件
- train.yaml 训练配置文件
- words.txt 字表映射文件
4、模型转onnx格式
model_dir =./ 20210618_u2pp_conformer_exp模型实际位置
onnx_model_dir = onnx模型导出位置
python3 wenet/bin/export_onnx_gpu.py --config=$model_dir/train.yaml --checkpoint=$model_dir/final.pt --cmvn_file=$model_dir/global_cmvn --ctc_weight=0.5 --output_onnx_dir=$onnx_model_dir --fp16
onnx格式模型有以下几个文件
- config.yaml 模型配置文件
- decoder_fp16.onnx
- decode.onnx
- encoder_fp16.onnx
- encoder.onnx
- global_cmvn 归一化单数文件
- train.yaml 训练配置文件
- words.txt 字表映射文件
5、配置文件拷贝
cp $model_dir/words.txt $model_dir/train.yaml $onnx_model_dir/
cp $model_dir/words.txt $onnx_model_dir/unit.txt
6、启动容器:
docker run --gpus '"device=0"' -itd -p 8000:8000 -p 8001:8001 -p 8002:8002 --shm-size=1g --ulimit memlock=-1 --name triton_server nvcr.io/nvidia/tritonserver:23.01-py3
7、模型脚本拷贝到容器或启动模型时做路径映射,添加路径映射的方式使得多个模型切换不方便,且存在修改/workspace/script/convert_start_server.sh 修改模型路径乱码的情况(暂未定位原因)
docker cp ./wenet/runtime/gpu/model_repo_stateful 容器ID:/ws/model_repo/
docker cp ./output_onnx_model 容器ID:/ws/onnx_model/
docker cp /home/ai_data2/jaxyu/wenet/runtime/gpu/scripts 容器ID:/workspace/
8、容器环境搭建或基于triton_server重新build一个镜像
- apt-get update
- apt-get -y install swig
- apt-get -y install python3-dev
- apt-get install -y cmake
- pip3 install torch=2.0.1 torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple/
- pip3 install -v kaldifeat pyyaml onnx -i https://pypi.tuna.tsinghua.edu.cn/simple/
- cd /workspace
- git clone https://github.com/Slyne/ctc_decoder.git && cd ctc_decoder/swig && bash setup.sh
- export PYTHONPATH=/root/.local/lib/python3.8/site-packages:/root/.local/lib/python3.8/site-packages/swig_decoders-1.1-py3.8-linux-x86_64.egg:$PYTHONPATH
9、启动服务
cd /workspace
sh ./scripts/convert_start_server.sh
10、报错信息处理
报错一:No module named 'swig_decoders'
export PYTHONPATH=/root/.local/lib/python3.8/site-packages:/root/.local/lib/python3.8/site-packages/swig_decoders-1.1-py3.8-linux-x86_64.egg:$PYTHONPATH
报错二:FileNotFoundError: [Errno 2] No such file or directory: '/ws/onnx_model/units.txt'
cp /ws/onnx_model/words.txt /ws/onnx_model/units.txt
11、hotwords添加
vim -c 'set encoding=utf-8' /ws/model_repo/wenet/hotwords.yaml #修改热词文件
vim /ws/model_repo/wenet/config_template.pbtxt 将 none改为/ws/model_repo/wenet/hotwords.yaml
12、语言模型添加
vim /ws/model_repo/wenet/config_template.pbtxt 修改lm_path参数
训练代码:lmplz -o 5 <input.txt> lm.arpa
lm.arpa 格式语言模型启动服务时会提示转成二进制更快:build_binary -s lm.arpa lm.bin
语言模型构建可参考链接:https://www.zhihu.com/tardis/bd/art/399494766
需注意训练时采用的是字符级还是字词级的ngram
13、测试案例
退出容器,切换到项目 /wenet/runtime/gpu/client目录下执行以下代码:
未添加热词时结果为:大学生利用漏洞免费吃肯德基或型
添加热词”获刑”后结果为:大学生利用漏洞免费吃肯德基获刑
import argparse import os import tritonclient.grpc as grpcclient from utils import cal_cer from speech_client import * import numpy as np import argparse parser = argparse.ArgumentParser() parser.add_argument( "-v", "--verbose", action="store_true", required=False, default=False, help="Enable verbose output", ) parser.add_argument( "-u", "--url", type=str, required=False, default="ip:8001", help="Inference server URL. Default is " "localhost:8001.", ) parser.add_argument( "--model_name", required=False, default="streaming_wenet", #default="attention_rescoring", choices=["attention_rescoring", "streaming_wenet"], help="the model to send request to", ) parser.add_argument( "--wavscp", type=str, required=False, default=None, help="audio_id \t wav_path", ) parser.add_argument( "--trans", type=str, required=False, default=None, help="audio_id \t text", ) parser.add_argument( "--data_dir", type=str, required=False, default=None, help="path prefix for wav_path in wavscp/audio_file", ) parser.add_argument( "--audio_file", type=str, required=False, default=None, help="single wav file path", ) # below arguments are for streaming # Please check onnx_config.yaml and train.yaml parser.add_argument("--streaming", action="store_true", required=False) parser.add_argument( "--sample_rate", type=int, required=False, default=16000, help="sample rate used in training", ) parser.add_argument( "--frame_length_ms", type=int, required=False, default=25, help="frame length", ) parser.add_argument( "--frame_shift_ms", type=int, required=False, default=10, help="frame shift length", ) parser.add_argument( "--chunk_size", type=int, required=False, default=16, help="chunk size default is 16", ) parser.add_argument( "--context", type=int, required=False, default=7, help="subsampling context", ) parser.add_argument( "--subsampling", type=int, required=False, default=4, help="subsampling rate", ) # FLAGS = parser.parse_args() FLAGS = parser.parse_args(args=[]) speech_client_cls = StreamingSpeechClient x="./test_wavs/mid.wav" with grpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: protocol_client = grpcclient speech_client = speech_client_cls(triton_client, FLAGS.model_name, protocol_client, FLAGS) predictions = [] result = speech_client.recognize(x) #output Get response from 1th chunk: Get response from 2th chunk: 大学 Get response from 3th chunk: 大学生利用 Get response from 4th chunk: 大学生利用漏洞 Get response from 5th chunk: 大学生利用漏洞免费吃 Get response from 6th chunk: 大学生利用漏洞免费吃肯德机 Get response from 7th chunk: 大学生利用漏洞免费吃肯德机获刑 Get response from 8th chunk: 大学生利用漏洞免费吃肯德基获刑
参考文档与链接: https://github.com/wenet-e2e/wenet/tree/main/runtime/gpu