[atARC137F]Overlaps
考虑将$2n$个位置离散(可以忽略同位置),问题即转换为以下模型
将$2n$个位置(等概率)配成$n$对,每一对左右分别打上$\pm 1$,任意前缀和$\le K$的概率
总方案数显然为${2n\choose n}\frac{n!}{2^{n}}$,下面考虑合法的方案数——
记$x_{i}$为第$i$个$-1$之前的前缀和(不包括该$-1$),那么答案即
$$
\sum_{\begin{matrix}\forall 1\le i\le n,1\le x_{i}\le K\\\forall 1\le i<n,x_{i}-1\le x_{i+1}\\x_{n}=1\end{matrix}}\prod_{i=1}^{n}x_{i}
$$
(前者的条件可以构造出$\pm 1$序列,后者即从左到右确定每一个$-1$所配对的$1$)
对第2个条件容斥,即要求相邻两个元素至少减小2,并定义以下信息——
$$
g_{n}=\sum_{\begin{matrix}\forall 1\le i\le n,1\le x_{i}\le K\\\forall 1\le i<n,x_{i+1}\le x_{i}+2\\x_{n}=1\end{matrix}}\prod_{i=1}^{n}x_{i}\ ,\
f_{n}=\sum_{\begin{matrix}\forall 1\le i\le n,1\le x_{i}\le K\\\forall 1\le i<n,x_{i+1}\le x_{i}+2\end{matrix}}\prod_{i=1}^{n}x_{i}\ ,\
F_{n}=\begin{cases}1&n=0\\\sum_{i=0}^{n-1}(-1)^{n-i-1}f_{n-i}F_{i}&n\ge 1\end{cases}
$$
关于$g$和$f$,可以对值域分治并记录$l$和$r$是否选择对应的4个序列,使用NTT合并,时间复杂度为$o(n\log^{2}n)$
记$\begin{cases}f(x)=\sum_{n\ge 1}(-1)^{n-1}f_{n}x^{n}\\F(x)=\sum_{n\ge 0}F_{n}x^{n}\end{cases}$,代入$F_{n}$的转移可得$F(x)=\frac{1}{1-f(x)}$,可以$o(n\log n)$多项式求逆计算
最终,答案即$\sum_{i=0}^{n-1}(-1)^{n-i-1}g_{n-i}F_{i}$,可以$o(n)$直接计算
总复杂度为$o(n\log^{2}n)$,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 400005 4 #define L 19 5 #define mod 998244353 6 #define ll long long 7 #define vi vector<int> 8 int n,K,sum,ans,fac[N],inv[N]; 9 vi a,b; 10 struct Data{ 11 int n; 12 vi a[4]; 13 }g,f[L]; 14 int C(int n,int m){ 15 return (ll)fac[n]*inv[m]%mod*inv[n-m]%mod; 16 } 17 int qpow(int n,int m){ 18 int s=n,ans=1; 19 while (m){ 20 if (m&1)ans=(ll)ans*s%mod; 21 s=(ll)s*s%mod,m>>=1; 22 } 23 return ans; 24 } 25 namespace Poly{ 26 int n,inv,rev[1<<L],w[2][L][1<<L]; 27 void Init(){ 28 for(int i=0;i<L;i++){ 29 w[0][i][0]=w[1][i][0]=1; 30 w[0][i][1]=qpow(3,(mod-1>>i+1)); 31 w[1][i][1]=qpow(w[0][i][1],mod-2); 32 for(int j=2;j<(1<<i);j++){ 33 w[0][i][j]=(ll)w[0][i][j-1]*w[0][i][1]%mod; 34 w[1][i][j]=(ll)w[1][i][j-1]*w[1][i][1]%mod; 35 } 36 } 37 } 38 void init(int m){ 39 n=inv=1; 40 while (n<m)n<<=1,inv=(ll)inv*(mod+1>>1)%mod; 41 for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)+(i&1)*(n>>1); 42 } 43 void ntt(vi &a,int p=0){ 44 for(int i=0;i<n;i++) 45 if (i<rev[i])swap(a[i],a[rev[i]]); 46 for(int i=2,t=0;i<=n;i<<=1,t++) 47 for(int j=0;j<n;j+=i) 48 for(int k=0;k<(i>>1);k++){ 49 int x=a[j+k],y=(ll)w[p][t][k]*a[j+k+(i>>1)]%mod; 50 a[j+k]=(x+y)%mod,a[j+k+(i>>1)]=(x-y+mod)%mod; 51 } 52 if (p){ 53 for(int i=0;i<n;i++)a[i]=(ll)a[i]*inv%mod; 54 } 55 } 56 vi get_inv(vi a,int m){ 57 vi ans; 58 if (m==1){ 59 ans.push_back(qpow(a[0],mod-2)); 60 return ans; 61 } 62 vi s=get_inv(a,(m+1>>1)); 63 init(m<<1),ans=s,a.resize(n),s.resize(n),ans.resize(m); 64 for(int i=m;i<n;i++)a[i]=0; 65 ntt(a),ntt(s);for(int i=0;i<n;i++)s[i]=(ll)a[i]*s[i]%mod*s[i]%mod;ntt(s,1); 66 for(int i=0;i<m;i++)ans[i]=((ans[i]<<1)%mod-s[i]+mod)%mod; 67 return ans; 68 } 69 void calc(int d,int l,int r){ 70 if (l==r){ 71 f[d].n=2,f[d].a[0]=vi{1,0},f[d].a[1]=f[d].a[2]=vi{0,0},f[d].a[3]=vi{0,l}; 72 return; 73 } 74 int mid=(l+r>>1); 75 calc(d+1,l,mid),f[d]=f[d+1],calc(d+1,mid+1,r); 76 f[d].n+=f[d+1].n-1,init(f[d].n); 77 for(int i=0;i<4;i++){ 78 f[d].a[i].resize(n),f[d+1].a[i].resize(n); 79 ntt(f[d].a[i]),ntt(f[d+1].a[i]); 80 } 81 g=f[d]; 82 for(int i=0;i<4;i++) 83 for(int j=0;j<n;j++){ 84 f[d].a[i][j]=(ll)g.a[i&2][j]*f[d+1].a[i&1][j]%mod; 85 f[d].a[i][j]=(f[d].a[i][j]+(ll)g.a[i&2|1][j]*f[d+1].a[i&1][j])%mod; 86 f[d].a[i][j]=(f[d].a[i][j]+(ll)g.a[i&2][j]*f[d+1].a[i&1|2][j])%mod; 87 } 88 for(int i=0;i<4;i++)ntt(f[d].a[i],1),f[d].a[i].resize(f[d].n); 89 } 90 }; 91 int main(){ 92 fac[0]=inv[0]=inv[1]=1,Poly::Init(); 93 for(int i=1;i<N;i++)fac[i]=(ll)fac[i-1]*i%mod; 94 for(int i=2;i<N;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 95 for(int i=1;i<N;i++)inv[i]=(ll)inv[i-1]*inv[i]%mod; 96 scanf("%d%d",&n,&K); 97 Poly::calc(0,1,K),a.resize(n+1),b.resize(n+1); 98 for(int i=0;i<4;i++)f[0].a[i].resize(n+1); 99 for(int i=0;i<=n;i++){ 100 a[i]=(f[0].a[2][i]+f[0].a[3][i])%mod; 101 for(int j=0;j<4;j++)b[i]=(b[i]+f[0].a[j][i])%mod; 102 if (i&1)b[i]=mod-b[i]; 103 else a[i]=mod-a[i]; 104 } 105 b=Poly::get_inv(b,n+1); 106 for(int i=0;i<n;i++)ans=(ans+(ll)a[n-i]*b[i])%mod; 107 sum=(ll)C((n<<1),n)*fac[n]%mod*qpow(2,mod-n-1)%mod; 108 ans=(ll)ans*qpow(sum,mod-2)%mod; 109 printf("%d\n",ans); 110 return 0; 111 }