元素次数位图
给定一个整型矩阵,用位图表示每行中各元素出现的频数。主要是要尽可能使用内建函数。此处的程序要求元素生成的区间较短才好。
import mxnet as mx;import numpy as np
def count_row(indices, cls_num):
"""
return same shape of indices, where each element represents how many times of the corresponding element in the indices appearances in that row.
args:
indices: (batch_size, destN) mx.NDArray of int
cls_num: int
return:
mask: (batch_size, destN) mx.NDArray of count
"""
indices_tmp = mx.nd.expand_dims(indices, axis=0) # (1, batch_size, destN)
i = mx.nd.arange(cls_num, ctx=indices.context).reshape(shape=(-1,1,1)) # (cls_num, 1, 1)
z= mx.nd.broadcast_sub(indices_tmp,i)==0 # (cls_num, batch_size, destN)
tmp = z.sum( axis=-1, keepdims=True) # (cls_num, batch_size, 1)
mask = mx.nd.broadcast_mul(z, tmp).sum(axis=0)
return mask
indices= mx.nd.array( np.random.randint(0,20, (12, 10)) )
mask = count_row(indices, 20)
print indices.asnumpy(), mask.asnumpy()