读取标签比数据多的数据,跳过没有数据的项
应该有比我方法好的思路,希望可以讨论。
(更新)我觉得那个while True可以去掉,因为__getitem__方法直接就可以循环了,还没有尝试,但理论我觉得没有问题。
class VideoDataset(Dataset): """Read data from the original dataset for feature extraction""" def __init__(self, videos_dir, video_names, score, video_format='RGB', width=None, height=None): super(VideoDataset, self).__init__() self.videos_dir = videos_dir self.video_names = video_names self.score = score self.format = video_format self.width = width self.height = height def __len__(self): return len(self.video_names) def __getitem__(self, idx): while True: try: video_name = self.video_names[idx] assert self.format == 'YUV420' or self.format == 'RGB' if self.format == 'YUV420': video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name), self.height, self.width, inputdict={'-pix_fmt':'yuvj420p'}) else: video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name)) video_score = self.score[idx] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) video_length = video_data.shape[0] video_channel = video_data.shape[3] video_height = video_data.shape[1] video_width = video_data.shape[2] transformed_video = torch.zeros([video_length, video_channel, video_height, video_width]) for frame_idx in range(video_length): frame = video_data[frame_idx] frame = Image.fromarray(frame) frame = transform(frame) transformed_video[frame_idx] = frame sample = {'video': transformed_video, 'score': video_score} return sample except: print(video_name + "视频不存在,跳过!") idx=idx+1 # torch.cuda.empty_cache() # 只有执行完上面这句,显存才会在Nvidia-smi中释放 if idx >= (len(video_names)-1): break else: continue # torch.cuda.empty_cache()
随心随我