基于TigerBot-13b训练其函数调用能力

写在前面

原生的tigerbot似乎并不支持函数调用,于是我来支持一下

 

数据集

我在huggingface上找了个英文的数据集

https://huggingface.co/datasets/sadmoseby/sample-function-call

这里面包含了1k组的函数调用,这个数据集的特点如下:

1. 包含有单个/多个/没有函数调用的情形

2. 描述函数的json_schema与OpenAI格式的一致(但多函数情况下,并没有用列表框起来)

3. 数据虽然是多轮对话的数据,但是每一个都是一整条的数据,且每个的开头与tigerbot的头不太一致

 

数据转换

我写了一个数据转换的代码,具体任务如下:

1. 将多个函数时没有用列表格式框选的情况给修复了

2. 切分为了多轮的对话,有多条训练数据

3. 修改了开头的情况

代码如下

  1 import re
  2 
  3 import json
  4 import re
  5 
  6 # system_prompt中可能有多个函数,多个函数的话要转为标准的[]格式
  7 def get_function_json(input_string):
  8     # 使用正则表达式分割字符串,找出独立的 JSON 字符串
  9     json_strings = re.findall(r'\{[\s\S]+?\}\s*(?=\{|$)', input_string)
 10 
 11     # 解析每个 JSON 字符串并把它们加入到列表中
 12     json_objects = []
 13     for json_str in json_strings:
 14         input_string = input_string.replace(json_str, '')
 15         try:
 16             json_obj = json.loads(json_str)
 17             json_objects.append(json_obj)
 18         except json.JSONDecodeError as e:
 19             print(f"Error decoding JSON: {e}")
 20     # 打印结果或进行其他操作
 21     if json_objects:
 22         return input_string + json.dumps(json_objects, ensure_ascii=False, indent=4)
 23     else:
 24         return input_string
 25 
 26 # 切分读入的数据
 27 def split_string_with_keywords(s, keywords):
 28     # 将关键词列表转化为正则表达式,使用括号捕获分隔符
 29     # 比如 ['system', 'assistant'] 会被转换成 (system)|(assistant)
 30     regex_pattern = '({})'.format('|'.join(map(re.escape, keywords)))
 31 
 32     # 使用 re.split,它会返回包含分隔符的列表
 33     parts = re.split(regex_pattern, s)
 34 
 35     # 初始化结果列表
 36     result = []
 37 
 38     # 存储上一个匹配到的关键词,初始时没有关键词
 39     last_keyword = None
 40 
 41     # 遍历分割后的列表
 42     for part in parts:
 43         # 如果当前部分是关键词,记录下来并继续下一轮循环
 44         if part in keywords:
 45             last_keyword = part
 46             continue
 47         # 如果当前部分不是关键词,且上一部分是关键词,则将其作为结果加入
 48         if last_keyword:
 49             result.append((last_keyword, part.strip()))
 50             last_keyword = None  # 重置关键词
 51 
 52     return result
 53 
 54 max_len = 0
 55 
 56 
 57 def count_words_and_punctuation(s):
 58     # 使用正则表达式来匹配单词和标点符号
 59     # \w+ 匹配单词字符(字母、数字、下划线)出现一次或多次组成的单词
 60     # | 表示或,用来分隔不同的匹配规则
 61     # \s 表示空白字符
 62     # [^\w\s] 匹配任意不是单词字符和不是空白字符的字符,即标点符号
 63     matches = re.findall(r'\w+|[^\w\s]', s)
 64 
 65     # 计算匹配项的数量,即单词和标点符号的总数
 66     return len(matches)
 67 
 68 def solve(input):
 69     global max_len
 70     max_len = max(max_len , count_words_and_punctuation(input))
 71     import json
 72     # 基础替换
 73     input = input.replace('<|endoftext|>', '')
 74 
 75     replace_map = {
 76         'SYSTEM:' : '\n\n### System:\n ',
 77         'ASSISTANT:': '\n\n### Response:\n ',
 78         'USER:': '\n\n### Instruction:\n ',
 79         'FUNCTION RESPONSE:': '\n\n### Function:\n '
 80     }
 81 
 82     data = split_string_with_keywords(input, list(replace_map.keys()))
 83 
 84     # 更换函数的格式
 85     if data[0][0] == 'SYSTEM:':
 86         data[0] = (data[0][0], get_function_json(data[0][1]))
 87 
 88     return_data = []
 89     train_str = ''
 90     for element in data:
 91         train_str += replace_map[element[0]]
 92         if element[0] == 'ASSISTANT:':
 93             return_data.append({
 94                 "instruction": train_str,
 95                 "input": "",
 96                 "output": element[1]
 97             })
 98         train_str += element[1]
 99 
