(项目)数字仪表识别(定长/不定长)+三色灯分类

记录一下做数字仪表检测项目的过程,会附带部分代码

业务背景:四块仪表,每块表界面是4位(红色)数字,即要检测识别4个4位数字。在检测界面还有三个灯,三个灯都是3中颜色,红色、黄色、绿色。

     我要做的就是实时的检测出4个4位数字具体的数值,并且对3个灯进行分类。

解决思路:首先在摄像头所拍摄到的界面中定位到数字、灯所在的区域,然后进行识别或者分类

解决方法:

  数字识别,有以下解决方法:

    一、使用模板匹配的方法,因为数字都是规范格式的,所以使用模板匹配可以保证准确率,但是速度会比较慢;

    二、在定位到数字区域后,进行字符分割,优点是字符分割后可以使用很简单的神经网络就可以进行数字0-9的分类,缺点是要一个一个数字识别;

    三、直接使用CNN网络进行分类,类似验证码识别,为了能够将验证码图片的文本信息输入到卷积神经网络模型里面去训练,需要将文本信息向量化

    编码,参见https://my.oschina.net/u/876354/blog/3048523这篇博客。优点是使用的CNN网络很简单,缺点是需要大量的训练数据,否则模型预测效果很差;

    四、使用cnn+lstm+ctc这种比较成熟的深度学习在ocr的应用的模型组合。这里我采用的是第四种方法,模型有比较好的鲁棒性,由于ctc的存在,

    将来换成不定长的数字仪表也可以识别。具体原理参见https://my.oschina.net/u/876354/blog/3070699。这里我使用的网络是lstm+ctc,之所以没有使用

    cnn是因为,也使用了“有色灯分类”中基于颜色提取4位数字目标,直接确定了4位数字所在的区域,在训练数据采集时,得到的即是使用该方法摄像头

    采集到的数字图片,而在预测的过程中,摄像头拍摄的区域也会自动分隔出4张小的数字图片送入模型进行预测。基于颜色提取的目标比较稳定,因此没有

    使用cnn进行特征提取。

  有色灯分类,有以下解决方法:

    一、使用深度学习的目标检测方法,有one-stage和two-stage两种,比较经典的是fast-rcnn和ssd、yolo等,但是由于这里只是3种灯的分类,业务

    场景很简单,这里不使用这种方法。

    二、使用opencv中基于颜色提取目标,可以在摄像头具体采集数字仪表和有色灯的场景下,提取有色灯3种颜色的hsv色彩空间的信息,然后使用

    opencv中的一系列API进行目标提取和分类。(inrange得到二值图像,滤波,腐蚀,膨胀,得到梯度,根据梯度找到目标轮廓)。我使用的是第二种方法。

 

代码部分: 

1、基于颜色提取数字所在的矩形框

 

 

注意:读取图片,将RGB装换成hsv色彩空间,基于上图hsv的色彩值,利用opencv下面的api进行目标颜色的提取,从黑到紫都可以提取。核心api是cv2.inRange()

def extrace_object_demo(src):
    img = cv2.imread(src)
    
    # 1.将RGB装换成hsv色彩空间
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # # 通道数是 3
    img_binary = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 通道数是 1
    
    # 2.定义数组,说明你要提取(过滤)的颜色目标,
    # 三通道,所以是三个参数
    # 红色
    lower_hsv_g = np.array([156, 43, 46])
    upper_hsv_g = np.array([180, 255, 255])
    
    # 3.进行过滤,提取,得到二值图像
    mask_red = cv2.inRange(hsv, lower_hsv_g, upper_hsv_g)  # 通道数是 1

    # 合并展示
    res = np.hstack((img_binary, mask_red))
    cv2.imshow("res", res)

    cv2.waitKey(0)  
    cv2.destroyAllWindows()
    
    return mask_red

原图:

 

 

 基于红色得到的二值图:(截图,请忽略图像size改变)

 

 

 2、基于二值图得到颜色目标的梯度

注意:高斯滤波的api采用的滤波器这里选的是3x3的,可以改其他尺寸,建议奇数,同时执行滤波的次数可以改变。

同理,腐蚀、膨胀的滤波器和滤波次数都可以改。

def img_preprocessing(src):
    # 调用基于颜色过滤的函数
    mask_red = extrace_object_demo(src)
    # 高斯滤波,去噪声
    gaussian = cv2.GaussianBlur(mask_red, (3, 3), 1)
    # 腐蚀
    kernel = np.ones((5, 5), np.uint8)
    erosion = cv2.erode(gaussian, kernel, iterations = 1)
    # 膨胀
    dige_dilate = cv2.dilate(erosion, kernel, iterations = 1)
    # (形态学)梯度运算 = 膨胀运算 - 腐蚀运算
    gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel)
    return gradient

