Loading

题解-CTS2019 珍珠

题面

CTS2019 珍珠

\(n\) 个在 \([1,d]\) 内的整数,求使可以拿出 \(2m\) 个整数凑成 \(m\) 个相等的整数对的方案数。

数据范围:\(0\le m\le 10^9\)\(1\le n\le 10^9\)\(1\le d\le 10^5\)


蒟蒻语

非常巧妙的题,主要要用到二项式反演、指数级生成函数和 NTT

做个广告,这是我读过最好的生成函数讲解:link


蒟蒻解

\(c_i\) 表示 \(i\) 这个数的出现次数。

\(odd=\sum [c_i\in {\rm odd} ]\),即 \(c_i\) 奇数个数。

很明显最多能凑成 \(\frac{n-odd}{2}\) 对,按题意:

\[\begin{aligned} \frac{n-odd}{2}&\ge m\\ odd&\le n-2m \end{aligned} \]

这里有两个特判,如果 \(n-2m<0\) 答案是 \(0\),如果 \(n-2m\ge d\) 答案是 \(d^n\)

\(g(i)\) 表示 \(odd=i\) 的方案数。

\(f(i)\) 表示钦定 \(i\)\(c_i\) 是奇数,剩下随意的方案数(不是 \(odd\ge i\) 的方案数,这里对一些排列会重复统计,但是反演完就没事了)。

\[f(i)=\sum_{x=i}^d{x\choose i}g(x)\Longleftrightarrow g(i)=\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x) \]

所以可以先求 \(f(i)\),用到了指数级生成函数,中间把每个 \(e\) 的幂次项展开,最后归成卷积形式:

\[\begin{aligned} f(i)=&{d \choose i} \left(\frac{e^x-e^{-x}}{2}\right)^i e^{(d-i)x} n![n]\\ =&{d \choose i} (e^x-e^{-x})^i e^{(d-i)x} \frac{n!}{2^i}[n]\\ =&{d \choose i} e^{(d-i)x} \sum_{j=0}^i(-1)^{i-j} {i\choose j}e^{(j-(i-j))x} \frac{n!}{2^i}[n]\\ =&{d \choose i} e^{(d-i)x} \sum_{j=0}^i(-1)^{i-j} {i\choose j}e^{(2j-i)x} \frac{n!}{2^i}[n]\\ =&{d \choose i} \frac{n!}{2^i}\sum_{j=0}^i(-1)^{j} {i\choose j}e^{(d-2j)x} [n]\\ =&{d \choose i} \frac{n!}{2^i}\sum_{j=0}^i(-1)^{j} {i\choose j} \frac{(d-2j)^n}{n!}\\ =&\frac{d!}{i!(d-i)!2^i}\sum_{j=0}^i(-1)^{j} \frac{i!}{j!(i-j)!} (d-2j)^n\\ =&\frac{d!}{(d-i)!2^i}\sum_{j=0}^i\frac{(-1)^{j}(d-2j)^n}{j!}\cdot \frac{1}{(i-j)!} \\ \end{aligned} \]

最后一个难点是如何求:

\[g(i)=\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x) \]

感觉可以凑成卷积形式,但总差一点。尝试把 \(f\)\(g\) 都反过来,即令 \(f'(x)=f(d-x)\)\(g'(x)=g(d-x)\)

\[\begin{aligned} g(i)=&\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x)\\ =&\sum_{x=i}^d(-1)^{x-i}{x\choose i} f'(d-x)\\ =&\sum_{x=0}^{d-i}(-1)^{d-x-i}{d-x\choose i} f'(x)\\ \end{aligned}\\ g'(d-i)=\sum_{x=0}^{d-i}(-1)^{d-x-i}{d-x\choose i} f'(x)\\ \begin{aligned} g'(i)=&\sum_{x=0}^{i}(-1)^{i-x}{d-x\choose d-i} f'(x)\\ =&\sum_{x=0}^{i}(-1)^{i-x}\frac{(d-x)!}{(d-i)!(i-x)!} f'(x)\\ =&\frac{1}{(d-i)!}\sum_{x=0}^{i} (-1)^{i-x}\frac{1}{(i-x)!}\cdot f'(x)(d-x)!\\ \end{aligned}\\ \]

