[luogu9171]染色数组

定义集合\(S\)由同时满足以下条件的\(x\)构成:

  • \([1,x)\)\(\le a_{x}\)的元素 和 \((x,n]\)\(\ge a_{x}\)的元素 构成递增子序列
  • \([1,x)\)\(\ge a_{x}\)的元素 和 \((x,n]\)\(\le a_{x}\)的元素 构成递减子序列

性质1:\(a\)为完美数组当且仅当\(S\ne \empty\)

必要性:注意\(x\in S\)\(x\)可以染成红色和绿色的必要条件

充分性:任取\(x\in S\),将条件中第\(1\)类元素染成红色,其余染成绿色即可

性质2:\(x\in S\),则\(x+1\in S\iff \forall i\in [1,x),a_{i}\not\in [a_{x},a_{x+1}]\cup [a_{x+1},a_{x}]\)

性质3:\(\begin{cases}l=\min_{x\in S}x\\r=\max_{x\in S}x\end{cases}\),则\(S=[l,r]\cap Z\)且满足以下条件之一

  • \(l+1=r\)\(a_{l}=a_{r}\)
  • \(a_{l}<a_{l+1}<...<a_{r}\)\(a_{l}>a_{l+1}>...>a_{r}\)

(证明可以自行分类讨论得到)

为了方便,以下均假设为第\(2\)种情况,第\(1\)种情况是类似的


\(r\)处统计方案数,注意到对于\(a_{[1,r)}\),恰存在一种染色方式使得红色的结尾\(<a_{r}\)且绿色的结尾\(>a_{r}\)

定义\(f_{i,x,y}\)表示前\(i\)个位置中两序列结尾分别为\(x,y\)的方案数,转移易优化至\(O(nm^{2})\),后缀类似

枚举\(r,a_{r}\)\(a_{r+1}\)后,即求形如\(\sum_{x\le x_{0}}\sum_{y\ge y_{0}}f_{i,x,y}\),预处理即可,时间复杂度为\(O(Cnm^{2})\)


性质4:得分最大的染色方案形如

  • \([1,r)\)\(\le a_{r}\)的元素 和 \((r,n]\)\(\ge a_{r}\)的元素 染成红色
  • \([1,r)\)\(\ge a_{r}\)的元素 和 \((r,n]\)\(\le a_{r}\)的元素 染成绿色
  • \(r\)的颜色取染红色或绿色中的较大值

\(a_{l}<a_{l+1}<...<a_{r}\)的情况为例,即在其中选至多一个元素染成绿色

注意到染成红色/绿色的得分单调递减/递增,显然可以贪心,最终即形如结论

仍枚举\(r,a_{r}\)\(a_{r+1}\),将贡献拆开,得分分为以下几部分——

  • \([1,r)\):用\(g/g0/g1_{i,x,y}\)表示……的得分和、红色数和 和 绿色数和 即可

  • \(r\):设\([1,r)\)中剩余\(k\)个位置,且\(<a_{r},>a_{r}\)的可填数有\(x,y\)个,答案即形如

    \[\sum_{i=0}^{k}{k\choose i}{x\choose i}{y\choose k-i}\max(iw+C,(k-i)(m-w+1)+C) \]

  • \((r,n]\):染成红色的\(a_{i}\)为例,即统计\(a_{[1,r)}\)\(>a_{i}\)的元素个数

    • \(r<t\),则前者中\(a_{r}\)\(a_{r+1}\)均已确定,外层枚举仅\(O(nm)\)

      注意到组合数中每个数出现次数相同,而出现次数之和即\(\sum_{i=0}^{k}i{k\choose i}{x\choose i}{y\choose k-i}\),同样预处理出后取平均即可

      已经确定的数出现次数即方案数,得到所有数出现次数后即可得到答案

    • \(r\ge t\),则后缀完全未确定,其中出现次数不同的数仅有\([1,a_{r+1}),a_{r+1},(a_{r},m]\)\(3\)

      此时,仅需在前缀修改时快速计算即可,修改次数共\(O(n)\)

