[九省联考2018]秘密袭击coat
题目描述
Luogu
题目大意:给一棵\(n\)个点的树,求所有联通块中第\(K\)大的权值\(W_k\)之和。
数据范围:\(K\leq n\leq 1666\) , \(W_{max}\leq 1666\),答案对\(64123\)取模,时限\(7sec\)。
题解
\(Ans = \sum_{S} Kth\ of\ S = \sum_{v = 1}^W v\sum_{S} [Kth\ of\ S\ = v]\)
\(Ans = \sum_{v = 1}^W \sum_{S} [Kth\ of\ S \ge v]\)
我们令\(cnt(S,v)\)表示\(S\)中权值大于等于\(v\)的节点个数。
\(Ans = \sum_{v=1}^W \sum_{S} [cnt(S,v)\ge K]\)
然后就可以设计一个\(dp\)了,设\(f_{u,v,j}\)表示\(u\)为根的联通块中,\(W\ge v\)联通块数。
转移显然:\(f_{u,v,j} = \prod_{son_i} f_{son_i,v,k_{son_i}}\),其中\(\sum_{son_i} k_{son_i} = j - [W_u\ge v]\)。
根据上述可得:\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j}\)。
卡一下\(j\)那一维,用树上\(lca\)那套分析一下,复杂度就是严格\(O(n^2W)\)的。
然后竟然就能成功AC原题数据了qaq......
不管了。
可以注意到\(j\)那一维是一个背包,所以自然就能想到生成函数。
设\(F_{u,v} = \sum_{j=0}^n f_{u,v,j} x^j\),设\(G_{u,v} = \sum_{s\in Tree_u} F_{s,v} = \sum_{j=0}^n g_{u,v,j}x^j\),那么\(F\)的转移就是一个卷积了。
即初始化后,\(F_{u,v} = \prod_{son_i} F_{son_i,v}\)。
\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j} = \sum_{v=1}^W \sum_{j=K}^n g_{1,v,j} = \sum_{j=K}^n \sum_{v=1}^W g_{1,v,j}\)。
我们的目标即求\(\sum_{v=1}^W G_{1,v}\)的每一项系数。
每次转移都卷积显然是傻子。
熟悉\(FFT\)原理的童鞋都知道先用点值表示,最后再拉格朗日插值回去即可得到每一项的系数。
外部枚举\(x = 1,2...n+1\),下面我们来考虑如何计算\(F\)、\(G\)的点值表示。
注意到由于转了点值表示,所以多项式乘法是对位乘法,就可以用线段树合并维护了。
线段树每个叶子节点\(v\)维护点值\(F_{u,v},G_{u,v}\),我们在每个点\(u\)要干这些事:
- 初始化:把区间\([1,W_u]\)的\(F\)加上\(x\),把区间\((W_u,W_{max}]\)的\(F\)加上\(1\)。
- 合并:把\(F_{son_i}\)对位相乘,\(G_{son_i}\)对位相加。
- 结束:把\(G_{u}\)加上\(F_u\),把\(F_u\)加\(1\),便于下次转移。
维护一个标记\((a,b,c,d)\),表示\((F,G)\ \to\ F(aF + b , cF + d + G)\)。
那么初始化对应标记\((1,x,0,0)\)、\((1,1,0,0)\)。结束对应标记\((1,1,1,0)\)。
合并的时候,线段树合并,设合并\(x\)、\(y\)。
若其中一个点(以\(y\)为例)没有儿子了,也就是说下面的节点的\((F,G)\)都是一样的了。
此时\(F_y = b\)、\(G_y = d\),对应\((F_x,G_x)\to (F_xF_y,G_x+G_y)\),修改\(x\)的标记,然后\(return\)即可。
最后把根节点的线段树遍历一遍,标记都放下去后叶子节点\(v\)的标记中的\(d\)即\(G_{1,v}\)。
最后套一下拉格朗日公式:\(H(x) = \sum_{i=1}^{n+1} H(i) \prod_{j\neq i} \frac{x-j}{i-j}\)。
你说每次算\(\prod(x - j)\)是\(O(n^2)\)的?
多项式除法了解一下蟹蟹qwq......先别管\(j\neq i\)就行了。
复杂度\(O(n^2logW)\),被暴力吊着打,代码其实挺短的。
实现代码
#include<bits/stdc++.h>
#define IL inline
#define _ 2005
#define ll long long
#define ld long double
using namespace std ;
IL ll gi(){
ll data = 0 , m = 1; char ch = 0 ;
while((ch != '-') && (ch < '0' || ch > '9')) ch = getchar() ;
if(ch == '-'){m = 0 ; ch = getchar() ; }
while(ch >= '0' && ch <= '9'){data = (data<<1) + (data<<3) + (ch^48) ; ch = getchar() ; }
return (m) ? data : -data ;
}
#define mod 64123
int n , K , W , oo , stk[_ * _] , ans[_] , f[_] , fz[_] , H[_] , Ans , inv[mod] , rt[_] , val[_] ;
struct _Edge{
int to , next ;
}Edge[_ << 1] ; int head[_] , CNT ;
IL void AddEdge(int u , int v) {
Edge[++ CNT] = (_Edge){v , head[u]} ; head[u] = CNT ; return ;
}
struct Target {
int a , b , c , d ;
IL Target() {a = 1 ; b = 0 ; c = 0 ; d = 0 ; }
IL Target(int s1,int s2,int s3,int s4) {a = s1 ; b = s2 ; c = s3 ; d = s4 ; }
} ;
IL Target operator + (Target A , Target B) {
Target C ;
C.a = 1ll * A.a * B.a % mod ; C.b = (1ll * B.a * A.b % mod + B.b) % mod ;
C.c = (1ll * A.a * B.c % mod + A.c) % mod ;
C.d = (1ll * A.b * B.c % mod + A.d + B.d) % mod ;
return C ;
}
struct Node {
int ls , rs ; Target tag ;
IL Node(){ls = rs = 0 ; tag = Target() ; return ; }
}t[_ * _] ;
IL int NewNode() {
if(stk[0]) {t[stk[stk[0]]] = Node() ; return stk[stk[0] --] ; }
else {t[++oo] = Node() ; return oo ; }
}
void PushDown(int o) {
if(!t[o].ls) t[o].ls = NewNode() ; if(!t[o].rs) t[o].rs = NewNode() ;
t[t[o].ls].tag = t[t[o].ls].tag + t[o].tag ;
t[t[o].rs].tag = t[t[o].rs].tag + t[o].tag ;
t[o].tag = Target() ;
return ;
}
void Insert(int &o , int l , int r , int ql , int qr , Target E) {
if(!o) o = NewNode() ; if(ql <= l && r <= qr) {t[o].tag = t[o].tag + E ; return ; }
int mid = (l + r) >> 1 ;
PushDown(o) ;
if(ql <= mid) Insert(t[o].ls , l , mid , ql , qr , E) ;
if(qr > mid) Insert(t[o].rs , mid + 1 , r , ql , qr , E) ;
return ;
}
int Merge(int o , int os) {
if(!o || !os) return o + os ;
if(!t[o].ls && !t[o].rs) swap(o , os) ;
if(!t[os].ls && !t[os].rs) {
t[o].tag.a = 1ll * t[o].tag.a * t[os].tag.b % mod ;
t[o].tag.b = 1ll * t[o].tag.b * t[os].tag.b % mod ;
t[o].tag.d = (t[o].tag.d + t[os].tag.d) % mod ;
stk[++stk[0]] = os ;
return o ;
}
PushDown(o) ; PushDown(os) ;
t[o].ls = Merge(t[o].ls , t[os].ls) ;
t[o].rs = Merge(t[o].rs , t[os].rs) ;
stk[++stk[0]] = os ;
return o ;
}
void Dfs(int u , int From , int x) {
Insert(rt[u] , 1 , W , 1 , val[u] , Target(1 , x , 0 , 0)) ;
if(val[u] + 1 <= W) Insert(rt[u] , 1 , W , val[u] + 1 , W , Target(1 , 1 , 0 , 0)) ;
for(int e = head[u] ; e ; e = Edge[e].next) {
int v = Edge[e].to ; if(v == From) continue ;
Dfs(v , u , x) ;
rt[u] = Merge(rt[u] , rt[v]) ;
}
t[rt[u]].tag = t[rt[u]].tag + Target(1 , 1 , 1 , 0) ; return ;
}
void GetAns(int o , int l , int r , int x) {
if(l == r) {H[x] = (H[x] + t[o].tag.d) % mod ; return ; }
PushDown(o) ;
int mid = (l + r) >> 1 ;
GetAns(t[o].ls , l , mid , x) ; GetAns(t[o].rs , mid + 1 , r , x) ;
return ;
}
IL void Solve(int x) {
oo = 0 ; stk[0] = 0 ; for(int i = 1; i <= n; i ++) rt[i] = 0 ;
Dfs(1 , 0 , x) ;
H[x] = 0 ; GetAns(rt[1] , 1 , W , x) ;
}
IL void Lagrange() {
inv[0] = 1 ; inv[1] = 1 ; for(int i = 2; i < mod; i ++) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod ;
fz[0] = 1 ;
for(int i = 1; i <= n + 1; i ++)
for(int j = n + 1; j >= 0; j --)
if(j) fz[j] = (fz[j - 1] + 1ll * fz[j] * (mod - i) % mod) % mod ; else fz[j] = 1ll * fz[j] * (mod - i) % mod ;
for(int i = 1; i <= n + 1; i ++) {
int coef = 1 ;
for(int j = 1; j <= n + 1; j ++) if(i != j) coef = 1ll * coef * (i + mod - j) % mod ;
for(int j = 0; j <= n + 1; j ++) f[j] = fz[j] ;
for(int j = 0; j <= n + 1; j ++) {
if(j) f[j] = (f[j] - f[j - 1] + mod) % mod ;
f[j] = 1ll * inv[mod - i] * f[j] % mod ;
}
coef = 1ll * inv[coef] * H[i] % mod ;
for(int j = 0; j <= n; j ++) ans[j] = (ans[j] + 1ll * coef * f[j] % mod) % mod ;
}
return ;
}
int main() {
n = gi() ; K = gi() ; W = gi() ;
for(int i = 1; i <= n; i ++) val[i] = gi() ;
for(int i = 1,u,v; i < n; i ++) u = gi() , v = gi() , AddEdge(u , v) , AddEdge(v , u) ;
Solve(1) ;
for(int i = 1; i <= n + 1; i ++) Solve(i) ;
//for(int i = 1; i <= n + 1; i ++) cout << "H("<<i<<") = " << H[i] << endl ;
Lagrange() ;
Ans = 0 ;
for(int j = K; j <= n; j ++) Ans = (Ans + ans[j]) % mod ;
cout << Ans << endl ;
return 0 ;
}
所以所谓的整体DP到底是啥啊,根本没看到什么虚树的影子啊?