maskrcnn_video

# -*- coding: utf-8 -*-
# ----------------------------
# !  Copyright(C) 2022
#   All right reserved.
#   文件名称:xxx.py
#   摘   要:xxx
#   当前版本:1.0
#   作   者:刘恩甫
#   完成日期:2022-x-x
# -----------------------------

import logging
import numpy as np
import os
import tensorflow as tf
import cv2
import shutil
import collections
from PIL import Image
import PIL.ImageColor as ImageColor
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from time import time
import math
import random
from datetime import datetime

image_shape=(600,600,3)
# 载入对应关系
category_index = {1: {'id': 1, 'name': 'truck'},
                  2: {'id': 2, 'name': 'crane'},
                  3: {'id': 3, 'name': 'claw'},}

def load_pb_and_get_input_output_node(pb_path):
    '''
    获取模型,获取输入输出节点
    :param pb_path:
    :return:
    '''
    #获取模型
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(pb_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

        # 获取输入图像节点
        image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
        #获取输出节点
        ops = tf.get_default_graph().get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        tensor_dict = {}
        for key in ['num_detections', 'detection_boxes', 'detection_scores',
            'detection_classes', 'detection_masks_reframed']:
            tensor_name = key + ':0'
            if tensor_name in all_tensor_names:
                tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)
        return image_tensor, tensor_dict,detection_graph

def resize_img(img, dst_img_size):
    height_scale = dst_img_size[0] / img.shape[0]
    width_scale = dst_img_size[1] / img.shape[1]
    scale = min(height_scale, width_scale)
    resize_height = int(round(scale * img.shape[0]))
    resize_width = int(round(scale * img.shape[1]))
    resized_img = cv2.resize(img, (resize_width, resize_height))
    before_y = int((dst_img_size[0] - resize_height) / 2)
    after_y = dst_img_size[0] - resize_height - before_y
    before_x = int((dst_img_size[1] - resize_width) / 2)
    after_x = dst_img_size[1] - resize_width - before_x
    pad_width = ((before_y, after_y), (before_x, after_x), (0, 0))
    return np.pad(resized_img, pad_width, 'constant', constant_values=0),\
           [before_y,after_y,before_x,after_x,scale]

def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)

STANDARD_COLORS = [
    'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
    'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
    'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
    'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
    'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
    'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
    'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
    'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
    'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
    'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
    'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
    'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
    'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
    'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
    'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
    'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
    'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
    'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
    'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
    'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
    'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
    'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
    'WhiteSmoke', 'Yellow', 'YellowGreen'
]
def visualize_boxes_and_labels_on_image_array(
        image,boxes,classes,scores,category_index,instance_masks=None,
        groundtruth_box_visualization_color='black',
        use_normalized_coordinates=False,max_boxes_to_draw=20,min_score_thresh=.7,
        agnostic_mode=False,line_thickness=4,skip_scores=False,skip_labels=False):
  '''
  可视化部分
  :param image:
  :param boxes:
  :param classes:
  :param scores:
  :param category_index:
  :param instance_masks:
  :param groundtruth_box_visualization_color:
  :param use_normalized_coordinates:
  :param max_boxes_to_draw:
  :param min_score_thresh: 分类阈值
  :param agnostic_mode:
  :param line_thickness:
  :param skip_scores:
  :param skip_labels:
  :return:
  '''
  box_to_display_str_map = collections.defaultdict(list)#保存boxes的display_str
  box_to_color_map = collections.defaultdict(str)#保存boxes的color
  box_to_instance_masks_map = {}#保存每个box对应的mask

  #最大画max_boxes_to_draw个
  for i in range(min(max_boxes_to_draw, boxes.shape[0])):
    if scores is None or scores[i] > min_score_thresh:
      box = tuple(boxes[i].tolist())
      if instance_masks is not None:
        box_to_instance_masks_map[box] = instance_masks[i]
      if scores is None:
        box_to_color_map[box] = groundtruth_box_visualization_color
      else:
        display_str = ''
        if not skip_labels:
          if not agnostic_mode:
            if classes[i] in category_index.keys():
              class_name = category_index[classes[i]]['name']
            else:
              class_name = 'N/A'
            display_str = str(class_name)
        if not skip_scores:
          if not display_str:
            display_str = '{}%'.format(int(100*scores[i]))
          else:
            display_str = '{}: {}%'.format(display_str, int(100*scores[i]))
        box_to_display_str_map[box].append(display_str)
        if agnostic_mode:
          box_to_color_map[box] = 'DarkOrange'
        else:
          box_to_color_map[box] = STANDARD_COLORS[classes[i] % len(STANDARD_COLORS)]

  #画框和mask
  res_list = []
  for box, color in box_to_color_map.items():
    ymin, xmin, ymax, xmax = box
    int_box=int(xmin), int(ymin), int(xmax), int(ymax)

    #画mask
    # draw_mask_on_image_array(image,box_to_instance_masks_map[box],color=color)
    #
    # #画bounding_box
    # draw_bounding_box_on_image_array(image,ymin,xmin,ymax,xmax,color=color,
    #     thickness=line_thickness,display_str_list=box_to_display_str_map[box],
    #     use_normalized_coordinates=use_normalized_coordinates)
    # return image

    cls=box_to_display_str_map[box][0].split(':')[0]
    mask=box_to_instance_masks_map[box]*255
    res_list.append([cls,int_box,mask])
  return res_list

