[luogu9171]染色数组
定义集合\(S\)由同时满足以下条件的\(x\)构成:
- \([1,x)\)中\(\le a_{x}\)的元素 和 \((x,n]\)中\(\ge a_{x}\)的元素 构成递增子序列
- \([1,x)\)中\(\ge a_{x}\)的元素 和 \((x,n]\)中\(\le a_{x}\)的元素 构成递减子序列
性质1:\(a\)为完美数组当且仅当\(S\ne \empty\)
必要性:注意\(x\in S\)是\(x\)可以染成红色和绿色的必要条件
充分性:任取\(x\in S\),将条件中第\(1\)类元素染成红色,其余染成绿色即可
性质2:若\(x\in S\),则\(x+1\in S\iff \forall i\in [1,x),a_{i}\not\in [a_{x},a_{x+1}]\cup [a_{x+1},a_{x}]\)
性质3:记\(\begin{cases}l=\min_{x\in S}x\\r=\max_{x\in S}x\end{cases}\),则\(S=[l,r]\cap Z\)且满足以下条件之一
- \(l+1=r\)且\(a_{l}=a_{r}\)
- \(a_{l}<a_{l+1}<...<a_{r}\)或\(a_{l}>a_{l+1}>...>a_{r}\)
(证明可以自行分类讨论得到)
为了方便,以下均假设为第\(2\)种情况,第\(1\)种情况是类似的
在\(r\)处统计方案数,注意到对于\(a_{[1,r)}\),恰存在一种染色方式使得红色的结尾\(<a_{r}\)且绿色的结尾\(>a_{r}\)
定义\(f_{i,x,y}\)表示前\(i\)个位置中两序列结尾分别为\(x,y\)的方案数,转移易优化至\(O(nm^{2})\),后缀类似
枚举\(r,a_{r}\)和\(a_{r+1}\)后,即求形如\(\sum_{x\le x_{0}}\sum_{y\ge y_{0}}f_{i,x,y}\),预处理即可,时间复杂度为\(O(Cnm^{2})\)
性质4:得分最大的染色方案形如
- \([1,r)\)中\(\le a_{r}\)的元素 和 \((r,n]\)中\(\ge a_{r}\)的元素 染成红色
- \([1,r)\)中\(\ge a_{r}\)的元素 和 \((r,n]\)中\(\le a_{r}\)的元素 染成绿色
- \(r\)的颜色取染红色或绿色中的较大值
以\(a_{l}<a_{l+1}<...<a_{r}\)的情况为例,即在其中选至多一个元素染成绿色
注意到染成红色/绿色的得分单调递减/递增,显然可以贪心,最终即形如结论
仍枚举\(r,a_{r}\)和\(a_{r+1}\),将贡献拆开,得分分为以下几部分——
-
\([1,r)\):用\(g/g0/g1_{i,x,y}\)表示……的得分和、红色数和 和 绿色数和 即可
-
\(r\):设\([1,r)\)中剩余\(k\)个位置,且\(<a_{r},>a_{r}\)的可填数有\(x,y\)个,答案即形如
\[\sum_{i=0}^{k}{k\choose i}{x\choose i}{y\choose k-i}\max(iw+C,(k-i)(m-w+1)+C) \] -
\((r,n]\):染成红色的\(a_{i}\)为例,即统计\(a_{[1,r)}\)中\(>a_{i}\)的元素个数
-
若\(r<t\),则前者中\(a_{r}\)和\(a_{r+1}\)均已确定,外层枚举仅\(O(nm)\)
注意到组合数中每个数出现次数相同,而出现次数之和即\(\sum_{i=0}^{k}i{k\choose i}{x\choose i}{y\choose k-i}\),同样预处理出后取平均即可
已经确定的数出现次数即方案数,得到所有数出现次数后即可得到答案
-
若\(r\ge t\),则后缀完全未确定,其中出现次数不同的数仅有\([1,a_{r+1}),a_{r+1},(a_{r},m]\)共\(3\)类
此时,仅需在前缀修改时快速计算即可,修改次数共\(O(n)\)
-
时间复杂度为\(O(Cn^{2}m^{2})\)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=55,M=205,mod=998244353;
int t,n,m,n0,ans1,ans2,a[N];
int inv[M],C[M][M],Q[N][M][M],Qi[N][M][M];
int lp[N],visp[N][M],Xp[N][M],Yp[N][M],cx[N][M],cy[N][M],cntp[M];
int ls[N],viss[N][M],Xs[N][M],Ys[N][M],cnts[M];
int add(int x,int y){
x+=y;
return (x<mod ? x : x-mod);
}
struct Data{
int f,g,g0,g1;
Data upd0(int x){
return Data{f,(int)((g+(ll)x*g0)%mod),g0,add(f,g1)};
}
Data upd1(int x){
return Data{f,(int)((g+(ll)(m-x+1)*g1)%mod),add(f,g0),g1};
}
}fp[N][M][M];
Data add(Data x,Data y){
return Data{add(x.f,y.f),add(x.g,y.g),add(x.g0,y.g0),add(x.g1,y.g1)};
}
Data dec(Data x,Data y){
return Data{add(x.f,mod-y.f),add(x.g,mod-y.g),add(x.g0,mod-y.g0),add(x.g1,mod-y.g1)};
}
void get_sum(Data a[M][M]){
for(int i=0;i<=m;i++)
for(int j=m+1;j>i;j--){
if (i)a[i][j]=add(a[i][j],a[i-1][j]);
if (j<=m)a[i][j]=add(a[i][j],a[i][j+1]);
if ((i)&&(j<=m))a[i][j]=dec(a[i][j],a[i-1][j+1]);
}
}
void get_fp(){
memset(fp,0,sizeof(fp));
for(int x=0;x<=m;x++)
for(int y=x+1;y<=m+1;y++)fp[0][x][y]=Data{1,0,0,0};
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if ((i<=n0)&&(a[i]!=j))continue;
for(int x=0;x<j;x++){
fp[i][x][j]=add(fp[i][x][j],fp[i-1][x][j+1].upd0(j));
if (x)fp[i][x][j]=dec(fp[i][x][j],fp[i-1][x-1][j+1].upd0(j));
}
for(int y=j+1;y<=m+1;y++){
fp[i][j][y]=add(fp[i][j][y],fp[i-1][j-1][y].upd1(j));
if (y<=m)fp[i][j][y]=dec(fp[i][j][y],fp[i-1][j-1][y+1].upd1(j));
}
}
get_sum(fp[i]);
}
}
int S2(int n){
return n*(n+1)*((n<<1)+1)/6;
}
int get_sum(int l,int r,int a){
return S2(r)-S2(l-1)+a*(l+r)*(r-l+1)/2;
}
void addp(int l,int r,int x){
for(int i=l;i<=r;i++)cntp[i]=add(cntp[i],x);
}
void adds(int l,int r,int x){
for(int i=l;i<=r;i++)cnts[i]=add(cnts[i],x);
}
int calc(int j){
int s=0,ans=0;
for(int i=1;i<j;i++){
ans=(ans+(ll)i*s%mod*cnts[i])%mod;
s=add(s,cntp[i]);
}
s=0;
for(int i=m;i>j;i--){
ans=(ans+(ll)(m-i+1)*s%mod*cnts[i])%mod;
s=add(s,cntp[i]);
}
return ans;
}
int main(){
inv[0]=inv[1]=1;
for(int i=2;i<M;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
for(int i=0;i<M;i++){
C[i][0]=C[i][i]=1;
for(int j=1;j<i;j++)C[i][j]=add(C[i-1][j],C[i-1][j-1]);
}
for(int k=0;k<N;k++)
for(int x=0;x<M;x++)
for(int y=0;y<M;y++)
for(int i=0;i<=k;i++){
int s=(ll)C[k][i]*C[x][i]%mod*C[y][k-i]%mod;
Q[k][x][y]=add(Q[k][x][y],s);
Qi[k][x][y]=(Qi[k][x][y]+(ll)i*s%mod*inv[x])%mod;
}
scanf("%d",&t);
while (t--){
scanf("%d%d%d",&n,&m,&n0);
for(int i=1;i<=n0;i++)scanf("%d",&a[i]);
ans1=ans2=0,get_fp();
for(int i=1;i<=n;i++){
lp[i]=max(i-n0-1,0),ls[i]=n-max(i,n0);
for(int j=1;j<=m;j++){
visp[i][j]=viss[i][j]=0;
if ((i<=n0)&&(a[i]!=j))continue;
bool flag=cx[i][j]=cy[i][j]=0;
Xp[i][j]=0,Yp[i][j]=m+1;
for(int k=1;(k<i)&&(k<=n0);k++){
if (a[k]==j){flag=1;break;}
if (a[k]<j){
if (Xp[i][j]>=a[k]){flag=1;break;}
Xp[i][j]=a[k],cx[i][j]++;
}
else{
if (Yp[i][j]<=a[k]){flag=1;break;}
Yp[i][j]=a[k],cy[i][j]++;
}
}
if (!flag)visp[i][j]=1;
flag=0,Xs[i][j]=Ys[i][j]=j;
for(int k=i+1;k<=n0;k++){
if (a[k]==j){flag=1;break;}
if (a[k]<j){
if (Xs[i][j]<=a[k]){flag=1;break;}
Xs[i][j]=a[k];
}
else{
if (Ys[i][j]>=a[k]){flag=1;break;}
Ys[i][j]=a[k];
}
}
if (!flag)viss[i][j]=1;
}
}
for(int i=1;i<n;i++)
for(int j=1;j<=m;j++)
if ((visp[i][j])&&(viss[i+1][j])){
int xp=j-Xp[i][j]-1,yp=Yp[i][j]-j-1;
int xs=Xs[i+1][j]-1,ys=m-Ys[i+1][j],s=Q[ls[i+1]][xs][ys];
ans1=(ans1+(ll)fp[i-1][j-1][j+1].f*s)%mod;
ans2=(ans2+(ll)fp[i-1][j-1][j+1].g*s)%mod;
ans2=(ans2+(ll)(j*cx[i][j]+(m-j+1)*(cy[i][j]+lp[i]))*Q[lp[i]][xp][yp]%mod*s)%mod;
ans2=(ans2+(ll)((j<<1)-m-1+mod)*xp%mod*Qi[lp[i]][xp][yp]%mod*s)%mod;
memset(cntp,0,sizeof(cntp));
memset(cnts,0,sizeof(cnts));
addp(Xp[i][j]+1,j-1,Qi[lp[i]][xp][yp]);
addp(j+1,Yp[i][j]-1,Qi[lp[i]][yp][xp]);
for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],Q[lp[i]][xp][yp]);
adds(1,Xs[i+1][j]-1,Qi[ls[i+1]][xs][ys]);
adds(Ys[i+1][j]+1,m,Qi[ls[i+1]][ys][xs]);
for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
ans2=add(ans2,calc(j));
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
if ((visp[i][j])&&(viss[i][j])){
int xp=j-Xp[i][j]-1,yp=Yp[i][j]-j-1;
if (i==n){
ans1=add(ans1,fp[i-1][j-1][j+1].f);
ans2=add(ans2,fp[i-1][j-1][j+1].g);
for(int z=0;z<=lp[i];z++)
ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*C[xp][z]%mod*C[yp][lp[i]-z])%mod;
continue;
}
for(int w=1;w<j;w++){
if ((i+1<=n0)&&(a[i+1]!=w))continue;
int xs=Xs[i+1][w]-1,ys=m-Ys[i][j],s=Q[ls[i+1]][xs][ys];
Data o=dec(fp[i-1][j-1][j+1],fp[i-1][w-1][j+1]);
ans1=(ans1+(ll)o.f*s)%mod,ans2=(ans2+(ll)o.g*s)%mod;
for(int z=0;z<=lp[i];z++){
int s0=add(C[xp][z],(xp<j-w ? 0 : mod-C[xp-(j-w)][z]));
ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*s0%mod*C[yp][lp[i]-z]%mod*s)%mod;
}
if (i<n0){
memset(cntp,0,sizeof(cntp));
memset(cnts,0,sizeof(cnts));
if (xp>=j-w)addp(Xp[i][j]+1,w-1,add(Qi[lp[i]][xp][yp],mod-Qi[lp[i]][xp-(j-w)][yp]));
addp(j+1,Yp[i][j]-1,add(Qi[lp[i]][yp][xp],(xp<j-w ? 0 : mod-Qi[lp[i]][yp][xp-(j-w)])));
for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],o.f);
adds(1,Xs[i+1][w]-1,Qi[ls[i+1]][xs][ys]);
adds(Ys[i][j]+1,m,Qi[ls[i+1]][ys][xs]);
cnts[w]=add(cnts[w],s);
for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
ans2=add(ans2,calc(j));
}
else{
int s1=Qi[ls[i+1]][xs][ys],s2=Qi[ls[i+1]][ys][xs];
if (xp>=j-w){
int W=Xp[i][j]+1,s0=add(Qi[lp[i]][xp][yp],mod-Qi[lp[i]][xp-(j-w)][yp]);
ans2=(ans2+(ll)get_sum(W,w-1,-W)*s0%mod*s1+(ll)w*(w-W)*s0%mod*s)%mod;
}
int W=Yp[i][j]-1,s0=add(Qi[lp[i]][yp][xp],(xp<j-w ? 0 : mod-Qi[lp[i]][yp][xp-(j-w)]));
ans2=(ans2+(ll)get_sum(m-W+1,m-j,W-m-1)*s0%mod*s2)%mod;
for(int k=1;(k<i)&&(k<=n0);k++){
if (a[k]<w)ans2=(ans2+(ll)(a[k]+w)*(w-a[k]-1)/2*o.f%mod*s1+(ll)w*o.f%mod*s)%mod;
if (a[k]>j)ans2=(ans2+(ll)((m-j)+(m-a[k]+2))*(a[k]-j-1)/2*o.f%mod*s2)%mod;
}
}
}
for(int w=j+1;w<=m;w++){
if ((i+1<=n0)&&(a[i+1]!=w))continue;
int xs=Xs[i][j]-1,ys=m-Ys[i+1][w],s=Q[ls[i+1]][xs][ys];
Data o=dec(fp[i-1][j-1][j+1],fp[i-1][j-1][w+1]);
ans1=(ans1+(ll)o.f*s)%mod,ans2=(ans2+(ll)o.g*s)%mod;
for(int z=0;z<=lp[i];z++){
int s0=add(C[yp][lp[i]-z],(yp<w-j ? 0 : mod-C[yp-(w-j)][lp[i]-z]));
ans2=(ans2+(ll)max(j*(cx[i][j]+z),(m-j+1)*(cy[i][j]+lp[i]-z))*C[lp[i]][z]%mod*C[xp][z]%mod*s0%mod*s)%mod;
}
if (i<n0){
memset(cntp,0,sizeof(cntp));
memset(cnts,0,sizeof(cnts));
addp(Xp[i][j]+1,j-1,add(Qi[lp[i]][xp][yp],(yp<w-j ? 0 : mod-Qi[lp[i]][xp][yp-(w-j)])));
if (yp>=w-j)addp(w+1,Yp[i][j]-1,add(Qi[lp[i]][yp][xp],mod-Qi[lp[i]][yp-(w-j)][xp]));
for(int k=1;(k<i)&&(k<=n0);k++)cntp[a[k]]=add(cntp[a[k]],o.f);
adds(1,Xs[i][j]-1,Qi[ls[i+1]][xs][ys]);
adds(Ys[i+1][w]+1,m,Qi[ls[i+1]][ys][xs]);
cnts[w]=add(cnts[w],s);
for(int k=i+2;k<=n0;k++)cnts[a[k]]=add(cnts[a[k]],s);
ans2=add(ans2,calc(j));
}
else{
int s1=Qi[ls[i+1]][xs][ys],s2=Qi[ls[i+1]][ys][xs];
int W=Xp[i][j]+1,s0=add(Qi[lp[i]][xp][yp],(yp<w-j ? 0 : mod-Qi[lp[i]][xp][yp-(w-j)]));
ans2=(ans2+(ll)get_sum(W,j-1,-W)*s0%mod*s1)%mod;
if (yp>=w-j){
int W=Yp[i][j]-1,s0=add(Qi[lp[i]][yp][xp],mod-Qi[lp[i]][yp-(w-j)][xp]);
ans2=(ans2+(ll)get_sum(m-W+1,m-w,W-m-1)*s0%mod*s2+(ll)(m-w+1)*(W-w)*s0%mod*s)%mod;
}
for(int k=1;(k<i)&&(k<=n0);k++){
if (a[k]<j)ans2=(ans2+(ll)(a[k]+j)*(j-a[k]-1)/2*o.f%mod*s1)%mod;
if (a[k]>w)ans2=(ans2+(ll)((m-w)+(m-a[k]+2))*(a[k]-w-1)/2*o.f%mod*s2+(ll)(m-w+1)*o.f%mod*s)%mod;
}
}
}
}
printf("%d %d\n",ans1,ans2);
}
return 0;
}