提取到的梯度长这个样子,有点不平整,没关系,不影响最终的效果:

 

 

 

3、利用opencv的api,基于梯度找到其矩形轮廓

注意:这里num=4可以修改,如果以后要提取的矩形框目标是3个,则改为3,以此类推。

注意调用cv2.rectangle()这个api的时候,传入的img应该是原图,这样就会在原图上画矩形框了。其他的api都是opencv提供好的,学会其用法正确传参就可以。

def get_box(gt, num=4):
    # 对前面得到的梯度,find其轮廓
    contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    # 创建一个列表,用来保存矩形框坐标值
    list_box = []
    # 取前num个面积最大的
    cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:num]

    for c in cnt:
        # 得到坐标
        x, y, w, h = cv2.boundingRect(c)
        # 存起来
        list_box.append((x, y, w, h))
        # 将矩形画出来
        draw_img = cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1)
    cv_show("draw_img", draw_img)

    # 返回坐标列表
    return list(set(list_box))  # 有重复,去重

返回的坐标列表一般是这个样子: [(83, 252, 154, 67), (66, 50, 173, 76), (374, 51, 183, 78), (366, 265, 174, 71)]

其效果图:

 

 

 

4、利用上述返回的坐标,可以在原图依次裁剪4张数字表的小图片,进而可以满足后续的训练以及其他工作。

def create_crapimg(raw_img, list_box):
    img = cv2.imread(raw_img)
    # 新建一个列表,用于保存裁剪下来的图片
    list_crap = []

    # 把矩形框对应的目标区域图片裁剪出来
    for i,box in enumerate(list_box):
        x, y, w, h = box
        # 获得裁剪图片
        img_crap = img[y:y+h, x:x+w]
        
        # 修改图片形状
        img_crap = cv2.resize(img_crap, (256, 32), 3)
        
        # 保存裁剪后的图片
        list_crap.append(img_crap)
        
    return list_crap

得到的图片是这个样子的:(文件名后面会说为什么命名成这样)

 

 

 总结:到此为止,关于数字仪表基于颜色提取就做完了,后面就是如何训练和使用模型了。

 

关于模型、训练相关代码:

注意:直接运行下面代码中的train()即可,保证各个路径对即可,注意将每个图片数据命名为:label_随便.jpg,比如数字是1233,则可以命名为1233_随便起一个名字.jpg。

这个因为下面的函数会读取文件名,然后将label转化成稀疏矩阵节约存储空间,方便训练。

模型的代码是tensorflow封装好的,即

# 定义LSTM网络
cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True)  # LSTM cell中的block数量
stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)

这里的LSTMCell、MultiRNNCell是可以修改的,可以改为rnn模块下封装的其他cell看看效果会不会更好,同时num_hidden和num_layers也是可以改的(2的倍数或次方),越大则模型越复杂。

#coding:utf-8
# 基于 lstm ctc 训练识别不定长的文字

import numpy as np
import cv2
import os
import tensorflow as tf
import random
import time
import datetime
# from captcha.image import ImageCaptcha
from PIL import Image, ImageFont, ImageDraw

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 定义一些常量
# 元数据集
DIGITS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

# 图片大小
OUTPUT_SHAPE = (32, 256)

# 训练最大轮次
num_epochs = 10000
num_hidden = 128
num_layers = 2
num_classes = len(DIGITS) + 1

# 初始化学习速率
INITIAL_LEARNING_RATE = 1e-3
DECAY_STEPS = 5000
REPORT_STEPS = 100
LEARNING_RATE_DECAY_FACTOR = 0.9  
MOMENTUM = 0.9

BATCHES = 10
BATCH_SIZE = 64
TRAIN_SIZE = BATCHES * BATCH_SIZE

# # 命令行参数
# # 定义model训练的步数 step
# tf.app.flags.DEFINE_integer("max_step", 0, "训练模型的步数")
# # 定义model的路径 load + 名字
# tf.app.flags.DEFINE_string("model_dir", " ", "模型保存的路径+模型名字")

# # 获取上述二者, 在运行的时候指定--->下面的参数要修改对应的FLAGS.max_step和FLAGS.model_dir
# FLAGS = tf.app.flags.FLAGS

# 命令行指令, 一定要写模型名字。。。
# python xx.py --max_step=xx --load="xx+模型名字"

data_dir = './tmp/train_data/'
model_dir = './tmp/train_data_model/'


# 稀疏矩阵转序列
def decode_a_seq(indexes, spars_tensor):
    decoded = []
    for m in indexes:
        str = DIGITS[spars_tensor[1][m]]
        decoded.append(str)
    return decoded

