pytorch基础2
下面是常见函数的代码例子
1 import torch 2 import numpy as np 3 print("分割线-----------------------------------------") 4 #加减乘除操作 5 a = torch.rand(3,4) 6 b = torch.rand(4) 7 print(a) 8 print(b) 9 print(torch.add(a, b)) 10 print(torch.sub(a, b)) 11 print(torch.mul(a, b)) 12 print(torch.div(a, b)) 13 print(torch.all(torch.eq(a - b,torch.sub(a,b))))#判断torch的减法和python的减法结果是否一致 14 print("分割线-----------------------------------------") 15 #矩阵乘法(点乘和叉乘)matmul mm @ * 16 a = torch.ones(2,2)*3 17 b = torch.ones(2,2) 18 print(a*b)#点乘积 19 print(a.matmul(b))#叉乘积 20 print(a@b)#叉乘积 21 print(a.mm(b))#叉乘积,相比于前两种,这一种只能适合二维数组的乘积 22 a = torch.rand(4,3,28,64) 23 b = torch.rand(4,3,64,32) 24 #torch.mm(a,b).shape#此时会报错,mm只适合二维 25 print(torch.matmul(a,b).shape)#torch.Size([4, 3, 28, 32]) 26 b = torch.rand(4,1,64,32) 27 torch.matmul(a, b).shape #torch.Size([4, 3, 28, 32]) 28 b = torch.rand(4,64,32) 29 #torch.matmul(a, b).shape ,报错,因为b的4对应a的3无法进行广播,所以报错 30 print("分割线-----------------------------------------") 31 #power的使用 32 a = torch.full([2,2],3) 33 print(a.pow(2)) 34 print(a**2) 35 aa = a**2 36 print(aa.sqrt()) 37 print(aa**(0.5)) 38 print(aa.rsqrt())#开根号后的倒数 39 print("分割线-----------------------------------------") 40 #floor(),ceil(),round(),trunc(),frac()的使用 41 a = torch.tensor(3.14) 42 print(a.floor(),a.ceil(),a.trunc(),a.frac())#后两个是取整,和取小数 43 print(a.round()) #四舍五入 44 print("分割线-----------------------------------------") 45 #clamp 和 dim,keepdim 46 grad = torch.rand(2,3)*15 47 print(grad.max(),grad.median(), grad.min()) 48 print(grad) 49 print(grad.clamp(10))#小于10的都变成10 50 print(grad.clamp(3,10))#不在3到10之间的变为3或者10 51 a = torch.randn(4,10) 52 print(a.max(dim=1))#返回每一列最大值组成的数组和对应的下标 53 print(a.argmax(dim = 1))#这个只返回最大值对应的下标 54 print(a.max(dim=1,keepdim=True))#keepdim的作用是使返回值维度是否仍然为原来的维度不变 55 print(a.argmax(dim=1,keepdim=True)) 56 print("分割线-----------------------------------------") 57 #topk和kthvalue 58 print(a.topk(3,dim=1))#返回每行最大的三个数和下标 59 print(a.topk(3,dim=1,largest=False))#返回每行最小的三个数和下标 60 print(a.kthvalue(5,dim=1))#返回每行第五大的数字和对应数字的下标 61 print("分割线-----------------------------------------") 62 #矩阵的比较,cat的使用 63 m = torch.rand(2,2) 64 n = torch.rand(2,2) 65 print(m,n) 66 print(m == n) 67 print(m.eq(n)) 68 print(m > n) 69 a = torch.rand(4,32,8) 70 b = torch.rand(5,32,8) 71 print(torch.cat([a,b],dim = 0).shape)#按行进行拼接,torch.Size([9, 32, 8]) 72 #更详细的拼接可以看下面的图 73 a1 = torch.rand(4,3,32,32) 74 a2 = torch.rand(4,1,32,32) 75 #torch.cat([a1,a2],dim = 0).shape#报错原因是如果进行维度0上进行拼接,则要保证其他维度必须一致 76 a1 = torch.rand(4,3,14,32) 77 a2 = torch.rand(4,3,14,32) 78 print(torch.cat([a1,a2],dim=2).shape)#torch.Size([4, 3, 28, 32]) 79 print("分割线-----------------------------------------") 80 #stack和split的使用 81 #用来进行维度的扩充,这个就是在dim =2进行扩充 82 print(torch.stack([a1,a2],dim=2).shape)#torch.Size([4, 3, 2, 14, 32]) 83 aa , bb =a1.split([2,2],dim=0)#拆分成两份每份数目是2,2 84 print(aa.shape,bb.shape) 85 aaa,bbb = a1.split(2,dim=0)#每份长度为2 86 print(aaa.shape,bbb.shape) 87 aa,bb = a1.chunk(2,dim = 0)#拆成两块,每块一个 88 print("分割线-----------------------------------------") 89 #where,gather的使用 90 cond = torch.rand(2,2) 91 print(cond) 92 a = torch.zeros(2,2) 93 b = torch.ones(2,2) 94 s = torch.where(cond>0.5,a,b) 95 print(s)#如果大于0.5对应位置为a的对应位置的值,否则为b的对应位置的值 96 prob = torch.randn(4,10) 97 idx = prob.topk(dim =1,k=3) 98 id = idx[1] 99 print(id)#索引下标 100 label= torch.arange(10)+100 101 d = torch.gather(label.expand(4,10),dim =1,index = id)#获取对应索引下标的值 102 print(d)
运行结果如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | D:\anaconda\anaconda\pythonw.exe D:/Code/Python/龙良曲pytorch学习/高级操作.py 分割线----------------------------------------- tensor([[0.5581, 0.2369, 0.1379, 0.3702], [0.1565, 0.1022, 0.5839, 0.1778], [0.0204, 0.1498, 0.5276, 0.4219]]) tensor([0.7969, 0.9313, 0.0608, 0.0245]) tensor([[1.3551, 1.1682, 0.1988, 0.3947], [0.9535, 1.0335, 0.6448, 0.2023], [0.8173, 1.0811, 0.5884, 0.4464]]) tensor([[-0.2388, -0.6944, 0.0771, 0.3457], [-0.6404, -0.8291, 0.5231, 0.1533], [-0.7766, -0.7815, 0.4667, 0.3974]]) tensor([[0.4448, 0.2206, 0.0084, 0.0091], [0.1247, 0.0952, 0.0355, 0.0044], [0.0162, 0.1395, 0.0321, 0.0103]]) tensor([[ 0.7003, 0.2544, 2.2669, 15.1075], [ 0.1964, 0.1097, 9.5973, 7.2539], [ 0.0255, 0.1609, 8.6706, 17.2148]]) tensor(True) 分割线----------------------------------------- tensor([[3., 3.], [3., 3.]]) tensor([[6., 6.], [6., 6.]]) tensor([[6., 6.], [6., 6.]]) tensor([[6., 6.], [6., 6.]]) torch.Size([4, 3, 28, 32]) 分割线----------------------------------------- tensor([[9., 9.], [9., 9.]]) tensor([[9., 9.], [9., 9.]]) tensor([[3., 3.], [3., 3.]]) tensor([[3., 3.], [3., 3.]]) tensor([[0.3333, 0.3333], [0.3333, 0.3333]]) 分割线----------------------------------------- tensor(3.) tensor(4.) tensor(3.) tensor(0.1400) tensor(3.) 分割线----------------------------------------- tensor(14.8811) tensor(8.5843) tensor(5.4463) tensor([[10.3914, 14.8811, 8.5843], [10.6012, 5.4463, 5.7588]]) tensor([[10.3914, 14.8811, 10.0000], [10.6012, 10.0000, 10.0000]]) tensor([[10.0000, 10.0000, 8.5843], [10.0000, 5.4463, 5.7588]]) torch.return_types.max( values=tensor([1.1859, 0.7394, 1.2261, 0.5407]), indices=tensor([5, 1, 2, 4])) tensor([5, 1, 2, 4]) torch.return_types.max( values=tensor([[1.1859], [0.7394], [1.2261], [0.5407]]), indices=tensor([[5], [1], [2], [4]])) tensor([[5], [1], [2], [4]]) 分割线----------------------------------------- torch.return_types.topk( values=tensor([[ 1.1859, 0.8406, 0.7883], [ 0.7394, 0.4172, 0.2871], [ 1.2261, 0.9851, 0.9759], [ 0.5407, 0.1773, -0.2789]]), indices=tensor([[5, 4, 7], [1, 2, 4], [2, 8, 4], [4, 1, 8]])) torch.return_types.topk( values=tensor([[-1.7351, -0.3469, -0.3116], [-1.8399, -1.1521, -0.3790], [-1.3753, -0.6663, -0.2762], [-1.6875, -1.5461, -0.9697]]), indices=tensor([[0, 8, 6], [3, 5, 0], [7, 1, 5], [0, 2, 3]])) torch.return_types.kthvalue( values=tensor([-0.1758, 0.0470, -0.2039, -0.6223]), indices=tensor([2, 9, 3, 6])) 分割线----------------------------------------- tensor([[0.9107, 0.4905], [0.6499, 0.3425]]) tensor([[0.6911, 0.9619], [0.1428, 0.5437]]) tensor([[False, False], [False, False]]) tensor([[False, False], [False, False]]) tensor([[ True, False], [ True, False]]) torch.Size([9, 32, 8]) torch.Size([4, 3, 28, 32]) 分割线----------------------------------------- torch.Size([4, 3, 2, 14, 32]) torch.Size([2, 3, 14, 32]) torch.Size([2, 3, 14, 32]) torch.Size([2, 3, 14, 32]) torch.Size([2, 3, 14, 32]) 分割线----------------------------------------- tensor([[0.7541, 0.3861], [0.9605, 0.7175]]) tensor([[0., 1.], [0., 0.]]) tensor([[5, 4, 6], [8, 2, 3], [8, 6, 4], [6, 2, 1]]) tensor([[105, 104, 106], [108, 102, 103], [108, 106, 104], [106, 102, 101]]) Process finished with exit code 0 |
作者:你的雷哥
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须在文章页面给出原文连接,否则保留追究法律责任的权利。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:基于图像分类模型对图像进行分类
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 25岁的心里话
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
· 零经验选手,Compose 一天开发一款小游戏!
2018-11-08 实现迪杰斯特拉算法