题解 乐
考虑总数减不合法的
就是减去含有 border 的方案数
考虑只在最短 border 处进行统计,可以证明最短 border 长度 \(\leqslant \frac{len}{2}\)
所以令 \(f_i\) 为长度为 \(i\) 不含 border 方案数
有
\[f_i=\sum\limits_{j=1}^{\lfloor\frac{i}{2}\rfloor}f_js_{i-j}s_i^{-1}
\]
分治 NTT 即可……吗?
发现上界比较迷,那么从一个区间 \([l, r]\) 是可能转移到 \([l, l+2(r-l)]\) 的
然后就没有了……吗?
发现这个题 \(n=1e6\),然后做法是 \(O(n\log^2n)\) 的
所以只做 \([0, \lfloor\frac{i}{2}\rfloor)\) 再加 ull 优化可以信仰过
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5000010
#define ll long long
#define ull unsigned long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int a[N], v[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
#if 0
namespace force{
ll ans;
ull p[N], h[N];
const ull base=13131;
inline ull hashing(int l, int r) {return h[r]-h[l-1]*p[r-l+1];}
void check() {
for (int i=1; i<=n; ++i) if (a[i]>v[i]) return ;
for (int i=1; i<=n; ++i) h[i]=h[i-1]*base+a[i];
for (int i=1; i<n; ++i) if (hashing(1, i)==hashing(n-i+1, n)) return ;
++ans;
}
void dfs(int u) {
if (u>n) {check(); return ;}
for (int i=1; i<=v[u]; ++i) {
a[u]=i;
dfs(u+1);
}
}
void solve() {
p[0]=1;
for (int i=1; i<=n; ++i) p[i]=p[i-1]*base;
dfs(1);
printf("%lld\n", ans);
}
}
namespace task1{
ll f[N];
inline ll prod(int l, int r) {ll ans=1; for (int i=l; i<=r; ++i) ans=ans*v[i]%mod; return ans;}
void solve() {
for (int i=1; i<=n; ++i) {
f[i]=prod(1, i);
for (int j=1; j<=i/2; ++j) {
f[i]=(f[i]-f[j]*prod(j+1, i-j))%mod;
}
}
#ifdef DEBUG
cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
#endif
printf("%lld\n", (f[n]%mod+mod)%mod);
}
}
#endif
namespace task2{
ll f[N], bit[N], g[5010][5010];
inline void upd(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]=bit[i]*dat%mod;}
inline ll query(int i) {ll ans=1; for (; i; i-=i&-i) ans=ans*bit[i]%mod; return ans;}
//inline ll prod(int l, int r) {return query(r)*qpow(query(l-1), mod-2)%mod;}
inline ll prod(int l, int r) {return l>r?1:g[l][r];}
void solve() {
f[0]=1;
// for (int i=0; i<=n; ++i) bit[i]=1;
for (int i=1; i<=n; ++i) f[i]=f[i-1]*v[i]%mod;
for (int i=1; i<=n; ++i) {
g[i][i]=v[i];
for (int j=i+1; j<=n; ++j) g[i][j]=g[i][j-1]*v[j]%mod;
}
for (int i=1; i<=n; ++i) {
for (int j=1; j<=i/2; ++j) {
f[i]=(f[i]-f[j]*prod(j+1, i-j))%mod;
}
}
// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
printf("%lld\n", (f[n]%mod+mod)%mod);
}
}
namespace task3{
int rev[N];
ll f[N], g[N], h[N], t[N];
void ntt(ll* a, int len, int op) {
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
ll wn, w, t;
for (int i=1; i<len; i<<=1) {
wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
for (int j=0,step=i<<1; j<len; j+=step) {
w=1;
for (int k=j; k<j+i; ++k,w=w*wn%mod) {
t=a[k+i]*w%mod;
a[k+i]=(a[k]-t)%mod;
a[k]=(a[k]+t)%mod;
}
}
}
if (op==-1) {
ll inv=qpow(len, mod-2);
for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
}
}
void solve() {
g[0]=1;
for (int i=1; i<=n; ++i) g[i]=g[i-1]*v[1]%mod;
int lim;
for (lim=1; lim<=n; lim<<=1) ;
for (int len=2,bct=1; len<=lim; len<<=1,++bct) {
for (int i=0; i<(len<<1); ++i) h[i]=t[i]=0;
for (int i=0; i<len; ++i) t[i]=g[i];
for (int i=0; i<(len>>1); ++i) h[i<<1]=f[i];
// for (int i=0; i<len; ++i)
// for (int j=0; j<len; ++j)
// t[i+j]=(t[i+j]+h[i]*g[j])%mod;
for (int i=0; i<(len<<1); ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct));
ntt(h, (len<<1), 1); ntt(t, (len<<1), 1);
for (int i=0; i<(len<<1); ++i) h[i]=h[i]*t[i]%mod;
ntt(h, (len<<1), -1);
for (int i=(len>>1); i<len; ++i) f[i]=(g[i]-h[i])%mod;
}
// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
printf("%lld\n", (f[n]%mod+mod)%mod);
}
}
signed main()
{
freopen("music.in", "r", stdin);
freopen("music.out", "w", stdout);
n=read();
// cout<<double(sizeof(task3::f)*6+sizeof(task2::g))/1000/1000<<endl;
for (int i=1; i<=n; ++i) v[i]=read();
// force::solve();
if (n<=5000) task2::solve();
else task3::solve();
return 0;
}