Torch - Dataset 和 DataLoader

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F


class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=",", dtype=np.float32)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        self.len = self.x_data.shape[0]

    def __getitem__(self, index):
        data = self.x_data[index]
        label = self.y_data[index]
        return data, label

    def __len__(self):
        return self.len


dataset = DiabetesDataset("filepath")

train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2, )


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.sigmoid(self.linear3(x))
        return x


model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCELoss()


def train():
    loss = 0
    for epoch in range(100):
        for i, (x, y) in enumerate(train_loader):
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("Epoch:", epoch, "loss=", loss.item())


if __name__ == '__main__':
    train()
from torch.utils.data import Dataset
import os
import cv2 as cv

class MyData(Dataset):  # 继承Dataset

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir   # 定义根目录
        self.label_dir = label_dir  # 定义标签目录
        self.path = os.path.join(self.root_dir, self.label_dir)  # 定义路径
        self.img_path = os.listdir(self.path)   # 定义图片路径

    def __getitem__(self, index):
        img_name = self.img_path[index]  # 获取每一张图片的名称
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 将存放图片的路径进行拼接
        img = cv.imread(img_item_path)   # 读取路径中的图片
        label = self.label_dir   # 图片的标签
        return img, label

    def __len__(self):
        return len(self.img_path)  # 获取有多少张图片


root_dir = "dataset/val"  # 根目录
ants_label_dir = "ants"   # 存放蚂蚁的目录
bees_label_dir = "bees"   # 存放蜜蜂的目录
ants_dataset = MyData(root_dir, ants_label_dir)  # 实例化 ants 
bees_label_dir = MyData(root_dir, bees_label_dir)  # 实例化 bees 

print("ants_dataset:", len(ants_dataset))
print("bees_label_dir:", len(bees_label_dir))

train_dataset = ants_dataset + bees_label_dir  # 将两个数据集存放到一起
print("train_dataset:", len(train_dataset))

img1, label1 = train_dataset[69]  # 查看第70张图片和标签信息

img2, label2 = train_dataset[70]  # 查看第71张图片和标签信息

#  图像显示
cv.imshow("image1", img1)
cv.imshow("image2", img2)
cv.waitKey(0)
cv.destroyAllWindows()
posted @ 2023-04-27 00:31  X1OO  阅读(38)  评论(0)    收藏  举报