SSD进行视频的物体检测

  1. video_test.py

    from utils.tag_video import VideoTag
    from nets.ssd_net import SSD300
    
    if __name__ == '__main__':
        input_shape = (300, 300, 3)
        # 数据集的配置
        class_names = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
                       "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train",
                       "tvmonitor"]
        model = SSD300(input_shape, num_classes=len(class_names))
        # 加载已训练好的模型
        model.load_weights("./ckpt/pre_trained/weights_SSD300.hdf5", by_name=True)
    
        vt = VideoTag(model, input_shape, len(class_names))
        vt.run("./datasets/person_video3.mp4")
    
  2. utils.tag_video.py

    """
    配置获取相关预测数据类别,网络参数
    获取摄像头视频
    获取摄像每帧数据,进行格式形状处理
    模型预测、结果NMS过滤
    画图:显示物体位置,FPS值(每秒帧数)
    """
    from tensorflow.python.keras.preprocessing.image import img_to_array
    from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
    from utils.ssd_utils import BBoxUtility
    import numpy as np
    import cv2
    
    
    class VideoTag(object):
        def __init__(self, model, input_shape, num_classes):
            self.model = model
            self.input_shape = input_shape
            self.num_classes = num_classes
            self.bbox_util = BBoxUtility(num_classes=self.num_classes)
            self.class_names = ["background", "aeroplane", "bicycle", "bird", "boat",
                                "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
                                "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train",
                                "tvmonitor"]
    
            # 创建不同类别20中显示的颜色类型
            self.class_colors = []
            for i in range(0, self.num_classes):
                hue = 255 * i / self.num_classes
                col = np.zeros((1, 1, 3)).astype("uint8")
                col[0][0][0] = hue
                col[0][0][1] = 128
                col[0][0][2] = 255
                cvcol = cv2.cvtColor(col, cv2.COLOR_HSV2BGR)
                col = (int(cvcol[0][0][0]), int(cvcol[0][0][1]), int(cvcol[0][0][2]))
                self.class_colors.append(col)
    
        def run(self, file_path, conf_thresh=0.6):
            """
            运行捕捉摄像头,每一帧图片数据,进行预测,标记显示
            :return:
            """
            # 获取摄像头视频
            cap = cv2.VideoCapture(file_path)
            if not cap.isOpened():
                raise IOError(("打开本视频或者摄像头失败!"))
    
            # 获取摄像每帧数据,进行格式形状处理
            while True:
                ret, orig_image = cap.read()
                if not ret:
                    print("视频检测结束")
                    return
    
                # 对每一帧视频中的图片或者摄像头捕捉的图片,进行大小(300, 300)
                # BGR--> RGB
                source_image = np.copy(orig_image)
                resized = cv2.resize(orig_image, (self.input_shape[0], self.input_shape[1]))
                rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
    
                # 保留原始图片
                # print(source_image.shape)
                to_draw = cv2.resize(resized, (int(source_image.shape[1]), int(source_image.shape[0])))
    
                # 模型预测、结果NMS过滤
                # 3维--->4维,preprocess_input
                inputs = [img_to_array(rgb)]
                x = preprocess_input(np.array(inputs))
                y = self.model.predict(x)
    
                results = self.bbox_util.detection_out(y)
                print(results[0].shape)
    
                # 画图:显示物体位置,FPS值(每秒帧数)
                # 画图显示
                if len(results) > 0 and len(results[0]) > 0:
                    # 获取每个框的位置以及类别概率:[[标签,概率,4个坐标], ..., []]
                    det_label = results[0][:, 0]
                    det_conf = results[0][:, 1]
                    det_xmin = results[0][:, 2]
                    det_ymin = results[0][:, 3]
                    det_xmax = results[0][:, 4]
                    det_ymax = results[0][:, 5]
    
                    # 过滤概率小的
                    top_indices = [i for i, conf in enumerate(det_conf) if conf >= conf_thresh]
    
                    top_label = det_label[top_indices]
                    top_conf = det_conf[top_indices]
                    top_xmin = det_xmin[top_indices]
                    top_ymin = det_ymin[top_indices]
                    top_xmax = det_xmax[top_indices]
                    top_ymax = det_ymax[top_indices]
    
                    for i in range(top_conf.shape[0]):
                        xmin = int(round(top_xmin[i] * to_draw.shape[1]))
                        ymin = int(round(top_ymin[i] * to_draw.shape[0]))
                        xmax = int(round(top_xmax[i] * to_draw.shape[1]))
                        ymax = int(round(top_ymax[i] * to_draw.shape[0]))
    
                        class_num = int(top_label[i])
                        print("该帧图片检测到第{}物体,索引为为{}".format(i, class_num))
                        # 画出这一帧中所有物体框的位置
                        cv2.rectangle(to_draw, (xmin, ymin), (xmax, ymax), self.class_colors[class_num], 2)
    
                        # 画出文本框
                        text = self.class_names[class_num] + ' ' + ("%.2f" % (top_conf[i]))
                        # 矩形框
                        text_top = (xmin, ymin - 10)
                        text_bot = (xmin + 80, ymin + 5)
                        text_pos = (xmin + 5, ymin)
                        cv2.rectangle(to_draw, text_top, text_bot, self.class_colors[class_num], -1)
                        cv2.putText(to_draw, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1,
                                    cv2.LINE_AA)
    
                # 计算 FPS显示
                fps = "FPS: " + str(cap.get(cv2.CAP_PROP_FPS))
    
                # 画出FPS
                cv2.rectangle(to_draw, (0, 0), (50, 17), (255, 255, 255), -1)
                cv2.putText(to_draw, fps, (3, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 1)
    
                # 显示当前图片
                cv2.imshow("SSD detector result", to_draw)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
    
            # 释放资源
            cap.release()
            cv2.destroyAllWindows()
            return None
    
posted @ 2022-08-01 10:27  BNTU  阅读(103)  评论(0编辑  收藏  举报