使用torchvision中的pre-trained 模型

用torchvision中预训练好的模型fine-tune

Fine-tuning a pre-trained model from others usually involves the following steps:

  1. Load the pre-trained model: Load the pre-trained model's weights and architecture into memory. The weights are usually stored in a file format like TensorFlow's .ckpt or PyTorch's .pth.

  2. Replace the last layer: Replace the last layer of the pre-trained model, typically the classification layer, with a new one that matches the number of classes in your task. This new layer will be randomly initialized.

  3. Freeze some layers: Freeze some layers of the pre-trained model to prevent them from being updated during training. The intuition is that these layers have already learned useful features for general-purpose tasks and should not be disturbed, while the newly added layer will be fine-tuned to adapt to the new task.

  4. Train the model: Train the entire model, or just the last layer, on your own dataset using standard supervised learning techniques.

  5. Evaluate the model: Evaluate the fine-tuned model on a validation set to assess its performance.

  6. Adjust the model: If the performance is not satisfactory, adjust the hyperparameters, or try unfreezing more layers and continuing the training process.

Here's an example in PyTorch that demonstrates fine-tuning a pre-trained ResNet-50 model on the CIFAR-10 dataset:

import torch
import torch.nn as nn
import torchvision.models as models

# Load pre-trained model
pretrained_model = models.resnet50(pretrained=True)
num_features = pretrained_model.fc.in_features
num_classes = 10

# Replace the last layer
pretrained_model.fc = nn.Linear(num_features, num_classes)

# Freeze some layers
for param in pretrained_model.parameters():
    param.requires_grad = False
for param in pretrained_model.fc.parameters():
    param.requires_grad = True

# Train the model
train_loader = ...
val_loader = ...
optimizer = torch.optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for x, y_true in train_loader:
        optimizer.zero_grad()
        y_pred = pretrained_model(x)
        loss = loss_fn(y_pred, y_true)
        loss.backward()
        optimizer.step()

    # Evaluate the model
    with torch.no_grad():
        total_correct = 0
        total_examples = 0
        for x, y_true in val_loader:
            y_pred = pretrained_model(x)
            y_pred_class = torch.argmax(y_pred, dim=1)
            total_correct += torch.sum(y_pred_class == y_true).item()
            total_examples += len(y_true)

        accuracy = total_correct / total_examples
        print(f"Epoch {epoch}, Validation Accuracy: {accuracy:.3f}")

不在torchvision中的模型,需要有该模型的源代码,才能知道其架构,还需要有保存的权重文件

Assuming that the architecture of the ChenNet model is defined in a Python script named chen_net.py and the weights are saved in a file named chen_net_weights.pth, here's an example of how you can fine-tune it on a new dataset:

import torch
import torch.nn as nn
from chen_net import ChenNet  # import the ChenNet model

# Load pre-trained model
pretrained_model = ChenNet()
pretrained_model.load_state_dict(torch.load("chen_net_weights.pth"))

# Replace the last layer
num_features = pretrained_model.fc.in_features
num_classes = 10  # replace with your own number of classes
pretrained_model.fc = nn.Linear(num_features, num_classes)

# Freeze some layers
for param in pretrained_model.parameters():
    param.requires_grad = False
for param in pretrained_model.fc.parameters():
    param.requires_grad = True

# Train the model
train_loader = ...
val_loader = ...
optimizer = torch.optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for x, y_true in train_loader:
        optimizer.zero_grad()
        y_pred = pretrained_model(x)
        loss = loss_fn(y_pred, y_true)
        loss.backward()
        optimizer.step()

    # Evaluate the model
    with torch.no_grad():
        total_correct = 0
        total_examples = 0
        for x, y_true in val_loader:
            y_pred = pretrained_model(x)
            y_pred_class = torch.argmax(y_pred, dim=1)
            total_correct += torch.sum(y_pred_class == y_true).item()
            total_examples += len(y_true)

        accuracy = total_correct / total_examples
        print(f"Epoch {epoch}, Validation Accuracy: {accuracy:.3f}")
posted @ 2023-03-30 10:34  裹紧我的小棉袄  阅读(102)  评论(0编辑  收藏  举报