Transformers Pipeline + Mistral-7B-Instruct-v0.x修改Chat Template

在使用https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3提供的Generate with transformers代码进行测试时,产生以下报错:

from transformers import pipeline

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
chatbot(messages)
TemplateError: Conversation roles must alternate user/assistant/user/assistant/...

这个错误是由于Mistral本身不支持system prompt导致的。
查看tokenizer.apply_chat_template的源码,可以看到默认的chat template是这样的:

def default_chat_template(self):
        """
        This template formats inputs in the standard ChatML format. See
        https://github.com/openai/openai-python/blob/main/chatml.md
        """
        return (
            "{% for message in messages %}"
            "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
            "{% endfor %}"
            "{% if add_generation_prompt %}"
            "{{ '<|im_start|>assistant\n' }}"
            "{% endif %}"
        )

为了在使用Transformers Pipeline + Mistral模型时能够支持system prompt,我们需要修改默认的chat template:

{% if messages[0]['role'] == 'system' %}
    {% set system_message = messages[0]['content'] | trim + '\n\n' %}
    {% set messages = messages[1:] %}
{% else %}
    {% set system_message = '' %}
{% endif %}

{{ bos_token + system_message}}
{% for message in messages %}
    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
    {% endif %}

    {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
    {% elif message['role'] == 'assistant' %}
        {{ ' ' + message['content'] | trim + eos_token }}
    {% endif %}
{% endfor %}

在代码中将默认的chat_template覆盖:

tokenizer.apply_chat_template(
        messages, 
        chat_template=mistral_chat_template,
        tokenize=False, 
        add_generation_prompt=True
)

这样就可以顺利进行推理了。

posted @ 2024-07-17 11:14  一蓑烟雨度平生  阅读(34)  评论(0编辑  收藏  举报