loj3315 「ZJOI2020」抽卡

  • 前置知识: \(min-max\) 容斥(期望形式):\(E(\min(S))=\sum_{T \subseteq S}(-1)^{|T|+1}E(\max(T))\)\(\min(S)=\sum_{T \subseteq S}(-1)^{|T|+1}\max(T)\)。本题中的 \(E(\min(S))\) 表示 \(S\) 中至少有一段长度为 \(k\) 的连续数字都出现的期望时间,\(E(\max(T))\) 表示 \(T\) 中每一段长度为 \(k\) 的连续数字都出现的期望时间( \(S,T\) 为段的集合)。

  • 我们可以发现一些性质:如果说我选了两个 \(len=k\) 的段,它们有交集(1.)或者紧挨着一起(2.),并且第二段的开头和第一段的开头不相邻,那么它们的贡献实际上是没有的。如下图:

img

  • 为啥呢,由于 \(i..i+k-1\) 是连续的,\(i+t..i+t+k-1\) 也是连续的,那么 \(i..i+t+k-1\) 也是连续的(\(t\leq k\))。我们考虑这样:

img

  • 如果我们固定最上面的和最下面的必选,那么相当于固定了 \(min-max\) 容斥中的 \(\max(S)\),那么中间的 \(i+1..i+t-1\) 开头的这些段选或不选对这个 \(\max\) 是没有影响的,影响的只是前面的容斥系数。
  • \(i+1..i+t-1\) 开头的这些段选或不选的容斥系数之和,通过二项式定理,应该为 \(\binom{t-1}0(-1)^3+\binom{t-1}1(-1)^4+\binom{t-1}2(-1)^5+\cdots+\binom{t-1}{t-1}(-1)^{t+2}\),也就等于 \((-1)^3(1-1)^{t-1}=0\)。(每一项系数是 \((-1)^{|T|+1}\)
  • 要注意 \(t=1\) 的时候是不能被抵消的。也就是只有可能是单独的 \(len=k\) 的或开头连续的 \(len=k+1\) 的(也就是两个 \(len=k\) 开头连续的拼起来)。且所有段不能相交,两个 \(len=k\) 的段之间不能连续。
  • 现在考虑连续的 \(len=k+1\) 的段和 \(len=k\) 的段/ \(len=k\) 的段和 \(len=k+1\) 的段有没有贡献,这样的话相当于求长度为 \(2k+1\) 的会不会全被抵消:

img

  • 显然,我们发现它的贡献并不能被消掉,因为 \(i+k\) 这个数它是必须被选的,也就是右端点在 \(i+k..i+2k-1\) 之间的段必须选择一个。
  • 如果这些段都可以选或不选,通过二项式定理,那么贡献应该为 \(0\)
  • 但是必须选一个的话,那么 \(\binom k0(-1)^3\) 这一项消失了,那么贡献应该为 \(1\)。也就是我们可以强制钦定 \(k+1/k\)\(k/k+1\)\(k+1/k\) 有贡献。
  • 再考虑两个连续的 \(len=k+1\) 的段有没有贡献。同样的:

img

  • 情况1:右端点在 \(i+k+1..i+2k-1\) 之间的段必须选择一个。情况2:右端点在 \(i+k\)\(i+2k\) 都选了。

  • 两种情况满足其一才能算入答案。若均不满足,也就是 \(i+k+1..i+2k-1\) 都没选, \(i+k\)\(i+2k\) 只选了一个或都没选。

  • 如果这些段都可以选或不选,通过二项式定理,那么贡献应该为 \(0\)。全不满足的情况:

    1. \(i+k..i+2k\) 都没选,其它都选了,贡献 \((-1)^3=-1\)
    2. \(i+k..i+2k-1\) 都没选,其它都选了,贡献 \((-1)^4=1\)
    3. \(i+k+1..i+2k\) 都没选,其它都选了,贡献 \((-1)^4=1\)
  • 故总贡献即为 \(-1\),也就是 \((-1)^5\)。(而 \(k+1\) 本来就是长度为 \(k\)\(4\) 个段,贡献即为 \(-1\)

  • 所以现在可以转化为,只选 \(len=k/k+1\) 的段,其中 \(len_1=k,len_2=k/k+1\) 的两个段不能的开头不能相邻,所有的段不能相交,问方案数。

  • 我们可以在原序列中选出一些不相交的长为 \(k+1\) 的段,然后考虑段中最后一个位置选/不选。可以发现,这样选出来不会出现 \(len_1=k,len_2=k/k+1\) 的开头相邻的两个段。记长度为 \(n\) 的连续段,这个算出来的生成函数为 \(g_n(x)\)

  • 但是有个问题,如果我选择了 \(n-k+1..n\) 这个长度为 \(k\) 的段,它不能通过 \(n-k+1..n+1\) 这个段然后不选 \(n+1\) 来得到,因为 \(n+1\) 越界了。但是我们会发现,如果选了这个段之后,就必定不能选 \(n-2k+1..n-k\) 这个长度为 \(k\) 的段了,这个的生成函数恰好为 \(-x^kg_{n-k}(x)\)

  • \(f_n(x)\) 表示长度为 \(n\) 的连续段的答案生成函数。有

\[f_n(x)=g_n(x)-x^kg_{n-k}(x)\\ g_n(x)=\sum_{i=0}^{\lfloor \frac n{k+1}\rfloor} \binom{n-ik}{i} (x^{k+1}-x^k)^i \]

  • 暴力计算时间复杂度为 \(\sum\limits_{i=0}^{\lfloor \frac n{k+1}\rfloor}i=O(\frac{n^2}{k^2})\)。然后它就水过了,水过了。。。。
  • 当然,算 \(g_n\) 的这个可以优化,我们考虑通过分治来计算。先令 \(val_i=\binom{n-ik}{i}\),则有 \(ans(l,r)=\sum\limits_{i=l}^r val_i (x^{k+1}-x^k)^i=ans(l,mid)+ans(mid+1,r)\times(x^{k+1}-x^k)^{mid-l+1}\),其中 \((x^{k+1}-x^k)^y=(x-1)^yx^{ky}\),而计算 \((x-1)^y\) 的时间复杂度为 \(O(y)\),故这个分治的总时间复杂度为 \(T(l)=2T(\frac l2)+O(lk \log lk)\),而这里的 \(l\) 为原先的 \(\lfloor \frac n{k+1}\rfloor\),把 \(k\) 提出来可得时间复杂度为 \(O(k\times \frac nk\log^2 \frac nk+k\times\frac nk\log \frac nk\log k)\)。设 \(n,k\) 同阶,故总时间复杂度为 \(O(n \log^2 n)\)
  • 最后把这些生成函数乘起来,可以用分治 \(NTT\),我这里用了一个类似于合并果子的东西,每次取次数最小的两个多项式乘起来。故总时间复杂度为 \(O(m\log^2 m)\)
#include<cstdio>
#include<vector>
#include<algorithm>
#include<queue>
using namespace std;
typedef long long ll;
typedef vector<int> vec;
typedef pair<int,int> pii;
const int Mod=998244353;
const int G=3; const int invG=(Mod+1)/3;
int m,k,a[210000],Ans[210000];
int fac[210000],inv[210000],invfac[210000];
vec poly[210000],v; int cnt;
priority_queue<pii,vector<pii>,greater<pii> > que;
inline int add(int x,int y){ return x+y>=Mod?x+y-Mod:x+y;}
inline int dec(int x,int y){ return x-y<0?x-y+Mod:x-y;}
inline int mul(int x,int y){ return 1ll*x*y%Mod;}
char Getchar(){
    static char now[1<<20],*S,*T;
    if (T==S){
        T=(S=now)+fread(now,1,1<<20,stdin);
        if (T==S) return EOF;
    }
    return *S++;
}
int read(){
    int x=0,f=1;
    char ch=Getchar();
    while (ch<'0'||ch>'9'){
        if (ch=='-') f=-1;
        ch=Getchar();
    }
    while (ch<='9'&&ch>='0') x=x*10+ch-'0',ch=Getchar();
    return x*f;
}
ll qpow(ll x,ll a){
    ll res=1;
    while (a){
        if (a&1) res=res*x%Mod;
        x=x*x%Mod; a>>=1;
    }
    return res;
}
inline ll getinv(int x){ return qpow(x,Mod-2);}
int rev[1100000];
int GPow[2][19][1100000];
void initG(){
    for (int p=1;p<=18;p++){
        int buf1=qpow(G,(Mod-1)/(1<<p));
        int buf0=qpow(invG,(Mod-1)/(1<<p));
        GPow[1][p][0]=GPow[0][p][0]=1;
        for (int i=1;i<(1<<p);i++){
            GPow[1][p][i]=mul(GPow[1][p][i-1],buf1);
            GPow[0][p][i]=mul(GPow[0][p][i-1],buf0);
        }
    }
}
void NTT(vec &a,int len,int inv){
    a.resize(len);
    for (int i=0;i<len;i++)
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int l=2,cnt=1;l<=len;l<<=1,cnt++){
        int m=l>>1;
        for (int i=0;i<len;i+=l){
            int *buf=GPow[inv][cnt];
            for (int j=0;j<m;j++,buf++) {
                int x=a[i+j],y=1ll*(*buf)*a[i+j+m]%Mod;
                a[i+j]=add(x,y),a[i+j+m]=dec(x,y);
            }
        }
    }
    if (inv!=1){
        ll inv=getinv(len);
        for (int i=0;i<len;i++) a[i]=mul(a[i],inv);
    }
}
void mult(vec &a,vec &b){
	int n=(int)a.size()+(int)b.size()-1;
    int bit=0; while ((1<<bit)<n) bit++;
    int len=1<<bit;
    for (int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    NTT(a,len,1); NTT(b,len,1);
    a.resize(len);
    for (int i=0;i<len;i++) a[i]=mul(a[i],b[i]);
    NTT(a,len,0);
}

vec tmp3;
int val[210000];
inline int C(int x,int y){
	if (x<y) return 0;
	return mul(fac[x],mul(invfac[y],invfac[x-y]));
}
vec solve(int l,int r){
	if (l==r){
		vec tmp;
		tmp.resize(1); tmp[0]=val[l];
		return tmp;
	}
	int mid=(l+r)>>1;
	vec tmp1=solve(l,mid),tmp2=solve(mid+1,r);
	int len=mid-l+1,t=k*len;
	tmp3.clear(); tmp3.resize(t+len+1);
	for (int i=0;i<=len;i++)
		if ((len-i)&1) tmp3[i+t]=dec(0,C(len,i));
		else tmp3[i+t]=C(len,i);
	mult(tmp3,tmp2);
	tmp3.resize(max(tmp3.size(),tmp1.size()));
	for (int i=0;i<(int)tmp1.size();i++) tmp3[i]=add(tmp3[i],tmp1[i]);
	return tmp3;
}
vec v1,v2;
void getans(int n){
	if (n<k) return;
	int i,s;
	for (i=0,s=0;s+i<=n;i++,s+=k) val[i+1]=C(n-s,i);
	v1=solve(1,i);
	for (i=0,s=k;s+i<=n;i++,s+=k) val[i+1]=C(n-s,i);
	v2=solve(1,i);
	v1.resize(max(v1.size(),v2.size()+k));
	for (int i=0;i<(int)v2.size();i++) v1[i+k]=dec(v1[i+k],v2[i]);
	while ((int)v1.size()>1&&!v1.back()) v1.pop_back();
	poly[++cnt]=v1;
}
int main(){
	m=read(); k=read(); initG();
	fac[0]=1; for (int i=1;i<=m+1;i++) fac[i]=mul(fac[i-1],i);
	inv[1]=1; for (int i=2;i<=m+1;i++) inv[i]=mul((Mod-Mod/i),inv[Mod%i]);
    for (int i=1;i<=m;i++) Ans[i]=add(Ans[i-1],mul(m,inv[i]));
	invfac[0]=1; for (int i=1;i<=m+1;i++) invfac[i]=mul(invfac[i-1],inv[i]);
	for (int i=1;i<=m;i++) a[i]=read();
	sort(a+1,a+m+1);
	for (int i=1,j=0;i<=m;i=j+1){
		j=i;
		while (j<m&&a[j+1]==a[j]+1) j++;
		getans(j-i+1);
	}
	for (int i=1;i<=cnt;i++) que.push(pii((int)poly[i].size(),i));
	while (que.size()>1){
		int x=que.top().second; que.pop();
		int y=que.top().second; que.pop();
		mult(poly[x],poly[y]);
		que.push(pii((int)poly[x].size(),x));
	}
	int x=que.top().second,ans=0;
	for (int i=1;i<(int)poly[x].size()&&i<=m;i++) ans=add(ans,mul(Ans[i],dec(0,poly[x][i])));
	printf("%d\n",ans);
	return 0;
}
posted @ 2020-08-05 21:10  hydd  阅读(200)  评论(0编辑  收藏  举报