1.准备好网络模型代码
import torch
import torch.nn as nn
import torch.optim as optim
class BP_36(nn.Module):
def __init__(self):
super(BP_36, self).__init__()
self.fc1 = nn.Linear(2, 36)
self.fc2 = nn.Linear(36, 25)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class BP_64(nn.Module):
def __init__(self):
super(BP_64, self).__init__()
self.fc1 = nn.Linear(2, 64)
self.fc2 = nn.Linear(64, 25)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class Bi_LSTM(nn.Module):
def __init__(self):
super(Bi_LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)
self.fc1 = nn.Linear(72, 25)
def forward(self, x):
x, _ = self.lstm(x)
x = self.fc1(x)
return x
class Bi_GRU(nn.Module):
def __init__(self):
super(Bi_GRU, self).__init__()
self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)
self.fc1 = nn.Linear(72, 25)
def forward(self, x):
x, _ = self.gru(x)
x = self.fc1(x)
return x
2.运行计算参数量和复杂度的脚本
import torch
from net import Bi_GRU
from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· AI 智能体引爆开源社区「GitHub 热点速览」
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具