SubsetRandomSampler 是什么?
SubsetRandomSampler
是 PyTorch 中的一个采样器(sampler)。在训练神经网络时,通常我们会从整个数据集中随机采样一个小批次的数据进行训练。SubsetRandomSampler
允许你从数据集的给定索引子集中采样数据。
具体来说,你可以为 SubsetRandomSampler
提供一个索引列表,它将在训练过程中按照这个列表的顺序或随机顺序采样数据。
在你的代码中,通过创建 SubsetRandomSampler
对象,你可以使用它来创建训练集和验证集的数据加载器,确保它们分别按照给定的索引列表进行采样。这样可以灵活地控制训练集和验证集的划分,而不必改变原始数据集的顺序。
例子:
from torch.utils.data.sampler import SubsetRandomSampler
# 假设 indices 是包含数据集索引的列表
train_sampler = SubsetRandomSampler(indices_train)
valid_sampler = SubsetRandomSampler(indices_valid)
# 使用 sampler 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=64, sampler=train_sampler)
valid_loader = DataLoader(dataset, batch_size=64, sampler=valid_sampler)
这样,train_loader
和 valid_loader
将按照 indices_train
和 indices_valid
的顺序或随机顺序采样数据。