def decode_sparse_tensor(sparse_tensor):
    decoded_indexes = list()
    current_i = 0
    current_seq = []

    for offset, i_and_index in enumerate(sparse_tensor[0]):
        i = i_and_index[0]
        if i != current_i:
            decoded_indexes.append(current_seq)
            current_i = i
            current_seq = list()
        current_seq.append(offset)
    decoded_indexes.append(current_seq)

    result = []
    for index in decoded_indexes:
        result.append(decode_a_seq(index, sparse_tensor))
    return result

# 准确性评估
# 输入:预测结果序列 decoded_list ,目标序列 test_targets
# 返回:准确率
def report_accuracy(decoded_list, test_targets):
    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)

    # 正确数量
    true_numer = 0

    # 预测序列与目标序列的维度不一致,说明有些预测失败,直接返回
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
              " test and detect length desn't match")
        return

    # 比较预测序列与结果序列是否一致,并统计准确率        
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")
        if hit:
            true_numer = true_numer + 1
    accuracy = true_numer * 1.0 / len(original_list)
    print("Test Accuracy:", accuracy)    
    return accuracy

# 转化一个序列列表为稀疏矩阵
def sparse_tuple_from(sequences, dtype=np.int32):
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

# 将文件和标签读到数组
def get_file_text_array():
    file_name_array=[]
    text_array=[]

    for parent, dirnames, filenames in os.walk(data_dir):
        file_name_array=filenames

    for f in file_name_array:
        text = f.split('_')[0]
        text_array.append(text)

    return file_name_array,text_array

# 生成一个训练batch
def get_next_batch(file_name_array, text_array, batch_size=128):
    inputs = np.zeros([batch_size, OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]])
    codes = []

    # 获取训练样本
    for i in range(batch_size):
        index = random.randint(0, len(file_name_array) - 1)
        image = cv2.imread(data_dir + file_name_array[index])
        image = cv2.resize(image, (OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]), 3)
        image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
        text = text_array[index]
        inputs[i, :] = np.transpose(image.reshape((OUTPUT_SHAPE[0], OUTPUT_SHAPE[1])))
        codes.append(list(text))

    targets = [np.asarray(i) for i in codes]
    sparse_targets = sparse_tuple_from(targets)
    seq_len = np.ones(inputs.shape[0]) * OUTPUT_SHAPE[1]

    return inputs, sparse_targets, seq_len

def get_train_model():
    inputs = tf.placeholder(tf.float32, [None, None, OUTPUT_SHAPE[0]])  # old
    targets = tf.sparse_placeholder(tf.int32)
    seq_len = tf.placeholder(tf.int32, [None])

    # 定义LSTM网络
    cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True)  # LSTM cell中的block数量
    stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
    outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)

    shape = tf.shape(inputs)
    batch_s, max_timesteps = shape[0], shape[1]
    outputs = tf.reshape(outputs, [-1, num_hidden])
    W = tf.Variable(tf.truncated_normal([num_hidden,
                                         num_classes],
                                        stddev=0.1), name="W")
    b = tf.Variable(tf.constant(0., shape=[num_classes]), name="b")
    logits = tf.matmul(outputs, W) + b
    logits = tf.reshape(logits, [batch_s, -1, num_classes])

    # 转置矩阵
    logits = tf.transpose(logits, (1, 0, 2))

    return logits, inputs, targets, seq_len, W, b

def train():
    # with tf.variable_scope("train"):
    # 获取训练样本数据
    file_name_array, text_array = get_file_text_array()

    # 定义学习率
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                               global_step,
                                               DECAY_STEPS,
                                               LEARNING_RATE_DECAY_FACTOR,
                                               staircase=True)
    # 获取网络结构
    logits, inputs, targets, seq_len, W, b = get_train_model()

    # 设置损失函数
    loss = tf.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len)
    cost = tf.reduce_mean(loss)

    # 设置优化器
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step)
    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
    acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), targets))

    init = tf.global_variables_initializer()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session() as session:
        session.run(init)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
        # saver.restore(session, tf.train.latest_checkpoint(model_dir))
        for curr_epoch in range(num_epochs):
            train_cost = 0
            train_ler = 0
            # todo
            for batch in range(BATCHES):
            # for batch in range(FLAGS.max_step):
                # 训练模型
                train_inputs, train_targets, train_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE)
                feed = {inputs: train_inputs, targets: train_targets, seq_len: train_seq_len}
                b_loss, b_targets, b_logits, b_seq_len, b_cost, steps, _ = session.run(
                    [loss, targets, logits, seq_len, cost, global_step, optimizer], feed)

                # 评估模型
                if steps > 0 and steps % REPORT_STEPS == 0:
                    test_inputs, test_targets, test_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE)
                    test_feed = {inputs: test_inputs,targets: test_targets,seq_len: test_seq_len}
                    dd, log_probs, accuracy = session.run([decoded[0], log_prob, acc], test_feed)
                    report_accuracy(dd, test_targets)

                    # 保存识别模型
                    save_path = saver.save(session, model_dir + "lstm_ctc_model.ctpk", global_step=steps)
                    # save_path = saver.save(session, FLAGS.model_dir, global_step=steps)

                c = b_cost
                train_cost += c * BATCH_SIZE

            train_cost /= TRAIN_SIZE
            # 计算 loss
            train_inputs, train_targets, train_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE)
            val_feed = {inputs: train_inputs,targets: train_targets,seq_len: train_seq_len}
            val_cost, val_ler, lr, steps = session.run([cost, acc, learning_rate, global_step], feed_dict=val_feed)

            # log = "{} Epoch {}/{}, steps = {}, train_cost = {:.3f}, val_cost = {:.3f}"
            log = "{} Epoch {}, steps = {}, train_cost = {:.3f}, val_cost = {:.3f}"
            print(log.format(curr_epoch + 1, num_epochs, steps, train_cost, val_cost))

 

