codeforces 439 E. Devu and Birthday Celebration 组合数学 容斥定理

题意:

q个询问,每一个询问给出2个数sum,n

1 <= q <= 10^5, 1 <= n <= sum <= 10^5

对于每一个询问,求满足下列条件的数组的方案数

1.数组有n个元素,ai >= 1

2.sigma(ai) = sum

3.gcd(ai) = 1

 

solution:

这道题的做法类似bzoj2005能量采集

f(d) 表示gcd(ai) = d 的方案数

h(d) 表示d|gcd(ai)的方案数

令ai = bi * d

则有sigma(bi) = sum / n

  d | gcd(ai)

还要满足bi >= 1

则显然有h(d) = C(sum / d - 1,n - 1)

    h(d) = f(d) + f(2d) + ... + f(d_max)

 

这里的d满足:

1.d是sum 的约数

2.sum / d >= n

则f(d) = h(d) - sigma(f(j)) ,2d <=j<=sum/n

倒序遍历d

ans = f(1)

 

由于询问的次数太多,每次询问后,可以把(sum,n)放入map中,记录下来

 

                                            
  //File Name: cf439E.cpp
  //Author: long
  //Mail: 736726758@qq.com
  //Created Time: 2016年02月17日 星期三 14时58分16秒
                                   

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <map>
#include <cmath>
#include <cstdlib>
#include <vector>

#define LL long long
#define pb push_back

using namespace std;

const int MAXN = 1e5+5;
const int MOD = 1e9+7;

LL f[MAXN];
LL jie[MAXN];
bool is[MAXN];
vector<int> dive;
map< pair<int,int>,int > rem; 

void init()
{
    jie[0] = 1;
    for(int i=1;i<MAXN;i++){
        jie[i] = jie[i-1] * i % MOD;
    }
    rem.clear();
}

void get_dive(int sum,int n)
{
    int e = (int)sqrt(sum + 0.0);
    dive.clear();
    int j;
    for(int i=1;i<=e;i++){
        if(sum % i == 0){
            if(sum / i >= n)
                dive.pb(i);
            j = sum / i;
            if(j != i && sum / j >= n)
                dive.pb(j);
        }
    }
    sort(dive.begin(),dive.end());
    for(int i=0;i<dive.size();i++){
        is[dive[i]] = true;
    }
}

LL qp(LL x,LL y)
{
    LL res = 1LL;
    while(y){
        if(y & 1)
            res = res * x % MOD;
        x = x * x % MOD;
        y >>= 1;
    }
    return res;
}

LL comb(int x ,int y)
{
    if(y < 0 || y > x)
        return 0;
    if(y == 0 || y == x)
        return 1;
    return jie[x] * qp(jie[y] * jie[x-y] % MOD,MOD  - 2) % MOD;
}

void solve(int sum,int n)
{
    map< pair<int,int>,int >::iterator it;
    it = rem.find(make_pair(sum,n));
    if(it != rem.end()){
        printf("%d\n",(int)(it->second));
        return ;
    }
    memset(f,0,sizeof f);
    memset(is,false,sizeof is);
    get_dive(sum,n);
    int ma = dive.size();
    for(int i=ma-1;i>=0;i--){
        int d = dive[i];
        f[d] = comb(sum / d - 1,n - 1);
        for(int j=2*d;j<=dive[ma-1];j+=d){
            if(is[j]){
                f[d] = ((f[d] - f[j] + MOD) % MOD + MOD) % MOD;
            }
        }
    }
    printf("%d\n",(int)f[1]);
    rem[make_pair(sum,n)] = f[1];
    return ;
}

int main()
{
    init();
    int test;
    scanf("%d",&test);
    while(test--){
        int sum,n;
        scanf("%d %d",&sum,&n);
        solve(sum,n);
    }
    return 0;
}

 

posted on 2016-06-03 17:42  _fukua  阅读(478)  评论(0编辑  收藏  举报