卷积神经网络CNN实战:MINST手写数字识别——调用模型/模型预测

import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

from net import CNN 

# 初始化模型
model = CNN()

# 加载模型文件
model_path = 'C:/Users/25372/Desktop/newbie/output/model.pth'  # 替换为你的模型文件路径
model.load_state_dict(torch.load(model_path))
model.eval()  # 切换到评估模式

# 读取和处理 PNG 图像
image_path = 'img1.png'  # 替换为你的图像文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像

# 定义图像预处理转换
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整图像大小
    transforms.ToTensor(),        # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])

# 预处理图像
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)  # 添加批次维度

# 进行预测
with torch.no_grad():
    output = model(image_tensor)
    _, predicted = torch.max(output, 1)
    predicted_class = predicted.item()

# 转换图像为适合 OpenCV 的格式
image = (np.array(image)).astype(np.uint8)  # 确保图像数据在 0-255 范围内

# 将图像转换为 BGR 格式以便可用 OpenCV 显示
image_bgr = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)  # 将灰度图转换为 BGR 图

# 显示图像
cv2.imshow(f'Predicted class: {predicted_class}', image)
cv2.waitKey(0)  # 等待按键
cv2.destroyAllWindows()  # 关闭所有窗口
posted @ 2024-07-22 15:34  SXWisON  阅读(4)  评论(0编辑  收藏  举报