def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
  if image.dtype != np.uint8:
    raise ValueError('`image` not of type np.uint8')
  if mask.dtype != np.uint8:
    raise ValueError('`mask` not of type np.uint8')
  if np.any(np.logical_and(mask != 1, mask != 0)):
    raise ValueError('`mask` elements should be in [0, 1]')
  if image.shape[:2] != mask.shape:
    raise ValueError('The image has spatial dimensions %s but the mask has '
                     'dimensions %s' % (image.shape[:2], mask.shape))
  rgb = ImageColor.getrgb(color)
  pil_image = Image.fromarray(image)

  solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
  pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
  pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
  pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)#复合函数
  np.copyto(image, np.array(pil_image.convert('RGB')))

def draw_bounding_box_on_image_array(image,ymin,xmin,ymax,xmax,color='red',
                                     thickness=4,display_str_list=(),
                                     use_normalized_coordinates=True):
  """Adds a bounding box to an image (numpy array).

  Bounding box coordinates can be specified in either absolute (pixel) or
  normalized coordinates by setting the use_normalized_coordinates argument.

  Args:
    image: a numpy array with shape [height, width, 3].
    ymin: ymin of bounding box.
    xmin: xmin of bounding box.
    ymax: ymax of bounding box.
    xmax: xmax of bounding box.
    color: color to draw bounding box. Default is red.
    thickness: line thickness. Default value is 4.
    display_str_list: list of strings to display in box
                      (each to be shown on its own line).
    use_normalized_coordinates: If True (default), treat coordinates
      ymin, xmin, ymax, xmax as relative to the image.  Otherwise treat
      coordinates as absolute.
  """
  image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
  draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
                             thickness, display_str_list,
                             use_normalized_coordinates)
  np.copyto(image, np.array(image_pil))

def draw_bounding_box_on_image(image,ymin,xmin,ymax,xmax,color='red',thickness=4,
                               display_str_list=(),use_normalized_coordinates=True):
  draw = ImageDraw.Draw(image)
  im_width, im_height = image.size
  if use_normalized_coordinates:
    (left, right, top, bottom) = (xmin * im_width, xmax * im_width, ymin * im_height, ymax * im_height)
  else:
    (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
  #画bounding box
  draw.line([(left, top), (left, bottom), (right, bottom),(right, top), (left, top)], width=thickness, fill=color)
  try:
    font = ImageFont.truetype('arial.ttf', 24)
  except IOError:
    font = ImageFont.load_default()

  # If the total height of the display strings added to the top of the bounding box exceeds the top of the image,
  # stack the strings below the bounding box instead of above.
  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
  # Each display_str has a top and bottom margin of 0.05x.
  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
  #string放置的底部
  if top > total_display_str_height:
    text_bottom = top
  else:
    text_bottom = bottom + total_display_str_height
  # Reverse list and print from bottom to top.
  for display_str in display_str_list[::-1]:
    text_width, text_height = font.getsize(display_str)
    margin = np.ceil(0.05 * text_height)
    #文本域:[(左,上),(右,下)]
    draw.rectangle([(left, text_bottom - text_height - 2 * margin), (left + text_width,text_bottom)],fill=color)
    draw.text((left + margin, text_bottom - text_height - margin),display_str,fill='black',font=font)
    text_bottom -= text_height - 2 * margin

def Perspective_transform(image, pts):
    '''透视变换'''
    pts=pts.squeeze().astype(np.float32)
    (tl, tr, br, bl) = pts
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(int(widthA), int(widthB))
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(int(heightA), int(heightB))

    # in the top-left, top-right, bottom-right, and bottom-left order
    dst = np.array([[0, 0],
                    [maxWidth - 1, 0],
                    [maxWidth - 1, maxHeight - 1],
                    [0, maxHeight - 1]], dtype="float32")
    M = cv2.getPerspectiveTransform(pts, dst)
    warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
    return warped

def post_process(output_dict,resize_info_,image_np):
    '''类型转换等'''
    output_dict['num_detections'] = int(output_dict['num_detections'][0])
    output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)
    output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
    output_dict['detection_scores'] = output_dict['detection_scores'][0]
    output_dict['detection_masks'] = output_dict['detection_masks_reframed'][0]

    # 结果后处理,300
    detection_boxes = output_dict['detection_boxes']  # normalized coordinate
    detection_classes = output_dict['detection_classes']
    detection_scores = output_dict['detection_scores']
    detection_masks = output_dict.get('detection_masks')
    # 对于detection_boxes进行后处理
    # 考虑pad
    detection_boxes *= 600  # unnormalized
    detection_boxes[:, 0] -= resize_info_[0]  # ymin,
    detection_boxes[:, 1] -= resize_info_[2]  # xmin
    detection_boxes[:, 2] -= resize_info_[0]  # ymax
    detection_boxes[:, 3] -= resize_info_[2]  # xmax
    # 考虑scale
    detection_boxes /= resize_info_[4]

    # 对于detection_masks进行后处理
    new_detection_masks = np.zeros((detection_masks.shape[0], image_np.shape[0], image_np.shape[1]))
    for i in range(len(detection_masks)):
        new_mask = detection_masks[i, resize_info_[0]:(image_shape[0] - resize_info_[1]),
                   resize_info_[2]:(image_shape[1] - resize_info_[3])]
        new_mask = cv2.resize(new_mask, (image_np.shape[1], image_np.shape[0]))
        new_detection_masks[i] = new_mask

    new_detection_masks = new_detection_masks.astype(np.uint8)

    return detection_boxes, detection_classes, detection_scores,new_detection_masks

