HW2

一个用于节省内存的技巧:(构建Dataset后可将原始数据删除!训练和验证完成后可将训练集和验证集的DataLoader删除!)

import gc
from torch.utils.data import DataLoader, Dataset, random_split

# preprocess data
train_X, train_y = pass
val_X, val_y = pass

# get dataset
train_set = MyDataset(train_X, train_y)
val_set = MyDataset(val_X, val_y)

# remove raw feature to save memory
del train_X, train_y, val_X, val_y
gc.collect()

# get dataloader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

"""此处为训练和验证过程..."""

# after training, before testing
del train_loader, val_loader
gc.collect()

 

posted @ 2023-08-05 16:09  Peg_Wu  阅读(3)  评论(0编辑  收藏  举报