代码笔记27 numpy和pytorch中的多维数组切片
原来还可以用数组切数组,我算是长见识了。不多说了,直接上代码应该可以明白
import numpy as np
xyz = np.arange(36).reshape(3, 4, 3)
B, N, C = xyz.shape
farthest = np.random.randint(0, N, size=B) # torch.randint(0, N, (B,), dtype=torch.long) # 初始时随机选择一点 (B)
batch_indices = np.arange(B) # (0-batch_size)的数组
centroid = xyz[batch_indices, farthest, :]
compare = xyz[:, farthest, :]
print("xyz:", xyz)
print("farthest:", farthest)
print("batch_indices:", batch_indices)
print("Two dimension slice:", centroid)
print("equivalent to:")
print(xyz[batch_indices[0],farthest[0],:])
print(xyz[batch_indices[1],farthest[1],:])
print(xyz[batch_indices[2],farthest[2],:])
print("One dimension slice:", compare)
最后mark一位解读PointNet++的博主,我也是看而有感
https://blog.csdn.net/weixin_42707080/article/details/105279415