[CF1096G]Lucky Tickets

\(\text{Problem}\)题目链接

\(\text{Solution}\)

考虑 \(dp\)。设 \(dp_{i,j}\) 表示前 \(i\) 个位置取数得到和为 \(j\) 的方案数,那么有个很朴素的转移方程就是:

\[\qquad dp_{i,j}=\sum\limits_{p=1}^{k}dp_{i-1,j-d_{p}} \qquad \]

答案即为 \(\sum\limits_{i}dp_{\frac{n}{2},i}\times dp_{\frac{n}{2},i}\)

这个 \(dp\)\(O(n^2)\) 的,需要优化。考虑到对答案有贡献的 \(i\)\(O(n)\) 级别的,故可以求出 \(\sum\limits_{i}dp_{\frac{n}{2},i}\) 的每一项,然后线性求和。

优化转移方程。设集合 \(S=\{d_{i}\}\),则可以构造函数 \(f_{i}=[i\in S]\),那么转移方程变为:

\[\qquad dp_{i,j}=\sum\limits_{p=0}^{j}dp_{i-1,j-p}\times f_{p} \qquad \]

至于集合 \(S\) 的拓展,把 \(f\) 可以自卷。现在考虑 \(dp_{i}=dp_{i-1}*f\),发现 \(dp_{0}=\{1,0,0,...\}\),故显然有 \(dp_{\frac{n}{2}}=f^{\frac{n}{2}}\)\(n\) 较小,直接朴素地做快速幂即可。注意到多项式长度要动态开,否则会 \(\text{TLE}\)

\(\text{Code}\)

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=4000010, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,K,T,rev[N],f[25][2],d[N],ans,g[N];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void NTT(int *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2, cnt=1;i<=T;cnt++, i<<=1)
    {
        int wn=f[cnt][type];
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0, inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void QPow(int p)
{
    g[0]=1;
    T=1; while(T<=20) T<<=1;
    Get_Rev();
    for(;p;p>>=1)
    {
        if(p&1)
        {
            int qwq=T-1;
            for(ri int i=T-1;~i;i--) if(d[i]) { qwq=i; break; }
            T=1; while(T<=qwq*2) T<<=1; Get_Rev();
            NTT(g,1), NTT(d,1);
            for(ri int i=0;i<T;i++) g[i]=1ll*g[i]*d[i]%Mod;
            NTT(g,0), NTT(d,0);
        }
        int qwq=T-1;
        for(ri int i=T-1;~i;i--) if(d[i]) { qwq=i; break; }
        T=1; while(T<=qwq*2) T<<=1; Get_Rev();
        NTT(d,1);
        for(ri int i=0;i<T;i++) d[i]=1ll*d[i]*d[i]%Mod;
        NTT(d,0);
    }
}
signed main()
{
    f[23][1]=ksc(3,119), f[23][0]=ksc(332748118,119);
    for(ri int i=22;~i;i--) f[i][1]=1ll*f[i+1][1]*f[i+1][1]%Mod, f[i][0]=1ll*f[i+1][0]*f[i+1][0]%Mod;
    n=read(), K=read();
    for(ri int i=1;i<=K;i++) { int x=read(); d[x]=1; }
    QPow(n/2);
    for(ri int i=0;i<=n*5;i++) ans=(ans+1ll*g[i]*g[i]%Mod)%Mod;
    printf("%d\n",ans);
    return 0;
}
posted @ 2020-08-07 18:31  zkdxl  阅读(80)  评论(0编辑  收藏  举报