def get_truck_mask(res_list,image_np):
    truck_list = []
    truck_mask = np.zeros_like(image_np)
    for r in res_list:
        if r[0] == 'truck':
            truck_list.append(r)
    truck_mask = cv2.cvtColor(truck_mask, cv2.COLOR_BGR2GRAY)
    # 合并truck的mask
    if len(truck_list)!=0 :
        for i,v in enumerate(truck_list):
            truck_mask = cv2.bitwise_or(truck_mask, truck_list[i][2])
    truck_mask = cv2.cvtColor(truck_mask, cv2.COLOR_GRAY2BGR)
    return truck_mask,truck_list

def cos_dist(a, b):
    if len(a) != len(b):
        return None
    part_up = 0.0
    a_sq = 0.0
    b_sq = 0.0
    for a1, b1 in zip(a, b):
        part_up += a1*b1
        a_sq += a1**2
        b_sq += b1**2
    part_down = math.sqrt(a_sq*b_sq)
    if part_down == 0.0:
        return None
    else:
        return part_up / part_down

def clockwise(pts):
    # sort the points based on their x-coordinates
    xSorted = pts[np.argsort(pts[:, 0]), :]

    # grab the left-most and right-most points from the sorted
    # x-roodinate points
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]

    # now, sort the left-most coordinates according to their
    # y-coordinates so we can grab the top-left and bottom-left
    # points, respectively
    leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    (tl, bl) = leftMost

    # now that we have the top-left and bottom-left coordinate, use it as an
    # base vector to calculate the angles between the other two vectors

    vector_0 = np.array(bl - tl)
    vector_1 = np.array(rightMost[0] - tl)
    vector_2 = np.array(rightMost[1] - tl)

    angle = [np.arccos(cos_dist(vector_0, vector_1)), np.arccos(cos_dist(vector_0, vector_2))]
    (br, tr) = rightMost[np.argsort(angle), :]

    # return the coordinates in top-left, top-right,bottom-right, and bottom-left order
    return np.array([tl, tr, br, bl], dtype="float32")

def getDist_P2P(Point0, PointA):
    '''# ***** 求两点间距离*****'''
    distance = math.pow((Point0[0] - PointA[0]), 2) + math.pow((Point0[1] - PointA[1]), 2)
    distance = math.sqrt(distance)
    return distance

def get_rotated_rect(approx,truck_mask_copy):
    '''计算最小外接矩形,返回旋转矩形和顺时针的四个点'''
    approx = approx.squeeze()
    min_area_rect = cv2.minAreaRect(approx)
    angle = min_area_rect[-1]
    rotated_rect = cv2.boxPoints(min_area_rect)
    rotated_rect = np.int0(rotated_rect)
    rotated_rect = clockwise(rotated_rect)
    approx = np.array(clockwise(approx)).astype(np.float32)
    # 显示最小外接矩形
    print('rotated_rect:',rotated_rect)
    for i in range(len(rotated_rect) - 1):
        cv2.line(truck_mask_copy, rotated_rect[i].astype(np.int), rotated_rect[i + 1].astype(np.int), (0, 255, 0))
    cv2.line(truck_mask_copy, rotated_rect[-1].astype(np.int), rotated_rect[0].astype(np.int), (0, 255, 0))
    return rotated_rect,approx

def search_approx(epi_thres_list,cnt,truck_mask_copy):
    fitting_record = []
    for epi_thres in epi_thres_list:
        # 多边形拟合
        epsilon = epi_thres * cv2.arcLength(cnt, True)
        fitting_points = cv2.approxPolyDP(cnt, epsilon, True)
        fitting_record.append((epi_thres, fitting_points))

    for record in fitting_record:
        if len(record[1]) == 4:
            approx = np.array(clockwise(record[1].squeeze())).astype(np.float32)
            approx = approx.reshape((-1, 1, 2))
            print("approx 1")
            break
    else:
        # 计算与拟合曲线距离最近的四个点
        # 计算最小外接矩形,返回旋转矩形和顺时针的四个点
        rotated_rect, approx = get_rotated_rect(fitting_record[0][1], truck_mask_copy)
        min_points = []
        for rr in rotated_rect:
            tmp_min_record = []
            for bb in approx:
                dist = getDist_P2P(rr, bb)
                tmp_min_record.append((dist, bb))
            min_dist = sorted(tmp_min_record, key=lambda x: x[0])[0][1]
            min_points.append(min_dist)
        print('min_points',min_points)
        approx = np.array(clockwise(min_points)).astype(np.float32)
        print("approx 2")

    return approx