预测的代码:

这里之所以把加载模型单独封装成一个函数,是为了提高速度,获取网络结构值只加载一次。因为后面是在摄像头采集图像,一帧一帧进行预测,因此多次加载模型

会很浪费时间,同时在预测的主循环中tensorflow的session也只加载一次,可以大大提高时间。(session很占用资源)

# LSTM+CTC 文字识别能力封装
# 加载模型
def load_model():
    # 获取网络结构
    tf.reset_default_graph()
    logits, inputs, targets, seq_len, W, b = get_train_model()
    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
    saver = tf.train.Saver()
    # sess = tf.Session()
    # # 加载模型
    # saver.restore(sess, tf.train.latest_checkpoint(model_dir))

    return saver, inputs, seq_len, decoded, log_prob

# 输入:图片
# 输出:识别结果文字
def predict(images_path, saver, inputs, seq_len, decoded, log_prob, sess):
    # 加载模型
    # saver.restore(sess, tf.train.latest_checkpoint(model_dir))
    # 图像预处理
    result_dict = {}
    for i, image in images_path.items():
        image = cv2.resize(image, (OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]), 3)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        pred_inputs = np.zeros([1, OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]])
        pred_inputs[0, :] = np.transpose(image.reshape((OUTPUT_SHAPE[0], OUTPUT_SHAPE[1])))
        pred_seq_len = np.ones(1) * OUTPUT_SHAPE[1]
        # 模型预测
        pred_feed = {inputs: pred_inputs, seq_len: pred_seq_len}
        dd, log_probs = sess.run([decoded[0], log_prob], pred_feed)
        # 识别结果转换
        detected_list = decode_sparse_tensor(dd)[0]
        detected_text = ''
        for d in detected_list:
            detected_text = detected_text + d
        result_dict[i+1] = detected_text
    return result_dict  # 返回结果字典

 

基于PyQt5的界面展示:用一个摄像头不断采集图像,将采集到的图像和预测结果传给基于PyQt5搭建的界面,代码如下:

demo.py(这是PyQt5的界面代码)

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'demo.ui'
#
# Created by: PyQt5 UI code generator 5.9.2
#
# WARNING! All changes made in this file will be lost!

from PyQt5 import QtCore, QtGui, QtWidgets

