读取标签比数据多的数据,跳过没有数据的项

应该有比我方法好的思路,希望可以讨论。

(更新)我觉得那个while True可以去掉,因为__getitem__方法直接就可以循环了,还没有尝试,但理论我觉得没有问题。

class VideoDataset(Dataset):
    """Read data from the original dataset for feature extraction"""
    def __init__(self, videos_dir, video_names, score, video_format='RGB', width=None, height=None):

        super(VideoDataset, self).__init__()
        self.videos_dir = videos_dir
        self.video_names = video_names
        self.score = score
        self.format = video_format
        self.width = width
        self.height = height

    def __len__(self):
        return len(self.video_names)

    def __getitem__(self, idx):
        while True:
            try:
                video_name = self.video_names[idx]
                assert self.format == 'YUV420' or self.format == 'RGB'
                if self.format == 'YUV420':
                    video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name), self.height, self.width, inputdict={'-pix_fmt':'yuvj420p'})
                else:
                    video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name))
                video_score = self.score[idx]

                transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

                video_length = video_data.shape[0]
                video_channel = video_data.shape[3]
                video_height = video_data.shape[1]
                video_width = video_data.shape[2]
                transformed_video = torch.zeros([video_length, video_channel, video_height, video_width])
                for frame_idx in range(video_length):
                    frame = video_data[frame_idx]
                    frame = Image.fromarray(frame)
                    frame = transform(frame)
                    transformed_video[frame_idx] = frame

                sample = {'video': transformed_video,
                          'score': video_score}

                return sample
            except:
                print(video_name + "视频不存在,跳过!")
                idx=idx+1
                # torch.cuda.empty_cache()
                # 只有执行完上面这句,显存才会在Nvidia-smi中释放
            if idx >= (len(video_names)-1):
                break
            else:
                continue
                # torch.cuda.empty_cache()

 

posted @ 2021-07-07 10:49  小筱痕  阅读(163)  评论(0编辑  收藏  举报