def get_warpPerspective(truck_mask,src_img,epi_thres_list = [0.01,0.02,0.03,0.04,0.05,0.075,0.1]):
    '''找到四个点进行透视变换,获得透视后的图像
    epi_thres:外边框拟合参数,越小,拟合越精细
    '''
    # 二值化
    gray = cv2.cvtColor(truck_mask, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    if len(contours)==0:#没有卡车的情况
        return np.zeros_like(gray)
    else:
        truck_mask_copy = truck_mask.copy()
        cnt = contours[np.argmax([cv2.contourArea(cnt) for cnt in contours])]  # 取面积最大的truck contour作为主卡车
        #寻找approx,拟合的四个点
        approx=search_approx(epi_thres_list, cnt, truck_mask_copy)
        approx = np.array(clockwise(approx.squeeze()))

        cv2.polylines(truck_mask_copy, [approx.astype(np.int)], True, (0, 0, 255), 2)
        for p in approx.squeeze():
            cv2.circle(truck_mask_copy,(int(p[0]),int(p[1])),5,(255,255,255),5)
        cv2.imshow('truck_mask_copy', truck_mask_copy)

        #进行透视变换
        warpPerspective = Perspective_transform(src_img, approx)
        if warpPerspective.shape[0] > warpPerspective.shape[1]:
            warpPerspective = np.rot90(warpPerspective)

        return warpPerspective

def save_image(save_path,num, image):
    """Save the images.

    Args:
        num: serial number
        image: image resource

    Returns:
        None
    """
    image_path = save_path+'{}.jpg'.format(str(num))
    cv2.imwrite(image_path, image)

def compute_area(rect):
    '''计算矩形面积,分别是左上点和右下点'''
    return int(math.fabs(rect[2]-rect[0])*math.fabs(rect[3]-rect[1]))

def is_overlap(truck_box,crane_box):
    '''两个检测框框是否有交叉,如果有交集则返回重叠面积, 如果没有交集则返回 false'''
    x1,y1,w1,h1=truck_box[0],truck_box[1],truck_box[2]-truck_box[0],truck_box[3]-truck_box[1]
    x2,y2,w2,h2=crane_box[0],crane_box[1],crane_box[2]-crane_box[0],crane_box[3]-crane_box[1]

    if (x1 > x2 + w2):
        return 0
    if (y1 > y2 + h2):
        return 0
    if (x1 + w1 < x2):
        return 0
    if (y1 + h1 < y2):
        return 0
    colInt = abs(min(x1 + w1, x2 + w2) - max(x1, x2))
    rowInt = abs(min(y1 + h1, y2 + h2) - max(y1, y2))
    overlap_area = colInt * rowInt
    print('overlap_area:',overlap_area)
    return overlap_area>0

def get_truck_crane_claw(image_src,res_list):
    truck_info, crane_info, claw_info = [], [], []
    for v in res_list:
        if v[0] == 'truck':
            truck_info.append(v)
        elif v[0] == 'crane':
            crane_info.append(v)
        elif v[0] == 'claw':
            claw_info.append(v)

    truck_area = []
    for v in truck_info:
        rect_info = v[1]
        rect_area = compute_area(rect_info)
        truck_area.append((v[0], v[1], rect_area, v[2]))
    main_truck_rect, crane_rect, claw_rect = [], [], []
    if len(truck_area) != 0:
        main_truck_rect = sorted(truck_area, key=lambda x: x[2])[0][1]
    if len(crane_info) != 0:
        crane_rect = crane_info[0][1]
    if len(claw_info) != 0:
        claw_rect = claw_info[0][1]

    return main_truck_rect,crane_rect,claw_rect

def maskrcnn_algorithm(image_src,frame_interval_count):
    '''进行算法识别'''
    h,w,c=image_src.shape

    resize_ratio=600/max(h,w) #第一次图像缩放的系数
    multiplying=1/resize_ratio #第一次图像反放缩的系数

    # 第一次resize,为了提升后处理速度
    image_np = cv2.resize(image_src, dsize=None, fx=resize_ratio, fy=resize_ratio)
    # 第二次resize,进行图像pad操作,(600,600,3)
    image_np_resize, resize_info_ = resize_img(image_np, image_shape)
    # inference
    output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image_np_resize, 0)})
    # 后处理操作
    detection_boxes, detection_classes, detection_scores, new_detection_masks = post_process(output_dict,resize_info_,image_np)
    # 获取结果,分别为cls,box,mask的排列
    res_list = visualize_boxes_and_labels_on_image_array(image_np, detection_boxes, detection_classes,detection_scores, category_index,
                                                         instance_masks=new_detection_masks,use_normalized_coordinates=False,line_thickness=2)

    #获取识别结果crane,truck,claw的box
    main_truck_rect,crane_rect,claw_rect=get_truck_crane_claw(image_src, res_list)
    return main_truck_rect,crane_rect,claw_rect,image_np,multiplying,res_list


# 计算累计分布函数
def C(rk):
    # 读取图片灰度直方图
    # bins为直方图直方柱的取值向量
    # hist为bins各取值区间上的频数取值
    hist, bins = np.histogram(rk, 256, [0, 256])
    # 计算累计分布函数
    return hist.cumsum()

# 计算灰度均衡化映射
def T(rk):
    cdf = C(rk)
    # 均衡化
    cdf = (cdf - cdf.min()) * (255 - 0) / (cdf.max() - cdf.min()) + 0
    return cdf.astype('uint8')

