基于图像处理和tensorflow实现GTA5的车辆自动驾驶——第九节获取图像数据
代码已放到码云
作者主要说了一下几点
- 截止到上一节的伪AI,图像处理部分应该是结束了。作者不想深入做图像处理方面,从本节开始他准备使用神经网络训练
- 优化了代码
强烈建议从本节开始新建文件创建py文件,为了和以前的部分分开
制作训练集
- 利用pywin32获取窗口的数据集
- 利用getkeys.py文件获取用户输入的按键:W(前进),A(左转),B(右转),S(后退)
利用pywin32获取窗口的数据集
注:这个是有人在GitHub上传了代码,以前用的PIL库获取图像,但是帧率太低了,作者试了下这个新的代码,发现帧率比PIL库的高,于是采用这个代码了
实现代码见Gitee项目
使用方法,引用该python文件,直接调用方法就行
from grabscreen import grab_screen
screen = np.array(grab_screen(region=(0, 40, 800, 640)))
获取的图像转换为灰度图像(因为灰度图像相比彩色图像数据量少,计算代价较少)
screen = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
获取的图像改变大小,原来为800*600 转换为 80, 60。 这样图像的计算量也变少
screen = cv2.resize(screen, (80, 60))
至此利用pywin32获取窗口的数据集
步骤完成
利用getkeys.py文件获取用户输入的按键:W(前进),A(左转),B(右转),S(后退)
使用方法,引用该python文件,直接调用方法就行
from getkeys import key_check
key = key_check()
# key 为列表,包含用户press的按键(大写)
本次只考虑 前进,左转右转
# [A, W ,D]
outPutDatamat = [0, 0, 0]
if 'A' in key:
outPutDatamat[0] = 1
elif 'W' in key:
outPutDatamat[1] = 1
elif 'D' in key:
outPutDatamat[2] = 1
# 如果按下了`A`键,那么`outPutDatamat=[1, 0, 0]`
把训练集结合起来并保存
注:
- 倒计时完成后会对窗口进行记录
- 当得到2000帧后结束循环,并把200帧的训练集
test.npy
保存下来 - 尽量在有左转右转的地方进行跑,这样左转右转的数据会多些
i = 3
while i != 0:
print("time:", i)
time.sleep(0.5)
i -= 1
training_data = []
last_time = time.time()
paused = False
print('STARTING!!!')
while True:
if not paused:
screen = np.array(grab_screen(region=(0, 40, 800, 640)))
last_time = time.time()
# run a color convert:
screen = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
# resize to something a bit more acceptable for a CNN
screen = cv2.resize(screen, (80, 60))
# get key datamat
output = get_key()
training_data.append([screen, output])
print(len(training_data))
# cv2.imshow('window', cv2.cvtColor(screen, cv2.COLOR_BGR2RGB))
if len(training_data) % 2000== 0:
np.save('test.npy', np.array(training_data), allow_pickle=True)
break
if cv2.waitKey(25) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break
读取训练集,查看是否可以运行
这段代码比较简单,就不解释了
import time
from grabscreen import grab_screen
import cv2
from getkeys import key_check
import numpy as np
train_data = np.load('test.npy', allow_pickle=True)
for data in train_data:
img = data[0]
choice = data[1]
cv2.imshow('test',img)
print(choice)
if cv2.waitKey(25) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break