PyTorch内置模型detection的resnet50使用,使用本地的权重文件
1 ##完全使用本地权重,识别时根据识别准确率来确定是否绘制 2 import matplotlib.pyplot as plt 3 import torch 4 import torchvision.transforms as T 5 import torchvision 6 import cv2 7 from torchvision.io.image import read_image 8 from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights 9 10 import warnings 11 warnings.filterwarnings("ignore",category=ResourceWarning) 12 warnings.filterwarnings("ignore",category=DeprecationWarning) 13 116 img_path = "./jupyterlab/doc/ccc.jpg" ##骑着自行车的美女,任选 17 img = read_image(img_path)##用pytorch提供的io函数 18 19 weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT 20 ##读本地权重文件,权重文件到pytorch网站下载 21 model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None) 22 myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth') 23 model.load_state_dict(myweights) 24 model.eval()##识别工作模式 25 26 preprocess = weights_info.transforms() 27 batch = [preprocess(img)] 28 prediction = model(batch)[0] 29 labels = [weights_info.meta["categories"][i] for i in prediction["labels"]] 30 boxes = [i for i in prediction["boxes"]] 31 scores = [i for i in prediction["scores"]] 32 33 myimg = cv2.imread(img_path) 35 myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB) 36 for i,score in enumerate(scores): 37 if score.item() < 0.9 : continue##舍弃准确率90%以下的 38 myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1) 39 ##注意:cv2这里只接受整型坐标值 40 start_point = (int(boxes[i][0]), int(boxes[i][1])) 41 end_point = (int(boxes[i][2]), int(boxes[i][3])) 42 cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3) 43 cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3) 44 plt.figure(figsize=(7, 5)) 45 plt.imshow(myimg) 46 plt.xticks([]) 47 plt.yticks([]) 48 plt.show()