随机选取算法 (有权重的记录中选取)
三类随机问题
1. 已有n条记录,从中选取m条记录,选取出来的记录前后顺序不管。
实现思路:按行遍历所有记录,约隔n/m条取一个数据即可
2. 在1类情况下,还要求选取出来的m条记录是随机排序的
实现思路: 给n条记录,分别增加一列标记,值为随机选取的1至n之间的不重复数据,
实现参考博文 将文件内容按行随机排列
3. 区别于1,2类问题, 如果记录是有权重的,如何结合权重去随机选取。 比如A的权重为10, B的权重股为5, C的权重为1, 则随机选取4个时可能应该出现AABB。
这第三类问题是本文重点,下面开始解决。
实现思路: 以 A:10, B:5, C:1 三条记录上随机选取4条为例,(是否以权重排序这个无所谓)
对于
A10
B5
C1
首先,将第n行的数值赋为第n行加第n-1行的,递归执行,如下:
A10
B15
C16
然后每次从[1,16]随机选取一个数,如果落在[1,10]之间,则选取A,如果落在(10,15]之间则选B,如果落在(16,16]之间则选取C, 图示如下,谁占的区间大(权重高),被选上的概率更大。
知道了思路,实现起来就比较方便了, 需要考虑的一点可能就是我随即选了一个数值,比如12,我怎么跟B对应上? 其实也比较简单,用二分法查找即可。
下面附上实现代码:
python 版本(myrandom.py)
#!/usr/bin/python import string import sys import os import random #[begin, end) def get_pos(wlist, begin, end, rnum): if begin >= end: return begin; mid = (begin+end)/2; if wlist[mid] >= rnum: return get_pos(wlist, begin, mid, rnum); else: return get_pos(wlist, mid+1, end, rnum); if __name__=='__main__': if len(sys.argv) < 3: print "help" exit(-1); orifile = sys.argv[1] samplenum = (int)(sys.argv[2]) #read input file allweight = 0; ruler_weights = [] ruler_keys = [] for line in open(orifile, "r"): lineinfo = line.rstrip().split('\t') key = lineinfo[0] weight = 1 if len(lineinfo) >= 2: weight = (int)(lineinfo[1]); allweight += weight; ruler_weights.append(allweight); #idx, weight ruler_keys.append(key); #idx, key pass #begin random allnum = len(ruler_weights); for i in range(1,samplenum): rnum = random.randint(1,allweight); #[1,allweight] pos = get_pos(ruler_weights, 0, allnum, rnum); #[0,allnum) key = ruler_keys[pos] print "%s\t%d\t%d"%(key, ruler_weights[pos], rnum); pass;
c版本
#include <string> #include <cstdlib> #include <vector> using namespace std; const int LEN = 4098; const int MAX_QUERY_LEN = 2048; //返回属于[p,q)的随机数 int rand(int p, int q) { int size = q-p+1; return p+ rand()%size; } //删除行尾换行符 int chomp(char *str) { int len = strlen(str); while(len > 0 && (str[len - 1] == '\n' || str[len - 1] == '\r')) { str[len - 1] = 0; len--; } return len; } //获取一个随机数会落在哪个区间 int get_pos(vector<int> vec_freq, int begin, int end, int rand_num) { if(begin >= end) { return begin; } int mid = (begin + end)/2; if( vec_freq[mid] >= rand_num ) { return get_pos(vec_freq, begin, mid, rand_num ); } else { return get_pos(vec_freq, mid+1, end, rand_num ); } } //主函数 int main(int argc, char *argv[]) { //输入记录文件,两列,第一列为记录,第二列为热度值 FILE* infile = fopen(argv[1], "r"); if( infile == NULL ) { printf("Cann't open file %s.", argv[1]); return -1; } FILE* outfile = fopen(argv[2], "w"); if( outfile == NULL) { printf("Cann't open file %s to write.", argv[2]); return -1; } //要获取的随机记录个数 int num = atoi(argv[3]); if( num <= 0) { printf("num [%s] <= 0."); return -1; } //这两个数组用下标关联 vector<string> vec_query; vector<int> vec_freq; vec_query.clear(); vec_freq.clear(); int freq = 0; char line[MAX_QUERY_LEN]={0}; while( !feof(infile) ) { if( !fgets(line, sizeof(line),infile)) { break; } line[sizeof(line)-1] = 0; chomp(line); char* p_tab = strchr(line, '\t'); if( NULL == p_tab ) { printf("line format error. [%s]\n"); continue; } *p_tab = 0; string query(line); freq += atoi(p_tab+1); vec_query.push_back(query); vec_freq.push_back(freq); //printf("%s\t%d\n", line, freq); } for(int i=0; i < num; ++i) { int rand_num = rand(1, freq+1); int pos = get_pos(vec_freq, 0, vec_freq.size(), rand_num); fprintf(outfile, "%s\t%d\t%d\n", vec_query[pos].c_str(), vec_freq[pos], rand_num); } fclose(infile); fclose(outfile); return 0; }
转载请注明出处: http://www.cnblogs.com/liyuxia713/