w_size = 512 h_size = 256 sub_input_label = torch.zeros([input_label.shape[0],input_label.shape[1],h_size,w_size], dtype=torch.float32,device=input_label.device)
w_size表示横着的边
h_size表示竖着的边