Fork me on GitHub

一维卷积对一维数据进行特征再提取

点击查看代码
# 第一步  读取csv文件(循环读取)
# 第二步  将数据转化为tensor形式
# 第三步  创建一个列表 将tensor逐个放入列表
# 第四步  写入标签
import csv
import numpy as np
import torch
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import os
import pandas as pd
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
# 训练集
# 目录
path_train = r"D:\BaiduNetdiskDownload\pear\浅层特征\csv_hebing82\train" 
train_files= os.listdir(path_train)
train_feature_list=[]

for i in train_files:
    csv_file = csv.reader(open(path_train+ '/'+i,'r'))
    a = []
    for b in csv_file:
        a.append(b[0])
    c = []
    for d in a:
        e = [float(d)]
        c.append(e)
    train_feature_list.append(c)
#     c = torch.tensor(c)

    
total_train = 1054
train_label = np.zeros(total_train)
train_label[0:135] = 0     
train_label[135:288] = 1    
train_label[288:368] = 2   
train_label[368:593] = 3    
train_label[593:867] = 4     
train_label[867:1054] = 5   

train_label_tensor = torch.tensor(train_label,dtype=torch.long)
# train_label_tensor = train_label.tolist()
# train_label_tensor = [train_label_tensor]
# train_label_tensor = torch.tensor(train_label_tensor,dtype=torch.long)
# train_label_tensor=train_label_tensor.reshape(1054,1)

train_feature_list=torch.tensor(train_feature_list)
train_tensor = train_feature_list
# 测试集
path_val = r"D:\BaiduNetdiskDownload\pear\浅层特征\csv_hebing82\val" 
val_files= os.listdir(path_val)
val_feature_list=[]
for i in val_files:
    csv_file = csv.reader(open(path_val+ '/'+i,'r'))
    a = []
    for b in csv_file:
        a.append(b[0])
    c = []
    for d in a:
        e = [float(d)]
        c.append(e)
    val_feature_list.append(c)
    

total_val = 260
val_label = np.zeros(total_val)
val_label[0:33] = 0
val_label[33:71] = 1
val_label[71:90] = 2
val_label[90:146] = 3
val_label[146:214] = 4
val_label[214:260] = 5


val_label_tensor = torch.tensor(val_label,dtype=torch.long)

val_feature_list=torch.tensor(val_feature_list)
val_tensor = val_feature_list
# 搭建dataloader完毕
train_dataset = TensorDataset(train_tensor, train_label_tensor)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

val_dataset = TensorDataset(val_tensor, val_label_tensor)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False)

# 增加cnn
models = torch.nn.Sequential(
    torch.nn.BatchNorm1d(765),
    nn.Conv1d(in_channels=765,out_channels=765,kernel_size=1),
#     nn.Conv1d(in_channels=500,out_channels=400,kernel_size=1),
#     nn.MaxPool1d(1),
    torch.nn.BatchNorm1d(765),
    nn.Flatten(),
    torch.nn.Linear(765,388),
    torch.nn.Dropout(0.5),
    torch.nn.ReLU(),
    torch.nn.Linear(388,6)
)
device = torch.device('cuda')
LR = 0.001
epochs=100
net = models.to(device)
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
    net.parameters(),
    lr=LR,
)

def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]
#         images = torch.flatten(images)
        pred = model(images.to(device))
        pred = pred.reshape(len(pred),-1)
#         print('测试数据:',pred)
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()

        loss = loss_function(pred, labels.to(device))
        loss.backward()
        accu_loss += loss.detach()

        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}, lr: {:.5f}".format(
            epoch+1,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num,
            optimizer.param_groups[0]["lr"]
        )

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()
        # update lr
#         lr_scheduler.step()

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num

def evaluate(model, data_loader, device, epoch):
    loss_function = torch.nn.CrossEntropyLoss()

    model.eval()

    accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
    accu_loss = torch.zeros(1).to(device)  # 累计损失

    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred = pred.reshape(len(pred),-1)
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()

        loss = loss_function(pred, labels.to(device))
        accu_loss += loss

        data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(
            epoch+1,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num
        )

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num


train_loss_array = []
train_acc_array = []
val_loss_array = []
val_acc_array = []
for epoch in range(epochs):
    # train
    train_loss, train_acc = train_one_epoch(model=net,
                                            optimizer=optimizer,
                                            data_loader=train_loader,
                                            device=device,
                                            epoch=epochs)
    # validate
    val_loss, val_acc = evaluate(model=net,
                                 data_loader=val_loader,
                                 device=device,
                                 epoch=epochs)
    train_loss_array.append(train_loss)
    train_acc_array.append(train_acc)

    val_loss_array.append(val_loss)
    val_acc_array.append(val_acc)
posted @ 2023-05-06 20:28  sy-  阅读(109)  评论(0编辑  收藏  举报