| |
| |
| |
| |
| |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.autograd import Variable |
| |
| matplotlib_is_available = True |
| try: |
| from matplotlib import pyplot as plt |
| except ImportError: |
| print("Will skip plotting; matplotlib is not available.") |
| matplotlib_is_available = False |
| |
| |
| data_mean = 4 |
| data_stddev = 1.25 |
| |
| |
| |
| |
| |
| (name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4) |
| |
| print("Using data [%s]" % (name)) |
| |
| |
| |
| |
| def get_distribution_sampler(mu, sigma): |
| return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) |
| |
| |
| def get_generator_input_sampler(): |
| return lambda m, n: torch.rand(m, n) |
| |
| |
| |
| class Generator(nn.Module): |
| def __init__(self, input_size, hidden_size, output_size, f): |
| super(Generator, self).__init__() |
| self.map1 = nn.Linear(input_size, hidden_size) |
| self.map2 = nn.Linear(hidden_size, hidden_size) |
| self.map3 = nn.Linear(hidden_size, output_size) |
| self.f = f |
| |
| def forward(self, x): |
| x = self.map1(x) |
| x = self.f(x) |
| x = self.map2(x) |
| x = self.f(x) |
| x = self.map3(x) |
| return x |
| |
| class Discriminator(nn.Module): |
| def __init__(self, input_size, hidden_size, output_size, f): |
| super(Discriminator, self).__init__() |
| self.map1 = nn.Linear(input_size, hidden_size) |
| self.map2 = nn.Linear(hidden_size, hidden_size) |
| self.map3 = nn.Linear(hidden_size, output_size) |
| self.f = f |
| |
| def forward(self, x): |
| x = self.f(self.map1(x)) |
| x = self.f(self.map2(x)) |
| return self.f(self.map3(x)) |
| |
| def extract(v): |
| return v.data.storage().tolist() |
| |
| def stats(d): |
| return [np.mean(d), np.std(d)] |
| |
| def get_moments(d): |
| |
| |
| mean = torch.mean(d) |
| |
| diffs = d - mean |
| |
| var = torch.mean(torch.pow(diffs, 2.0)) |
| |
| std = torch.pow(var, 0.5) |
| |
| zscores = diffs / std |
| |
| |
| skews = torch.mean(torch.pow(zscores, 3.0)) |
| |
| kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 |
| final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,))) |
| return final |
| |
| def decorate_with_diffs(data, exponent, remove_raw_data=False): |
| mean = torch.mean(data.data, 1, keepdim=True) |
| mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0]) |
| diffs = torch.pow(data - Variable(mean_broadcast), exponent) |
| if remove_raw_data: |
| return torch.cat([diffs], 1) |
| else: |
| return torch.cat([data, diffs], 1) |
| |
| def train(): |
| |
| g_input_size = 1 |
| g_hidden_size = 5 |
| g_output_size = 1 |
| d_input_size = 500 |
| d_hidden_size = 10 |
| d_output_size = 1 |
| minibatch_size = d_input_size |
| |
| d_learning_rate = 1e-3 |
| g_learning_rate = 1e-3 |
| sgd_momentum = 0.9 |
| |
| num_epochs = 5000 |
| print_interval = 100 |
| d_steps = 20 |
| g_steps = 20 |
| |
| dfe, dre, ge = 0, 0, 0 |
| d_real_data, d_fake_data, g_fake_data = None, None, None |
| |
| discriminator_activation_function = torch.sigmoid |
| generator_activation_function = torch.tanh |
| |
| d_sampler = get_distribution_sampler(data_mean, data_stddev) |
| gi_sampler = get_generator_input_sampler() |
| G = Generator(input_size=g_input_size, |
| hidden_size=g_hidden_size, |
| output_size=g_output_size, |
| f=generator_activation_function) |
| D = Discriminator(input_size=d_input_func(d_input_size), |
| hidden_size=d_hidden_size, |
| output_size=d_output_size, |
| f=discriminator_activation_function) |
| |
| criterion = nn.BCELoss() |
| d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum) |
| g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum) |
| |
| for epoch in range(num_epochs): |
| for d_index in range(d_steps): |
| |
| D.zero_grad() |
| |
| |
| d_real_data = Variable(d_sampler(d_input_size)) |
| d_real_decision = D(preprocess(d_real_data)) |
| d_real_error = criterion(d_real_decision, Variable(torch.ones([1]))) |
| d_real_error.backward() |
| |
| |
| d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) |
| d_fake_data = G(d_gen_input).detach() |
| ''' |
| detach 意为分离,对某个张量调用函数detach() 的作用是返回一个Tensor,它和原张量的数据相同,但requires_grad=False, |
| 也就意味着detach() 得到的张量不会具有梯度。这一性质即使我们修改其gradrequires_grad 属性也无法改变。 |
| 记detach()得到的张量为de,后续基于de 继续进行计算,那么反向传播过程中, |
| 遇到调用了detach() 方法的张量就会终止 (强调: de 没有梯度),不会继续向后计算梯度。 |
| ''' |
| d_fake_decision = D(preprocess(d_fake_data.t())) |
| d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1]))) |
| d_fake_error.backward() |
| d_optimizer.step() |
| |
| dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0] |
| |
| for g_index in range(g_steps): |
| |
| G.zero_grad() |
| |
| gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) |
| g_fake_data = G(gen_input) |
| dg_fake_decision = D(preprocess(g_fake_data.t())) |
| g_error = criterion(dg_fake_decision, Variable(torch.ones([1]))) |
| |
| g_error.backward() |
| g_optimizer.step() |
| ge = extract(g_error)[0] |
| |
| if epoch % print_interval == 0: |
| print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " % |
| (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data)))) |
| |
| if matplotlib_is_available: |
| print("Plotting the generated distribution...") |
| values = extract(g_fake_data) |
| print(" Values: %s" % (str(values))) |
| plt.hist(values, bins=50) |
| plt.xlabel('Value') |
| plt.ylabel('Count') |
| plt.title('Histogram of Generated Distribution') |
| plt.grid(True) |
| plt.savefig(r'/root/data/pytorch_gpu/log/output/fenbu.jpg') |
| plt.show() |
| |
| |
| train() |
| |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?