Pytorch 拟合多项式

拟合多项式参数

# coding=utf-8
import torch
import numpy as np

class Net(torch.nn.Module):
    def __init__(self):
        # self.a = torch.rand(1, dtype=torch.float32, requires_grad=True)  # a
        # self.b = torch.rand(1, dtype=torch.float32, requires_grad=True)  # b
        # self.c = torch.rand(1, dtype=torch.float32, requires_grad=True)  # b
        self.a = torch.tensor([1.0], dtype=torch.float32,  requires_grad=True)
        self.b = torch.tensor([1.0], dtype=torch.float32,  requires_grad=True)
        self.c = torch.tensor([10.0], dtype=torch.float32, requires_grad=True)
        self.__parameters = dict(a=self.a, b=self.b, c=self.c) 
        self.___gpu = False

    def cuda(self):
        if not self.___gpu:
            self.a = self.a.cuda().detach().requires_grad_(True)
            self.b = self.b.cuda().detach().requires_grad_(True)
            self.c = self.c.cuda().detach().requires_grad_(True)
            self.__parameters = dict(a=self.a, b=self.b, c=self.c) 
            self.___gpu = True
        return self

    def cpu(self):
        if self.___gpu:
            self.a = self.a.cpu().detach().requires_grad_(True)
            self.b = self.b.cpu().detach().requires_grad_(True)
            self.c = self.c.cpu().detach().requires_grad_(True)
            self.__parameters = dict(a=self.a, b=self.b, c=self.c)
            self.___gpu = False
        return self

    def forward(self, inputs):
        return self.a * inputs ** 2  + self.b * inputs + self.c

    def parameters(self):
        for name, param in self.__parameters.items():
            yield param


def main():
    x = np.linspace(1, 50, 1000)
    a, b, c = 2, 1, 13
    y = a * x ** 2 + b * x + c
    x_ref = torch.from_numpy(x.astype(np.float32))
    y_ref = torch.from_numpy(y.astype(np.float32))

    net = Net()
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
    loss_op = torch.nn.MSELoss(reduction='sum')

    for i in range(1, 100001, 1):
        y_out = net.forward(x_ref)
        loss = loss_op(y_ref, y_out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_numpy = loss.cpu().detach().numpy()
        if i % 2000 == 0:  # 1000
            a = net.a.cpu().detach().numpy()
            b = net.b.cpu().detach().numpy()
            c = net.c.cpu().detach().numpy()
            print(i, loss_numpy, a, b, c)
        if loss_numpy < 0.0001:  # 0.00001
            a = net.a.cpu().detach().numpy()
            b = net.b.cpu().detach().numpy()
            c = net.c.cpu().detach().numpy()
            print(a, b, c)
            exit()

if __name__ == '__main__':
    main()
View Code

 

错误,不规范代码记录

 

# coding=utf-8
import torch
import numpy as np

class Net(torch.nn.Module):
    def __init__(self):
        self.LL = torch.tensor([-20], dtype=torch.float32,  requires_grad=True)
        self.MM = torch.tensor([-10], dtype=torch.float32,  requires_grad=True)
        self.NN = torch.tensor([-20], dtype=torch.float32, requires_grad=True)
        self.__parameters = dict(LL=self.LL, MM=self.MM, NN=self.NN) 
        self.___gpu = False

    def cuda(self):
        if not self.___gpu:
            self.LL = self.LL.cuda().detach().requires_grad_(True)
            self.MM = self.MM.cuda().detach().requires_grad_(True)
            self.NN = self.NN.cuda().detach().requires_grad_(True)
            self.__parameters = dict(LL=self.LL, MM=self.MM, NN=self.NN) 
            self.___gpu = True
        return self

    def cpu(self):
        if self.___gpu:
            self.LL = self.LL.cpu().detach().requires_grad_(True)
            self.MM = self.MM.cpu().detach().requires_grad_(True)
            self.NN = self.NN.cpu().detach().requires_grad_(True)
            self.__parameters = dict(LL=self.LL, MM=self.MM, NN=self.NN)
            self.___gpu = False
        return self

    def forward(self, k_data):
        Ev = -0.5747985
        N = len(k_data)
        E_out = torch.zeros(N,3, dtype=torch.float32)
        for i in range(N):
            Ham = torch.zeros(3,3, dtype=torch.float32)
            kx, ky, kz = k_data[i][0], k_data[i][1], k_data[i][2]
            Ham[0][0] = Ev + self.LL*kx**2 + self.MM*(ky**2 + kz**2)
            Ham[1][1] = Ev + self.LL*ky**2 + self.MM*(kx**2 + kz**2)
            Ham[2][2] = Ev + self.LL*kz**2 + self.MM*(kx**2 + ky**2)
            Ham[0][1], Ham[0][2] = self.NN*kx*ky, self.NN*kx*kz
            Ham[1][0], Ham[1][2] = self.NN*kx*ky, self.NN*ky*kz
            Ham[2][0], Ham[2][1] = self.NN*kx*kz, self.NN*ky*kz
            Eig = torch.linalg.eigvals(Ham)
            E_out[i,:] = np.real(Eig)
        return E_out

    def parameters(self):
        for name, param in self.__parameters.items():
            yield param

def reference():
    k_data = np.load("kpoints.npy")
    E_data = np.load("bands.npy")
    E_data = np.transpose(E_data)

    k_data = k_data[269:339]
    E_data = E_data[269:339]
    return k_data, E_data


def main():
    k_data, E_data = reference()
    k_data = torch.from_numpy(k_data.astype(np.float32))
    E_ref = torch.from_numpy(E_data.astype(np.float32))

    net = Net()
    if torch.cuda.is_available():
        k_data = k_data.cuda()
        E_ref = E_ref.cuda()
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
    loss_op = torch.nn.MSELoss(reduction='sum')

    for i in range(1, 100001, 1):
        E_out = net.forward(k_data)
        loss = loss_op(E_ref, E_out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_numpy = loss.cpu().detach().numpy()
        if i % 100 == 0:  # 1000
            LL = net.LL.cpu().detach().numpy()
            MM = net.MM.cpu().detach().numpy()
            NN = net.NN.cpu().detach().numpy()
            print(i, loss_numpy, LL, MM, NN)
        if loss_numpy < 0.0001:  # 0.00001
            LL = net.LL.cpu().detach().numpy()
            MM = net.MM.cpu().detach().numpy()
            NN = net.NN.cpu().detach().numpy()
            print(LL, MM, NN)
            exit()

if __name__ == '__main__':
    main()
    # k_data, E_data = reference()
    # print(k_data)
View Code

 

posted @ 2022-10-14 12:38  ghzphy  阅读(60)  评论(0编辑  收藏  举报