计算正逆序

 

1、利用 auc

#encoding=utf8
from itertools import groupby
import sys
def calc_auc_and_pnr_fast(label,pred):
    sample = zip(label,pred)
    ## 根据pred倒排
    sample_sorted = sorted(sample,key=lambda x: -x[1])
    pos = 0
    cnt = 0
    r_cnt = 0
    last_pred = 0
    for i in range(len(sample_sorted)):
        l, p = sample_sorted[i]
        if l == 1:
            pos += 1
        elif l == 0:
            cnt += pos # 截止目前,有pos个正样本比他大
            if (i != 0 and last_pred == p):
                cnt -= 0.5
        last_pred = p
    n = len(label)
    negs = n - pos
    r_cnt = pos * negs - cnt
    auc = float(cnt) / float(pos * negs)
    pnr = float(cnt) / r_cnt
    return auc, pnr

if __name__ == '__main__':
    for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]):
        lines = list(lines)
        #print lines
        trues = [float(x.strip().split('\t')[1]) for x in lines]
        preds = [float(x.strip().split('\t')[2]) for x in lines]
        auc, pnr = calc_auc_and_pnr_fast(trues, preds)
        ## auc = 1/(1+1/pnr) ==> pnr = 1/ (1/a - 1)
        pnr_check = 1. / (1. / auc - 1 + 1e-9)
        print auc, pnr, pnr_check

 

 

2、归并

https://www.jianshu.com/p/e9813ac25cb6

 

"""
inversecount
"""
from itertools import groupby
import sys


class InversionCounter(object):
    """
    InversionCounter
    """
    @classmethod
    def merge_sort_count_sub(cls, vals):
        """
        merge_sort_count_sub
        """
        if sys.version > '3':
            if len(list(vals)) <= 1:
                return vals, 0
        else:
            if len(vals) <= 1:
                return vals, 0

        n = len(vals)
        left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n / 2])
        right_vals, right_cnt = cls.merge_sort_count_sub(vals[n / 2:])

        left_i = 0
        right_i = 0

        mid_cnt = 0
        new_vals = []
        while True:
            if left_vals[left_i][1] <= right_vals[right_i][1]:
                new_vals.append(left_vals[left_i])
                left_i += 1
            elif left_vals[left_i][1] > right_vals[right_i][1]:
                mid_cnt += (len(left_vals) - left_i)
                new_vals.append(right_vals[right_i])
                right_i += 1

            if left_i == len(left_vals):
                new_vals.extend(right_vals[right_i:])
                break
            if right_i == len(right_vals):
                new_vals.extend(left_vals[left_i:])
                break

        return new_vals, left_cnt + mid_cnt + right_cnt


    @classmethod
    def merge_sort_count_strict_right(cls, trues, preds):
        """
        merge_sort_count_strict_right
        """
        neg_preds = (-p for p in preds)
        vals = zip(trues, neg_preds)
        if sys.version > '3':
            sorted(vals)
        else :
            vals.sort()
        return cls.merge_sort_count_sub(vals)[1]


    @classmethod
    def merge_sort_count_strict_wrong(cls, trues, preds):
        """
        merge_sort_count_strict_wrong
        """
        vals = zip(trues, preds)
        if sys.version > '3':
            sorted(vals)
        else :
            vals.sort()
        return cls.merge_sort_count_sub(vals)[1]


    @classmethod
    def merge_sort_count_right(cls, trues, preds):
        """
        merge_sort_count_right
        """
        return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds)


    @classmethod
    def merge_sort_count_wrong(cls, trues, preds):
        """
        merge_sort_count_wrong
        """
        return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds)


    @classmethod
    def merge_sort_count_pair(cls, trues, preds=None):
        """
        preds: dummpy variable, no need inside function
        """
        trues = sorted(trues)
        acc_num = 0
        pair = 0
        for k, ks in groupby(trues):
            current_num = sum(1 for _ in ks)
            acc_num += current_num
            pair += (len(trues) - acc_num) * current_num

        return pair

if __name__ == '__main__':
    right = 0.
    wrong = 0.
    for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]):
        lines = list(lines)
        #print lines
        trues = [float(x.strip().split('\t')[1]) for x in lines]
        preds = [float(x.strip().split('\t')[2]) for x in lines]
        right += InversionCounter.merge_sort_count_strict_right(trues, preds)
        wrong += InversionCounter.merge_sort_count_strict_wrong(trues, preds)
    print (right, wrong, right / wrong)

 

posted @ 2022-05-12 16:50  乐乐章  阅读(217)  评论(0编辑  收藏  举报