TedNet:一个用于张量分解网络的Pytorch工具包

摘要

张量分解网络(Tensor Decomposition Networks,TDNs)因其固有的紧凑架构而流行。为了给更多的研究人员提供一种灵活的方式来利用TDNs,我们提出了一个名为TedNet的Pytorch工具包。TedNet实现了5种张量分解(即,CANDECOMP/PARAFAC(CP)、Block-Term Tucker(BTT)、Tucker-2、Tensor Train(TT)和Tensor Ring(TR)在传统的深度神经层、卷积层和全连接层上。通过利用基本层,可以简单地构造各种TDNs。TedNet获取链接为https://github.com/tnbar/tednet.

引言

张量分解网络(Tensor Decomposition Networks,TDNs)是通过用张量格式分解深层神经层来构建的。由于可以从张量分解核中恢复层的原始张量,TDNs通常被视为相应网络的压缩方法。与卷积神经网络(CNN)和递归神经网络(RNN)等传统网络相比,TDNs可以小得多,占用很少的内存。例如,TT-LSTM [1],BTT-LSTM [2,3],TR-LSTM [4,5]能够以比原始模型更高的精度减少17,554,17,414和34,192倍的参数。TDNs具有结构轻巧、性能优良等优点,有望应用于移动的设备、微型计算机等资源受限的应用场合。由于这些优点,TDNs通常可以在许多任务中实现相对高的准确度,其中参数减少很大,例如动作识别[6,7]。TDNs也已在FPGA中实现,用于快速推理,具有超内存减少[8]和多任务学习,以提高表示能力[9]。在此背景下,我们设计了TedNet软件包,为研究者在TDNs上的探索提供方便。

有几个相关的软件包,如T3F [10],Tensorly [11],TensorD [12],TensorNetwork [13],tntorch [14],OSTD [15]和TensorTools [16]。构造了用于低秩分解的最优分解图,并用MATLAB实现。基于NumPy [17]的TensorTools仅实现CP分解,而T3F明确设计用于Tensorflow上的Tensor Train分解[18]。类似地,基于Tensorflow,TensorD支持CP和Tucker分解。相比之下,TedNet使用后端Pytorch实现了五种张量分解[19]。TensorNetwork基于Tensorflow构建,并集成了丰富的张量计算工具。然而,TensorNetwork用于张量分解算法而不是TDNs。Tensorly支持各种后端,包括CuPy,Pytorch,Tensorflow和MXNet [20]。不幸的是,尽管Tensorly在处理张量代数、张量分解和张量回归方面功能强大,但它仍然缺乏对应用程序编程接口(API)的支持,无法直接构建张量神经网络。有趣的是,Tensorly可以通过其张量分解操作来帮助初始化TedNet网络模块。相比之下,TedNet可以通过直接调用API快速建立TDN层。此外,我们还提供了目前研究人员流行的三种深度TDNs。由于Pytorch的动态图形机制,TedNet也可以灵活地为程序员调试。

细节信息

TedNet的设计目标是通过调用相应的API来构建TDNs,这可以极大地简化构建TDNs的过程。如图1所示,TedNet采用Pytorch作为训练框架,因为它具有自动微分功能和构建DNN模型的方便性。此外,TedNet还使用NumPy [17]来辅助张量运算。TedNet的基本模块是TNBase,它是一个抽象类,继承自torch.nn.Module。因此,TedNet模型可以与其他Pytorch模型友好地结合起来。作为一个抽象类,TNBase需要子类来实现4个功能.在图1的右侧,我们展示了TedNet的两个主要深度架构,即TD ResNet和TD LSTM,它们可能分别是卷积神经网络和递归神经网络中最常用的主干。

通常,DNN由CNN和Linear构建。CNN的权重是4模式张量C ∈ RK×K×Cin×Cout,其中K表示卷积窗口,Cin表示输入通道,Cout表示对应的输出通道。线性是一个矩阵W ∈ RI×O,其中I和O分别是输入和输出特征的长度。与DNN类似,TDNs由TD-CNN和TDLinear组成(为了简化,TD-表示相应的张量分解模型),其权重CW通过张量分解进行因式分解。按照这种模式,有5种常用的张量分解(即CP、Tucker-2、Block-Term Tucker、Tensor Train和Tensor Ring),满足了大多数常见情况。值得注意的是,TedNet是一个支持张量环分解的开源包。此外,基于TD-CNN和TD-Linears,TedNet已经构建了一些基于张量分解的深度神经网络,例如。TD-ResNets,TD-RNNs。

测试案例

import tednet.tnn.tensor_ring as tr
import torch
import torch.nn as nn
from torch import Tensor
class TRClassifier(nn.Module):
def __init__(self):
super(TRClassifier, self).__init__()
self.tr_cnn = tr.TRConv2D([1], [4, 5], [6, 6, 6, 6], 3)
# in_shape(输入通道), out_shape(输出通道), ranks, kernel_size, stride=1(default), padding=0(default)
self.tr_fc = tr.TRLinear([20, 26, 26], [10], [6, 6, 6, 6])
# in_shape, out_shape, ranks,
def forward(self, inputs):
out = self.tr_cnn(inputs) # 1 * 20 * 26 * 26
out = torch.relu(out)
out = out.view(inputs.size(0), -1) # 1 * 13520
out = self.tr_fc(out) # 1 * 10
return out
x = torch.randn(1, 1, 28, 28)
model = TRClassifier()
OUT = model(x)
print(OUT.shape)
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from .base import TRConv2D, TRLinear
class TRLeNet5(nn.Module):
def __init__(self, num_classes: int, rs: Union[list, np.ndarray]):
"""LeNet-5 based on Tensor Ring.
Parameters
----------
num_classes : int
The number of classes
rs : Union[list, numpy.ndarray]
The ranks of network.
"""
super(TRLeNet5, self).__init__()
assert len(rs) == 4, "The length of the rank should be 4."
self.c1 = TRConv2D([1], [4, 5], [rs[0], rs[0], rs[0], rs[0]], 5, padding=2)
self.s2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.c3 = TRConv2D([4, 5], [5, 10], [rs[1], rs[1], rs[1], rs[1], rs[1]], 5)
self.s4 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.fc5 = TRLinear([5, 5, 5, 10], [5, 8, 8], [rs[2], rs[2], rs[2], rs[2], rs[2], rs[2], rs[2]])
self.fc6 = TRLinear([5, 8, 8], [num_classes], [rs[3], rs[3], rs[3], rs[3]])
def forward(self, inputs: Tensor) -> Tensor:
"""forwarding method.
Parameters
----------
inputs : torch.Tensor
tensor :math:`\in \mathbb{R}^{b \\times C \\times H \\times W}`
Returns
-------
torch.Tensor
tensor :math:`\in \mathbb{R}^{b \\times num\_classes}`
"""
out = self.c1(inputs)
out = F.relu(out)
out = self.s2(out)
out = self.c3(out)
out = F.relu(out)
out = self.s4(out)
out = out.view(inputs.size(0), -1)
out = self.fc5(out)
out = F.relu(out)
out = self.fc6(out)
return out

其他信息参考tednet的文档

参考链接

[1] TEDNET

posted @   信海  阅读(413)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 提示词工程——AI应用必不可少的技术
· 地球OL攻略 —— 某应届生求职总结
· 字符编码:从基础到乱码解决
· SpringCloud带你走进微服务的世界
点击右上角即可分享
微信分享提示