「九省联考 2018」秘密袭击
「九省联考 2018」秘密袭击
解题思路
设 \(a_i\) 为树上联通块第 \(k\) 大大于等于 \(i\) 的个数,那么答案就是
设 \(dp[u][i][j]\) 表示以 \(u\) 为根的联通子树,大于等于 \(i\) 的点有 \(j\) 个的方案数,把最后一维写成生成函数的形式
转移也可以用生成函数的形式表示
用三模数 \(\text{NTT}\) 转移可以 \(\mathcal O(n^3 \log n)\) ,后面一维子树大小有关点分优化一下可以 \(\mathcal O(n^2 \log^2n)\) ,甚至不如暴力,其实标算也跑不过暴力。
这种时候就可以往点值这方向考虑,这里的操作点值是直接相乘,所以用线段树合并维护非常方便,另外还需要记一下所有 \(f(u,i)\) 和的多项式的点值。
我们将 \(n+1\) 个点带进去算出其在每一个 \(\sum_u f(u,i)\) 下的点值,用线段树合并维护的话需要的操作有:
维护两个 \((f,g)\) 表示点值以及子树点值和
- \((f,g)\rightarrow(f+1,g)\)
- \((f,g)\rightarrow (f,g+f)\)
- \((f,g)\rightarrow(f*x,g)\)
- \((f,g)\rightarrow(f*f_v,g+g_v)\)
这些操作都可以写作函数 \(tr(a,b,c,d)\) 的形式,表示 \((f,g)\rightarrow (af+b,g+cf+d)\) 。
这个函数满足结合律并且是封闭的,推一下复合就可以合并标记了,然后线段树合并的时候因为要支持下传标记,所以要在某个点没有左右儿子的时候将已经确定的点值转移到另外一个点上,不然复杂度会挂。
听说这个套路叫整体 \(\text{DP}\) ,我不太会这一套理论QwQ。
最后需要用拉格朗日插值把系数全部求出来,考虑拉格朗日插值最基本的式子
令 \(w =\prod (x-x_i)\),就能得到
先求出 \(w\) 的每一项的系数,然后每次把 \((x-x_i)\) 除掉,退位维护一下系数乘上其它常数加到答案的多项式上就可以了,这东西也可以分治 \(\text{NTT}\) 优化,不过这题没有啥必要。
一通操作下来整道题复杂度是 \(\mathcal O (n^2\log n)\) ,跑的大概是暴力的 \(10\) 倍。
code
/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 5005, mod = 64123;
struct Node{
int a, b, c, d;
inline void init(){ a = 1, b = c = d = 0; }
Node operator * (const Node & A) const{
return (Node){
(int) (1ll * A.a * a % mod),
(int) ((1ll * A.a * b + A.b) % mod),
(int) ((1ll * A.c * a + c) % mod),
(int) ((1ll * A.c * b + A.d + d) % mod)
};
}
};
vector<int> g[N];
int Y[N], d[N], rt[N], n, W, k;
namespace Seg{
#define mid ((l + r) >> 1)
Node tag[N*32];
int lc[N*32], rc[N*32], st[N*32], top, size;
inline int newnode(){
return ++size, lc[size] = rc[size] = 0, tag[size].init(), size;
}
inline void clear(int x){
lc[x] = rc[x] = 0, tag[x].init(), st[++top] = x;
}
inline void pushdown(int u){
if(!lc[u]) lc[u] = newnode();
if(!rc[u]) rc[u] = newnode();
tag[lc[u]] = tag[lc[u]] * tag[u];
tag[rc[u]] = tag[rc[u]] * tag[u], tag[u].init();
}
inline void change(int &u, int l, int r, int L, int R, Node x){
if(!u) u = newnode();
if(l >= L && r <= R) return (void) (tag[u] = tag[u] * x);
pushdown(u);
if(L <= mid) change(lc[u], l, mid, L, R, x);
if(mid < R) change(rc[u], mid + 1, r, L, R, x);
}
inline int merge(int x, int y){
if(!x || !y) return x + y;
if(!lc[x] && !rc[x]) swap(x, y);
if(!lc[y] && !rc[y])
tag[x] = tag[x] * (Node){tag[y].b, 0, 0, tag[y].d};
else{
pushdown(x), pushdown(y);
lc[x] = merge(lc[x], lc[y]);
rc[x] = merge(rc[x], rc[y]);
}
return x;
}
inline void getnode(int u, int l, int r, int x){
if(!u) return;
if(l == r) return (void) ((Y[x] += tag[u].d) %= mod);
pushdown(u);
getnode(lc[u], l, mid, x), getnode(rc[u], mid + 1, r, x);
}
}
inline void dfs(int u, int fa, int x){
Seg::change(rt[u], 1, W, 1, W, (Node){0, 1, 0, 0});
for(int i = 0; i < (int) g[u].size(); i++){
int v = g[u][i];
if(v == fa) continue;
dfs(v, u, x), rt[u] = Seg::merge(rt[u], rt[v]), rt[v] = 0;
}
if(d[u]) Seg::change(rt[u], 1, W, 1, d[u], (Node){x, 0, 0, 0});
Seg::change(rt[u], 1, W, 1, W, (Node){1, 0, 1, 0});
Seg::change(rt[u], 1, W, 1, W, (Node){1, 1, 0, 0});
}
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline void dec(int *a, int *b, int x){
static int tmp[N];
for(int i = 0; i <= n + 1; i++) tmp[i] = a[i];
for(int i = n + 1; i >= 1; i--){
b[i-1] = tmp[i];
(tmp[i-1] += 1ll * x * tmp[i] % mod) %= mod;
}
}
inline int Lagrange() {
static int G[N], F[N], inv[N], ans; G[0] = 1;
for(int i = 1; i <= n; i++) inv[i] = Pow(i, mod - 2);
for(int i = n + 1; i >= 1;--i)
for(int j = n + 1; j >= 0; j--){
G[j] = 1ll * (mod - i) * G[j] % mod;
if(j) (G[j] += G[j-1]) %= mod;
}
for(int i = 1; i <= n + 1; i++){
dec(G, F, i); int res = 0;
for(int j = k; j <= n; j++) (res += F[j]) %= mod;
for(int j = 1; j <= n + 1; j++) if(i != j) {
if(j < i) res = 1ll * res * inv[i-j] % mod;
else res = 1ll * res * (mod - inv[j-i]) % mod;
}
res = 1ll * res * Y[i] % mod; (ans += res) %= mod;
}
return ans;
}
int main(){
read(n), read(k), read(W);
for(int i = 1; i <= n; i++) read(d[i]);
for(int i = 1, x, y; i < n; i++){
read(x), read(y);
g[x].push_back(y), g[y].push_back(x);
}
for(int i = 1; i <= n + 1; i++){
dfs(1, 0, i);
Seg::getnode(rt[1], 1, W, i);
rt[1] = Seg::size = 0;
}
cout << Lagrange() << endl;
return 0;
}