【837】Hugging Face - Text classification
参考:Hugging Face - Text classification
主要步骤:
1. Load IMDb dataset
Start by loading the IMDb dataset from the 🤗 Datasets library:
from datasets import load_dataset imdb = load_dataset("imdb")
There are two fields in this dataset:
text
: the movie review text.label
: a value that is either0
for a negative review or1
for a positive review.
2. Preprocess
The next step is to load a DistilBERT tokenizer to preprocess the text
field:
from transformers import AutoTokenizer from transformers import DataCollatorWithPadding tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") def preprocess_function(examples): return tokenizer(examples["text"], truncation=True) tokenized_imdb = imdb.map(preprocess_function, batched=True) data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
3. Evaluate
Including a metric during training is often helpful for evaluating your model’s performance. You can quickly load a evaluation method with the 🤗 Evaluate library. For this task, load the accuracy metric (see the 🤗 Evaluate quick tour to learn more about how to load and compute a metric):
import evaluate import numpy as np accuracy = evaluate.load("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return accuracy.compute(predictions=predictions, references=labels)
4. Train
Before you start training your model, create a map of the expected ids to their labels with id2label
and label2id
:
id2label = {0: "NEGATIVE", 1: "POSITIVE"} label2id = {"NEGATIVE": 0, "POSITIVE": 1}
To finetune a model in TensorFlow, start by setting up an optimizer function, learning rate schedule, and some training hyperparameters:
from transformers import create_optimizer import tensorflow as tf batch_size = 16 num_epochs = 5 batches_per_epoch = len(tokenized_imdb["train"]) // batch_size total_train_steps = int(batches_per_epoch * num_epochs) optimizer, schedule = create_optimizer(init_lr=2e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
Then you can load DistilBERT with TFAutoModelForSequenceClassification along with the number of expected labels, and the label mappings:
from transformers import TFAutoModelForSequenceClassification model = TFAutoModelForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id )
Convert your datasets to the tf.data.Dataset
format with prepare_tf_dataset():
tf_train_set = model.prepare_tf_dataset( tokenized_imdb["train"], shuffle=True, batch_size=16, collate_fn=data_collator, ) tf_validation_set = model.prepare_tf_dataset( tokenized_imdb["test"], shuffle=False, batch_size=16, collate_fn=data_collator, ) import tensorflow as tf model.compile(optimizer=optimizer) from transformers.keras_callbacks import KerasMetricCallback metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set) from transformers.keras_callbacks import PushToHubCallback push_to_hub_callback = PushToHubCallback( output_dir="my_awesome_model", tokenizer=tokenizer, ) callbacks = [metric_callback, push_to_hub_callback] model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3, callbacks=callbacks)
5. Inference
Great, now that you’ve finetuned a model, you can use it for inference!
text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three." from transformers import pipeline classifier = pipeline("sentiment-analysis", model="stevhliu/my_awesome_model") classifier(text) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_model") inputs = tokenizer(text, return_tensors="tf") from transformers import TFAutoModelForSequenceClassification model = TFAutoModelForSequenceClassification.from_pretrained("stevhliu/my_awesome_model") logits = model(**inputs).logits predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) model.config.id2label[predicted_class_id]
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2021-05-25 【564】用 R 制图
2021-05-25 【563】DBSCAN通过地理距离实现