元素次数位图

给定一个整型矩阵,用位图表示每行中各元素出现的频数。主要是要尽可能使用内建函数。此处的程序要求元素生成的区间较短才好。

import mxnet as mx;import numpy as np
def count_row(indices, cls_num):
    """
        return same shape of indices, where each element represents how many times of the corresponding element in the indices appearances in that row.
        args:
            indices: (batch_size, destN) mx.NDArray of int
            cls_num:    int
        return:
            mask:   (batch_size, destN) mx.NDArray of count
    """
    indices_tmp = mx.nd.expand_dims(indices, axis=0) # (1, batch_size, destN)
    i = mx.nd.arange(cls_num, ctx=indices.context).reshape(shape=(-1,1,1)) # (cls_num, 1,  1)
    z= mx.nd.broadcast_sub(indices_tmp,i)==0 # (cls_num, batch_size, destN)
    tmp = z.sum( axis=-1, keepdims=True) # (cls_num, batch_size, 1)

    mask = mx.nd.broadcast_mul(z, tmp).sum(axis=0)
    return mask

indices= mx.nd.array( np.random.randint(0,20, (12, 10)) )
mask = count_row(indices, 20)
print indices.asnumpy(), mask.asnumpy()

posted @ 2018-03-16 21:06  rotxin  阅读(279)  评论(0编辑  收藏  举报