Live2D

CF1553I Stairs 题解

link

Solution

虽然但是,这个sb题目真的很sb,不知道怎么评到3400的,也不知道为什么我又没有做出来😥

可以发现的是,对于序列上面一段连续的极长自然数序列是相互独立的。所以也就意味着对于 \([l,r]\) 如果 \(a_i\) 都相同,意味这它们会划分成 \((r-l+1)/a_l\) 段,每一段都是连续自然数段,且相邻两段之间不连续。比如样例 \([3,3,3,1,1,1]\) 就会划分成 \(\{3,1,1,1\}\) ,意味着 \([1,3],[4,4],[5,5],[6,6]\) 为连续段。

假设我们划分出了 \(m\) 段。这时候我们发现如果不需要保证划分出的相邻两段就非常好算,因为长度大于 \(1\) 的段从小到大或者从大到小就会产生 \(2\) 的贡献,确定段之间的相对关系又会产生 \(m!\) 的贡献。那么我们就可以很自然地想到容斥,即设 \(f_i\) 表示最后划分成 \(i\) 段的贡献,那么即是:

\[\sum_{i=1}^{m} (-1)^{m-i}f_ii! \]

然后我们发现求 \(f_i\) 其实可以直接分治 FFT,需要记录的信息也就最左/右段的划分出的段长度是否 \(>1\)

复杂度 \(\Theta(n\log^2 n)\)

Code

#include <bits/stdc++.h>
using namespace std;
     
#define Int register int
#define mod 998244353
#define MAXN 400005

// char buf[1<<21],*p1=buf,*p2=buf;
// #define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}

#define poly vector<int>
#define SZ(A) ((A).size())
#define Gi 332748118
#define G 3

int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
	int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
	return res;
}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}

int rev[MAXN];

void ntt (poly &a,int type){
	int lim = SZ(a);
	for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
	for (Int i = 0;i < lim;++ i) if (i < rev[i]) swap (a[i],a[rev[i]]);
	for (Int i = 1;i < lim;i <<= 1){
		int wn = qkpow (type == 1 ? G : Gi,(mod - 1) / (i << 1));
		for (Int j = 0,r = i << 1;j < lim;j += r)
			for (Int k = 0,w = 1;k < i;++ k,w = mul (w,wn)){
				int x = a[j + k],y = mul (w,a[i + j + k]);
				a[j + k] = add (x,y),a[i + j + k] = dec (x,y);
			}
	}
	if (type == 1) return ;
	int iv = qkpow (lim,mod - 2);
	for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],iv);
}

poly operator * (poly A,poly B){
	int sz = SZ(A) + SZ(B) - 1,lim = 1;
	while (lim < sz) lim <<= 1;
	A.resize (lim),B.resize (lim),ntt (A,1),ntt (B,1);
	for (Int i = 0;i < lim;++ i) A[i] = mul (A[i],B[i]);
	ntt (A,-1),A.resize (sz);
	return A;
}

int n,m,a[MAXN],b[MAXN];

struct node{
	poly st[2][2];
	poly * operator [] (const int &key){return st[key];}
	void init (int len){for (Int i = 0;i < 2;++ i) for (Int j = 0;j < 2;++ j) st[i][j].clear (),st[i][j].resize (len);} 
};

node divide (int L,int R){
	node Now;Now.init (R - L + 2);
	if (R - L <= 2){
		if (R == L + 1) Now[b[L]][b[R]][2] = (b[L] + 1) * (b[R] + 1),Now[1][1][1] = 2;
		else 
			Now[b[L]][b[R]][3] = (b[L] + 1) * (b[L + 1] + 1) * (b[R] + 1),
			Add (Now[b[L]][1][2],(b[L] + 1) * 2),Add (Now[1][b[R]][2],2 * (b[R] + 1)),
			Now[1][1][1] = 2;
		return Now;
	}
	int mid = L + R >> 1;
	node sL = divide (L,mid),sR = divide (mid + 1,R);
	for (Int al = 0;al < 2;++ al) for (Int ar = 0;ar < 2;++ ar)
		for (Int bl = 0;bl < 2;++ bl) for (Int br = 0;br < 2;++ br){
			poly tmp = sL[al][ar] * sR[bl][br];
			for (Int i = 0;i < SZ(tmp);++ i){
				Add (Now[al][br][i],tmp[i]);
				if (i){
					if (ar + bl == 2) Add (Now[al][br][i - 1],mul (tmp[i],mod + 1 >> 1));
					if (ar + bl == 1) Add (Now[al][br][i - 1],tmp[i]);
					if (ar + bl == 0) Add (Now[al][br][i - 1],mul (tmp[i],2));
				}
			}
		}
	return Now;
}

signed main(){
	read (n);
	for (Int x = 1;x <= n;++ x) read (a[x]);
	for (Int l = 1,r;l <= n;l = r + 1){
		r = l;
		while (r + 1 <= n && a[r + 1] == a[l]) ++ r;
		if ((r - l + 1) % a[l]) return puts ("0") & 0;
		for (Int t = 1;t <= (r - l + 1) / a[l];++ t) b[++ m] = (a[l] > 1);
	}
	if (m == 1) return puts ("1") & 0;
	node Sum = divide (1,m);
	int ans = 0,pre = 1;
	for (Int i = 1;i <= m;++ i){
		int res = 0;
		for (Int sl = 0;sl < 2;++ sl) for (Int sr = 0;sr < 2;++ sr) Add (res,Sum[sl][sr][i]);
		pre = mul (pre,i),ans = m - i & 1 ? dec (ans,mul (res,pre)) : add (ans,mul (res,pre));
	}
	write (ans),putchar ('\n');
	return 0;
}
posted @ 2022-11-08 21:15  Dark_Romance  阅读(28)  评论(0编辑  收藏  举报