【集训队作业2018】count
Solution
虽然做的时候没考虑太多,但是仔细想想还是挺有意思的。
我们注意到似乎只有 \(1\to m\) 都出现过这个限制比较难以处理,但是我们打表发现不考虑这个限制答案是对的(\(m>n\) 时显然为 \(0\))。其实因为我们只关心每个区间最大值位置,那么对于一种情况如果我们能够造出 \([1,m]\) 都出现过即可,发现如果我们一开始先离散化,然后按值从大到小处理,值相同的按位置从小到大处理,然后顺序赋值即可,显然可以赋成 \([1,m]\)。
那么接下来的步骤就很简单了。我们设长度为 \(f_{s,l}\) 表示长度为 \(l\) 的序列填出 \([1,s]\) 的方案数。可以得到转移式:
\[f_{s,l}=\sum_{x=1} f_{s-1,x-1}f_{s,l-x}
\]
特别的 \(f_{s,0}=1\)。答案即是 \(f_{m,n}\)。
那么我们把 \(f_{s,l}\) 构造生成函数 \(F_s(x)\),那么可以得到:
\[F_s(x)=1+xF_{s-1}(x)F_s(x)
\]
\[\Rightarrow F_s=1/(1-xF_{s-1}(x))
\]
然后你发现如果我们设 \(F_s(x)=t_s(x)/g_s(x)\),那么我们有:
\[t_s(x)=g_{s-1}(x),g_s(x)=g_{s-1}(x)-xt_{s-1}(x)
\]
你发现这个玩意就可以用矩阵转移了。复杂度就是 \(\Theta(n\log^2 n)\) 的。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define mod 998244353
#define MAXN 400005
// char buf[1<<21],*p1=buf,*p2=buf;
// #define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}
#define mod 998244353
#define Gi 332748118
#define G 3
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
return res;
}
void Add (int &a,int b){a = add (a,b);}
void Sbu (int &a,int b){a = dec (a,b);}
typedef vector<int> poly;
int w[MAXN],rev[MAXN];
#define SZ(A) ((int)A.size())
void init_ntt (){
int lim = 1 << 18;
for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << 17);
int Wn = qkpow (3,(mod - 1) / lim);w[lim >> 1] = 1;
for (Int i = lim / 2 + 1;i < lim;++ i) w[i] = mul (w[i - 1],Wn);
for (Int i = lim / 2 - 1;i;-- i) w[i] = w[i << 1];
}
void ntt (poly &a,int type){
#define G 3
#define Gi 332748118
static int d[MAXN];int lim = a.size();
for (Int i = 0,z = 18 - __builtin_ctz(lim);i < lim;++ i) d[rev[i] >> z] = a[i];
for (Int i = 1;i < lim;i <<= 1)
for (Int j = 0;j < lim;j += i << 1)
for (Int k = 0;k < i;++ k){
int x = mul (w[i + k],d[i + j + k]);
d[i + j + k] = dec (d[j + k],x),d[j + k] = add (d[j + k],x);
}
for (Int i = 0;i < lim;++ i) a[i] = d[i] % mod;
if (type == -1){
reverse (a.begin() + 1,a.begin() + lim);
for (Int i = 0,Inv = qkpow (lim,mod - 2);i < lim;++ i) a[i] = mul (a[i],Inv);
}
#undef G
#undef Gi
}
poly operator * (poly A,poly B){
int lim = 1,l = 0,len = SZ(A) + SZ(B) - 1;
while (lim < SZ(A) + SZ(B)) lim <<= 1,++ l;
A.resize (lim),B.resize (lim),ntt (A,1),ntt (B,1);
for (Int i = 0;i < lim;++ i) A[i] = mul (A[i],B[i]);
ntt (A,-1),A.resize (len);
return A;
}
poly operator + (poly A,poly B){
int len = max (SZ(A),SZ(B));A.resize (len),B.resize (len);
for (Int i = 0;i < len;++ i) Add (A[i],B[i]);
return A;
}
poly operator - (poly A,poly B){
int len = max (SZ(A),SZ(B));A.resize (len),B.resize (len);
for (Int i = 0;i < len;++ i) A[i] = dec (A[i],B[i]);
return A;
}
poly operator * (poly A,int B){
for (Int i = 0;i < SZ(A);++ i) A[i] = mul (A[i],B);
return A;
}
poly operator * (int B,poly A){
for (Int i = 0;i < SZ(A);++ i) A[i] = mul (A[i],B);
return A;
}
poly inv (poly A,int n){
if (n == 1) return poly(1,qkpow (A[0],mod - 2));
poly F1,F0 = inv (A,(n + 1) >> 1);
F1 = F0 * 2 - F0 * F0 * A,F1.resize (n);
return F1;
}
poly inv(poly A){return inv(A,SZ(A));}
struct Matrix{
poly mat[2][2];
Matrix(){for (Int i = 0;i < 2;++ i) for (Int j = 0;j < 2;++ j) mat[i][j].clear ();}
poly * operator [] (const int &key){return mat[key];}
Matrix operator * (const Matrix &p)const{
Matrix New;
for (Int i = 0;i < 2;++ i)
for (Int j = 0;j < 2;++ j)
for (Int k = 0;k < 2;++ k) New[i][k] = New[i][k] + mat[i][j] * p.mat[j][k];
return New;
}
Matrix operator ^ (int b){
Matrix res,a = *this;
for (Int i = 0;i < 2;++ i) res[i][i].resize (1),res[i][i][0] = 1;
for (;b;b >>= 1,a = a * a) if (b & 1) res = res * a;
return res;
}
};
signed main(){
int n,m;
read (n,m);
if (m >= n + 1) return puts ("0") & 0;
init_ntt ();Matrix B;
B[0][0].resize (1),B[0][1].resize (2),B[0][1][1] = mod - 1,
B[1][0].resize (1),B[1][0][0] = 1,B[1][1].resize (1),B[1][1][0] = 1;
B = B ^ m;
poly up = B[0][0] + B[1][0],dow = B[0][1] + B[1][1];
poly S = up * inv(dow,n + 1);
write (S[n]),putchar ('\n');
return 0;
}