Transformers Pipeline + Mistral-7B-Instruct-v0.x修改Chat Template
在使用https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3提供的Generate with transformers代码进行测试时,产生以下报错:
from transformers import pipeline
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
chatbot(messages)
TemplateError: Conversation roles must alternate user/assistant/user/assistant/...
这个错误是由于Mistral本身不支持system prompt导致的。
查看tokenizer.apply_chat_template的源码,可以看到默认的chat template是这样的:
def default_chat_template(self):
"""
This template formats inputs in the standard ChatML format. See
https://github.com/openai/openai-python/blob/main/chatml.md
"""
return (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
为了在使用Transformers Pipeline + Mistral模型时能够支持system prompt,我们需要修改默认的chat template:
{% if messages[0]['role'] == 'system' %}
{% set system_message = messages[0]['content'] | trim + '\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{{ bos_token + system_message}}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + message['content'] | trim + eos_token }}
{% endif %}
{% endfor %}
在代码中将默认的chat_template覆盖:
tokenizer.apply_chat_template(
messages,
chat_template=mistral_chat_template,
tokenize=False,
add_generation_prompt=True
)
这样就可以顺利进行推理了。