P7720 Estahv 题解

国赛前的最后一道多项式大题。

我承诺过我不会再做梦了。仍然放不下,但是只是仍然在死去之前看到一些东西。

言多必失。我会少说话的。只是为了——


第一眼观察数据范围和出题人知道是多项式题。考虑列生成函数。

第二眼知道这个卡特兰数 \(C(z)\) 一定是复合进某个函数里边的。那设这个函数是 \(F(x)\)。同时由 Alpha1022 老师特别喜欢多元函数和拉反,考虑用两个元刻画信息:用 \(z\) 刻画数字和,\(w\) 刻画黑格个数。每次填一段黑和一个白上去。一段黑考虑其选法:放 \(i\) 个黑格的选法数是 \(i+1\),即序列 \(0,2,3,4,\cdots\)。每个黑格为 \(zw\),于是生成函数显然 \(G(z)=\dfrac{zw(2-zw)}{(1-zw)^2}\)。那么放若干段就是 \(F(z)=\dfrac 1{1-2zG(z)}\),最后要求的即 \([z^n]F(C(z))\)

另类拉反一下得到 \([z^n]F(C(z))=[z^n]F(z)(1-z)(1+z)^{2n-1}\)。现在即要求 \(F(z)\)

整个拆开观察形式:

\[F(z)=\frac{1-2zw+z^2w^2}{1-2zw+z^2w^2-4z^2w+2z^3w^2} \]

首先这玩意显然是个微分有限的东西,直接整式递推艹过去,跑的飞快。

