triton部署基于wenet的流式asr服务

1、docker镜像下载

下载链接:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags

或者直接 docker pull nvcr.io/nvidia/tritonserver:23.01-py3

 

2、克隆wenet项目

git clone https://github.com/wenet-e2e/wenet.git

3、下载预训练模型,下载链接

https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.en.md

模型文件主要有以下四个文件:

  • final.pt 模型参数文件
  • global_cmvn 归一化单数文件
  • train.yaml  训练配置文件
  • words.txt   字表映射文件

 

4、模型转onnx格式

model_dir =./ 20210618_u2pp_conformer_exp模型实际位置

onnx_model_dir = onnx模型导出位置

python3 wenet/bin/export_onnx_gpu.py --config=$model_dir/train.yaml --checkpoint=$model_dir/final.pt --cmvn_file=$model_dir/global_cmvn --ctc_weight=0.5 --output_onnx_dir=$onnx_model_dir --fp16

onnx格式模型有以下几个文件

  • config.yaml 模型配置文件
  • decoder_fp16.onnx
  • decode.onnx
  • encoder_fp16.onnx
  • encoder.onnx
  • global_cmvn 归一化单数文件
  • train.yaml  训练配置文件
  • words.txt   字表映射文件

 

5、配置文件拷贝

cp $model_dir/words.txt $model_dir/train.yaml $onnx_model_dir/

cp $model_dir/words.txt  $onnx_model_dir/unit.txt

 

6、启动容器:

docker run --gpus '"device=0"' -itd  -p 8000:8000 -p 8001:8001 -p 8002:8002 --shm-size=1g --ulimit memlock=-1 --name triton_server nvcr.io/nvidia/tritonserver:23.01-py3

 

7、模型脚本拷贝到容器或启动模型时做路径映射,添加路径映射的方式使得多个模型切换不方便,且存在修改/workspace/script/convert_start_server.sh 修改模型路径乱码的情况(暂未定位原因)

docker cp ./wenet/runtime/gpu/model_repo_stateful 容器ID:/ws/model_repo/

docker cp ./output_onnx_model 容器ID:/ws/onnx_model/

docker cp /home/ai_data2/jaxyu/wenet/runtime/gpu/scripts 容器ID:/workspace/

 

8、容器环境搭建或基于triton_server重新build一个镜像

  • apt-get update
  • apt-get -y install swig
  • apt-get -y install python3-dev
  • apt-get install -y cmake
  • pip3 install torch=2.0.1 torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple/
  • pip3 install -v kaldifeat pyyaml onnx -i https://pypi.tuna.tsinghua.edu.cn/simple/
  • cd /workspace
  • git clone https://github.com/Slyne/ctc_decoder.git && cd ctc_decoder/swig && bash setup.sh
  • export PYTHONPATH=/root/.local/lib/python3.8/site-packages:/root/.local/lib/python3.8/site-packages/swig_decoders-1.1-py3.8-linux-x86_64.egg:$PYTHONPATH

 

 

9、启动服务

cd  /workspace

sh  ./scripts/convert_start_server.sh

 

 

10、报错信息处理

报错一:No module named 'swig_decoders' 

export PYTHONPATH=/root/.local/lib/python3.8/site-packages:/root/.local/lib/python3.8/site-packages/swig_decoders-1.1-py3.8-linux-x86_64.egg:$PYTHONPATH

报错二:FileNotFoundError: [Errno 2] No such file or directory: '/ws/onnx_model/units.txt' 

cp /ws/onnx_model/words.txt /ws/onnx_model/units.txt

 

11、hotwords添加

vim  -c 'set encoding=utf-8' /ws/model_repo/wenet/hotwords.yaml    #修改热词文件

vim /ws/model_repo/wenet/config_template.pbtxt   将 none改为/ws/model_repo/wenet/hotwords.yaml

 

12、语言模型添加

vim /ws/model_repo/wenet/config_template.pbtxt  修改lm_path参数

训练代码:lmplz -o 5 <input.txt> lm.arpa

lm.arpa 格式语言模型启动服务时会提示转成二进制更快:build_binary -s lm.arpa lm.bin

语言模型构建可参考链接:https://www.zhihu.com/tardis/bd/art/399494766

需注意训练时采用的是字符级还是字词级的ngram

 

13、测试案例

退出容器,切换到项目 /wenet/runtime/gpu/client目录下执行以下代码:

未添加热词时结果为:大学生利用漏洞免费吃肯德基或型
添加热词”获刑”后结果为:大学生利用漏洞免费吃肯德基获刑

 

 

import argparse
import os
import tritonclient.grpc as grpcclient
from utils import cal_cer
from speech_client import *
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    required=False,
    default=False,
    help="Enable verbose output",
)
parser.add_argument(
    "-u",
    "--url",
    type=str,
    required=False,
    default="ip:8001",
    help="Inference server URL. Default is " "localhost:8001.",
)
parser.add_argument(
    "--model_name",
    required=False,
    default="streaming_wenet",
    #default="attention_rescoring",
    choices=["attention_rescoring", "streaming_wenet"],
    help="the model to send request to",
)
parser.add_argument(
    "--wavscp",
    type=str,
    required=False,
    default=None,
    help="audio_id \t wav_path",
)
parser.add_argument(
    "--trans",
    type=str,
    required=False,
    default=None,
    help="audio_id \t text",
)
parser.add_argument(
    "--data_dir",
    type=str,
    required=False,
    default=None,
    help="path prefix for wav_path in wavscp/audio_file",
)
parser.add_argument(
    "--audio_file",
    type=str,
    required=False,
    default=None,
    help="single wav file path",
)
# below arguments are for streaming
# Please check onnx_config.yaml and train.yaml
parser.add_argument("--streaming", action="store_true", required=False)
parser.add_argument(
    "--sample_rate",
    type=int,
    required=False,
    default=16000,
    help="sample rate used in training",
)
parser.add_argument(
    "--frame_length_ms",
    type=int,
    required=False,
    default=25,
    help="frame length",
)
parser.add_argument(
    "--frame_shift_ms",
    type=int,
    required=False,
    default=10,
    help="frame shift length",
)
parser.add_argument(
    "--chunk_size",
    type=int,
    required=False,
    default=16,
    help="chunk size default is 16",
)
parser.add_argument(
    "--context",
    type=int,
    required=False,
    default=7,
    help="subsampling context",
)
parser.add_argument(
    "--subsampling",
    type=int,
    required=False,
    default=4,
    help="subsampling rate",
)

# FLAGS = parser.parse_args()
FLAGS = parser.parse_args(args=[])

speech_client_cls = StreamingSpeechClient
x="./test_wavs/mid.wav"
with grpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose) as triton_client:
    protocol_client = grpcclient
    speech_client = speech_client_cls(triton_client, FLAGS.model_name, protocol_client, FLAGS)
    predictions = []
    result = speech_client.recognize(x)

#output
Get response from 1th chunk: 
Get response from 2th chunk: 大学
Get response from 3th chunk: 大学生利用
Get response from 4th chunk: 大学生利用漏洞
Get response from 5th chunk: 大学生利用漏洞免费吃
Get response from 6th chunk: 大学生利用漏洞免费吃肯德机
Get response from 7th chunk: 大学生利用漏洞免费吃肯德机获刑
Get response from 8th chunk: 大学生利用漏洞免费吃肯德基获刑

 

参考文档与链接: https://github.com/wenet-e2e/wenet/tree/main/runtime/gpu

posted @ 2023-10-20 11:46  glowwormss  阅读(571)  评论(0编辑  收藏  举报