100     return return_data
101 
102 import pandas as pd
103 
104 train_data = []
105 
106 # 读取Parquet文件
107 df = pd.read_parquet('train-00000-of-00001.parquet')
108 column_name = df.columns[0]
109 for value in df[column_name]:
110     train_data += solve(value)
111 
112 with open('train_function_call.json', 'w', encoding='utf-8') as f:
113     json.dump(train_data, f, ensure_ascii=False, indent=4)
114 print(max_len)

 改好格式的数据如下(以response来切分,response前为输入,response为需要模型生成的输出):

### System:
 You are a helpful assistant with access to the following functions. Use them if required -
[
    {
        "name": "search_books",
        "description": "Search for books based on specified criteria",
        "parameters": {
            "type": "object",
            "properties": {
                "title": {
                    "type": "string",
                    "description": "The title of the book"
                },
                "author": {
                    "type": "string",
                    "description": "The author of the book"
                },
                "genre": {
                    "type": "string",
                    "description": "The genre of the book"
                }
            },
            "required": [
                "title"
            ]
        }
    }
]

### Instruction:
 I am looking for a book but I can't remember the full title. I know it has the word "Sun" in it and it's a science fiction novel.

### Response:
 <functioncall> {"name": "search_books", "arguments": '{"title": "Sun", "genre": "science fiction"}'}

### Function:
 {"results": [{"title": "The Sun Also Rises in Space", "author": "John Doe", "genre": "science fiction"}, {"title": "Sunset on Mars", "author": "Jane Doe", "genre": "science fiction"}]}

### Response:
 I found two science fiction books with "Sun" in the title. The first one is "The Sun Also Rises in Space" by John Doe and the second one is "Sunset on Mars" by Jane Doe.

 

 

启动训练

笔者依然在恒源云上,基于tigerbot-13b-chat-v5-4k进行训练。

考虑到vllm暂时不支持PEFT格式的adapter,此次依然采用了freeze训练。

为了尽可能地训练更多的层,笔者采用了单个A100-80G的显卡,这样可以在seq_len达到3072的情况下,训练10层的tranformer参数。

注意,此次的template和以前不太一样(因为有各种的function和自己添加的system),所以添加了一个新的模板

 1 register_template(
 2     name="null",
 3     prefix=[
 4         ""
 5     ],
 6     prompt=[
 7         "{{query}}"
 8     ],
 9     system="",
10     sep=[]
11 )

训练命令如下

 1 python src/train_bash.py \
 2     --stage sft \
 3     --model_name_or_path /hy-tmp/tigerbot-13b-chat-v5-4k \
 4     --do_train True \
 5     --finetuning_type freeze \
 6     --num_layer_trainable 10 \
 7     --template null \
 8     --dataset_dir data \
 9     --dataset train_function_call \
10     --cutoff_len 3072 \
11     --learning_rate 1e-4 \
12     --num_train_epochs 1.0 \
13     --per_device_train_batch_size 4 \
14     --gradient_accumulation_steps 2 \
15     --logging_steps 1 \
16     --save_steps 10000 \
17     --output_dir /hy-tmp/tigerbot-13b-function-call \
18     --fp16 True \
19     --plot_loss True \
20     --overwrite_output_dir
posted @ 2023-12-23 16:01  AlphaInf  阅读(107)  评论(0编辑  收藏  举报