题解 乐

传送门

考虑总数减不合法的
就是减去含有 border 的方案数
考虑只在最短 border 处进行统计,可以证明最短 border 长度 \(\leqslant \frac{len}{2}\)
所以令 \(f_i\) 为长度为 \(i\) 不含 border 方案数

\[f_i=\sum\limits_{j=1}^{\lfloor\frac{i}{2}\rfloor}f_js_{i-j}s_i^{-1} \]

分治 NTT 即可……吗?
发现上界比较迷,那么从一个区间 \([l, r]\) 是可能转移到 \([l, l+2(r-l)]\)
然后就没有了……吗?
发现这个题 \(n=1e6\),然后做法是 \(O(n\log^2n)\)
所以只做 \([0, \lfloor\frac{i}{2}\rfloor)\) 再加 ull 优化可以信仰过

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5000010
#define ll long long
#define ull unsigned long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int a[N], v[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

#if 0
namespace force{
	ll ans;
	ull p[N], h[N];
	const ull base=13131;
	inline ull hashing(int l, int r) {return h[r]-h[l-1]*p[r-l+1];}
	void check() {
		for (int i=1; i<=n; ++i) if (a[i]>v[i]) return ;
		for (int i=1; i<=n; ++i) h[i]=h[i-1]*base+a[i];
		for (int i=1; i<n; ++i) if (hashing(1, i)==hashing(n-i+1, n)) return ;
		++ans;
	}
	void dfs(int u) {
		if (u>n) {check(); return ;}
		for (int i=1; i<=v[u]; ++i) {
			a[u]=i;
			dfs(u+1);
		}
	}
	void solve() {
		p[0]=1;
		for (int i=1; i<=n; ++i) p[i]=p[i-1]*base;
		dfs(1);
		printf("%lld\n", ans);
	}
}

namespace task1{
	ll f[N];
	inline ll prod(int l, int r) {ll ans=1; for (int i=l; i<=r; ++i) ans=ans*v[i]%mod; return ans;}
	void solve() {
		for (int i=1; i<=n; ++i) {
			f[i]=prod(1, i);
			for (int j=1; j<=i/2; ++j) {
				f[i]=(f[i]-f[j]*prod(j+1, i-j))%mod;
			}
		}
		#ifdef DEBUG
		cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
		#endif
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}
#endif

namespace task2{
	ll f[N], bit[N], g[5010][5010];
	inline void upd(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]=bit[i]*dat%mod;}
	inline ll query(int i) {ll ans=1; for (; i; i-=i&-i) ans=ans*bit[i]%mod; return ans;}
	//inline ll prod(int l, int r) {return query(r)*qpow(query(l-1), mod-2)%mod;}
	inline ll prod(int l, int r) {return l>r?1:g[l][r];}
	void solve() {
		f[0]=1;
		// for (int i=0; i<=n; ++i) bit[i]=1;
		for (int i=1; i<=n; ++i) f[i]=f[i-1]*v[i]%mod;
		for (int i=1; i<=n; ++i) {
			g[i][i]=v[i];
			for (int j=i+1; j<=n; ++j) g[i][j]=g[i][j-1]*v[j]%mod;
		}
		for (int i=1; i<=n; ++i) {
			for (int j=1; j<=i/2; ++j) {
				f[i]=(f[i]-f[j]*prod(j+1, i-j))%mod;
			}
		}
		// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}

namespace task3{
	int rev[N];
	ll f[N], g[N], h[N], t[N];
	void ntt(ll* a, int len, int op) {
		for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
		ll wn, w, t;
		for (int i=1; i<len; i<<=1) {
			wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
			for (int j=0,step=i<<1; j<len; j+=step) {
				w=1;
				for (int k=j; k<j+i; ++k,w=w*wn%mod) {
					t=a[k+i]*w%mod;
					a[k+i]=(a[k]-t)%mod;
					a[k]=(a[k]+t)%mod;
				}
			}
		}
		if (op==-1) {
			ll inv=qpow(len, mod-2);
			for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
		}
	}
	void solve() {
		g[0]=1;
		for (int i=1; i<=n; ++i) g[i]=g[i-1]*v[1]%mod;
		int lim;
		for (lim=1; lim<=n; lim<<=1) ;
		for (int len=2,bct=1; len<=lim; len<<=1,++bct) {
			for (int i=0; i<(len<<1); ++i) h[i]=t[i]=0;
			for (int i=0; i<len; ++i) t[i]=g[i];
			for (int i=0; i<(len>>1); ++i) h[i<<1]=f[i];
			// for (int i=0; i<len; ++i)
			// 	for (int j=0; j<len; ++j)
			// 		t[i+j]=(t[i+j]+h[i]*g[j])%mod;
			for (int i=0; i<(len<<1); ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct));
			ntt(h, (len<<1), 1); ntt(t, (len<<1), 1);
			for (int i=0; i<(len<<1); ++i) h[i]=h[i]*t[i]%mod;
			ntt(h, (len<<1), -1);
			for (int i=(len>>1); i<len; ++i) f[i]=(g[i]-h[i])%mod;
		}
		// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}

signed main()
{
	freopen("music.in", "r", stdin);
	freopen("music.out", "w", stdout);

	n=read();
	// cout<<double(sizeof(task3::f)*6+sizeof(task2::g))/1000/1000<<endl;
	for (int i=1; i<=n; ++i) v[i]=read();
	// force::solve();
	if (n<=5000) task2::solve();
	else task3::solve();
	
	return 0;
}
posted @ 2022-02-12 21:14  Administrator-09  阅读(1)  评论(0编辑  收藏  举报