时间复杂度为\(O(Cn^{2}m^{2})\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=55,M=205,mod=998244353;
int t,n,m,n0,ans1,ans2,a[N];
int inv[M],C[M][M],Q[N][M][M],Qi[N][M][M];
int lp[N],visp[N][M],Xp[N][M],Yp[N][M],cx[N][M],cy[N][M],cntp[M];
int ls[N],viss[N][M],Xs[N][M],Ys[N][M],cnts[M];
int add(int x,int y){
	x+=y;
	return (x<mod ? x : x-mod);
}
struct Data{
	int f,g,g0,g1;
	Data upd0(int x){
		return Data{f,(int)((g+(ll)x*g0)%mod),g0,add(f,g1)};
	}
	Data upd1(int x){
		return Data{f,(int)((g+(ll)(m-x+1)*g1)%mod),add(f,g0),g1};
	}
}fp[N][M][M];
Data add(Data x,Data y){
	return Data{add(x.f,y.f),add(x.g,y.g),add(x.g0,y.g0),add(x.g1,y.g1)};
}
Data dec(Data x,Data y){
	return Data{add(x.f,mod-y.f),add(x.g,mod-y.g),add(x.g0,mod-y.g0),add(x.g1,mod-y.g1)};
}
void get_sum(Data a[M][M]){
	for(int i=0;i<=m;i++)
		for(int j=m+1;j>i;j--){
			if (i)a[i][j]=add(a[i][j],a[i-1][j]);
			if (j<=m)a[i][j]=add(a[i][j],a[i][j+1]);
			if ((i)&&(j<=m))a[i][j]=dec(a[i][j],a[i-1][j+1]);
		}
}
void get_fp(){
	memset(fp,0,sizeof(fp));
	for(int x=0;x<=m;x++)
		for(int y=x+1;y<=m+1;y++)fp[0][x][y]=Data{1,0,0,0};
	for(int i=1;i<=n;i++){
		for(int j=1;j<=m;j++){
			if ((i<=n0)&&(a[i]!=j))continue;
			for(int x=0;x<j;x++){
				fp[i][x][j]=add(fp[i][x][j],fp[i-1][x][j+1].upd0(j));
				if (x)fp[i][x][j]=dec(fp[i][x][j],fp[i-1][x-1][j+1].upd0(j));
			}
			for(int y=j+1;y<=m+1;y++){
				fp[i][j][y]=add(fp[i][j][y],fp[i-1][j-1][y].upd1(j));
				if (y<=m)fp[i][j][y]=dec(fp[i][j][y],fp[i-1][j-1][y+1].upd1(j));
			}
		}
		get_sum(fp[i]);
	}
}
int S2(int n){
	return n*(n+1)*((n<<1)+1)/6;
}
int get_sum(int l,int r,int a){
	return S2(r)-S2(l-1)+a*(l+r)*(r-l+1)/2;
}
void addp(int l,int r,int x){
	for(int i=l;i<=r;i++)cntp[i]=add(cntp[i],x);
}
void adds(int l,int r,int x){
	for(int i=l;i<=r;i++)cnts[i]=add(cnts[i],x);
}
int calc(int j){
	int s=0,ans=0;
	for(int i=1;i<j;i++){
		ans=(ans+(ll)i*s%mod*cnts[i])%mod;
		s=add(s,cntp[i]);
	}
	s=0;
	for(int i=m;i>j;i--){
		ans=(ans+(ll)(m-i+1)*s%mod*cnts[i])%mod;
		s=add(s,cntp[i]);
	}
	return ans;
}
int main(){
	inv[0]=inv[1]=1;
	for(int i=2;i<M;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
	for(int i=0;i<M;i++){
		C[i][0]=C[i][i]=1;
		for(int j=1;j<i;j++)C[i][j]=add(C[i-1][j],C[i-1][j-1]);
	}
	for(int k=0;k<N;k++)
		for(int x=0;x<M;x++)
			for(int y=0;y<M;y++)
				for(int i=0;i<=k;i++){
					int s=(ll)C[k][i]*C[x][i]%mod*C[y][k-i]%mod;
					Q[k][x][y]=add(Q[k][x][y],s);
					Qi[k][x][y]=(Qi[k][x][y]+(ll)i*s%mod*inv[x])%mod;
				}
	scanf("%d",&t);
	while (t--){
		scanf("%d%d%d",&n,&m,&n0);
		for(int i=1;i<=n0;i++)scanf("%d",&a[i]);
		ans1=ans2=0,get_fp();
		for(int i=1;i<=n;i++){
			lp[i]=max(i-n0-1,0),ls[i]=n-max(i,n0);
			for(int j=1;j<=m;j++){
				visp[i][j]=viss[i][j]=0;
				if ((i<=n0)&&(a[i]!=j))continue;
				bool flag=cx[i][j]=cy[i][j]=0;
				Xp[i][j]=0,Yp[i][j]=m+1;
				for(int k=1;(k<i)&&(k<=n0);k++){
					if (a[k]==j){flag=1;break;}
					if (a[k]<j){
						if (Xp[i][j]>=a[k]){flag=1;break;}
						Xp[i][j]=a[k],cx[i][j]++;
					}
					else{
						if (Yp[i][j]<=a[k]){flag=1;break;}
						Yp[i][j]=a[k],cy[i][j]++;
					}
				}
				if (!flag)visp[i][j]=1;
				flag=0,Xs[i][j]=Ys[i][j]=j;
				for(int k=i+1;k<=n0;k++){
					if (a[k]==j){flag=1;break;}
					if (a[k]<j){
						if (Xs[i][j]<=a[k]){flag=1;break;}
						Xs[i][j]=a[k];
					}
					else{
						if (Ys[i][j]>=a[k]){flag=1;break;}
						Ys[i][j]=a[k];
					}
				}
				if (!flag)viss[i][j]=1;
			}
		}
		for(int i=1;i<n;i++)
			for(int j=1;j<=m;j++)
				if ((visp[i][j])&&(viss[i+1][j])){
					int xp=j-Xp[i][j]-1,yp=Yp[i][j]-j-1;
					int xs=Xs[i+1][j]-1,ys=m-Ys[i+1][j],s=Q[ls[i+1]][xs][ys];
					ans1=(ans1+(ll)fp[i-1][j-1][j+1].f*s)%mod;
					ans2=(ans2+(ll)fp[i-1][j-1][j+1].g*s)%mod;
					ans2=(ans2+(ll)(j*cx[i][j]+(m-j+1)*(cy[i][j]+lp[i]))*Q[lp[i]][xp][yp]%mod*s)%mod;
					ans2=(ans2+(ll)((j<<1)-m-1+mod)*xp%mod*Qi[lp[i]][xp][yp]%mod*s)%mod;
					memset(cntp,0,sizeof(cntp));
					memset(cnts,0,sizeof(cnts));
					addp(Xp[i][j]+1,j-1,Qi[lp[i]][xp][yp]);
					addp(j+1,Yp[i][j]-1,Qi[lp[i]][yp][xp]);
					for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],Q[lp[i]][xp][yp]);
					adds(1,Xs[i+1][j]-1,Qi[ls[i+1]][xs][ys]);
					adds(Ys[i+1][j]+1,m,Qi[ls[i+1]][ys][xs]);
					for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
					ans2=add(ans2,calc(j));
				}
		for(int i=1;i<=n;i++)
			for(int j=1;j<=m;j++)
				if ((visp[i][j])&&(viss[i][j])){
					int xp=j-Xp[i][j]-1,yp=Yp[i][j]-j-1;
					if (i==n){
						ans1=add(ans1,fp[i-1][j-1][j+1].f);
						ans2=add(ans2,fp[i-1][j-1][j+1].g);
						for(int z=0;z<=lp[i];z++)
							ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*C[xp][z]%mod*C[yp][lp[i]-z])%mod;
						continue;
					}
					for(int w=1;w<j;w++){
						if ((i+1<=n0)&&(a[i+1]!=w))continue;
						int xs=Xs[i+1][w]-1,ys=m-Ys[i][j],s=Q[ls[i+1]][xs][ys];
						Data o=dec(fp[i-1][j-1][j+1],fp[i-1][w-1][j+1]);
						ans1=(ans1+(ll)o.f*s)%mod,ans2=(ans2+(ll)o.g*s)%mod;
						for(int z=0;z<=lp[i];z++){
							int s0=add(C[xp][z],(xp<j-w ? 0 : mod-C[xp-(j-w)][z]));
							ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*s0%mod*C[yp][lp[i]-z]%mod*s)%mod;
						}
						if (i<n0){
							memset(cntp,0,sizeof(cntp));
							memset(cnts,0,sizeof(cnts));
							if (xp>=j-w)addp(Xp[i][j]+1,w-1,add(Qi[lp[i]][xp][yp],mod-Qi[lp[i]][xp-(j-w)][yp]));
							addp(j+1,Yp[i][j]-1,add(Qi[lp[i]][yp][xp],(xp<j-w ? 0 : mod-Qi[lp[i]][yp][xp-(j-w)])));
							for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],o.f);
							adds(1,Xs[i+1][w]-1,Qi[ls[i+1]][xs][ys]);
							adds(Ys[i][j]+1,m,Qi[ls[i+1]][ys][xs]);
							cnts[w]=add(cnts[w],s);
							for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
							ans2=add(ans2,calc(j));
						}
						else{
							int s1=Qi[ls[i+1]][xs][ys],s2=Qi[ls[i+1]][ys][xs];
							if (xp>=j-w){
								int W=Xp[i][j]+1,s0=add(Qi[lp[i]][xp][yp],mod-Qi[lp[i]][xp-(j-w)][yp]);
								ans2=(ans2+(ll)get_sum(W,w-1,-W)*s0%mod*s1+(ll)w*(w-W)*s0%mod*s)%mod;
							}
							int W=Yp[i][j]-1,s0=add(Qi[lp[i]][yp][xp],(xp<j-w ? 0 : mod-Qi[lp[i]][yp][xp-(j-w)]));
							ans2=(ans2+(ll)get_sum(m-W+1,m-j,W-m-1)*s0%mod*s2)%mod;
							for(int k=1;(k<i)&&(k<=n0);k++){
								if (a[k]<w)ans2=(ans2+(ll)(a[k]+w)*(w-a[k]-1)/2*o.f%mod*s1+(ll)w*o.f%mod*s)%mod;
								if (a[k]>j)ans2=(ans2+(ll)((m-j)+(m-a[k]+2))*(a[k]-j-1)/2*o.f%mod*s2)%mod;
							}
						}
					}
					for(int w=j+1;w<=m;w++){
						if ((i+1<=n0)&&(a[i+1]!=w))continue;
						int xs=Xs[i][j]-1,ys=m-Ys[i+1][w],s=Q[ls[i+1]][xs][ys];
						Data o=dec(fp[i-1][j-1][j+1],fp[i-1][j-1][w+1]);
						ans1=(ans1+(ll)o.f*s)%mod,ans2=(ans2+(ll)o.g*s)%mod;
						for(int z=0;z<=lp[i];z++){
							int s0=add(C[yp][lp[i]-z],(yp<w-j ? 0 : mod-C[yp-(w-j)][lp[i]-z]));
							ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*C[xp][z]%mod*s0%mod*s)%mod;
						}
						if (i<n0){
							memset(cntp,0,sizeof(cntp));
							memset(cnts,0,sizeof(cnts));
							addp(Xp[i][j]+1,j-1,add(Qi[lp[i]][xp][yp],(yp<w-j ? 0 : mod-Qi[lp[i]][xp][yp-(w-j)])));
							if (yp>=w-j)addp(w+1,Yp[i][j]-1,add(Qi[lp[i]][yp][xp],mod-Qi[lp[i]][yp-(w-j)][xp]));
							for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],o.f);
							adds(1,Xs[i][j]-1,Qi[ls[i+1]][xs][ys]);
							adds(Ys[i+1][w]+1,m,Qi[ls[i+1]][ys][xs]);
							cnts[w]=add(cnts[w],s);
							for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
							ans2=add(ans2,calc(j));
						}
						else{
							int s1=Qi[ls[i+1]][xs][ys],s2=Qi[ls[i+1]][ys][xs];
							int W=Xp[i][j]+1,s0=add(Qi[lp[i]][xp][yp],(yp<w-j ? 0 : mod-Qi[lp[i]][xp][yp-(w-j)]));
							ans2=(ans2+(ll)get_sum(W,j-1,-W)*s0%mod*s1)%mod;
							if (yp>=w-j){
								int W=Yp[i][j]-1,s0=add(Qi[lp[i]][yp][xp],mod-Qi[lp[i]][yp-(w-j)][xp]);
								ans2=(ans2+(ll)get_sum(m-W+1,m-w,W-m-1)*s0%mod*s2+(ll)(m-w+1)*(W-w)*s0%mod*s)%mod;
							}
							for(int k=1;(k<i)&&(k<=n0);k++){
								if (a[k]<j)ans2=(ans2+(ll)(a[k]+j)*(j-a[k]-1)/2*o.f%mod*s1)%mod;
								if (a[k]>w)ans2=(ans2+(ll)((m-w)+(m-a[k]+2))*(a[k]-w-1)/2*o.f%mod*s2+(ll)(m-w+1)*o.f%mod*s)%mod;
							}
						}
					}
				}
		printf("%d %d\n",ans1,ans2);
	}
	return 0;
}
posted @ 2023-04-05 07:36  PYWBKTDA  阅读(867)  评论(9编辑  收藏  举报