tensordot 的源码解读

 

我一直以为,谈论数学计算只能用数学语言,就是用数学公式一步步推导

这个世界是,可知的,简洁的

复杂的是现象,不是本质

在线性代数中,介绍的矩阵乘法只有一种,可以退出,在程序中,不论多么高维度的矩阵乘法运算,其本质也是二维矩阵运算

import numpy as np
np.random.seed(10)
size_A=size=(3,4,2)
A = np.random.randint(5, size=size_A)
size_B=(4,2,3)
B = np.random.randint(12, size=size_B)
axes1=[[1,2], [0,1]]

# axes1=[[2], [1]]
C=np.tensordot(A, B, axes1)
print(C)
[[ 76 100  76]
 [ 71  39  36]
 [112  94  94]]
def transe_shape(arr,axes,ff,size):
    nda=arr.ndim
    
    notin = [k for k in range(nda) if k not in list(axes)]
    
    def mul(a):
        d=1
        for p in a :
            d=d*size[p]
        return d
    axes=list(axes)
    
    axm=mul(axes)
    nom=mul(notin)
    if ff:
        newaxes = notin + axes
        
        return newaxes,notin,(nom,axm)
    else:
        newaxes =axes+ notin 
        return newaxes,notin,(axm,nom)
        

axes_A,notin_A,shape_fA=transe_shape(A,axes1[0],1,size_A)
# A1=A.transpose([1, 2, 0]).reshape((3*1,2))

A1=A.transpose(axes_A).reshape(shape_fA)

axes_B,notin_B,shape_fB=transe_shape(B,axes1[1],0,size_B)
# B1=B.transpose( [1, 0, 2]).reshape((2,3*2))
B1=B.transpose( axes_B).reshape(shape_fB)


#矩阵相乘
C1=A1@B1


out_shape=[size_A[k] for k in notin_A ] +[size_B[k] for k in notin_B ]
# print(C1.reshape((1,3,3,2))==C)
print(C1.reshape(out_shape)==C)
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]
posted @   luoganttcc  阅读(49)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 我与微信审核的“相爱相杀”看个人小程序副业
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示