TensorDataset
导入相关包
from torch.utils.data import TensorDataset
特征与标签合并
HRdataset = TensorDataset(X, Y)
模型训练
for epoch in range(epochs):
for i in range(num_batch):
x, y = HRdataset[i * batch_size: i * batch_size + batch_size]
y_pred = model(x)
loss = loss_func(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
print('epoch: ', epoch, 'loss: ', loss_func(model(X), Y).data.item())
DataLoader
导入相关包
from torch.utils.data import DataLoader
加载数据
HR_ds = TensorDataset(X, Y)
HR_dl = DataLoader(HR_ds, batch_size = batch_size, shuffle = True)
模型训练
for epoch in range(epochs):
for x, y in HR_dl:
y_pred = model(x)
loss = loss_func(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
print('epoch: ', epoch, 'loss: ', loss_func(model(X), Y).data.item())
划分数据集
导入相关包
from sklearn.model_selection import train_test_split
划分数据集
train_x, test_x, train_y, test_y = train_test_split(X_data, Y_data)
包装数据
train_x = torch.from_numpy(train_x).type(torch.float32)
test_x = torch.from_numpy(test_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.float32)
train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size = batch_size)
定义准确率
def accuracy(y_pred, y_true):
return ((y_pred.data.numpy() > 0.5).astype('int') == y_true.numpy()).mean()
模型训练
for epoch in range(epochs):
for x, y in train_dl:
y_pred = model(x)
loss = loss_func(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
epoch_accuracy = accuracy(model(train_x), train_y)
epoch_loss = loss_func(model(train_x), train_y).data
epoch_test_accuracy = accuracy(model(test_x), test_y)
epoch_test_loss = loss_func(model(test_x), test_y).data
print('epoch: ', epoch, 'loss: ', round(epoch_loss.item(), 3), 'accuracy: ', round(epoch_accuracy.item(), 3),
'test_loss: ', round(epoch_test_loss.item(), 3), 'test_accuracy: ', round(epoch_test_accuracy.item(), 3))