PyTorch内置模型detection的resnet50使用,使用本地的权重文件

 

 1         ##完全使用本地权重,识别时根据识别准确率来确定是否绘制
 2         import matplotlib.pyplot as plt
 3         import torch
 4         import torchvision.transforms as T
 5         import torchvision
 6         import cv2
 7         from torchvision.io.image import read_image
 8         from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
 9 
10         import warnings
11         warnings.filterwarnings("ignore",category=ResourceWarning)
12         warnings.filterwarnings("ignore",category=DeprecationWarning)
13 
116         img_path = "./jupyterlab/doc/ccc.jpg"        ##骑着自行车的美女,任选
17         img = read_image(img_path)##用pytorch提供的io函数
18 
19         weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
20         ##读本地权重文件,权重文件到pytorch网站下载
21         model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None)
22         myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')
23         model.load_state_dict(myweights)
24         model.eval()##识别工作模式
25         
26         preprocess = weights_info.transforms()
27         batch = [preprocess(img)]
28         prediction = model(batch)[0]
29         labels = [weights_info.meta["categories"][i] for i in prediction["labels"]]
30         boxes = [i for i in prediction["boxes"]]
31         scores = [i for i in prediction["scores"]]
32 
33         myimg = cv2.imread(img_path)
35         myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB)
36         for i,score in enumerate(scores):
37             if score.item() < 0.9 : continue##舍弃准确率90%以下的
38             myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1)
39             ##注意:cv2这里只接受整型坐标值
40             start_point = (int(boxes[i][0]), int(boxes[i][1]))
41             end_point = (int(boxes[i][2]), int(boxes[i][3]))
42             cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3)
43             cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3)
44         plt.figure(figsize=(7, 5))
45         plt.imshow(myimg)
46         plt.xticks([])
47         plt.yticks([])
48         plt.show()

 

posted @ 2022-10-26 16:24  凤凰城堡  阅读(478)  评论(0编辑  收藏  举报