代码
1. bert 二分类
import tensorflow as tf from sklearn.model_selection import train_test_split from transformers import BertTokenizer, TFBertModel import pandas as pd # 加载预训练的BERT模型和tokenizer bert_model_name = './bert' tokenizer = BertTokenizer.from_pretrained(bert_model_name) bert_model = TFBertModel.from_pretrained(bert_model_name) # 定义输入处理函数 def encode_texts(query, title, tokenizer, max_length=128): encoded_dict = tokenizer.encode_plus( query, title, add_special_tokens=True, # 添加 [CLS], [SEP] 等标记 max_length=max_length, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='tf' # 返回 TensorFlow 张量 ) return encoded_dict['input_ids'], encoded_dict['attention_mask'] # 构建模型 def build_model(bert_model): input_ids = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='input_ids') attention_mask = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='attention_mask') bert_output = bert_model(input_ids, attention_mask=attention_mask) cls_output = bert_output.last_hidden_state[:, 0, :] # 取出 [CLS] 向量 dense = tf.keras.layers.Dense(256, activation='relu')(cls_output) dropout = tf.keras.layers.Dropout(0.3)(dense) dense2 = tf.keras.layers.Dense(128, activation='relu')(dropout) output = tf.keras.layers.Dense(1, activation='sigmoid')(dense2) # 二分类问题用 sigmoid 激活 optimizer = tf.keras.optimizers.Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-07) model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output) model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]) return model # 读取数据集 def load_dataset(file_path, tokenizer, max_length=128): queries = [] titles = [] labels = [] data = pd.read_csv(file_path, sep="\t") for query, title, label in zip(data['query'].tolist(), data['title'].tolist(), data["label"].tolist()): queries.append(query) titles.append(title) labels.append(int(label)) input_ids_list = [] attention_mask_list = [] for query, title in zip(queries, titles): input_ids, attention_mask = encode_texts(query, title, tokenizer, max_length) input_ids_list.append(input_ids) attention_mask_list.append(attention_mask) input_ids = tf.concat(input_ids_list, axis=0) attention_masks = tf.concat(attention_mask_list, axis=0) labels = tf.convert_to_tensor(labels) return {'input_ids': input_ids, 'attention_mask': attention_masks}, labels # 加载训练和测试数据 train_data, train_labels = load_dataset('train.csv', tokenizer) test_data, test_labels = load_dataset('test.csv', tokenizer) # 将TensorFlow张量转换为numpy数组 train_input_ids_np = train_data['input_ids'].numpy() train_attention_masks_np = train_data['attention_mask'].numpy() train_labels_np = train_labels.numpy() # 将训练数据进一步划分为训练集和验证集 train_input_ids, val_input_ids, train_attention_masks, val_attention_masks, train_labels, val_labels = train_test_split( train_input_ids_np, train_attention_masks_np, train_labels_np, test_size=0.1, random_state=42) # 将numpy数组转换回TensorFlow张量 train_inputs = {'input_ids': tf.convert_to_tensor(train_input_ids), 'attention_mask': tf.convert_to_tensor(train_attention_masks)} val_inputs = {'input_ids': tf.convert_to_tensor(val_input_ids), 'attention_mask': tf.convert_to_tensor(val_attention_masks)} train_labels = tf.convert_to_tensor(train_labels) val_labels = tf.convert_to_tensor(val_labels) # 模型实例化 model = build_model(bert_model) model.summary() # 训练模型 epochs = 3 batch_size = 8 for epoch in range(epochs): print(f"Epoch {epoch + 1}/{epochs}") history = model.fit( x={'input_ids': train_inputs['input_ids'], 'attention_mask': train_inputs['attention_mask']}, y=train_labels, validation_data=( {'input_ids': val_inputs['input_ids'], 'attention_mask': val_inputs['attention_mask']}, val_labels ), epochs=1, # 每次只训练一个 epoch batch_size=batch_size, shuffle=True ) # 基于测试数据集进行评估 loss, accuracy, auc = model.evaluate(test_data, test_labels) print(f"Test loss: {loss}, Test accuracy: {accuracy}, Test AUC: {auc}") if epoch == 1: model.save(f"./bert_relevance_model", save_format='tf')
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程
2023-05-29 模型训练-tips
2023-05-29 开发 Java笔记-----注解