Loj #6503. 「雅礼集训 2018 Day4」Magic

Loj #6503. 「雅礼集训 2018 Day4」Magic

题目描述

前进!前进!不择手段地前进!——托马斯 · 维德

魔法纪元元年。

1453 年 5 月 3 日 16 时,高维碎片接触地球。

1453 年 5 月 28 日 21 时,碎片完全离开地球。

1453 年,君士坦丁堡被围城,迪奥娜拉接触到四维泡沫空间,成为魔法师,最终因高维碎片消失失去魔力而身死。

为了改写这段历史,你不惜耗费你珍藏已久的魔术卡来回到魔法纪元元年。

在使用这些魔术卡之前,你却对它们的排列起了兴趣...

桌面上摆放着 \(m\) 种魔术卡,共 \(n\) 张,第 \(i\) 种魔术卡数量为 \(a_i\),魔术卡顺次摆放,形成一个长度为 \(n\) 的魔术序列,在魔术序列中,若两张相邻魔术卡的种类相同,则它们被称为一个魔术对。

两个魔术序列本质不同,当且仅当存在至少一个位置,使得两个魔术序列这个位置上的魔术卡的种类不同,求本质不同的恰好包含 \(k\) 个魔术对的魔术序列的数量,答案对 \(998244353\) 取模。

输入格式

第一行三个整数 \(m, n, k\)

第二行 \(m\) 个正整数,第 \(i\) 个正整数表示 \(a_i\)

输出格式

一行一个整数表示答案。

数据范围与提示

对于 \(100 \%\) 的数据满足 \(1 \leq m \leq 20000, 0 \leq k \leq n \leq 100000, \sum_{i = 1}^{m} a_i = n\)


首先假设同种颜色的卡片是有标号的,因为这样要好做得多。最后算出来的答案还要乘上\(\frac{1}{\prod_{i=1}^m {a_i!}}\)

然后就可以考虑容斥了。设\(g_i\)表示强制有\(i\)个魔术对的方案数,答案为\(\sum_{i=k}^n(-1)^{i-k}\binom{k}{i}g_i\)

对于第\(i\)种卡片,我们要先求出强制要求有\(x\)对魔术对的方案数。这个问题等价于将\(a_i\)张卡片分成\(a_i-x\)个排列,答案为\(\frac{a_i!\cdot \binom{a_i-1}{a_i-x-1}}{(a_i-x)!}\)。所以对于第\(i\)中卡片,将其分为\(x\)个排列的生成函数为

\[f_i(x)=\sum_{j=1}^{a_i}\frac{a_i!\binom{a_i-1}{j-1}}{j!}x^j \]

将所有的\(f\)乘起来就得到了\(g\),具体实现可以用分治\(NTT\)

\[g(x)=\sum_{j=1}^n([x^j]\prod_{i=1}^m f_i(x))\cdot j!\cdot x^j \]

因为这些排列之间的相对位置也需要确定,所以\(p\)个排列的方案数要乘上\(p!\)

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 100005
#define M 20005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

const ll mod=998244353;
ll ksm(ll t,ll x) {
	ll ans=1;
	for(;x;x>>=1,t=t*t%mod)
		if(x&1) ans=ans*t%mod;
	return ans;
}


int m,n,k;
int a[M];
ll W[20][N<<2];
void pre(int s) {
	for(int i=1;i<=s;i++) {
		int len=1<<i;
		ll t=ksm(3,(mod-1)/len);
		W[i][0]=1;
		for(int j=1;j<=len;j++) W[i][j]=W[i][j-1]*t%mod;
	}
}
void NTT(ll *a,int d,int flag) {
	static int rev[N<<2];
	int n=1<<d;
	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int s=1;s<=d;s++) {
		int len=1<<s,mid=len>>1;
		for(int i=0;i<n;i+=len) {
			for(int j=0;j<mid;j++) {
				ll t=flag==1?W[s][j]:W[s][len-j];
				ll u=a[i+j],v=a[i+j+mid]*t%mod;
				a[i+j]=(u+v)%mod;
				a[i+j+mid]=(u-v+mod)%mod;
			}
		}
	}
	if(flag==-1) {
		ll inv=ksm(n,mod-2);
		for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
	}
}

ll fac[N],ifac[N];
ll C(int n,int m) {return n<m?0:fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
vector<int>st[M],g;
int size[M];

int Find(int L,int R) {
	int l=L,r=R,mid;
	while(l<r) {
		mid=l+r+1>>1;
		if(size[mid]-size[L-1]<=size[R]-size[mid-1]) l=mid;
		else r=mid-1;
	}
	return l;
}

vector<int>solve(int l,int r) {
	static ll A[N<<2],B[N<<2];
	if(l==r) return st[l];
	int mid=Find(l,r),tot=size[r]-size[l-1];
	vector<int>L=solve(l,mid),R=solve(mid+1,r),ans;
	ans.clear();
	int d=ceil(log2(tot+1));
	for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
	for(int i=0;i<L.size();i++) A[i]=L[i];
	for(int i=0;i<R.size();i++) B[i]=R[i];
	NTT(A,d,1),NTT(B,d,1);
	for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
	NTT(A,d,-1);
	for(int i=0;i<=tot;i++) ans.push_back(A[i]);
	return ans;
}

bool cmp(int x,int y) {return x<y;}
ll ans;

int main() {
	pre(18);
	m=Get(),n=Get(),k=Get();
	fac[0]=1;
	for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
	ifac[n]=ksm(fac[n],mod-2);
	for(int i=n-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
	for(int i=1;i<=m;i++) a[i]=Get();
	sort(a+1,a+1+m,cmp);
	for(int i=1;i<=m;i++) {
		st[i].push_back(0);
		for(int j=1;j<=a[i];j++) {
			st[i].push_back(C(a[i],j)*fac[a[i]-1]%mod*ifac[j-1]%mod);
		}
	}
	for(int i=1;i<=m;i++) size[i]=size[i-1]+a[i];
	g=solve(1,m);
	for(int i=0;i<=n;i++) g[i]=g[i]*fac[i]%mod;
	for(int i=0;i<=n-k;i++) {
		if((k-(n-i))&1) ans-=C(n-i,k)*g[i]%mod;
		else ans+=C(n-i,k)*g[i]%mod;
	}
	ans=(ans%mod+mod)%mod;
	for(int i=1;i<=m;i++) ans=ans*ifac[a[i]]%mod;
	cout<<ans;
	return 0;
}

posted @ 2019-06-30 15:17  hec0411  阅读(519)  评论(1编辑  收藏  举报