使用torchvision中的pre-trained 模型
用torchvision中预训练好的模型fine-tune
Fine-tuning a pre-trained model from others usually involves the following steps:
-
Load the pre-trained model: Load the pre-trained model's weights and architecture into memory. The weights are usually stored in a file format like TensorFlow's .ckpt or PyTorch's .pth.
-
Replace the last layer: Replace the last layer of the pre-trained model, typically the classification layer, with a new one that matches the number of classes in your task. This new layer will be randomly initialized.
-
Freeze some layers: Freeze some layers of the pre-trained model to prevent them from being updated during training. The intuition is that these layers have already learned useful features for general-purpose tasks and should not be disturbed, while the newly added layer will be fine-tuned to adapt to the new task.
-
Train the model: Train the entire model, or just the last layer, on your own dataset using standard supervised learning techniques.
-
Evaluate the model: Evaluate the fine-tuned model on a validation set to assess its performance.
-
Adjust the model: If the performance is not satisfactory, adjust the hyperparameters, or try unfreezing more layers and continuing the training process.
Here's an example in PyTorch that demonstrates fine-tuning a pre-trained ResNet-50 model on the CIFAR-10 dataset:
import torch
import torch.nn as nn
import torchvision.models as models
# Load pre-trained model
pretrained_model = models.resnet50(pretrained=True)
num_features = pretrained_model.fc.in_features
num_classes = 10
# Replace the last layer
pretrained_model.fc = nn.Linear(num_features, num_classes)
# Freeze some layers
for param in pretrained_model.parameters():
param.requires_grad = False
for param in pretrained_model.fc.parameters():
param.requires_grad = True
# Train the model
train_loader = ...
val_loader = ...
optimizer = torch.optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for x, y_true in train_loader:
optimizer.zero_grad()
y_pred = pretrained_model(x)
loss = loss_fn(y_pred, y_true)
loss.backward()
optimizer.step()
# Evaluate the model
with torch.no_grad():
total_correct = 0
total_examples = 0
for x, y_true in val_loader:
y_pred = pretrained_model(x)
y_pred_class = torch.argmax(y_pred, dim=1)
total_correct += torch.sum(y_pred_class == y_true).item()
total_examples += len(y_true)
accuracy = total_correct / total_examples
print(f"Epoch {epoch}, Validation Accuracy: {accuracy:.3f}")
不在torchvision中的模型,需要有该模型的源代码,才能知道其架构,还需要有保存的权重文件
Assuming that the architecture of the ChenNet model is defined in a Python script named chen_net.py and the weights are saved in a file named chen_net_weights.pth, here's an example of how you can fine-tune it on a new dataset:
import torch
import torch.nn as nn
from chen_net import ChenNet # import the ChenNet model
# Load pre-trained model
pretrained_model = ChenNet()
pretrained_model.load_state_dict(torch.load("chen_net_weights.pth"))
# Replace the last layer
num_features = pretrained_model.fc.in_features
num_classes = 10 # replace with your own number of classes
pretrained_model.fc = nn.Linear(num_features, num_classes)
# Freeze some layers
for param in pretrained_model.parameters():
param.requires_grad = False
for param in pretrained_model.fc.parameters():
param.requires_grad = True
# Train the model
train_loader = ...
val_loader = ...
optimizer = torch.optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for x, y_true in train_loader:
optimizer.zero_grad()
y_pred = pretrained_model(x)
loss = loss_fn(y_pred, y_true)
loss.backward()
optimizer.step()
# Evaluate the model
with torch.no_grad():
total_correct = 0
total_examples = 0
for x, y_true in val_loader:
y_pred = pretrained_model(x)
y_pred_class = torch.argmax(y_pred, dim=1)
total_correct += torch.sum(y_pred_class == y_true).item()
total_examples += len(y_true)
accuracy = total_correct / total_examples
print(f"Epoch {epoch}, Validation Accuracy: {accuracy:.3f}")