5.AlexNet猫狗分类(Lightning框架)
Published on 2024-06-27 20:34 in 分类: 算法 with 真真夜夜
分类: 算法

5.AlexNet猫狗分类(Lightning框架)

    # net.py
    import torch
    import torch.nn as nn
    import lightning as L
    from torchmetrics.classification import BinaryAccuracy
    
    class AlexNet(L.LightningModule):
        def __init__(self, num_classes=1):  
            super(AlexNet, self).__init__()
            self.save_hyperparameters()
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.Conv2d(96, 256, kernel_size=5, padding=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.Conv2d(256, 384, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(384, 384, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(384, 256, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2),
            )
            self.classifier = nn.Sequential(
                nn.Dropout(),
                nn.Linear(256 * 6 * 6, 4096),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, num_classes),
            )
            self.train_accuracy = BinaryAccuracy()
            self.val_accuracy = BinaryAccuracy()
    
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), 256 * 6 * 6)
            x = self.classifier(x)
            return x
    
        def training_step(self, batch, batch_idx):
            images, labels = batch
            outputs = self(images).squeeze(1)
            loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
            acc = self.train_accuracy(outputs, labels)
            self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return loss
    
        def validation_step(self, batch, batch_idx):
            images, labels = batch
            outputs = self(images).squeeze(1)
            loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
            acc = self.val_accuracy(outputs, labels)
            self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return loss
    
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
            return optimizer
    
    
    # main.py
    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import lightning as L
    from lightning.pytorch.callbacks import ModelCheckpoint
    from net import AlexNet  # 从net.py导入AlexNet
    
    L.seed_everything(42)
    torch.set_float32_matmul_precision('high')
    
    # 加载数据集
    data_dir = './data'
    
    # 定义数据集的转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # 定义LightningDataModule
    class DataModule(L.LightningDataModule):
        def __init__(self, data_dir, batch_size, transform):
            super().__init__()
            self.data_dir = data_dir
            self.batch_size = batch_size
            self.transform = transform
    
        def setup(self, stage=None):
            # 使用ImageFolder加载数据集
            self.train_dataset = datasets.ImageFolder(self.data_dir + '/train', transform=self.transform)
            self.val_dataset = datasets.ImageFolder(self.data_dir + '/val', transform=self.transform)
            self.test_dataset = datasets.ImageFolder(self.data_dir + '/test', transform=self.transform)
    
        def train_dataloader(self):
            return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
    
        def val_dataloader(self):
            return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
    
        def test_dataloader(self):
            return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
    
    # 实例化数据模块和模型
    data_module = DataModule(data_dir=data_dir, batch_size=32, transform=transform)
    model = AlexNet(num_classes=1)
    
    # 定义 ModelCheckpoint 回调
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',  # 监控的指标
        dirpath='checkpoints',  # 保存的路径
        filename='best-checkpoint',  # 保存的文件名
        save_top_k=1,  # 仅保存最好的模型
        mode='max'  # 指标越大越好
    )
    
    # 使用Trainer进行训练
    trainer = L.Trainer(
        max_epochs=15,
        accelerator='gpu',
        devices=1,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(model, datamodule=data_module)
    
    # 训练完成后,可以加载最佳模型
    best_model_path = checkpoint_callback.best_model_path
    print(f"Best model saved at: {best_model_path}")
    
    import os
    import torch
    import torch.nn as nn
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import pandas as pd
    import matplotlib.pyplot as plt
    from torchmetrics.classification import BinaryAccuracy
    import lightning as L
    from PIL import Image
    import seaborn as sns
    from sklearn.metrics import confusion_matrix
    from net import AlexNet  # 从 net.py 导入 AlexNet
    
    # 设置随机种子和精度
    L.seed_everything(42)
    torch.set_float32_matmul_precision("high")
    
    # 数据转换
    transform = transforms.Compose(
        [
            transforms.Resize((227, 227)),
            transforms.ToTensor(),
        ]
    )
    
    # 数据目录
    data_dir = "./data"
    
    # 加载测试数据集
    test_dataset = datasets.ImageFolder(data_dir + "/test", transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=804, shuffle=False, num_workers=4)
    
    # 加载最佳模型检查点
    best_model_path = "checkpoints/best-checkpoint.ckpt"
    best_model = AlexNet.load_from_checkpoint(best_model_path, num_classes=1)
    
    # 设置模型为评估模式并将其移至 GPU
    best_model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    best_model.to(device)
    
    # 定义评估指标和损失函数
    test_accuracy = BinaryAccuracy().to(device)
    test_loss_fn = nn.BCEWithLogitsLoss()
    
    # 用于保存分类错误的图像路径的列表
    mis_cat = []
    mis_dog = []
    
    # 用于收集真实标签和预测标签的列表
    true_labels = []
    predicted_labels = []
    
    # 禁用梯度计算以提高推理速度
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
    
            # 进行预测
            outputs = best_model(images).squeeze(1)
    
            # 计算损失和准确率
            test_loss = test_loss_fn(outputs, labels.float())
            test_acc = test_accuracy(outputs, labels)
    
            # 收集分类错误的图像路径
            preds = (torch.sigmoid(outputs).cpu() > 0.5).numpy().astype(int)
            incorrect_indices = preds != labels.cpu().numpy()
    
            for idx, incorrect in enumerate(incorrect_indices):
                if incorrect:
                    image_path = test_dataset.imgs[idx][0]
                    if labels.cpu().numpy()[idx] == 0:  # 真实标签为猫
                        mis_cat.append(image_path)
                    else:  # 真实标签为狗
                        mis_dog.append(image_path)
    
            # 收集真实标签和预测标签
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(preds)
    
            print(f"测试损失: {test_loss.item():.4f}, 测试准确率: {test_acc.item():.4f}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(true_labels, predicted_labels)
    
    # 绘制混淆矩阵
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        cbar=False,
        annot_kws={"fontsize": 15, "fontweight": "bold"},
        xticklabels=["Cat", "Dog"],
        yticklabels=["Cat", "Dog"],
    )
    plt.xlabel("Predicted Labels", fontsize=14)
    plt.ylabel("True Labels", fontsize=14)
    plt.title("confusion_matrix", fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    # 保存图像文件
    plt.savefig("results/confusion_matrix.png", bbox_inches="tight")
    plt.show()
    
    
    # 绘制分类错误的猫和狗图像
    def plot_images(cat_paths, dog_paths, n=5):
        plt.figure(figsize=(15, 10))
        for i, img_path in enumerate(cat_paths[:n]):
            img = Image.open(img_path)
            plt.subplot(2, n, i + 1)
            plt.imshow(img)
            plt.title(f"Dog {i + 1}")
            plt.axis("off")
        for i, img_path in enumerate(dog_paths[:n]):
            img = Image.open(img_path)
            plt.subplot(2, n, i + 1 + n)
            plt.imshow(img)
            plt.title(f"Cat {i + 1}")
            plt.axis("off")
        plt.savefig("results/mistake.png")
        plt.show()
    
    
    print("显示分类错误的猫和狗图像:")
    plot_images(mis_cat, mis_dog, n=5)
    
    posted @   真真夜夜  阅读(41)  评论(0编辑  收藏  举报
    相关博文:
    阅读排行:
    · winform 绘制太阳,地球,月球 运作规律
    · TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
    · AI 智能体引爆开源社区「GitHub 热点速览」
    · Manus的开源复刻OpenManus初探
    · 写一个简单的SQL生成工具
    点击右上角即可分享
    微信分享提示