class Ui_mainWindow(object):
    def setupUi(self, mainWindow):
        mainWindow.setObjectName("mainWindow")
        mainWindow.resize(1920, 1080)
        palette1 = QtGui.QPalette()
        palette1.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap('images/bg2.jpg')))
        mainWindow.setPalette(palette1)


        self.centralwidget = QtWidgets.QWidget(mainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.graphicsView = PlotWidget(self.centralwidget)
        self.graphicsView.setGeometry(QtCore.QRect(160, 340, 250, 150))
        brush = QtGui.QBrush(QtGui.QColor(255, 255, 255, 0))
        brush.setStyle(QtCore.Qt.SolidPattern)
        self.graphicsView.setBackgroundBrush(brush)
        self.graphicsView.setObjectName("graphicsView")
        self.pushButton = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton.setGeometry(QtCore.QRect(170, 90, 81, 41))
        self.pushButton.setObjectName("pushButton")
        # palette = QtGui.QPalette()
        # brush = QtGui.QBrush(QtGui.QColor(255, 255, 0))
        # brush.setStyle(QtCore.Qt.SolidPattern)
        # palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.ButtonText, brush)
        # brush = QtGui.QBrush(QtGui.QColor(255, 255, 0))
        # brush.setStyle(QtCore.Qt.SolidPattern)
        # palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.ButtonText, brush)
        # brush = QtGui.QBrush(QtGui.QColor(120, 120, 120))
        # brush.setStyle(QtCore.Qt.SolidPattern)
        # palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.ButtonText, brush)
        # self.pushButton.setPalette(palette)

        self.graphicsView_2 = PlotWidget(self.centralwidget)
        self.graphicsView_2.setGeometry(QtCore.QRect(1460, 340, 250, 150))
        self.graphicsView_2.setObjectName("graphicsView_2")
        self.graphicsView_3 = PlotWidget(self.centralwidget)
        self.graphicsView_3.setGeometry(QtCore.QRect(160, 660, 250, 150))
        self.graphicsView_3.setObjectName("graphicsView_3")
        self.graphicsView_4 = PlotWidget(self.centralwidget)
        self.graphicsView_4.setGeometry(QtCore.QRect(1460, 660, 251, 151))
        self.graphicsView_4.setObjectName("graphicsView_4")
        self.imageLabel = QtWidgets.QLabel(self.centralwidget)
        self.imageLabel.setGeometry(QtCore.QRect(615, 340, 640, 480))
        self.imageLabel.setAutoFillBackground(False)
        self.imageLabel.setFrameShape(QtWidgets.QFrame.Box)
        self.imageLabel.setText("")
        self.imageLabel.setObjectName("imageLabel")
        self.label = QtWidgets.QLabel(self.centralwidget)
        self.label.setGeometry(QtCore.QRect(660, 130, 891, 111))
        font = QtGui.QFont()
        font.setFamily("Agency FB")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.label.setFont(font)
        self.label.setObjectName("label")
        self.layoutWidget = QtWidgets.QWidget(self.centralwidget)
        self.layoutWidget.setGeometry(QtCore.QRect(170, 180, 231, 131))
        self.layoutWidget.setObjectName("layoutWidget")
        self.gridLayout = QtWidgets.QGridLayout(self.layoutWidget)
        self.gridLayout.setContentsMargins(0, 0, 0, 0)
        self.gridLayout.setObjectName("gridLayout")
        self.label_7 = QtWidgets.QLabel(self.layoutWidget)
        self.label_7.setText("")
        self.label_7.setObjectName("label_7")
        self.gridLayout.addWidget(self.label_7, 2, 1, 1, 1)
        self.label_2 = QtWidgets.QLabel(self.layoutWidget)
        self.label_2.setObjectName("label_2")
        self.gridLayout.addWidget(self.label_2, 0, 0, 1, 1)
        self.label_5 = QtWidgets.QLabel(self.layoutWidget)
        self.label_5.setText("")
        self.label_5.setObjectName("label_5")
        self.gridLayout.addWidget(self.label_5, 0, 1, 1, 1)
        self.label_4 = QtWidgets.QLabel(self.layoutWidget)
        self.label_4.setObjectName("label_4")
        self.gridLayout.addWidget(self.label_4, 2, 0, 1, 1)
        self.label_6 = QtWidgets.QLabel(self.layoutWidget)
        self.label_6.setText("")
        self.label_6.setObjectName("label_6")
        self.gridLayout.addWidget(self.label_6, 1, 1, 1, 1)
        self.label_3 = QtWidgets.QLabel(self.layoutWidget)
        self.label_3.setObjectName("label_3")
        self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1)
        mainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(mainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 1920, 30))
        self.menubar.setObjectName("menubar")
        mainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(mainWindow)
        self.statusbar.setObjectName("statusbar")
        mainWindow.setStatusBar(self.statusbar)

        self.retranslateUi(mainWindow)
        QtCore.QMetaObject.connectSlotsByName(mainWindow)

        self.label_2.setVisible(False)
        self.label_3.setVisible(False)
        self.label_4.setVisible(False)
        self.label_5.setVisible(False)
        self.label_6.setVisible(False)
        self.label_7.setVisible(False)

    def retranslateUi(self, mainWindow):
        _translate = QtCore.QCoreApplication.translate
        mainWindow.setWindowTitle(_translate("mainWindow", "数字仪表智能采集演示系统"))
        self.pushButton.setText(_translate("mainWindow", "开始采集"))
        self.label.setText(_translate("mainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
"p, li { white-space: pre-wrap; }\n"
"</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
"<p style=\" margin-top:12px; margin-bottom:12px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:36pt; font-weight:600; color:#FFFF00;\">数字仪表智能采集演示系统</span></p></body></html>"))
        self.label_2.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">1号灯</span></p></body></html>"))
        self.label_4.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">3号灯</span></p></body></html>"))
        self.label_3.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">2号灯</span></p></body></html>"))

from pyqtgraph import PlotWidget

 

最终的运行代码:

