HDU-6333 Problem B. Harvest of Apples 莫队

HDU-6333

 

题意:

有n个不同的苹果,你最多可以拿m个,问有多少种取法,多组数据,组数和n,m都是1e5,所以打表也打不了。

思路:

这道题要用到组合数的性质,记S(n,m)为从n中最多取m个的方法总数,显然是C(n,0),C(n,1)……C(n,m)的和。

显然S(n,m+1) = S(n, m) + C(n,m+1);

还有一个等式就不那么明显了,S(n+1,m) = 2 * S(n,m) - C(n,m);

我也是在王神犇的指导下明白的。

 

既然知道了一组(n,m)是可以在很快的时间下转移到(n+1,m),(n-1,m),(n,m+1),(n,m-1)的,这个时候就要想到莫队。把每一组的n和m转化成区间的右端点和左端点,是不是很神奇。

那如何求组合数C(n,m)?可以先预处理出n的前缀阶乘,每次除一下,就可以得到,当然,因为是取模意义下,这里除一下就要去乘以这个数的逆元。

 

 

这题还有个细节就是要先更新作为n的右端点,为了防止右端点小于左端点的情况出现,即n 比 m 小。

 

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <queue>
#include <list>
#include <cstdlib>
#include <iterator>
#include <cmath>
#include <iomanip>
#include <bitset>
#include <cctype>
#include <iostream>
using namespace std;
//#pragma comment(linker, "/STACK:102400000,102400000")  //c++
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue

typedef long long ll;
typedef unsigned long long ull;

typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;

//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
// #define endl '\n'

#define OKC ios::sync_with_stdio(false);cin.tie(0)
#define FT(A,B,C) for(int A=B;A <= C;++A)  //用来压行
#define REP(i , j , k)  for(int i = j ; i <  k ; ++i)
//priority_queue<int ,vector<int>, greater<int> >que;

const ll mos = 0x7FFFFFFF;  //2147483647
const ll nmos = 0x80000000;  //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f; //18

template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}
// #define _DEBUG;         //*//
#ifdef _DEBUG
freopen("input", "r", stdin);
// freopen("output.txt", "w", stdout);
#endif
/*-----------------------show time----------------------*/
                #define bel(x) ((x-1)/B+1)
                const int maxn = 1e5+9;
                const int B = 233;
                const int MOD = 1e9+7;
                ll ans[maxn];
                struct node {
                    int n,m;
                    int id;
                } p[maxn];

                ll X,Y;
                void exgcd(ll a,ll b){
                    if(b==0){
                        X = 1;Y = 0;
                        return;
                    }
                    exgcd(b,a%b);
                    ll tmp = X;
                    X = Y;
                    Y = tmp - a/b*Y;
                }

                bool cmp(const node &a,const node &b){
                    if(bel(a.m) == bel(b.m))
                        return a.n < b.n;
                    return bel(a.m) < bel(b.m);
                }
                ll pm[maxn],TWO,NY[maxn];       //NY是预处理的逆元,不出来会TLE
                void init(){
                    pm[0] = 1;
                    exgcd(1,MOD);
                    NY[0] = (X + MOD)%MOD;
                    for(int i=1; i<maxn; i++){
                        pm[i] = (pm[i-1] * i + MOD) % MOD;
                        exgcd(pm[i],MOD);
                        NY[i] = (X + MOD)%MOD;
                    }
                    
                    exgcd(2,MOD);
                    TWO = (X + MOD)%MOD;
                }
                ll get(int n, int x){
                        if(n-x < 0)return 0;
                        ll res = (pm[n] *NY[n-x])%MOD;
                        res = (res * NY[x])%MOD;
                        return res;
                }
                ll sum = 0;
                void del1(int x,int n){
                    sum =(sum- get(n,x)+MOD)%MOD;
                }

                void add1(int x,int n){
                    sum=(sum + get(n,x) +MOD)%MOD;
                }

                void del2(int x,int n){ 
                    sum = ((sum + get(n,x))*TWO + MOD)%MOD;
                }

                void add2(int x,int n){
                    sum = (sum * 2 - get(n,x)+MOD)%MOD;
                }

int main(){
                init();
                int q;
                scanf("%d", &q);
                for(int i=1; i<=q; i++){
                    scanf("%d%d",&p[i].n,&p[i].m);
                    p[i].id = i;
                }
                sort(p+1,p+1+q,cmp);
                int pl = p[1].m, pr = p[1].n;
                for(int i=0; i<=pl; i++)
                {
                    sum = (sum + get(pr,i) + MOD)%MOD;
                }
                ans[p[1].id] = sum;
               // cout<<"**"<<endl;
                for(int i=2; i<=q; i++){
                    while(pr < p[i].n) add2(pl,pr),pr++;//这里要先更新作为n的右区间,防止m>n;
                    while(pr > p[i].n) pr--,del2(pl,pr);

                    while(pl < p[i].m) pl++,add1(pl,pr);
                    while(pl > p[i].m) del1(pl,pr),pl--;
                    ans[p[i].id] = sum%MOD;
                }

                for(int i=1; i<=q; i++){
                    printf("%lld\n", ans[i]);
                }
        return 0;
}
HDU6333

 

posted @ 2018-08-07 20:59  ckxkexing  阅读(165)  评论(0编辑  收藏  举报