P7581-「RdOI R2」路径权值【长链剖分,dp】

正题

题目链接:https://www.luogu.com.cn/problem/P7581


题目大意

给出\(n\)个点的有边权有根树,\(m\)次询问一个节点\(x\)的所有\(k\)级儿子两两之间路径长度。

\(1\leq n,m\leq 10^6\)


解题思路

有根长剖,无根点分治。所以这题应该是长剖(?,先离线一下询问

然后略微分析一下,两两的路径长度所以需要合并两棵子树向上的路径,合并的时候又需要记录子树的\(k\)级儿子到该节点的距离和,还有\(k\)级儿子个数。

所以要记录三个东西,\(f_{i,j}\)表示\(i\)节点的\(j\)级儿子个数,\(g_{i,j}\)表示\(i\)节点的\(j\)级儿子到根的距离和,\(h_{i,j}\)表示\(i\)节点的\(j\)级儿子两两之间的路径。

然后这三个用长剖转移就好了。

时间复杂度\(O(n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
#define mp(x,y) make_pair(x,y)
#define lowbit(x) (x&-x)
using namespace std;
const ll N=1e6+10,P=1e9+7;
struct node{
	ll to,next,w;
}a[N<<1];
ll n,m,tot,p[N],ls[N],ans[N],len[N],son[N],dis[N],F[N],G[N],H[N];
ll *now,*tmp,*buf,*f[N],*g[N],*h[N];
vector<pair<ll,ll> >v[N];
void addl(ll x,ll y,ll w){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;a[tot].w=w;
	return;
}
void dfs(ll x,ll fa){
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		dis[y]=dis[x]+a[i].w;dfs(y,x);
		if(len[y]>len[son[x]])son[x]=y;
	}
	len[x]=len[son[x]]+1;
	return;
}
void solve(ll x,ll fa){
	f[x][0]=1;g[x][0]=dis[x];
	if(son[x]){
		f[son[x]]=f[x]+1;
		g[son[x]]=g[x]+1;
		h[son[x]]=h[x]+1;
		solve(son[x],x);
	}
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||y==son[x])continue;
		f[y]=now;now+=len[y];
		g[y]=tmp;tmp+=len[y];
		h[y]=buf;buf+=len[y];
		solve(y,x);
		for(ll j=0;j<len[y];j++){
			ll t1=(g[x][j+1]-f[x][j+1]*dis[x])%P;
			ll t2=(g[y][j]-f[y][j]*dis[x])%P;
			(h[x][j+1]+=t1*f[y][j]%P+f[x][j+1]*t2%P)%=P;
			(h[x][j+1]+=h[y][j])%=P;
			(g[x][j+1]+=g[y][j])%=P;
			f[x][j+1]+=f[y][j];
		}
	}
	for(ll i=0;i<v[x].size();i++){
		ll k=v[x][i].first,id=v[x][i].second;
		if(k>=len[x])ans[id]=0;else ans[id]=h[x][k];
	}
	return;
}
signed main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<n;i++){
		ll x,y,w;
		scanf("%lld%lld%lld",&x,&y,&w);
		addl(x,y,w);addl(y,x,w);
	}
	for(ll i=1;i<=m;i++){
		ll x,k;
		scanf("%lld%lld",&x,&k);
		v[x].push_back(mp(k,i));
	}
	dfs(1,1);
	now=f[1]=F;now+=len[1];
	tmp=g[1]=G;tmp+=len[1];
	buf=h[1]=H;buf+=len[1];
	solve(1,1);
	for(ll i=1;i<=m;i++)
		printf("%lld\n",(ans[i]+P)%P);
	return 0;
}
posted @ 2021-05-04 21:52  QuantAsk  阅读(60)  评论(0编辑  收藏  举报