from PyQt5.QtWidgets import QApplication, QMainWindow
from PyQt5.QtCore import pyqtSignal, QThread
from PyQt5.QtGui import QImage, QPixmap
import sys
import array
import cv2
from demo import Ui_mainWindow
import numpy as np
from imutils.video import WebcamVideoStream
import readvc_box_03 as read_box
from PIL import Image, ImageDraw, ImageFont
import tensorflow as tf
from main import load_model, predict
# 创建一个图标
from PyQt5.QtGui import QIcon
import random
import time

model_dir = './tmp/train_data_model_13/'

class MeterData(object):
    def __init__(self):
        self.saved_points_num = 40
        self.meter_num = 4
        self.meters = [array.array('d') for i in range(self.meter_num)]
        self.cur_numbers = [0. for i in range(self.meter_num)]

    def add_meters_number(self, numbers):
        if len(numbers) == self.meter_num:
            self.cur_numbers = numbers
            for i in range(self.meter_num):
                self.add_meter_number(self.cur_numbers[i], i)

    def add_meter_number(self, number, id):
        if len(self.meters[id]) < self.saved_points_num:
            self.meters[id].append(number)
        else:
            self.meters[id][:-1] = self.meters[id][1:]
            self.meters[id][-1] = number

class ImageData(object):
    def __init__(self):
        self.pixmap = None

    def add_image(self, img):
        rgb_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        rgb_img = cv2.resize(rgb_img, (640, 480), interpolation=cv2.INTER_CUBIC)
        q_img = self.get_qimage(rgb_img)
        self.pixmap = QPixmap.fromImage(q_img)

    def get_qimage(self, image):
        height, width, colors = image.shape
        bytesPerLine = 3 * width
        image = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888)
        image = image.rgbSwapped()
        return image

class LightData(object):
    def __init__(self):
        self.one = ''
        self.two = ''
        self.three = ''

    def set(self, red, orange, green):
        '''
        传入三组颜色字典:
        {'r': [(133, 138, 28, 26)]}
        {'o': [(183, 139, 29, 28)]}
        {'g': [(235, 140, 27, 29)]}
        '''
        self.one = self.get(red, orange, green)[0]
        self.two = self.get(red, orange, green)[1]
        self.three = self.get(red, orange, green)[2]

    def get(self, red, orange, green):
        a = sorted([red, orange, green], key=lambda x: list(x.values())[0][0][0])
        color_list = []
        for item in a:
            color_list.append(list(item.keys())[0])
        return color_list

meterData = MeterData()
imageData = ImageData()
lightData = LightData()