我直接把代码先贴在这里,虽然完全不懂它在干什么。回头去学习一下,顺便压压长度。

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <iostream>
#include <vector>
using namespace std;
const int mod=998244353;
int n,m;
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
int C[100010],inv[100010],f[100010],g[100010];
struct poly{
    int n,a[20];
    int calc(int x){
        int ans=0;
        for(int i=0,pw=1;i<=n;i++,pw=1ll*pw*x%mod)ans=(ans+1ll*pw*a[i])%mod;
        return ans;
    }
}p[20];
int cnt,deg;
void get(int n,int d){
    n++;
    static int a[310][310],tmp[310];
    int B=(n+2)/(d+2),R=n-B+1,C=B*(d+1);
    for(int i=0;i<R;i++){
        for(int j=0;j<B;j++){
            int val=g[i+j];
            for(int k=0;k<=d;k++){
                a[i][j*(d+1)+k]=val;
                val=1ll*val*(i+j)%mod;
            }
        }
    }
    int c=0;
    for(int i=0;i<C;i++){
        int mx=-1;
        for(int j=c;j<R;j++){
            if(a[j][i]){
                mx=j;break;
            }
        }
        if(mx==-1)break;
        for(int j=0;j<C;j++)swap(a[c][j],a[mx][j]);
        int inv=qpow(a[c][c],mod-2);
        for(int j=i;j<C;j++)a[c][j]=1ll*a[c][j]*inv%mod;
        for(int j=c+1;j<R;j++){
            int rate=a[j][i];
            for(int k=i;k<C;k++)a[j][k]=(a[j][k]-1ll*a[i][k]*rate%mod+mod)%mod;
        }
        c++;
    }
    for(int i=c-1;i>=0;i--){
        for(int j=i-1;j>=0;j--){
            a[j][c]=(a[j][c]-1ll*a[i][c]*a[j][i]%mod+mod)%mod;
        }
    }
    int od=c/(d+1);
    p[0].a[c%(d+1)]=1;
    for(int i=c-1;i>=0;i--)p[od-i/(d+1)].a[i%(d+1)]=(mod-a[i][c])%mod;
    for(int i=0;i<=od;i++){
        for(int j=0;j<=d;j++)tmp[j]=0;
        for(int k=0;k<=d;k++){
            int val=1;
            for(int j=k;j<=d;j++){
                tmp[k]=(tmp[k]+1ll*val*p[i].a[j])%mod;
                val=1ll*val*qpow(j-k+1,mod-2)%mod*(mod-i)%mod*(j+1)%mod;
            }
        }
        for(int j=0;j<=d;j++)p[i].a[j]=tmp[j];
    }
    cnt=od;deg=d;
}
int suf[100010],val[100010];
int main(){
    scanf("%d",&n);inv[1]=1;m=min(n>>1,80);suf[n+1]=1;
    for(int i=2;i<=n;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    C[0]=1;
    for(int i=1;i<=n;i++)C[i]=1ll*C[i-1]*(2*n-i+1)%mod*inv[i]%mod;
    for(int i=1;i<=n;i++)C[i]=1ll*C[i]*(n-i)%mod*inv[n]%mod;
    f[0]=1;
    for(int i=1;i<=m;i++){
        for(int j=n;j>=3;j--)f[j]=(4ll*f[j-2]-2ll*f[j-3]%mod+mod)%mod;
        f[2]=4ll*f[0]%mod;
        f[0]=f[1]=0;
        for(int j=2;j<=n;j++)f[j]=(f[j]+2ll*f[j-1]-f[j-2]+mod)%mod;
        for(int j=(i<<1);j<=n;j++)g[j-i]=(g[j-i]+1ll*f[j]*C[n-j])%mod;
    }
    if(n<=70){
        for(int i=0;i<=n;i++)printf("%d ",g[i]);puts("");
        return 0;
    }
    get(70,9);
    p[0].n=4;
    for(int i=1;i<=cnt;i++)p[i].n=deg;
    for(int i=71;i<=n;i++)val[i]=p[0].calc(i);
    for(int i=n;i>=71;i--)suf[i]=1ll*suf[i+1]*val[i]%mod;
    int inv=qpow(suf[71],mod-2);
    for(int i=71;i<=n;i++){
        int tmp=0;
        for(int j=1;j<=cnt;j++){
            tmp=(tmp-1ll*p[j].calc(i)*g[i-j]%mod+mod)%mod;
        }
        g[i]=1ll*tmp*suf[i+1]%mod*inv%mod;
        inv=1ll*inv*val[i]%mod;
    }
    for(int i=0;i<=n;i++)printf("%d ",g[i]);puts("");
    return 0;
}

好了说正经的。

事实上假如设 \(s_k=[z^{n-k}](1-z)(1+z)^{2n-1}\),那么我们要求的就是

\[\sum_{k=0}^ns_k[z^k]F(z) \]

转置:

\[\sum_{k=0}^ns_k[w^k]F(z) \]

换元 \(x=zw\),后边的东西就是

\[z^k[x^k]\frac{(1-x)^2}{1-(1+2z)2x+(1+2z)x^2} \]

根据线性递推的知识我们知道这是个线性递推,设它的第 \(k\) 项为 \(f_k\)。则有显然的递推关系:

\[f_k=(1+2z)2zf_{k-1}-(1+2z)z^2f_{k-2}+[k=0]-2z[k=1]+z^2[k=2] \]

可以列出转移矩阵 \(W_i\)。于是我们现在求的就是

\[\sum_{k=0}^ns_k\prod_{i=0}^kW_i \]

这是 Feux Follets。转置直接算,复杂度 \(O(n\log^2n)\)。但是常数太大(转置乘带个 \(27\) 的常数),过不去。(感觉开 10s 能过)

int n;
poly h;
struct node{
    poly a[3][3];
    void resize(int n){
        a[0][0].resize(n);a[0][1].resize(n);a[0][2].resize(n);
        a[1][0].resize(n);a[1][1].resize(n);a[1][2].resize(n);
        a[2][0].resize(n);a[2][1].resize(n);a[2][2].resize(n);
    }
    node operator*(const node&s)const{
        node ans;
        ans.a[0][0]=a[0][0]*s.a[0][0]+a[0][1]*s.a[1][0]+a[0][2]*s.a[2][0];
        ans.a[0][1]=a[0][0]*s.a[0][1]+a[0][1]*s.a[1][1]+a[0][2]*s.a[2][1];
        ans.a[0][2]=a[0][0]*s.a[0][2]+a[0][1]*s.a[1][2]+a[0][2]*s.a[2][2];
        ans.a[1][0]=a[1][0]*s.a[0][0]+a[1][1]*s.a[1][0]+a[1][2]*s.a[2][0];
        ans.a[1][1]=a[1][0]*s.a[0][1]+a[1][1]*s.a[1][1]+a[1][2]*s.a[2][1];
        ans.a[1][2]=a[1][0]*s.a[0][2]+a[1][1]*s.a[1][2]+a[1][2]*s.a[2][2];
        ans.a[2][0]=a[2][0]*s.a[0][0]+a[2][1]*s.a[1][0]+a[2][2]*s.a[2][0];
        ans.a[2][1]=a[2][0]*s.a[0][1]+a[2][1]*s.a[1][1]+a[2][2]*s.a[2][1];
        ans.a[2][2]=a[2][0]*s.a[0][2]+a[2][1]*s.a[1][2]+a[2][2]*s.a[2][2];
        return ans;
    }
    node operator^(const node&s)const{
        node ans;
        ans.a[0][0]=(a[0][0]^s.a[0][0])+(a[0][1]^s.a[0][1])+(a[0][2]^s.a[0][2]);
        ans.a[0][1]=(a[0][0]^s.a[1][0])+(a[0][1]^s.a[1][1])+(a[0][2]^s.a[1][2]);
        ans.a[0][2]=(a[0][0]^s.a[2][0])+(a[0][1]^s.a[2][1])+(a[0][2]^s.a[2][2]);
        ans.a[1][0]=(a[1][0]^s.a[0][0])+(a[1][1]^s.a[0][1])+(a[1][2]^s.a[0][2]);
        ans.a[1][1]=(a[1][0]^s.a[1][0])+(a[1][1]^s.a[1][1])+(a[1][2]^s.a[1][2]);
        ans.a[1][2]=(a[1][0]^s.a[2][0])+(a[1][1]^s.a[2][1])+(a[1][2]^s.a[2][2]);
        ans.a[2][0]=(a[2][0]^s.a[0][0])+(a[2][1]^s.a[0][1])+(a[2][2]^s.a[0][2]);
        ans.a[2][1]=(a[2][0]^s.a[1][0])+(a[2][1]^s.a[1][1])+(a[2][2]^s.a[1][2]);
        ans.a[2][2]=(a[2][0]^s.a[2][0])+(a[2][1]^s.a[2][1])+(a[2][2]^s.a[2][2]);
        return ans;
    }
}P[20],Q[400010];
#define lson rt<<1
#define rson rt<<1|1
void solve1(int rt,int l,int r){
    if(l==r){
        Q[rt].a[0][0].f.emplace_back(0);Q[rt].a[0][0].f.emplace_back(2);Q[rt].a[0][0].f.emplace_back(4);
        Q[rt].a[0][1].f.emplace_back(0);Q[rt].a[0][1].f.emplace_back(0);Q[rt].a[0][1].f.emplace_back(mod-1);Q[rt].a[0][1].f.emplace_back(mod-2);
        if(l==0){
            Q[rt].a[0][2].f.emplace_back(1);
        }
        else if(l==1){
            Q[rt].a[0][2].f.emplace_back(0);Q[rt].a[0][2].f.emplace_back(mod-2);
        }
        else if(l==2){
            Q[rt].a[0][2].f.emplace_back(0);Q[rt].a[0][2].f.emplace_back(0);Q[rt].a[0][2].f.emplace_back(1);
        }
        else Q[rt].a[0][2].f.emplace_back(0);
        Q[rt].a[1][0].f.emplace_back(1);
        Q[rt].a[1][1].f.emplace_back(0);
        Q[rt].a[1][2].f.emplace_back(0);
        Q[rt].a[2][0].f.emplace_back(0);
        Q[rt].a[2][1].f.emplace_back(0);
        Q[rt].a[2][2].f.emplace_back(1);
        return;
    }
    int mid=(l+r)>>1;
    solve1(lson,l,mid);solve1(rson,mid+1,r);
    if(r<n)Q[rt]=Q[rson]*Q[lson];
}
int ans[100010];
void solve2(int rt,int l,int r,int dep){
    P[dep].resize(min((int)P[dep].a[0][2].size(),2*(r-l+2)));
    if(l==r){
        P[dep]=P[dep]^Q[rt];
        ans[l]=P[dep].a[0][0][0];
        return;
    }
    int mid=(l+r)>>1;
    P[dep+1]=P[dep];
    solve2(lson,l,mid,dep+1);
    P[dep+1]=P[dep]^Q[lson];
    solve2(rson,mid+1,r,dep+1);
}
int main(){
    n=read();init(n+1<<1);
    poly f(2*n);
    h.resize(n+1);
    for(int i=0;i<=2*n-1;i++)f[i]=C(2*n-1,i);
    for(int i=2*n-1;i>=1;i--)f[i]=sub(f[i],f[i-1]);
    for(int i=0;i<=n;i++)h[i]=f[n-i];
    solve1(1,0,n);
    P[0].a[0][0].f.emplace_back(0);
    P[0].a[0][1].f.emplace_back(0);
    P[0].a[0][2]=h;
    P[0].a[1][0].f.emplace_back(0);
    P[0].a[1][1].f.emplace_back(0);
    P[0].a[1][2].f.emplace_back(0);
    P[0].a[2][0].f.emplace_back(0);
    P[0].a[2][1].f.emplace_back(0);
    P[0].a[2][2].f.emplace_back(0);
    solve2(1,0,n,0);
    for(int i=0;i<=n;i++)print(ans[i]),putchar('\n');
    return 0;
}
posted @ 2023-07-11 21:32  gtm1514  阅读(40)  评论(0编辑  收藏  举报