集训队作业2018: 青春猪头少年不会梦到兔女郎学姐
提交的话目前只知道 luogu 可交。
我们考虑加入若干个隔板:首先相邻不同色则一定需要加入一个隔板,然后考虑相同颜色我们也可以加入一个隔板(或不加入)。然后在计算的时候对每个分割情况算每一段颜色相同的方案数。
这个时候每一段的区间是固定的。那么我们考虑给每一个长度加一个权值 \(a_i\),一个分割的权值是每段的权值之和。
那么设:
\[F = \sum_i a_i x^i
\\
G = \sum_i i*x_i = \frac{x}{(1-x)^2}
\]
则有:
\[\begin{aligned}
\frac{F}{1-F}&= G
\\
F&=\frac{G}{G+1}
\\
F&=\frac{x}{1-x+x^2}
\end{aligned}
\]
考虑一种分割:我们不断向后循环位移,直到序列上存在一个分割点在 \(i\) 到 \(n\) ,这样会唯一对应到一个分割情况。那么我们可以考虑将所有分割后的段分配出来,然后第一段需要多带一个长度的权值(表示实际循环位移了多少)。
那么我们只需要算出来每种颜色分成 \(i\) 段,每段权值是 \(T\) 的方案数即可(还需要分第一段是否是所有的第一段)。
如果不是,那么即求:
\[g(i) = [x^n]F^i
\]
如果是,即求:
\[g(i) = [x^n]F^{i-1}*F'*x\\
= [x^{n-1}](F^{i})'*\frac{1}{i}\\
= [x^{n}](F^{i})*\frac{n}{i}
\]
那么求出来第一个就行了。这里求一个 \(F\) 的复合逆就能算。
而 \(F\) 的复合逆 \(H\):
\[\frac{H}{1-H+H^2}=x
\\
x*H^2-(x+1)*H+1=0
\\
H = \frac{x+1-\sqrt{(x+1)^2-4x^2}}{2x}
\]
最后 分治FFT 即可。
复杂度是 \(O(\sum a_i\log a_i +n\log n\log m)\)。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
inline int add(int a,int b){a+=b;return a>=mod?a-mod:a;}
inline int sub(int a,int b){a-=b;return a<0?a+mod:a;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
inline int qpow(int a,int b){int ret=1;for(;b;b>>=1,a=mul(a,a))if(b&1)ret=mul(ret,a);return ret;}
const int inv_2=(mod+1)>>1;
/* math */
const int N = 5e5+5;
int fac[N], ifac[N], inver[N];
inline void init(int n=5e5){
fac[0]=ifac[0]=1;for(int i=1;i<=n;i++)fac[i]=mul(fac[i-1],i);
ifac[n]=qpow(fac[n],mod-2);for(int i=n-1;i;i--)ifac[i]=mul(ifac[i+1],i+1);
inver[1]=1;for(int i=2;i<=n;i++)inver[i]=mul(mod-mod/i, inver[mod%i]);
}
typedef vector<int> poly;
namespace polynomial{
const int Ntt_Lim = 1e6+5;
int rev[Ntt_Lim],_w[Ntt_Lim];
const int G_mod = 3;
poly deri(poly a){
for(int i=0;i+1<(int)a.size();i++)a[i]=mul(a[i+1],i+1);
a.pop_back();return a;
}
poly inte(poly a){
a.push_back(0);for(int i=(int)a.size()-2;~i;i--)a[i+1]=mul(a[i],inver[i+1]/* qpow(i+1,mod-2) */);
a[0]=0;return a;
}
poly p_add(poly a,poly b){a.resize(max(a.size(),b.size()));for(size_t i=0;i<b.size();i++)a[i]=add(a[i],b[i]);return a;}
poly p_sub(poly a,poly b){a.resize(max(a.size(),b.size()));for(size_t i=0;i<b.size();i++)a[i]=sub(a[i],b[i]);return a;}
inline void _w_init(){
for(int step=1;step*2<=Ntt_Lim;step<<=1){
int wn = qpow(G_mod, (mod-1)/(step<<1));
for(int j=step,w=1;j<step<<1;j++,w=mul(w,wn)){
_w[j]=w;
}
}
}
inline void dft(int *f,int len,int type){
int l=0;while(1<<l<len)++l;
for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<len;i++)if(i<rev[i])swap(f[i],f[rev[i]]);
for(int step=1;step<len;step<<=1){
// int wn=_w[step];// int wn = qpow(G_mod, (mod-1)/(step<<1));
for(int i=0;i<len;i+=step<<1)for(int x,y,j=0;j<step;j++){
x=f[i+j],y=mul(_w[j+step],f[i+j+step]);
f[i+j]=add(x,y),f[i+j+step]=sub(x,y);
}
}
if(type==1)return;
int inv=qpow(len,mod-2);reverse(f+1,f+len);
for(int i=0;i<len;i++)f[i]=mul(f[i],inv);
}
poly ntt(poly a,poly b,int n,int m){
int l=1;while(l<n+m-1)l<<=1;
a.resize(l),b.resize(l);dft(&a[0],l,1),dft(&b[0],l,1);
for(int i=0;i<l;i++)a[i]=mul(a[i],b[i]);
dft(&a[0],l,-1);a.resize(n+m-1);
return a;
}
poly ntt(poly a,poly b){return ntt(a,b,a.size(),b.size());}
poly inv(poly &f,int n){
if(n==1)return poly(1,qpow(f[0],mod-2));
poly a(&f[0],&f[n]),b=inv(f,(n+1)/2);int l=1;while(l<n<<1)l<<=1;
a.resize(l),b.resize(l);
dft(&a[0],l,1),dft(&b[0],l,1);
for(int i=0;i<l;i++)a[i]=mul(b[i], sub(2,mul(a[i],b[i])));
dft(&a[0],l,-1);a.resize(n);
return a;
}
poly inv(poly a){return inv(a,a.size());}
poly sqrt(poly &f,int n){
if(n==1)return poly(1,1);
poly a(&f[0],&f[n]),b=sqrt(f,(n+1)/2);
b.resize(n);a=ntt(a,inv(b));a.resize(n);
for(int i=0;i<n;i++)a[i]=mul(inv_2, add(a[i],b[i]));
return a;
}
poly sqrt(poly a){return sqrt(a,a.size());}
poly ln(poly a){
int l=a.size();a=inte(ntt(deri(a),inv(a)));
a.resize(l);return a;
}
poly exp(poly& f,int n){
if(n==1)return poly(1,1);//f[0]=1
poly a(n,0),b=exp(f,(n+1)/2);
b.resize(n);a=ln(b);
for(int i=0;i<n;i++)a[i]=sub(f[i],a[i]);a[0]=add(a[0],1);
a=ntt(a,b);a.resize(n);
return a;
}
poly exp(poly a){return exp(a,a.size());}
pair<poly,poly> div(poly a,poly b){//assert(a.size()>=b.size())
if(a.size()<b.size())return make_pair(poly(1,0),a);
int n=a.size(),m=b.size();
poly ra=a,rb=b;
reverse(ra.begin(),ra.end()),reverse(rb.begin(),rb.end());
ra.resize(n-m+1),rb.resize(n-m+1);
poly c=ntt(ra,inv(rb));c.resize(n-m+1);reverse(c.begin(),c.end());
poly d=p_sub(a,ntt(b,c));d.resize(m-1);
return make_pair(c,d);
}
}
using namespace polynomial;
poly get_inverse(int n){
poly g(n+2);
g[0]=1,g[1]=2,g[2]=mod-3;
g=sqrt(g);
g[0]=sub(g[0], 1), g[1]=sub(g[1],1);
poly ret(n+1);
for(int i=0;i<=n;i++){
ret[i]=mul(sub(0,g[i+1]),inv_2);
}
return ret;
}
poly compinv;
poly f1[N], f2[N];
inline void solve(int id, int n){
poly g(&compinv[0], &compinv[0]+n);
for(int i=0;i<n;i++)g[i]=mul(g[i],n);
g=exp(g);
f1[id].resize(n+1);
f2[id].resize(n+1);
int iv=qpow(n,mod-2);
for(int i=1;i<=n;i++){
f1[id][i]=mul(mul(i, iv), g[n-i]);
f2[id][i-1]=g[n-i];
}
for(int i=0;i<=n;i++){
f1[id][i]=mul(f1[id][i],ifac[i]);
f2[id][i]=mul(f2[id][i],ifac[i]);
}
}
inline pair<poly,poly> dvd(int l,int r){
if(l==r)return make_pair(f1[l], f2[l]);
else{
int mid=(l+r)>>1;
pair<poly,poly> lp=dvd(l,mid), rp=dvd(mid+1,r);
return make_pair(ntt(lp.first,rp.first),p_add(ntt(lp.first,rp.second),ntt(lp.second,rp.first)));
}
}
int main()
{
_w_init();
init();
compinv=get_inverse(200000);
for(int i=0;i<200000;i++)compinv[i] = compinv[i+1];
compinv.resize(200000);
compinv = ln(inv(compinv));
int n;cin >> n;
for(int i=1;i<=n;i++){
int x;scanf("%d",&x);
solve(i,x);
}
poly ret=dvd(1,n).second;
int ans=0;
for(int i=0;i<(int)ret.size();i++){
ans=add(ans, mul(ret[i], fac[i]));
}
cout << ans << endl;
return 0;
}