可视化全连接层(蒙特卡洛法)
import random import torch import torch.nn as nn import matplotlib.pyplot as plt import numpy as np import math from torch.utils.data import DataLoader from torch.utils.data import Dataset epochs=1000 class pt: def __init__(self,x,y): self.x=x self.y=y class logistic(nn.Module): def __init__(self): super(logistic, self).__init__() self.w = torch.nn.Parameter(torch.randn(2, 1)) self.b = torch.nn.Parameter(torch.zeros([1])) self.line1=torch.nn.Linear(2,1) self.line2=nn.Sequential( nn.Linear(2,5000), nn.ReLU(), nn.Linear(5000, 1), ) self.pred=None def forward(self, X): #self.pred=torch.matmul(X,self.w)+self.b self.pred=self.line2(X) return torch.sigmoid(self.pred),self.pred def generate_point(th=0.4,start=50,end=90): class_list=[] point_list=[] for angle in range(360): theta=3.14/180.0*angle x=math.cos(theta)+random.random()*th y=math.sin(theta)+random.random()*th point_list.append(list([x,y])) if angle>start and angle<end or angle>180 and angle<230 or angle>250 and angle<300:#or angle>180 and angle<230 class_list.append(0) else: class_list.append(1) return np.array(point_list),np.array(class_list) class fdata(Dataset): # dirname 为训练/测试数据地址,使得训练/测试分开 def __init__(self, train=True): super(fdata, self).__init__() self.data,self.label = generate_point() def __len__(self): return self.data.shape[0] def __getitem__(self, index): image = self.data[index] image = image.astype(np.float32) image = torch.unsqueeze(torch.from_numpy(image),0) label = self.label[index] label = np.array(label.astype(np.float32)).reshape(1) label = torch.unsqueeze(torch.from_numpy(label), 0) return image,label def draw(pt_list,cls_list,module): plt.title('circle') pt_r,pt_b=[],[] for n in range(len(cls_list)): if(cls_list[n]==1): pt_r.append(pt_list[n]) else: pt_b.append(pt_list[n]) pt_r=np.array(pt_r) pt_b = np.array(pt_b) line_list=[] for n in range(-10,20): n=n*0.1 for m in range(-10,20): m=m*0.1 line_list.append([n,m]) line_array=[] line_tensor=torch.from_numpy(np.array(line_list)).reshape(-1,1,2).float() output,pred=module(line_tensor) pred=pred.squeeze().detach().numpy().tolist() for n in range(len(pred)): if pred[n]<0: line_array.append(line_list[n]) line_array=np.array(line_array) plt.scatter(line_array[:, 0], line_array[:, 1], c="g") plt.scatter(pt_r[:,0],pt_r[:,1],c="r") plt.scatter(pt_b[:, 0],pt_b[:, 1],c="b") plt.xlim(-1,2) plt.ylim(-1, 2) plt.show() criterion = nn.BCELoss() md=logistic() opt=torch.optim.Adam(md.parameters(),lr=0.001) pt_list,cls_list=generate_point() # input=torch.from_numpy(pt_list).reshape(-1,1,2).float() # label=torch.from_numpy(cls_list).reshape(-1,1,1).float() train_dataset = fdata() train_dataloder = DataLoader(train_dataset, batch_size=10, num_workers=0, drop_last=True,shuffle=True) for i in range(epochs): for input,label in train_dataloder: output,pred=md(input) loss=criterion(output,label) opt.zero_grad() loss.backward() opt.step() print("第"+str(i)+":"+str(loss)) draw(pt_list,cls_list,md)
若非线性激活函数在定义域上是连续光滑的,能够被n阶泰勒公式描述,他所具有的拟合精度便是n阶。激活函数本质上就是拟合所用到的基函数。
按照kolmogorov定理,三层就可以拟合所有连续函数(写在代码上,就是两层),并说明隐层结点数应该为输入特征数的两倍加一。
二次方程具有的精度是2阶,三次方程具有的精度是3阶,leaky relu的非线性程度也是低阶的,与alpha值有关。
如果我将代码中的relu改为ex,或者是二次函数,显然也是可以跑通的。但是当网络层数过深时,连乘效应会导致梯度爆炸和消失,激活函数的梯度应该等于一。
总结:
0.带有激活函数的全连接层(至少两层)越宽,其拟合能力越强。全连接层不应该太深,超过三层一般是无用的。经实验,多层连接层对于低阶激活函数是有效的,对于高阶(relu)有可能是无效的,加层数并不一定能增加拟合能力,而且多层隐层受木桶效应的影响很严重。
来自知乎的优质回答,也间接说明了单纯增加层数没什么用,短连接指的是残差连接:




1.在全连接层中,不要使用可以被低阶泰勒公式描述的激活函数(后称低阶激活函数),因为他的拟合精度是有限的。针对低阶激活函数,即使全连接层很宽,拟合精度也无法提升,因为低阶激活函数可训练参数数目是固定的,可训练参数数目就是全连接层可拟合参数的上限,宽度高于这个上限是无用的,没法提高拟合精度。(反正使用relu就对了)
2.为了使拟合精度提高、低维映射到高维可分(至少大于输入层、输出层宽度),隐层宽度一般会很大。
3.当网络层数过深时,连乘效应会导致梯度爆炸和消失,激活函数的梯度应该等于一。(反正使用relu就对了)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理