PyTorch内置模型detection的resnet50使用,使用下载到本地的权重文件

 

 

        ##完全使用本地权重,识别时根据识别准确率来确定绘制
        import matplotlib.pyplot as plt
        import torch
        import torchvision.transforms as T
        import torchvision
        import cv2
        from torchvision.io.image import read_image
        from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights

        ##屏蔽一些恼人的提示
        import warnings
        warnings.filterwarnings("ignore",category=ResourceWarning)
        warnings.filterwarnings("ignore",category=DeprecationWarning)

        img_path = "./jupyterlab/doc/ccc.jpg"        ##骑着自行车的美女,图片随意到哪里下载都行
        img = read_image(img_path)##用pytorch提供的io函数

        weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        ##读本地权重文件,权重文件到pytorch下载
        model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None)
        myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')
        model.load_state_dict(myweights)
        model.eval()
        
        preprocess = weights_info.transforms()
        batch = [preprocess(img)]
        prediction = model(batch)[0]
        labels = [weights_info.meta["categories"][i] for i in prediction["labels"]]
        boxes = [i for i in prediction["boxes"]]
        scores = [i for i in prediction["scores"]]

        myimg = cv2.imread(img_path)
         myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB)
        for i,score in enumerate(scores):
            if score.item() < 0.9 : continue##舍弃准确率90%以下的
            myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1)
            ##注意,cv2这里只接受整型坐标值,要将boxes的坐标转成int
            start_point = (int(boxes[i][0]), int(boxes[i][1]))
            end_point = (int(boxes[i][2]), int(boxes[i][3]))
            cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3)
            cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3)
        plt.figure(figsize=(7, 5))
        plt.imshow(myimg)
        plt.xticks([])
        plt.yticks([])
        plt.show()        

 

posted @ 2022-10-26 16:16  凤凰城堡  阅读(723)  评论(0编辑  收藏  举报