蓄水池抽样(转)
Init : a reservoir with the size: k
for i= k+1 to N
M=random(1, i);
if( M < k)
SWAP the Mth value and ith value
end for
网上有人给出了证明,先转过来:
【转】
证明:
每次都是以 k/i 的概率来选择
例: k=1000的话,从1001开始作选择,1001被选中的概率是1000/1001,1002被选中的概率是1000/1002,与我们直觉是相符的。
接下来证明:
假设当前是i+1, 按照我们的规定,i+1这个元素被选中的概率是k/i+1,也即第 i+1 这个元素在蓄水池中出现的概率是k/i+1
此时考虑前i个元素,如果前i个元素出现在蓄水池中的概率都是k/i+1的话,说明我们的算法是没有问题的。
对这个问题可以用归纳法来证明:k < i <=N
1.当i=k+1的时候,蓄水池的容量为k,第k+1个元素被选择的概率明显为k/(k+1), 此时前k个元素出现在蓄水池的概率为 k/(k+1), 很明显结论成立。
2.假设当 j=i 的时候结论成立,此时以 k/i 的概率来选择第i个元素,前i-1个元素出现在蓄水池的概率都为k/i。
证明当j=i+1的情况:
即需要证明当以 k/i+1 的概率来选择第i+1个元素的时候,此时任一前i个元素出现在蓄水池的概率都为k/(i+1).
前i个元素出现在蓄水池的概率有2部分组成, ①在第i+1次选择前得出现在蓄水池中,②得保证第i+1次选择的时候不被替换掉
①.由2知道在第i+1次选择前,任一前i个元素出现在蓄水池的概率都为k/i
②.考虑被替换的概率:
首先要被替换得第 i+1 个元素被选中(不然不用替换了)概率为 k/i+1,其次是因为随机替换的池子中k个元素中任意一个,所以不幸被替换的概率是 1/k,故
前i个元素中任一被替换的概率 = k/(i+1) * 1/k = 1/i+1
则没有被替换的概率为: 1 - 1/(i+1) = i/i+1
综合① ②,通过乘法规则
得到前i个元素出现在蓄水池的概率为 k/i * i/(i+1) = k/i+1
故证明成立
对于抽样问题,最近看见了一些方法,做个总结:
问题:要求从1,2,3..n中,以等概率的方式,抽取m个元素
1、使用上面的蓄水池抽样
void sample_pool(const int N, const int m)
{
int i, tmp,rd;
int* x = new int[N];
for(i = 0 ; i < N ; i ++)
x[i] = i + 1;
for(i = m ; i < N; i ++ )
{
rd = rand()%i;
if(rd < m)
swap(x[i],x[rd]);
}
for(i = 0 ; i < m; i ++)
cout<<x[i]<<" ";
delete []x;
x = NULL;
}空间和时间均为O(N)
2 、从N个中选取m个, 可以先确定一个后,然后从身下的N-1个中选取m-1个出来。
void sample_rand(const int N,const int m)
{
int select = m,i,rd;
int remain = N;
for(i = 0; i < N ; i++)
{
rd = rand()%remain;
if(rd < select)
{
cout<< i<<" ";
select--;
}
remaining--;
}
}
上面这个方法非常经典,是Knuth在the art of computer programming中提出的。使用的额外空间为O(1),时间为O(N)。其概率的证明也是非常简单的。简单推到可发现,是等概率选择每个元素的。而且,最后肯定会选择刚刚好m个元素,前面一直没选择的话,则当remaining == select时,就会都选择。
3、将抽样的看成是一个集合,则要从N中选择出m个不同的元素,存入到集合中,可用set来完成
利用STL中的set来完成这个功能。
void sample_set(const int N,const int m)
{
set<int>s;
while(s.size()<m)
{
s.insert(rand()%n);
}
for(set<int>::iterator it = s.begin();it!=s.end();it++)
cout<<*it<<" ";
}
4、扰乱一个递增序列。
for i =[0,N)
swap(x[i],x[rand(i,n-1)];
有人证明,只要扰乱前m个就可以。
void sample_shuf(const int N,const int m)
{
int i, j;
int *x = new int[N];
for(i = 0 ; i <N; i++) x[i]=i+1;
for(i = 0 ; i < m ; i ++)
{
j = rand(i,n-1);
swap(x[i],x[j]);
}
sort(x,x+m);
Print(x,m);
delete []x;x= NULL;
}