U2-Net 预测函数
包含单个图片检测以及视频检测
import os
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import subprocess
from torchvision.transforms import transforms
from src import u2net_full
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
# 获取GPU相关信息
def get_gpu_info():
try:
cmd_out = subprocess.check_output('nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader',
shell=True)
gpu_info = cmd_out.decode().strip().split('\n')
gpu_info = [info.split(', ') for info in gpu_info]
return gpu_info
except subprocess.CalledProcessError as e:
print("Error while invoking nvidia-smi: ", e)
return None
# 打印 GPU 型号及占用情况
def print_gpu_usage():
gpu_info = get_gpu_info()
if gpu_info:
total_memory = 0
used_memory = 0
for name, used, total in gpu_info:
used_memory += int(used.strip().split()[0])
total_memory += int(total.strip().split()[0])
memory_usage_percent = round(used_memory / total_memory * 100, 2)
print(f"GPU: {name.strip()}, Memory used: {used.strip()}, Memory total: {total.strip()}"
f", Memory usage: {memory_usage_percent}%")
# 将原图像与分割后的图像混合
def Image_Blend(src, res):
info = res.shape
height = info[0]
width = info[1]
dst = np.zeros((height, width, 3), np.uint8)
# 分割后的图换色
mask = ~(res == [0, 0, 0]).all(axis=2)
res[mask] = [0, 0, 255]
dst = res
# 2.图像混合
img = cv2.addWeighted(src, 0.8, dst, 0.2, 0, dtype=cv2.CV_8UC3)
return img
# 读取 Gpu 信息
def gpu_info() -> str:
info = ''
for id in range(torch.cuda.device_count()):
p = torch.cuda.get_device_properties(id)
info += f'CUDA:{id} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n'
return info[:-1]
# 单单检测图片
def pic_predict(threshold, device, data_transform, origin_img, model):
h, w = origin_img.shape[:2]
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device) # [C, H, W] -> [1, C, H, W]
with torch.no_grad():
# init model
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
# 推理
pred = model(img)
pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]
img_res = Image_Blend(origin_img,seg_img)
cv2.imwrite("result/pred_result11.png", cv2.cvtColor(img_res.astype(np.uint8), cv2.COLOR_RGB2BGR))
# 视频检测
def video_pre(threshold, device, data_transform, origin_img, model):
h, w = origin_img.shape[:2]
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device) # [C, H, W] -> [1, C, H, W]
with torch.no_grad():
# init model
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
# 推理
pred = model(img)
# 打印GPU占用信息
print_gpu_usage()
pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]
img_res = Image_Blend(origin_img, seg_img)
return img_res
def main():
weights_path = "model_best.pth"
img_path = "test/video.mp4"
threshold = 0.5
# 判断图片路径是否正确
assert os.path.exists(img_path), f"image file {img_path} dose not exists."
# 判断 Gpu 是否可用
if torch.cuda.is_available():
print(gpu_info())
# 设置硬件 根据Gpu 是否可用,来选择用GPU 还是 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(320),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
# 载入模型
model = u2net_full()
weights = torch.load(weights_path, map_location='cpu')
if "model" in weights:
model.load_state_dict(weights["model"])
else:
model.load_state_dict(weights)
model.to(device)
model.eval()
str = os.path.splitext(img_path)[-1]
if str == ".mp4":
print("视频读入")
# 获取视频部分
cap = cv2.VideoCapture(img_path)
i = 0
# 2、获取图像的属性(宽和高),并将其转化为整数
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
# 3、创建保存视频的对象,设置编码格式、帧率、图像的宽高等
out = cv2.VideoWriter('result/OutPut2.avi', cv2.VideoWriter_fourcc(*'FFV1'), 30,
(frame_width, frame_height))
while (cap.isOpened()):
# 4、获取每一帧图像
ret, frame = cap.read()
img = frame
i += 1
start_time = time.time() # 开始处理一帧图片的时间
img_res = video_pre(threshold, device, data_transform, img, model)
# 5、将每一帧图像写入到输出文件中
if ret == True:
out.write(img_res)
else:
break
end_time = time.time()
cost_time = end_time - start_time
print("检测第 {} 帧花了 {:.8f}s 。".format(i, cost_time))
cap.release()
out.release()
cv2.destroyAllWindows()
elif str == ".jpg":
print("图片读入")
start_time = time.time() # 开始处理一帧图片的时间
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
pic_predict(threshold, device, data_transform, origin_img, model)
end_time = time.time()
cost_time = end_time - start_time
print("检测一张图片花了 {:.8f}s 。".format(cost_time))
else:
print("请重新读入图片或者视频")
if __name__ == '__main__':
main()