class workThread(QThread):
    trigger = pyqtSignal()
    def __init__(self):
        super(workThread, self).__init__()
        
    def run(self):
        saver, inputs, seq_len, decoded, log_prob = load_model()
        sess = tf.Session()
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
        print("[INFO] camera sensor warming up...")
        vs = WebcamVideoStream(src=0).start()
        print('camera ok')
        while True:
            cap = cv2.VideoCapture(0)

            cap.set(cv2.CAP_PROP_FOCUS, 160)
            cap.set(cv2.CAP_PROP_EXPOSURE , -8)
            cap.set(cv2.CAP_PROP_BACKLIGHT , 0)
            #########读取摄像头并且得到处理结果
            frame = vs.read()
            # todo read_box在这里展开写
            # 读取视频的每一帧, 返回其梯度
            gradient = read_box.img_preprocessing(frame)
            # print(gradient)
            # 得到每四张小图片的坐标位置
            list_box = read_box.get_box(gradient, num=4)
            # 根据坐标位置画出矩形框
            for i, bbx in enumerate(list_box):
                x, y, w, h = bbx
                draw_img = cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 1)
            # 根据矩形框,将四张小图片从原图裁剪下来,得到图片字典
            # lb是排序后的坐标位置

            try:
                dict_crap, lb = read_box.create_crapimg(frame, list_box)
            except:
                import traceback
                traceback.print_exc()
            if not lb:
                continue
            
            # 进行预测,返回预测数字
            # sess作为变量,只初始化一次
            t1 = time.time()
            predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob, sess)
            t2 = time.time()
            print(t2 - t1)
            # todo 预测数字出现非四位的,将错误图片保存到本地
            # if len(predict_img_dict[1]) != 4:
            #     image = dict_crap[1 - 1]
            #     predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob)
            #     name = predict_img_dict[1]
            #     cv2.imwrite('./error_imgs/0/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image)
            # if len(predict_img_dict[2]) != 4:
            #     image = dict_crap[2 - 1]
            #     predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob)
            #     name = predict_img_dict[2]
            #     cv2.imwrite('./error_imgs/1/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image)
            # if len(predict_img_dict[3]) != 4:
            #     image = dict_crap[3 - 1]
            #     predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob)
            #     name = predict_img_dict[3]
            #     cv2.imwrite('./error_imgs/2/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image)
            # if len(predict_img_dict[4]) != 4:
            #     image = dict_crap[4 - 1]
            #     predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob)
            #     name = predict_img_dict[4]
            #     cv2.imwrite('./error_imgs/3/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image)
            if predict_img_dict is not None:
                num_list = [int(num) if num else int(8888) for num in predict_img_dict.values()]
                # print(num_list)
                if len(num_list) >= 3:
                    num_list[1], num_list[2] = num_list[2], num_list[1]
                    if len(str(num_list[0])) != 4 or len(str(num_list[1])) != 4 or len(str(num_list[2])) != 4 or len(
                            str(num_list[3])) != 4:
                        continue

                    for idx, box in enumerate(lb):
                        # 将预测数字显示在图片上
                        showimg = cv2.putText(draw_img, str(predict_img_dict[idx + 1]), (int((box[0])), int(box[1] - 10)),
                                              cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA)
            else:
                pass

            ##########
            global meterData, imageData, lightData
            # todo 红绿灯三色
            # get_color(frame)
            '''
            g [(237, 85, 31, 31)]
            r [(291, 85, 31, 31)]
            o [(345, 86, 32, 31)]
            '''
            # 画框,并得到坐标和颜色的三组函数
            list_box_red = get_color_red(frame)
            list_box_orange = get_color_orange(frame)
            list_box_green = get_color_green(frame)
            if list_box_red["红色"] != [] and list_box_orange["黄色"] != [] and list_box_green["绿色"] != []:
                # 不为空,再进行颜色判断
                lightData.set(list_box_red, list_box_orange, list_box_green)

            # todo 写中文
            # 加一个总的判断,如果没有读入识别的数字和颜色,则显示摄像头拍到的任何内容
            if draw_img is not None and frame is not None and gradient is not None \
                    and list_box is not None and showimg is not None:
                if list_box_red["红色"] != [] and list_box_orange["黄色"] != [] and list_box_green["绿色"] != []:
                    red_x, red_y = list_box_red["红色"][0][0], list_box_red["红色"][0][1]
                    orange_x, orange_y = list_box_orange["黄色"][0][0], list_box_orange["黄色"][0][1]
                    green_x, green_y = list_box_green["绿色"][0][0], list_box_green["绿色"][0][1]

                    img_PIL = Image.fromarray(cv2.cvtColor(showimg, cv2.COLOR_BGR2RGB))  # 图像从OpenCV格式转换成PIL格式
                    font = ImageFont.truetype('font/simsun.ttc', 20)  # 40为字体大小,根据需要调整
                    draw = ImageDraw.Draw(img_PIL)
                    draw.text((red_x, red_y-20), '', font=font, fill=(255, 0, 0))
                    draw.text((orange_x, orange_y-20), '', font=font, fill=(255, 255, 0))
                    draw.text((green_x, green_y-20), '绿', font=font, fill=(0, 255, 0))
                    frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)  # 转换回OpenCV格式
                # 如果三色灯识别失败
                else:
                    frame = showimg
            else:
                frame = frame
            #
            values = num_list
            image = frame
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            meterData.add_meters_number(values)
            imageData.add_image(image)
            # todo 三色灯
            ###
            # lightData.set(list_box_red, list_box_orange, list_box_green)
            self.trigger.emit()
        sess.close()

class MeterMainWindow(QMainWindow, Ui_mainWindow):
    updateSignal = pyqtSignal()
    def __init__(self, parent=None):
        super(MeterMainWindow, self).__init__(parent)
        self.setupUi(self)
        self.initUi()

    def initUi(self):
        self.curves = [self.graphicsView.plot(pen='y'), self.graphicsView_2.plot(pen='y'),
                       self.graphicsView_3.plot(pen='y'), self.graphicsView_4.plot(pen='y')]
        # self.numbers = [self.lcdNumber, self.lcdNumber_2, self.lcdNumber_3, self.lcdNumber_4]
        self.workThread = workThread()
        self.pushButton.clicked.connect(self.start)

    def start(self):
        self.workThread.start()
        self.workThread.trigger.connect(self.update)

    def update(self):
        # todo 更新检测结果
        global meterData, imageData, lightData
        self.imageLabel.setPixmap(imageData.pixmap)
        # todo
        self.label_5.setText(lightData.one)
        self.label_6.setText(lightData.two)
        self.label_7.setText(lightData.three)
        for i, curve in enumerate(self.curves):
            curve.setData(meterData.meters[i])

