代码笔记27 numpy和pytorch中的多维数组切片

原来还可以用数组切数组,我算是长见识了。不多说了,直接上代码应该可以明白

import numpy as np

xyz = np.arange(36).reshape(3, 4, 3)
B, N, C = xyz.shape
farthest = np.random.randint(0, N, size=B)  # torch.randint(0, N, (B,), dtype=torch.long) # 初始时随机选择一点 (B)
batch_indices = np.arange(B)  # (0-batch_size)的数组
centroid = xyz[batch_indices, farthest, :]
compare = xyz[:, farthest, :]
print("xyz:", xyz)
print("farthest:", farthest)
print("batch_indices:", batch_indices)
print("Two dimension slice:", centroid)

print("equivalent to:")
print(xyz[batch_indices[0],farthest[0],:])
print(xyz[batch_indices[1],farthest[1],:])
print(xyz[batch_indices[2],farthest[2],:])

print("One dimension slice:", compare)

最后mark一位解读PointNet++的博主,我也是看而有感
https://blog.csdn.net/weixin_42707080/article/details/105279415

posted @ 2023-05-01 21:31  The1912  阅读(65)  评论(0编辑  收藏  举报