# ***************************** K-means 聚类 **********************************
# in:二维数据点 xMax,yMax:边界最大值(图像尺寸)
def Kmeans(input, k, xMax, yMax):
    # 加上分类信息
    keyPoint = [[0 for x in range(3)] for y in range(len(input))]
    for i in range(len(keyPoint)):
        keyPoint[i][0] = input[i][0]
        keyPoint[i][1] = input[i][1]
        keyPoint[i][2] = 999
    # 初始化 k 个中心点
    center = [[0 for x in range(3)] for y in range(k)]
    # radious = [0 for x in range(k)]
    for i in range(k):
        center[i][0] = random.randint(0, xMax)
        center[i][1] = random.randint(0, yMax)

    # 停止迭代的三个条件
    time = 0  # 迭代次数
    timeMax = 4
    changed = 0  # 重新分配
    a = 0.01  # 最小移动与图像尺度的比例
    move = 0  # 所有类中心移动距离小于moveMax
    moveMax = a * xMax

    # 未到最大迭代次数
    while time < timeMax:
        time = time + 1
        # 计算每个点的最近分类
        for i in range(len(keyPoint)):
            dis = -1
            for j in range(k):
                x = keyPoint[i][0] - center[j][0]
                y = keyPoint[i][1] - center[j][1]
                disTemp = x * x + y * y
                # 更新当前最近分类并标记
                if (disTemp < dis) | (dis == -1):
                    dis = disTemp
                    keyPoint[i][2] = j
        # 更新类中心点坐标
        for i in range(k):
            xSum = 0
            ySum = 0
            num = 0
            for j in range(len(keyPoint)):
                if keyPoint[j][2] == i:
                    xSum = xSum + keyPoint[j][0]
                    ySum = ySum + keyPoint[j][1]
                    num = num + 1
            if num != 0:
                center[i][0] = xSum / num
                center[i][1] = ySum / num
    # 记录每个分类的点数量
    for i in range(len(keyPoint)):
        center[keyPoint[i][2]][2] = center[keyPoint[i][2]][2] + 1
    return center
# meanShift
# input:二维数据点
def MeanShift(input, r):
    classification = []
    startNum = 100  # 起始点数量
    radium = r  # 窗口半径
    num = len(input)  # 样本数量
    Sample = np.int32([[0, 0, 0] for m in range(num)])  # 添加分类信息 0为未分类
    for i in range(num):
        Sample[i][0] = input[i][0]
        Sample[i][1] = input[i][1]

    # 随机选择一个起始点
    for i in range(startNum):
        # 范围
        ptr = random.randint(0, num - 1)

        # 记录分类中心点
        center = [0, 0]
        center[0] = Sample[ptr][0]
        center[1] = Sample[ptr][1]
        Flag = 0
        # 判断终止条件
        iteration = 0
        while ((Flag == 0) & (iteration < 10)):
            orientation = [0, 0]  # 移动方向
            # 找出窗口内的所有样本点
            for j in range(num):
                oX = Sample[j][0] - center[0]
                oY = Sample[j][1] - center[1]
                dist = math.sqrt(oX * oX + oY * oY)
                # 该点在观察窗内
                if dist <= radium:
                    orientation[0] = orientation[0] + oX / 20
                    orientation[1] = orientation[1] + oY / 20
            # 开始漂移
            center[0] = center[0] + orientation[0]
            center[1] = center[1] + orientation[1]
            # 中心点不再移动时
            oX = orientation[0]
            oY = orientation[1]
            iteration = iteration + 1
            if math.sqrt(oX * oX + oY * oY) < 3:
                Flag = 1

        # 添加不重复的新分类信息
        Flag = 1
        for i in range(len(classification)):
            # 与当前存在的分类位置差别小于5
            oX = classification[i][0] - center[0]
            oY = classification[i][1] - center[1]
            if math.sqrt(oX * oX + oY * oY) < math.sqrt(classification[i][2]) + 30:
                Flag = 0
                break
        if Flag == 1:
            temp = [center[0], center[1], 0]
            classification.append(temp)

    # 给所有样本点分类
    for i in range(num):
        Index = 0
        minValue = 99999
        # 找出最近的分类
        for j in range(len(classification)):
            xx = classification[j][0] - Sample[i][0]
            yy = classification[j][1] - Sample[i][1]
            distination = abs(xx * xx + yy * yy)
            if distination <= minValue:
                Index = j
                minValue = distination
        Sample[i][2] = Index
        classification[Index][2] = classification[Index][2] + 1

    return classification

