选择函数index_select

书中(pytorch入门实战)讲:index_select(input, dim, index),指定维度dim上选取,未有示例。

 

查到相关资料后,

 

import torch as t  # 导入torch模块
c = t.randn(3, 6) # 定义tensor
print(c)
b = t.index_select(c, 0, t.tensor([0, 2]))
"""
第一个参数c为被引用tensor,
第二个参数0表示按行索引,1表示按列索引
第三个参数是一个tensor,表示tensor索引的序号,eg:b里面tensor[0, 2]表示第0行和第2行
"""
print(b)
print(c.index_select(0, t.tensor([0, 2]))) # 从输出结果看,此用法与b中方法等价
c = t.index_select(c, 1, t.tensor([1, 3])) # 按列索引,第1列到第3列
print(c)

输出:

tensor([[-1.3710, 0.0348, -0.0733, 0.1358, 1.2035, -0.5978],
            [ 0.4770, -0.0906, -0.7095, 0.3073, 0.2640, 1.9909],
            [-1.3719, -0.4406, 1.3095, -0.4160, 0.0700, 1.2667]])


tensor([[-1.3710, 0.0348, -0.0733, 0.1358, 1.2035, -0.5978],
            [-1.3719, -0.4406, 1.3095, -0.4160, 0.0700, 1.2667]])


tensor([[-1.3710, 0.0348, -0.0733, 0.1358, 1.2035, -0.5978],
            [-1.3719, -0.4406, 1.3095, -0.4160, 0.0700, 1.2667]])


tensor([[ 0.0348, 0.1358],
            [-0.0906, 0.3073],
            [-0.4406, -0.4160]])

posted @ 2020-05-12 11:22  珠江水手  阅读(1249)  评论(0编辑  收藏  举报