QOJ #4812. Counting Sequence
首先显然有一个\(O(n^2)\)的dp:设 \(f_{i,j}\) 表示当前总和为 \(i\) ,结尾是 \(j\) 的方案数,转移是平凡的。
因为相邻两项差只有 \(1\) ,因此所有 \(a_i\) 和 \(a_1\) 的差不会超过 \(\sqrt {2n}+O(1)\),但是并没有什么用,因为我们不能直接记录每个数和 \(a_1\) 的差值,题目中有\(a_i>0\) 的限制。
观察发现当 \(a_1>\sqrt {2n}\) 的时候\(a\) 是不可能碰到 \(0\) 的,因此可以考虑将两者分开来算。
对于 \(a_1\leq \sqrt{2n}\),直接用 \(O(n^2)\) 的dp,时间复杂度 \(O(n\sqrt n)\)。
对于 \(a_1\geq \sqrt{2n}\) 的部分,可以记 \(dp_{i,j,S}\) 表示到了序列的第 \(i\) 位,当前和 \(a_1\) 的差值是 \(j\) ,总和为 \(S\) 的方案数,答案可以枚举 \(a_1\) 计算。
但是这个是 \(O(n^2)\) 的(雾
发现这种dp方式非常不优雅,可以倒着dp,每次在开头加上一个数,考虑对后面的数会造成什么影响,就可以做到 \(O(n\sqrt n)\) 。
#include<bits/stdc++.h>
#define Gc() getchar()
#define Me(x,y) memset(x,y,sizeof(x))
#define Mc(x,y) memcpy(x,y,sizeof(x))
#define d(x,y) ((m)*(x-1)+(y))
#define R(n) (rnd()%(n)+1)
#define Pc(x) putchar(x)
#define LB lower_bound
#define UB upper_bound
#define PB push_back
using ll=long long;using db=double;using lb=long db;using ui=unsigned;using ull=unsigned ll;
using namespace std;const int N=3e5+5,M=1.6e3+5,K=2e3+5,mod=998244353,Mod=mod-1;const db eps=1e-9;const int INF=1e9+7;mt19937 rnd(time(0));
int n,c,k,B;ll f[M][M],Ans,f1[N*2],f2[N*2];
int main(){
freopen("1.in","r",stdin);
int i,j;scanf("%d%d",&n,&c);k=sqrt(2*n)+10;B=2*k;
for(i=1;i<=k;i++) f[i][i]=1;for(i=1;i<=n;i++){
Me(f[i%B],0);if(i<=k) f[i][i]=1;
for(j=1;j<=min(i,B);j++)f[i%B][j]=(f[i%B][j]+f[(i-j)%B][j-1]+f[(i-j)%B][j+1]*c)%mod;
}
for(i=1;i<=B;i++) Ans+=f[n%B][i];
f1[n]=1;for(i=1;i<=2*n/(k+1);i++){
for(j=k+1;j*i<=2*n;j++) Ans+=f1[2*n-j*i];
Mc(f2,f1);Me(f1,0);
for(j=0;j<=2*n-i;j++) f1[j]=(f1[j]+f2[j+i]*c)%mod;
for(j=i;j<=2*n;j++) f1[j]=(f1[j]+f2[j-i])%mod;
}
printf("%lld\n",Ans%mod);
}