def sift_detection(before_in,after_leave):
    func = 2
    a = 1 # 显示比例
    # detectDensity = 2
    detectDensity = 1.5
    shreshood = 350
    windowSize = 40

    # 二值化
    gray1 = cv2.cvtColor(before_in, cv2.COLOR_BGR2GRAY)  ##要二值化图像,必须先将图像转为灰度图
    # _, binary1 = cv2.threshold(gray1, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    gray2 = cv2.cvtColor(after_leave, cv2.COLOR_BGR2GRAY)  ##要二值化图像,必须先将图像转为灰度图
    # _, binary2 = cv2.threshold(gray2, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    # 3.2 直方图均衡化
    rk1 = gray1.flatten()
    rk2 = gray2.flatten()
    imgDst1 = T(rk1)[gray1]
    imgDst2 = T(rk2)[gray2]

    img1 = cv2.cvtColor(imgDst1, cv2.COLOR_GRAY2BGR)
    img2 = cv2.cvtColor(imgDst2, cv2.COLOR_GRAY2BGR)

    # 保持尺寸一致
    h1,w1,_=img1.shape
    h2,w2,_=img2.shape
    src_h,src_w=max(h1,h2),max(w1,w2)
    # sift_size=300
    # max_hw=max(src_h,src_w)
    # resize_ratio=sift_size/max_hw
    # multiplying_ratio=1/resize_ratio
    # h,w=int(src_h*resize_ratio),int(src_w*resize_ratio)
    h,w=src_h,src_w

    if img1.shape[0]<img1.shape[1]:
        img1 = cv2.resize(img1,dsize=(w,h))
        img2 = cv2.resize(img2,dsize=(w,h))
    else:
        img1 = cv2.resize(img1,dsize=(w,h))
        img2 = cv2.resize(img2,dsize=(w,h))

    sift = cv2.xfeatures2d.SIFT_create()
    # 检测关键点
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)

    # 关键点匹配
    FLANN_INDEX_KDTREE = 0
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=6)
    search_params = dict(checks=10)

    flann = cv2.FlannBasedMatcher(index_params, search_params)

    matches = flann.knnMatch(des1, des2, k=2)

    good = []
    for m, n in matches:
        if m.distance < 0.7 * n.distance:
            good.append(m)

    # 把good中的左右点分别提出来找单应性变换
    pts_src = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
    pts_dst = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
    # 单应性变换
    M, mask = cv2.findHomography(pts_src, pts_dst, cv2.RANSAC, 5.0)

    # 输出差异识别结果 ********************************************************
    if func == 2:
        # M矩阵中xy方向的偏移量
        dX = M[0][2]  # x方向 负为左比右小
        dY = M[1][2]  # y方向 负为左比右小

        # 图像的长宽
        height, width, channel = img1.shape

        # 设定关键点的尺度
        size = int(width * 0.01)

        # 自动选择采样点的位置范围
        xMinLeft = width
        xMaxLeft = 0
        yMinLeft = height
        yMaxLeft = 0
        xMinRight = width
        xMaxRight = 0
        yMinRight = height
        yMaxRight = 0

        # 用当前匹配成功的点集分析合适的检测范围
        for i in range(len(pts_src)):
            if mask[i] == 1:
                if pts_src[i][0][1] < yMinLeft:
                    yMinLeft = pts_src[i][0][1]
                if pts_src[i][0][1] > yMaxLeft:
                    yMaxLeft = pts_src[i][0][1]
                if pts_src[i][0][0] < xMinLeft:
                    xMinLeft = pts_src[i][0][0]
                if pts_src[i][0][0] > xMaxLeft:
                    xMaxLeft = pts_src[i][0][0]
        for i in range(len(pts_dst)):
            if mask[i] == 1:
                if pts_dst[i][0][1] < yMinRight:
                    yMinRight = pts_dst[i][0][1]
                if pts_dst[i][0][1] > yMaxRight:
                    yMaxRight = pts_dst[i][0][1]
                if pts_dst[i][0][0] < xMinRight:
                    xMinRight = pts_dst[i][0][0]
                if pts_dst[i][0][0] > xMaxRight:
                    xMaxRight = pts_dst[i][0][0]

        xMinLeft = xMinLeft + 2 * size
        yMinLeft = yMinLeft + 3 * size

        # 检测范围确定
        interval = detectDensity * size  # 监测点间隔
        searchWidth = int((xMaxLeft - xMinLeft) / interval - 2)
        searchHeight = int((yMaxLeft - yMinLeft) / interval - 2)
        searchNum = searchWidth * searchHeight
        demo_src = np.float32([[0] * 2] * searchNum * 1).reshape(-1, 1, 2)
        for i in range(searchWidth):
            for j in range(searchHeight):
                demo_src[i + j * searchWidth][0][0] = xMinLeft + i * interval + size
                demo_src[i + j * searchWidth][0][1] = yMinLeft + j * interval + size

        # 单应性变换 左图映射到右图的位置
        demo_dst = cv2.perspectiveTransform(demo_src, M)

        # 把差异点画出来
        heightO = max(img1.shape[0], img2.shape[0])
        widthO = img1.shape[1] + img1.shape[1]
        output = np.zeros((heightO, widthO, 3), dtype=np.uint8)
        output[0:img1.shape[0], 0:img1.shape[1]] = img1
        output[0:img2.shape[0], img2.shape[1]:] = img2[:]
        # output2
        output2 = output

        # 转换成KeyPoint类型
        kp_src = [cv2.KeyPoint(demo_src[i][0][0], demo_src[i][0][1], size) for i in range(demo_src.shape[0])]
        kp_dst = [cv2.KeyPoint(demo_dst[i][0][0], demo_dst[i][0][1], size) for i in range(demo_dst.shape[0])]

        # 计算这些关键点的SIFT描述子
        keypoints_image1, descriptors_image1 = sift.compute(img1, kp_src)
        keypoints_image2, descriptors_image2 = sift.compute(img2, kp_dst)

        # 差异点
        diffLeft = []
        diffRight = []

        # 分析差异
        for i in range(searchNum):
            nowShreshood = shreshood
            difference = 0
            for j in range(128):
                d = abs(descriptors_image1[i][j] - descriptors_image2[i][j])
                difference = difference + d * d
            difference = math.sqrt(difference)

            # 右图关键点位置不超出范围
            if (demo_dst[i][0][1] >= 0) & (demo_dst[i][0][0] >= 0):
                if difference > nowShreshood:
                    if func == 2:
                        diffLeft.append([demo_src[i][0][0], demo_src[i][0][1]])
                        diffRight.append([demo_dst[i][0][0], demo_dst[i][0][1]])

        # 聚类后输出
        if func == 2:
            outLeft = MeanShift(diffLeft, windowSize)

            left = np.float32([[0] * 2] * len(outLeft) * 1).reshape(-1, 1, 2)
            for i in range(len(outLeft)):
                left[i][0][0] = outLeft[i][0]
                left[i][0][1] = outLeft[i][1]
                right = cv2.perspectiveTransform(left, M)
                outRight = [[0 for x in range(3)] for y in range(len(outLeft))]
            for i in range(len(outLeft)):
                outRight[i][0] = right[i][0][0]
                outRight[i][1] = right[i][0][1]
                outRight[i][2] = outLeft[i][2]

            # 将点数大于50的类画出来 点数不足50认为是错误导致的
            thres = 50
            time = 5
            output3 = np.zeros_like(output2)
            for i in range(len(outLeft)):
                if outLeft[i][2] > thres:
                    cv2.circle(output3, (int(outLeft[i][0]), int(outLeft[i][1])), int(np.sqrt(outLeft[i][2])) * time,(255, 255, 255), 2)
            for i in range(len(outRight)):
                if outRight[i][2] > thres:
                    cv2.circle(output3, (int(outRight[i][0]) + width, int(outRight[i][1])),int(np.sqrt(outRight[i][2])) * time, (255, 255, 0), 2)

            # 输出结果
            # out = cv2.resize(output3, (int(output.shape[1] * a), int(output.shape[0] * a)),interpolation=cv2.INTER_CUBIC)
            left_out=output3[0:img1.shape[0], 0:img1.shape[1],:]
            return left_out