然后就可以求 \(g(i)\) 了。

答案便是 \(\sum_{i=0}^{n-2m} g(i)\)


代码

#include <bits/stdc++.h>
using namespace std;

//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair((a),(b))
#define x first
#define y second
#define bg begin()
#define ed end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
#define R(i,a,b) for(int i=(a),i##E=(b);i<i##E;i++)
#define L(i,a,b) for(int i=(b)-1,i##E=(a)-1;i>i##E;i--)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1e5;
int n,m,d,ans;

//Math
const int mod=998244353;
int Pow(int a,int x){
    int res=1; for(;x;x>>=1,a=1ll*a*a%mod)
    if(x&1) res=1ll*res*a%mod; return res;
}
const int mN=N+1;
int fac[mN],ifac[mN];
void math_init(){
    fac[0]=1;R(i,1,d+1) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[d]=Pow(fac[d],mod-2);
    L(i,0,d) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}

//Poly
const int pN=mN<<2;
int f[pN],g[pN];
const int G=3,iG=Pow(3,mod-2);
int rev[pN],pn;
void poly_init(){
    pn=1<<int(ceil(log2(d*2+2)));
    R(i,0,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));
}
void NTT(int* a,int t){
    R(i,0,pn)if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int mid=1;mid<pn;mid<<=1){
        int wn=Pow(~t?G:iG,(mod-1)/(mid<<1));
        for(int i=0;i<pn;i+=(mid<<1)){
            int w=1;
            R(j,i,mid+i){
                int x=a[j],y=1ll*a[mid+j]*w%mod;
                a[j]=(x+y)%mod,a[mid+j]=(x-y+mod)%mod,w=1ll*w*wn%mod;
            }
        }
    }
    if(!~t){
        int in=Pow(pn,mod-2);
        R(i,0,pn) a[i]=1ll*a[i]*in%mod;
    }
}

//Main
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>d>>n>>m;
    if(n-2*m<0) return cout<<0<<'\n',0;
    if(d<=n-2*m) return cout<<Pow(d,n)<<'\n',0;
    math_init(),poly_init();
    R(i,0,d+1){
        f[i]=1ll*Pow((d-2*i+mod)%mod,n)*ifac[i]%mod;
        if(i&1) f[i]=(mod-f[i])%mod;
        g[i]=ifac[i];
    }
    R(i,d+1,pn) f[i]=g[i]=0;
    NTT(f,1),NTT(g,1);
    R(i,0,pn) f[i]=1ll*f[i]*g[i]%mod;
    NTT(f,-1);
    R(i,d+1,pn) f[i]=g[i]=0;
    R(i,0,d+1) f[i]=1ll*f[i]*fac[d]%mod*ifac[d-i]%mod*Pow(2,mod-1-i)%mod;
    reverse(f,f+d+1);
    R(i,0,d+1) f[i]=1ll*f[i]*fac[d-i]%mod;
    R(i,0,d+1){
        g[i]=ifac[i];
        if(i&1) g[i]=(mod-g[i])%mod;
    }
    R(i,d+1,pn) f[i]=g[i]=0;
    NTT(f,1),NTT(g,1);
    R(i,0,pn) g[i]=1ll*f[i]*g[i]%mod;
    NTT(g,-1);
    R(i,d+1,pn) f[i]=g[i]=0;
    R(i,0,d+1) g[i]=1ll*g[i]*ifac[d-i]%mod;
    reverse(g,g+d+1);
    R(i,0,n-2*m+1) (ans+=g[i])%=mod;
    cout<<ans<<'\n';
    return 0;
}

祝大家学习愉快!

posted @ 2020-10-15 20:08  George1123  阅读(190)  评论(1编辑  收藏  举报