【JZOJ6082】染色问题

Description

有n个格子,现在用m种颜色按顺序染m次,每次可以染一段区间(如果区间内有别的颜色将会被这种颜色覆盖),问最终所有格子都有颜色的情况下,不同的颜色序列有多少种。

Solution

最终序列肯定是一段一段的颜色,其实每次染色相当于从原有的颜色段中插入一段颜色。
f i , j f_{i,j} fi,j表示前 i i i次染色,颜色段长度为 j j j的方案数,容易得到转移就是:
f i , j = f i − 1 , j + ∑ k = 0 j − 1 ( k + 1 ) f i − 1 , k f_{i,j}=f_{i-1,j}+\sum_{k=0}^{j-1}(k+1)f_{i-1,k} fi,j=fi1,j+k=0j1(k+1)fi1,k
f i , j f_{i,j} fi,j看成 f i ( x ) [ x j ] f_i(x)[x^j] fi(x)[xj],转移可以看成 f i ( x ) = f i − 1 ( x ) ( i x + 1 ) f_i(x)=f_{i-1}(x)(ix+1) fi(x)=fi1(x)(ix+1),最后再乘上个组合数,答案可以写成(这里m颜色必须染):
∑ i = 1 m C m − 1 i − 1 ( [ x i ] x ∏ j = 2 n ( j x + 1 ) ) \sum_{i=1}^mC_{m-1}^{i-1}([x^i]x\prod_{j=2}^n(jx+1)) i=1mCm1i1([xi]xj=2n(jx+1))
分治NTT会T,考虑倍增:
f m ( x ) = ∏ i = 1 m ( ( i + 1 ) x + 1 ) f_m(x)=\prod\limits_{i=1}^m((i+1)x+1) fm(x)=i=1m((i+1)x+1), f m ′ ( x ) = ∏ i = m + 1 2 m ( ( i + 1 ) x + 1 ) f_m'(x)=\prod\limits_{i=m+1}^{2m}((i+1)x+1) fm(x)=i=m+12m((i+1)x+1)
那么 f 2 m ( x ) = f m ( x ) f m ′ ( x ) f_{2m}(x)=f_m(x)f_m'(x) f2m(x)=fm(x)fm(x)
考虑求 f m ′ ( x ) f_m'(x) fm(x),改写一下变成 f m ′ ( x ) = ∏ i = 1 m ( ( i + 1 ) x + m x + 1 ) f_m'(x)=\prod\limits_{i=1}^m((i+1)x+mx+1) fm(x)=i=1m((i+1)x+mx+1)
a i = f m ( x ) [ x i ] a_i=f_m(x)[x^i] ai=fm(x)[xi] b i = f m ′ ( x ) [ x i ] b_i=f_m'(x)[x^i] bi=fm(x)[xi],枚举有 f m ( x ) f_m(x) fm(x)有多少个 1 1 1变成了 m m m,可以得到:
b i + j = C m − i j m j a i b_{i+j}=C_{m-i}^jm^ja_i bi+j=Cmijmjai
把组合数拆开,多项式乘法即可。

Code

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
using namespace std;
typedef long long ll;
const int N=1e6+10,M=3e6+10,mo=998244353;
int qpow(int x,int y){
	int s=1;
	for(;y;y>>=1,x=(ll)x*x%mo) if(y&1) s=(ll)s*x%mo;
	return s;
}
int fn;
int jc[N],ny[N];
int rev[M];
int pl(int x,int y){
	return x+y>=mo?x+y-mo:x+y;
}
void NTT(int *a,int sig){
	fo(i,1,fn-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int m=2;m<=fn;m<<=1){
		int half=m>>1,w0=qpow(3,(mo-1)/m);
		if(sig<0) w0=qpow(w0,mo-2);
		for(int i=0;i<fn;i+=m)
		for(int j=i,w=1;j<i+half;++j,w=(ll)w*w0%mo){
			int u=a[j],v=(ll)a[j+half]*w%mo;
			a[j]=pl(u,v),a[j+half]=pl(u,mo-v);
		}
	}
	if(sig<0){
		int nf=qpow(fn,mo-2);
		fo(i,0,fn-1) a[i]=(ll)a[i]*nf%mo;
	}
}
void mul(int *a,int *b,int ln,int nd){
	int cnt=0;
	for(fn=1;fn<=(ln<<1);fn<<=1) ++cnt;
	fo(i,1,fn-1) rev[i]=rev[i>>1]>>1|(i&1)<<(cnt-1);
	fo(i,ln+1,fn-1) a[i]=b[i]=0;
	NTT(a,1),NTT(b,1);
	fo(i,0,fn-1) a[i]=(ll)a[i]*b[i]%mo;
	NTT(a,-1);
	fo(i,nd+1,fn-1) a[i]=0;
}
int b[M],c[M],d[M];
void solve(int *a,int n){
	if(n==1){
		a[0]=1,a[1]=2;
		return;
	}
	int m=n>>1;
	solve(a,m);
	fo(i,0,m) b[i]=(ll)a[i]*jc[m-i]%mo,c[i]=(ll)qpow(m,i)*ny[i]%mo;
	mul(b,c,m,m);
	fo(i,0,m) b[i]=(ll)b[i]*ny[m-i]%mo;
	mul(a,b,m,m<<1);
	if(n&1){
		c[0]=0;
		fo(i,0,m<<1) c[i+1]=(ll)a[i]*(n+1)%mo;
		fo(i,0,n) a[i]=pl(a[i],c[i]);
	}
}
int C(int m,int n){
	return (ll)jc[m]*ny[n]%mo*ny[m-n]%mo;
}
int a[M];
int main()
{
	freopen("color.in","r",stdin);
	freopen("color.out","w",stdout);
	int n,m,mx;
	scanf("%d %d",&n,&m),mx=max(n,m);
	jc[0]=1;
	fo(i,1,mx) jc[i]=(ll)jc[i-1]*i%mo;
	ny[mx]=qpow(jc[mx],mo-2);
	fd(i,mx,1) ny[i-1]=(ll)ny[i]*i%mo;
	solve(a,n-1);
	int ans=0;
	fo(i,0,m-1) ans=pl(ans,(ll)a[i]*C(m-1,i)%mo);
	printf("%d",ans);
}

posted @ 2019-03-28 22:16  sadstone  阅读(44)  评论(0编辑  收藏  举报