CF1553I Stairs 题解
Solution
虽然但是,这个sb题目真的很sb,不知道怎么评到3400的,也不知道为什么我又没有做出来😥
可以发现的是,对于序列上面一段连续的极长自然数序列是相互独立的。所以也就意味着对于 \([l,r]\) 如果 \(a_i\) 都相同,意味这它们会划分成 \((r-l+1)/a_l\) 段,每一段都是连续自然数段,且相邻两段之间不连续。比如样例 \([3,3,3,1,1,1]\) 就会划分成 \(\{3,1,1,1\}\) ,意味着 \([1,3],[4,4],[5,5],[6,6]\) 为连续段。
假设我们划分出了 \(m\) 段。这时候我们发现如果不需要保证划分出的相邻两段就非常好算,因为长度大于 \(1\) 的段从小到大或者从大到小就会产生 \(2\) 的贡献,确定段之间的相对关系又会产生 \(m!\) 的贡献。那么我们就可以很自然地想到容斥,即设 \(f_i\) 表示最后划分成 \(i\) 段的贡献,那么即是:
\[\sum_{i=1}^{m} (-1)^{m-i}f_ii!
\]
然后我们发现求 \(f_i\) 其实可以直接分治 FFT,需要记录的信息也就最左/右段的划分出的段长度是否 \(>1\) 。
复杂度 \(\Theta(n\log^2 n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define mod 998244353
#define MAXN 400005
// char buf[1<<21],*p1=buf,*p2=buf;
// #define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}
#define poly vector<int>
#define SZ(A) ((A).size())
#define Gi 332748118
#define G 3
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
return res;
}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}
int rev[MAXN];
void ntt (poly &a,int type){
int lim = SZ(a);
for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
for (Int i = 0;i < lim;++ i) if (i < rev[i]) swap (a[i],a[rev[i]]);
for (Int i = 1;i < lim;i <<= 1){
int wn = qkpow (type == 1 ? G : Gi,(mod - 1) / (i << 1));
for (Int j = 0,r = i << 1;j < lim;j += r)
for (Int k = 0,w = 1;k < i;++ k,w = mul (w,wn)){
int x = a[j + k],y = mul (w,a[i + j + k]);
a[j + k] = add (x,y),a[i + j + k] = dec (x,y);
}
}
if (type == 1) return ;
int iv = qkpow (lim,mod - 2);
for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],iv);
}
poly operator * (poly A,poly B){
int sz = SZ(A) + SZ(B) - 1,lim = 1;
while (lim < sz) lim <<= 1;
A.resize (lim),B.resize (lim),ntt (A,1),ntt (B,1);
for (Int i = 0;i < lim;++ i) A[i] = mul (A[i],B[i]);
ntt (A,-1),A.resize (sz);
return A;
}
int n,m,a[MAXN],b[MAXN];
struct node{
poly st[2][2];
poly * operator [] (const int &key){return st[key];}
void init (int len){for (Int i = 0;i < 2;++ i) for (Int j = 0;j < 2;++ j) st[i][j].clear (),st[i][j].resize (len);}
};
node divide (int L,int R){
node Now;Now.init (R - L + 2);
if (R - L <= 2){
if (R == L + 1) Now[b[L]][b[R]][2] = (b[L] + 1) * (b[R] + 1),Now[1][1][1] = 2;
else
Now[b[L]][b[R]][3] = (b[L] + 1) * (b[L + 1] + 1) * (b[R] + 1),
Add (Now[b[L]][1][2],(b[L] + 1) * 2),Add (Now[1][b[R]][2],2 * (b[R] + 1)),
Now[1][1][1] = 2;
return Now;
}
int mid = L + R >> 1;
node sL = divide (L,mid),sR = divide (mid + 1,R);
for (Int al = 0;al < 2;++ al) for (Int ar = 0;ar < 2;++ ar)
for (Int bl = 0;bl < 2;++ bl) for (Int br = 0;br < 2;++ br){
poly tmp = sL[al][ar] * sR[bl][br];
for (Int i = 0;i < SZ(tmp);++ i){
Add (Now[al][br][i],tmp[i]);
if (i){
if (ar + bl == 2) Add (Now[al][br][i - 1],mul (tmp[i],mod + 1 >> 1));
if (ar + bl == 1) Add (Now[al][br][i - 1],tmp[i]);
if (ar + bl == 0) Add (Now[al][br][i - 1],mul (tmp[i],2));
}
}
}
return Now;
}
signed main(){
read (n);
for (Int x = 1;x <= n;++ x) read (a[x]);
for (Int l = 1,r;l <= n;l = r + 1){
r = l;
while (r + 1 <= n && a[r + 1] == a[l]) ++ r;
if ((r - l + 1) % a[l]) return puts ("0") & 0;
for (Int t = 1;t <= (r - l + 1) / a[l];++ t) b[++ m] = (a[l] > 1);
}
if (m == 1) return puts ("1") & 0;
node Sum = divide (1,m);
int ans = 0,pre = 1;
for (Int i = 1;i <= m;++ i){
int res = 0;
for (Int sl = 0;sl < 2;++ sl) for (Int sr = 0;sr < 2;++ sr) Add (res,Sum[sl][sr][i]);
pre = mul (pre,i),ans = m - i & 1 ? dec (ans,mul (res,pre)) : add (ans,mul (res,pre));
}
write (ans),putchar ('\n');
return 0;
}