[HEOI2018] 秘密袭击coat

Description

给定一棵 \(n\) 个点的树,每个点有点权 \(d_i\) ,请对于树上所有大于等于 \(k\) 个点的联通块,求出联通块中第 \(k\) 大的点权之和。\(n\le 1666,d_i\leq 1666\)。对 \(64123\) 取模。

Sol

先转化一下题目:

如果权值为 \(i\) 的点在某个联通块中是第 \(k\) 大,那么它对该联通块的贡献就是 \(i\),不妨对于 \(1\sim i\),统计一下有多少联通块的第 \(k\) 大是 \(\ge i\) 的,发现对于每个联通块,设它的第 \(k\) 大为 \(j\),那么这个 \(j\) 被统计了 \(j\) 遍,惊奇的发现这就是它本应有的贡献。这就等价于 求有多少个联通块,使得 \(\ge i\) 的数有 \(\ge k\) 个。

所以问题变成了,对于每个权值 \(i\),求有多少联通块有 \(\ge k\)\(\ge i\) 的数。

可以设计一个\(\text{DP}\)\(f[i][j][k]\) 表示点 \(i\) 为根的子树中,有 \(k\)\(\ge j\) 的包含点 \(i\) 的联通块数。

那么转移就是

\[\begin{align} f[i][j][k] &= \prod_{v \in son[i]} (f[v][j][k']+1) \ \ \ \ (d[i]<j,\sum k'=k)\\ f[i][j][k] &= \prod_{v \in son[i]} (f[v][j][k']+1) \ \ \ \ (d[i] \geqslant j,\sum k'=k-1) \end{align} \]

树上背包可以优化为 \(O(n^3)\)卡卡常即可通过。

考虑生成函数。

外层先枚举一个权值 \(j\) ,设 \(F_a(x)=\sum\limits_{i=0}^n f[a][j][i]\cdot x^i\) ,这是一个 \(n\) 次多项式。

求答案时可以令 \(G_a(x)\) 表示 \(a\) 子树中所有点的 \(F(x)\) 的和,然后求 \(G_{root}(x)\) 的第 \(k,\dots,n\) 项和。

那么转移就变成了

\[F_a(x)=\left(\prod_{b\in son_a}(1+F_b(x))\right)*\begin{cases}1 & d_a \ge i\\x&d_a\lt i\end{cases} \]

如果我们给定 \(x_0\),求出 \(F_a(x_0),G_a(x_0)\),那么可以直接 \(O(n)\;\text{DP}\)

那我们可以枚举 \(x_0=1\sim n+1\),每次\(\text{DP}\)一下求出点值,最后再用拉格朗日插值求答案不就好了。然而这样并没有跑的更快。

算一下现在的复杂度,外层枚举权值 \(O(n)\),枚举 \(x_0\) \(O(n)\),树形\(\text{DP}\;O(n)\),总复杂度 \(O(n^3)\)

写一下伪代码:

DP(now,i,x0)
	(f,g)=(1,0)
	for to in son(now)
		(f0,g0)=DP(to,i,x0)
		(f,g)=(f*(1+f0),g+g0)
	if(d[now]>=i)
		(f,g)=(f*x0,g)
	(f,g)=(f,g+f)
	return (f,g)

可不可以把枚举权值这个复杂度优化一下呢?

有个非常牛逼的科技叫整体\(\text{DP}\)。大概意思就是把许多次\(\text{DP}\)放在一起做。

一般用线段树维护每个询问,线段树的第 \(i\) 个叶子结点存储的值就是第 \(i\) 个询问的答案,在合并的时候使用线段树合并,更新一些\(\text{DP}\)值。

当然如果每个节点的线段树都是一颗满线段树的话复杂度显然不对,所以有个优化:如果线段树上一个节点 \(x\) 的子树中的所有询问的答案都一样,那么只需要保留 \(x\) 这个节点即可。

那回到这道题,就可以把枚举 \(W\) 个权值当做 \(W\) 次询问,然后就能用整体\(\text{DP}\)维护了。

具体来说,现在外层只需要枚举 \(x_0\),那么在树形\(\text{DP}\)到点 \(i\) 的时候,点 \(i\) 的线段树中第 \(j\) 个叶子结点存储的值 \(v_1,v_2\) 的含义就是,当 \(x=x_0\) 时,\(F_i(x)\) 的点值为 \(v_1\)\(G_i(x)\) 的点值为 \(v_2\)

那我们看一下伪代码中的每个操作都对应着线段树的什么操作:

  • (f,g)=(1,0) 整体赋值
  • (f,g)=(f*(1+f0),g+g0) 对应项合并
  • if d[a]>=i / (f,g)=(f*x0,g)\(1\sim d[a]\) 项整体打标记
  • (f,g)=(f,g+f) 整体打标记

所以问题就变成了,如何在线段树上维护好标记。

考虑我们需要做什么:

  1. 维护 (f,g)
  2. f整体加 \(1\)
  3. f乘上f0
  4. f加到g

因为对应项相乘并不好做,所以我们考虑定义一个类似于矩阵乘法一样的变换,\((a,b,c,d)\) 表示当前节点维护的(f,g)=(a+b*f,g+c+d*f) 为什么这么定义大概是xjb凑出来的?

然后变换的乘法就可以根据定义轻松推出来了懒得写了

单位变换就是 \((0,1,0,0)\),任何变换乘上该变换还为本身。

而每个点的(f,g)实际上就是维护出来的 \((b,d)\)\((a,c)\) 的存在大概是维护标记的需要?

于是有了这个就能求出 \(1\sim n+1\) 的点值来了。

最后一步,就是插值,求出原多项式的系数了。

这一步可以多项式快速插值实现,但是太难写,而且复杂度瓶颈不在这里。直接拉格朗日插值就好。

然而怎么求出来每项的系数呢?

拉格朗日的式子长这样:\(\sum\limits_{i=1}^{n+1} y_i\left(\prod_{j\ne i}\frac{(x-x_j)}{(x_i-x_j)} \right)\)

观察到分子之间的差别很小,可以提前背包算出来 \(\prod_j (x-x_j)\),转移是枚举当前选前面的 \(x\) 还是后边的 \(-x_j\)\(f[i]=f[i-1]-f[i]*x_j\),分母可以预处理逆元求出来。

那现在问题就只剩下了,分子多乘了一个 \(x-x_i\),我们要把这个退背包回去。

实际上也很简单,因为 \(f[j]=f'[j-1]-f'[j]*x_i\),我们实际上要求的是 \(f'[j]\),那把 \(j\) 从小到大枚举,然后移项一下就行了。

那求出来每项的系数之后,第 \(k\sim n\) 项的和就是答案了。

Code

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned int ui;
const int N=1670;
const int M=N*200;
const ui mod=64123;
#define ls ch[x][0]
#define rs ch[x][1]

int head[N],tot,dp[M],d[N];
ui ans[N],f[N],in[N],a[N],b[N];
int n,k,W,cnt,dc,rt[N],ch[M][2];

ui inc(ui x,ui y){
	return x+y>=mod?x+y-mod:x+y;
}

struct Node{
	ui a,b,c,d;
	Node(){}
	Node(ui aa,ui bb,ui cc,ui dd){a=aa,b=bb,c=cc,d=dd;}
	friend Node operator+(Node x,Node y){
		return Node(inc(y.a,x.a*y.b%mod),x.b*y.b%mod,inc(x.c+y.c,x.a*y.d%mod),inc(x.d,x.b*y.d%mod));
	}
}sum[M];

struct Edge{
	int to,nxt;
}edge[N<<1];

void add(int x,int y){
	edge[++cnt].to=y;
	edge[cnt].nxt=head[x];
	head[x]=cnt;
}

int newnode(){
	int t=dc?dp[dc--]:++tot;
	return sum[t]=Node(0,1,0,0),t;
}

void del(int x){
	if(!x) return;
	dp[++dc]=x;
	del(ls),del(rs);
	ls=rs=0;
}

void pushdown(int x){
	if(!ls) ls=newnode();
	if(!rs) rs=newnode();
	sum[ls]=sum[ls]+sum[x];
	sum[rs]=sum[rs]+sum[x];
	sum[x]=Node(0,1,0,0);
}

void merge(int &x,int &y){
	if(!ch[x][0] and !ch[x][1]) 
		swap(x,y);
	if(!ch[y][0] and !ch[y][1])
		 return sum[x]=sum[x]+Node(0,sum[y].a,sum[y].c,0),void();
	pushdown(x),pushdown(y);
	merge(ch[x][0],ch[y][0]),merge(ch[x][1],ch[y][1]);
}

void modify(int x,int l,int r,int ql,int qr,Node p){
	if(ql<=l and r<=qr) return sum[x]=sum[x]+p,void();
	int mid=l+r>>1; pushdown(x);
	if(ql<=mid) modify(ls,l,mid,ql,qr,p);
	if(mid<qr) modify(rs,mid+1,r,ql,qr,p);
}

void dfs(int now,int x0,int fa=0){
	rt[now]=newnode();sum[rt[now]]=Node(1,0,0,0);
	for(int i=head[now];i;i=edge[i].nxt){
		int to=edge[i].to;
		if(to==fa) continue;
		dfs(to,x0,now);
		merge(rt[now],rt[to]);
		del(rt[to]);
	}
	modify(rt[now],1,W,1,d[now],Node(0,x0,0,0));
	modify(rt[now],1,W,1,W,Node(0,1,0,1));
	modify(rt[now],1,W,1,W,Node(1,1,0,0));
}

ui query(int x,int l,int r){
	if(l==r) return sum[x].c;
	int mid=l+r>>1; pushdown(x);
	return (query(ls,l,mid)+query(rs,mid+1,r))%mod;
}

ui Lagrange(){
	in[1]=1; f[0]=1;
	for(int i=2;i<=n+1;i++)
		in[i]=(mod-mod/i)*in[mod%i]%mod;
	for(int i=1;i<=n+1;i++)
		for(int j=n+1;~j;j--){
			if(j) f[j]=inc(f[j-1],f[j]*(mod-i)%mod);
			else f[j]=f[j]*(mod-i)%mod;
		}
	for(int i=1;i<=n+1;i++){
		ui res=ans[i]; memcpy(b,f,sizeof 4*(n+2));
		for(int j=1;j<=n+1;j++){
			if(i==j) continue;
			if(i>j) res=res*in[i-j]%mod;
			else res=res*(mod-in[j-i])%mod;
		}
		for(int j=0;j<=n+1;j++){
			if(!j) b[j]=mod-b[j]*in[i]%mod;
			else b[j]=inc(mod,b[j-1]-b[j])*in[i]%mod;
		}
		for(int j=0;j<=n;j++)
			a[j]=inc(a[j],b[j]*res%mod);
	} ui ans=0;
	for(int i=k;i<=n;i++) ans=inc(ans,a[i]);
	return ans;
}

signed main(){
	scanf("%d%d%d",&n,&k,&W);
	for(int i=1;i<=n;i++) 
		scanf("%d",&d[i]);
	for(int x,y,i=1;i<n;i++)
		scanf("%d%d",&x,&y),add(x,y),add(y,x);
	for(int i=1;i<=n+1;i++){
		dfs(1,i);
		ans[i]=query(rt[1],1,W);
		del(rt[1]);
	} printf("%u\n",Lagrange()); return 0;
}
posted @ 2019-02-20 16:56  YoungNeal  阅读(532)  评论(0编辑  收藏  举报