import torch
import net.bilstm
import net.transformer
from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 统计Transformer模型的参数量和计算复杂度
model_transformer = net.transformer.AudioTransformer(80, 512, 6, 6) #填写的是模型的参数
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (2, 40, 256), as_strings=True, print_per_layer_stat=False) #填写的是输入
#网络x张量形状
print('Transformer模型参数量:' + params_transformer)
print('Transformer模型计算复杂度:' + flops_transformer)
# 统计BiLSTM模型的参数量和计算复杂度
model_bilstm = net.bilstm.BiLSTM(80, 512, 2, 6)
model_bilstm.to(device)
flops_bilstm, params_bilstm = get_model_complexity_info(model_bilstm, (2, 40, 256), as_strings=True, print_per_layer_stat=False)
print('BiLSTM模型参数量:' + params_bilstm)
print('BiLSTM模型计算复杂度:' + flops_bilstm)
分类:
深度学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· AI 智能体引爆开源社区「GitHub 热点速览」
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具