Dataset和DataLoader
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
data_x = torch.tensor([
[1.0, 2.0],
[1.5, 2.5],
[2.0, 3.0],
[2.5, 2.0],
[3.0, 2.5],
[1.0, 2.5],
[2.0, 1.5],
[3.0, 3.0],
[1.0, 1.0],
[2.0, 2.0]
], dtype=torch.float32)
data_y = torch.tensor([
[3.0],
[3.5],
[4.0],
[3.0],
[3.5],
[2.0],
[2.0],
[3.0],
[1.0],
[2.0]
], dtype=torch.float32)
class MyDataset(Dataset):
'''
'''
def __init__(self, x, y=None):
if y is None:
self.y = y
else:
self.y = torch.FloatTensor(y)
self.x = torch.FloatTensor(x)
def __getitem__(self, index):
if self.y is None:
return self.x[index]
else:
return self.x[index], self.y[index]
def __len__(self):
return len(self.x)
train_dataset = MyDataset(data_x,data_y) # 实例化
print(train_dataset)
print('train_dataset.x:', train_dataset.x, '\n',
'train_dataset.y:', train_dataset.y)
print(train_dataset.__getitem__(0))
print(train_dataset.__len__())
'''
<__main__.MyDataset object at 0x00000221C6B92110>
train_dataset.x: tensor([[1.0000, 2.0000],
[1.5000, 2.5000],
[2.0000, 3.0000],
[2.5000, 2.0000],
[3.0000, 2.5000],
[1.0000, 2.5000],
[2.0000, 1.5000],
[3.0000, 3.0000],
[1.0000, 1.0000],
[2.0000, 2.0000]])
train_dataset.y: tensor([[3.0000],
[3.5000],
[4.0000],
[3.0000],
[3.5000],
[2.0000],
[2.0000],
[3.0000],
[1.0000],
[2.0000]])
(tensor([1., 2.]), tensor([3.]))
10
'''
# pin_memory=True 可以在数据加载过程中提升性能
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, pin_memory=True)
for batch in train_loader:
inputs, labels = batch # 解包数据和标签
print(inputs)
print(labels)
'''
tensor([[3.0000, 2.5000],
[1.5000, 2.5000]])
tensor([[3.5000],
[3.5000]])
tensor([[2.5000, 2.0000],
[3.0000, 3.0000]])
tensor([[3.],
[3.]])
tensor([[2.0000, 3.0000],
[2.0000, 1.5000]])
tensor([[4.],
[2.]])
tensor([[2., 2.],
[1., 2.]])
tensor([[2.],
[3.]])
tensor([[1.0000, 2.5000],
[1.0000, 1.0000]])
tensor([[2.],
[1.]])
'''
# 可以看到,10个数据,batch_size=2,分为了5个batch,shuffle=True已经打乱了
在train函数中使用时:
for epoch in range(num_epochs):
for batch in train_loader:
# 清零梯度
optimizer.zero_grad()
inputs, labels = batch # 解包数据和标签
inputs.to(device)
labels.to(device)
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')