[ABC259Ex] Yet Another Path Counting
\(\text{Links}\)
[ABC259Ex] Yet Another Path Counting
题外话
-
淀粉质题单做不动了怎么办?来做一道根号题振奋一下精神吧/se!
-
我要饿死了,我要吃饭,以后在学校还是不要不吃早饭了/kk
题意
给一个 \(n\times n\) 的网格图,每个格子上有一个颜色。
每一步只能往右或者往下走,问有多少条路径的起点和终点的颜色相同,对 \(998244353\) 取模。
\(n\le 400\),\(2.00s\)。
题解
不同颜色的统计互不干扰,所以按颜色分开来统计。
考虑有用的信息只有起点和终点的颜色,所以枚举点对,组合数计算贡献,即 \(y2-y1+x2-x1\choose x2-x1\)。复杂度为 \(O(siz^2)\),其中 \(siz\) 为这种颜色的点数。
发现如果同种颜色的点数过大的话这个做法会 G。并且很难维护合并组合数的计算来降低复杂度。
但是点数有一个限制,即所有颜色的点数加起来为 \(n^2\),于是可以考虑根号分治了!
设置阈值 \(T\),当 \(siz\le T\) 时,直接用上面的暴力做法,此部分总时间复杂度为 \(O(\frac{n^2}{T}\times T^2)\),即 \(O(n^2T)\)。
当 \(siz\gt T\) 时,这样的颜色最多只有 \(\frac{n}{T}\) 种,那么对于每个颜色再搞个暴力做法。
考虑,这个暴力做法时间复杂度的正确性应该是不依赖于 \(siz\) 的,不然我们根分有什么用呢?全部用这个做法不就好了吗。
所以考虑 \(O(n^2)\) 的 \(dp\),钦定我们当前 \(solve\) 的颜色为 \(col\)。设 \(dp_{i,j}\) 表示从颜色为 \(col\) 的格子走到位置 \((i,j)\) 的方案数。
转移很简单:\(dp_{i,j}=dp_{i,j-1}+dp_{i-1,j}+[a_{i,j}=col]\)。
于是每个位置 \((i,j)\) 对 \(ans\) 的贡献应该是 \([a_{i,j}=col]\times f_{i,j}\)。此部分总时间复杂度为 \(O(\frac{n^2}{T}\times n^2)\),即 \(O(\frac{n^4}{T})\)。
然后就做完了。取 \(T=n\) 的时候达到平衡,总时间复杂度为 \(O(n^3)\)。
代码非常简单。(由于不怎么习惯用大量的 \(pair\),所以这篇码风可能比较诡异)
\(\text{Code}\)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define il inline
#define re register
const int N=405,T=400,mod=998244353;
int n,a[N][N],ans,fac[N<<1],inv[N<<1],invfac[N<<1],f[N][N];
#define pii pair<int,int>
#define mp make_pair
vector<pii >v[N*N];
il int read(){
re int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
il void Add(int &x,int y){
x=(x+y)%mod;
}
il int C(int n,int m){
if(n<0||m<0||n<m)return 0;
return fac[n]*invfac[m]%mod*invfac[n-m]%mod;
}
il bool cmp(pii x,pii y){
return y.first>=x.first&&y.second>=x.second;
}
il int disx(pii x,pii y){
return y.first-x.first;
}
il int disy(pii x,pii y){
return y.second-x.second;
}
#define nowi v[col][i]
#define nowj v[col][j]
il void solve1(int col){
int siz=(int)v[col].size();
for(re int i=0;i<siz;i++)
for(re int j=i;j<siz;j++)
if(cmp(nowi,nowj))Add(ans,C(disx(nowi,nowj)+disy(nowi,nowj),disx(nowi,nowj)));
}
il void solve2(int col){
for(re int i=1;i<=n;i++)
for(re int j=1;j<=n;j++){
f[i][j]=(f[i-1][j]+f[i][j-1]+(a[i][j]==col))%mod;
if(a[i][j]==col)Add(ans,f[i][j]);
}
}
il void GetInv(){
inv[1]=fac[1]=invfac[1]=fac[0]=invfac[0]=1;
for(re int i=2;i<=(n<<1);i++){
inv[i]=inv[mod%i]*(mod-mod/i)%mod;
fac[i]=fac[i-1]*i%mod;
invfac[i]=invfac[i-1]*inv[i]%mod;
}
}
signed main(){
n=read();
GetInv();
for(re int i=1;i<=n;i++)
for(re int j=1;j<=n;j++)
a[i][j]=read(),v[a[i][j]].push_back(mp(i,j));
for(re int col=1;col<=n*n;col++){
if(v[col].empty())continue;
if((int)v[col].size()<=T)solve1(col);
else solve2(col);
}
cout<<ans;
return 0;
}