import torch
import torch.nn as nn
import lightning as L
from torchmetrics.classification import BinaryAccuracy
class AlexNet(L.LightningModule):
def __init__(self, num_classes=1):
super(AlexNet, self).__init__()
self.save_hyperparameters()
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.train_accuracy = BinaryAccuracy()
self.val_accuracy = BinaryAccuracy()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images).squeeze(1)
loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
acc = self.train_accuracy(outputs, labels)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images).squeeze(1)
loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
acc = self.val_accuracy(outputs, labels)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return optimizer
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from net import AlexNet
L.seed_everything(42)
torch.set_float32_matmul_precision('high')
data_dir = './data'
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
class DataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size, transform):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transform
def setup(self, stage=None):
self.train_dataset = datasets.ImageFolder(self.data_dir + '/train', transform=self.transform)
self.val_dataset = datasets.ImageFolder(self.data_dir + '/val', transform=self.transform)
self.test_dataset = datasets.ImageFolder(self.data_dir + '/test', transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
data_module = DataModule(data_dir=data_dir, batch_size=32, transform=transform)
model = AlexNet(num_classes=1)
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints',
filename='best-checkpoint',
save_top_k=1,
mode='max'
)
trainer = L.Trainer(
max_epochs=15,
accelerator='gpu',
devices=1,
callbacks=[checkpoint_callback],
)
trainer.fit(model, datamodule=data_module)
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from torchmetrics.classification import BinaryAccuracy
import lightning as L
from PIL import Image
import seaborn as sns
from sklearn.metrics import confusion_matrix
from net import AlexNet
L.seed_everything(42)
torch.set_float32_matmul_precision("high")
transform = transforms.Compose(
[
transforms.Resize((227, 227)),
transforms.ToTensor(),
]
)
data_dir = "./data"
test_dataset = datasets.ImageFolder(data_dir + "/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=804, shuffle=False, num_workers=4)
best_model_path = "checkpoints/best-checkpoint.ckpt"
best_model = AlexNet.load_from_checkpoint(best_model_path, num_classes=1)
best_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model.to(device)
test_accuracy = BinaryAccuracy().to(device)
test_loss_fn = nn.BCEWithLogitsLoss()
mis_cat = []
mis_dog = []
true_labels = []
predicted_labels = []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = best_model(images).squeeze(1)
test_loss = test_loss_fn(outputs, labels.float())
test_acc = test_accuracy(outputs, labels)
preds = (torch.sigmoid(outputs).cpu() > 0.5).numpy().astype(int)
incorrect_indices = preds != labels.cpu().numpy()
for idx, incorrect in enumerate(incorrect_indices):
if incorrect:
image_path = test_dataset.imgs[idx][0]
if labels.cpu().numpy()[idx] == 0:
mis_cat.append(image_path)
else:
mis_dog.append(image_path)
true_labels.extend(labels.cpu().numpy())
predicted_labels.extend(preds)
print(f"测试损失: {test_loss.item():.4f}, 测试准确率: {test_acc.item():.4f}")
cm = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(8, 6))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
cbar=False,
annot_kws={"fontsize": 15, "fontweight": "bold"},
xticklabels=["Cat", "Dog"],
yticklabels=["Cat", "Dog"],
)
plt.xlabel("Predicted Labels", fontsize=14)
plt.ylabel("True Labels", fontsize=14)
plt.title("confusion_matrix", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.savefig("results/confusion_matrix.png", bbox_inches="tight")
plt.show()
def plot_images(cat_paths, dog_paths, n=5):
plt.figure(figsize=(15, 10))
for i, img_path in enumerate(cat_paths[:n]):
img = Image.open(img_path)
plt.subplot(2, n, i + 1)
plt.imshow(img)
plt.title(f"Dog {i + 1}")
plt.axis("off")
for i, img_path in enumerate(dog_paths[:n]):
img = Image.open(img_path)
plt.subplot(2, n, i + 1 + n)
plt.imshow(img)
plt.title(f"Cat {i + 1}")
plt.axis("off")
plt.savefig("results/mistake.png")
plt.show()
print("显示分类错误的猫和狗图像:")
plot_images(mis_cat, mis_dog, n=5)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· AI 智能体引爆开源社区「GitHub 热点速览」
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具