def sift_mask(sift_res):
    gray = cv2.cvtColor(sift_res, cv2.COLOR_BGR2GRAY)
    binary = np.where(gray == 0, 0, 1)
    binary = binary.astype(np.uint8)
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    mask = np.zeros_like(sift_res, dtype=np.uint8)
    for i in range(len(contours)):
        cv2.fillPoly(mask, [contours[i]], (255, 255, 255))
    mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
    return mask

def crop(sift_res,res_list):
    siftMask = sift_mask(sift_res)
    # cv2.imshow('mask', siftMask)
    for i in range(len(res_list)):
        if res_list[i][0] == 'truck':
            x1, y1, x2, y2 = res_list[i][1][0], res_list[i][1][1], res_list[i][1][2], res_list[i][1][3]
            true_truck_mask = res_list[i][2][y1:y2, x1:x2]
            true_truck_mask = cv2.resize(true_truck_mask, dsize=(sift_res.shape[1], sift_res.shape[0]), fx=None,fy=None)
            # main_mask 是取siftmask中面积最大的作为主变化区域,用来与true_truck_mask进行bitwise_and
            binary = np.where(siftMask == 0, 0, 1)
            binary = binary.astype(np.uint8)
            contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            cnt = contours[np.argmax([cv2.contourArea(cnt) for cnt in contours])]  # 取面积最大的作为主变化区域
            main_mask = np.zeros_like(binary, dtype=np.uint8)
            cv2.fillPoly(main_mask, [cnt], (255, 255, 255))

            bitwise_mask = cv2.bitwise_and(main_mask, true_truck_mask)
            binary1 = np.where(bitwise_mask == 0, 0, 1)
            binary1 = binary1.astype(np.uint8)
            contours1, _ = cv2.findContours(binary1, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            x, y, w, h = cv2.boundingRect(contours1[0])

            bitwise_mask = cv2.cvtColor(bitwise_mask, cv2.COLOR_GRAY2BGR)
            cv2.rectangle(bitwise_mask, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cv2.imshow('bitwise_mask', bitwise_mask)
            # cv2.waitKey(0)
            return x,y,x+w,y+h  #x1,y1,x2,y2

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s')
    # path
    # file_path = r'D:\liu_projects\tanggang\data\video\20220706.mp4'
    file_path = r'D:\liu_projects\tanggang\data\video\0f58f563d3e44297909b57e8c3e89356-2.avi'
    tmp_save_path = r'D:\liu_projects\tanggang\data\video\video\2022_07_04_11_05-0220704094132983_master1_avi\\'

    save_path=r'.\raw_pictures\\'
    PATH_TO_FROZEN_GRAPH ='frozen_model/frozen_inference_graph.pb'
    PATH_TO_TEST_IMAGES_DIR = 'test_imgs'
    PATH_TO_TEST_RESULT_DIR = 'test_results'

    #params
    count = 0  # count the number of pictures
    frame_interval = 50  # 控制采样间隔
    frame_interval_count = 0
    start_frame=100

    is_overlap_flag=False #判断是否有overlap的flag
    hasin_flag=False #crane已经进入过truck的flag
    count_num=0 #用于crane离开truck的时候
    truck_roi_queue=[]
    before_in, after_leave = np.zeros((10, 10)), np.zeros((10, 10))

    #mkdirs
    if os.path.exists(PATH_TO_TEST_RESULT_DIR):
        shutil.rmtree(PATH_TO_TEST_RESULT_DIR)
        os.makedirs(PATH_TO_TEST_RESULT_DIR)
    else:
        os.makedirs(PATH_TO_TEST_RESULT_DIR)

    if os.path.exists(tmp_save_path):
        shutil.rmtree(tmp_save_path)
        os.makedirs(tmp_save_path)
    else:
        os.makedirs(tmp_save_path)

    vc = cv2.VideoCapture(file_path)
    if vc.isOpened():
        ret, frame = vc.read()
    else:
        ret = False

    #载入冻结模型,获取输入输出节点
    image_tensor,tensor_dict,detection_graph=load_pb_and_get_input_output_node(PATH_TO_FROZEN_GRAPH)
    with detection_graph.as_default():
        with tf.Session() as sess:
            while ret:
                ret, frame = vc.read()
                if frame_interval_count < start_frame:#控制起始frame
                    frame_interval_count += 1
                    continue

                if frame_interval_count % frame_interval == 0:
                    # save_image(save_path,count, frame)
                    t1=time()
                    #进行算法识别
                    main_truck_rect,crane_rect,claw_rect,image_np,multiplying,res_list=maskrcnn_algorithm(frame,frame_interval_count)

                    if (crane_rect!=[] and main_truck_rect) or (claw_rect!=[] and main_truck_rect): #存在crane truck或者claw truck的情况下才进行以下处理
                        #获取在原图上的放大后的truck_box,crane_box,claw_box
                        truck_box,crane_box,claw_box=np.array(main_truck_rect)*multiplying,np.array(crane_rect)*multiplying,np.array(claw_rect)*multiplying
                        # truck_box,crane_box,claw_box=np.array(main_truck_rect),np.array(crane_rect),np.array(claw_rect)
                        truck_box, crane_box, claw_box=truck_box.astype(int),crane_box.astype(int),claw_box.astype(int)

                        roi=frame[truck_box[1]:truck_box[3], truck_box[0]:truck_box[2], :]
                        # roi=image_np[main_truck_rect[1]:main_truck_rect[3], main_truck_rect[0]:main_truck_rect[2], :]
                        # 记录上一次的truck_roi
                        truck_roi_queue.append(roi)

                        #计算overlap
                        if len(crane_rect) != 0:
                            is_overlap_flag = is_overlap(main_truck_rect, crane_rect)
                        elif len(claw_rect) != 0:
                            is_overlap_flag = is_overlap(main_truck_rect, crane_rect)

                        # 解决第一次crane进入到truck的情况
                        if is_overlap_flag == True and hasin_flag == False:
                            hasin_flag = True
                            if len(truck_roi_queue)>1:
                                before_in=truck_roi_queue[-2]
                            else:
                                before_in = truck_roi_queue[-1]
                            cv2.imshow('before_in', before_in)

                        if is_overlap_flag == False and hasin_flag == True:
                            # 解决crane离开truck的情况,此处要注意为了让crane离开truck,故意延时一帧
                            count_num += 1
                            if count_num == 2:  # crane离开truck的情况的延迟图像选择,故意延时一帧
                                hasin_flag = False
                                count_num = 0
                                after_leave=truck_roi_queue[-1]
                                cv2.imshow('after_leave', after_leave)
                                #利用sift算法进行判定变化位置,目前比较耗时
                                sift_res=sift_detection(before_in,after_leave)
                                cv2.imshow('sift_res', sift_res)

                                #crop 功能用于判级
                                x1,y1,x2,y2=crop(sift_res,res_list)
                                crop_img=before_in[y1:y2,x1:x2]
                                cv2.imshow('crop_img',crop_img)
                                cv2.waitKey(0)


                        image_copy = image_np.copy()
                        if main_truck_rect != []:
                            cv2.rectangle(image_copy, (main_truck_rect[0], main_truck_rect[1]),(main_truck_rect[2], main_truck_rect[3]), color=(0, 0, 255))
                        if crane_rect != []:
                            cv2.rectangle(image_copy, (crane_rect[0], crane_rect[1]),(crane_rect[2], crane_rect[3]), color=(255, 0, 0))
                        if claw_rect != []:
                            cv2.rectangle(image_copy, (claw_rect[0], claw_rect[1]), (claw_rect[2], claw_rect[3]),color=(255, 0, 0))
                        cv2.imshow('image_copy', image_copy)

                        if len(truck_roi_queue) > 5: # 5 是队列的长度,保证占内存不会太大
                            del truck_roi_queue[0:3] # 只保证后2个是存在roi的

                    else:#缺少识别对象
                        pass

                    print('消耗时间',time()-t1)
                    logging.info("num:" + str(count) + ", frame: " + str(frame_interval_count))
                    count += 1
                    # cv2.waitKey(0)
                frame_interval_count += 1
                cv2.waitKey(1)
            vc.release()

  

posted @ 2022-07-12 16:16  刘恩福  阅读(45)  评论(0编辑  收藏  举报