def get_color_red(img):
    # 1.将RGB装换成hsv色彩空间
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # # 通道数是 3
    # 2.定义数组,说明你要提取(过滤)的颜色目标, 红色,橙色,绿色
    # lower_hsv = np.array([0, 69, 194])
    # upper_hsv = np.array([24, 141, 255])
    lower_hsv = np.array([0, 84, 250])
    upper_hsv = np.array([25, 111, 255])
    # 3.进行过滤,提取,得到二值图像
    mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv)  # 通道数是 1
    # 高斯滤波,去噪声
    gaussian = cv2.GaussianBlur(mask_, (3, 3), 1)
    # 腐蚀
    kernel = np.ones((3, 3), np.uint8)
    erosion = cv2.erode(gaussian, kernel, iterations=1)
    # 膨胀
    dige_dilate = cv2.dilate(erosion, kernel, iterations=4)
    # (形态学)梯度运算 = 膨胀运算 - 腐蚀运算
    gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel)

    # 对前面得到的梯度,find其轮廓
    contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    # 创建一个列表,用来保存矩形框坐标值
    list_box = []
    # 取前num个面积最大的
    cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1]
    for c in cnt:
        # 得到坐标
        x, y, w, h = cv2.boundingRect(c)
        # 存起来
        list_box.append((x, y, w, h))
        # 将矩形画出来
        draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1)
    list_box = list(set(list_box))
    return {"红色": list_box}

def get_color_orange(img):
    # 1.将RGB装换成hsv色彩空间
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # # 通道数是 3
    # 2.定义数组,说明你要提取(过滤)的颜色目标, 红色,橙色,绿色
    lower_hsv = np.array([15, 127, 147])
    upper_hsv = np.array([30, 180, 255])
    # 3.进行过滤,提取,得到二值图像
    mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv)  # 通道数是 1
    # 高斯滤波,去噪声
    gaussian = cv2.GaussianBlur(mask_, (3, 3), 1)
    # 腐蚀
    kernel = np.ones((3, 3), np.uint8)
    erosion = cv2.erode(gaussian, kernel, iterations=1)
    # 膨胀
    dige_dilate = cv2.dilate(erosion, kernel, iterations=4)
    # (形态学)梯度运算 = 膨胀运算 - 腐蚀运算
    gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel)

    # 对前面得到的梯度,find其轮廓
    contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    # 创建一个列表,用来保存矩形框坐标值
    list_box = []
    # 取前num个面积最大的
    cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1]
    for c in cnt:
        # 得到坐标
        x, y, w, h = cv2.boundingRect(c)
        # 存起来
        list_box.append((x, y, w, h))
        # 将矩形画出来
        draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1)
    list_box = list(set(list_box))
    return {"黄色": list_box}

def get_color_green(img):
    # 1.将RGB装换成hsv色彩空间
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)  # # 通道数是 3
    # 2.定义数组,说明你要提取(过滤)的颜色目标, 红色,橙色,绿色
    lower_hsv = np.array([72, 145, 149])
    upper_hsv = np.array([89, 255, 255])
    # 3.进行过滤,提取,得到二值图像
    mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv)  # 通道数是 1
    # 高斯滤波,去噪声
    gaussian = cv2.GaussianBlur(mask_, (3, 3), 1)
    # 腐蚀
    kernel = np.ones((3, 3), np.uint8)
    erosion = cv2.erode(gaussian, kernel, iterations=1)
    # 膨胀
    dige_dilate = cv2.dilate(erosion, kernel, iterations=4)
    # (形态学)梯度运算 = 膨胀运算 - 腐蚀运算
    gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel)

    # 对前面得到的梯度,find其轮廓
    contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    # 创建一个列表,用来保存矩形框坐标值
    list_box = []
    # 取前num个面积最大的
    cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1]
    for c in cnt:
        # 得到坐标
        x, y, w, h = cv2.boundingRect(c)
        # 存起来
        list_box.append((x, y, w, h))
        # 将矩形画出来
        draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1)
    list_box = list(set(list_box))
    return {"绿色": list_box}

def main():
    app = QApplication(sys.argv)
    # 给窗口设置一个图标
    app.setWindowIcon(QIcon('./images/mainimg.ico'))
    # 创建该类
    main_Window = MeterMainWindow()
    main_Window.show()
    sys.exit(app.exec_())

if __name__ == '__main__':
    main()

注意:这里的def get_color_red(img): 写了相似的三个函数,思想和提取红色的数字是一样的,做的工作就是对三色灯进行分类。

 

 

后续的优化思路:可以在lstm前加cnn,进行图像的特征提取,将提取后的特征序列化后传入lstm。

 

 

posted @ 2019-11-18 11:49  胖白白  阅读(2763)  评论(0编辑  收藏  举报