题解 时代的眼泪

传送门

初看应该是个换根DP之类
但我是从分类讨论贡献入手的
先转化为每个点的贡献是通过这个点的出发点权值比这个点小的使者数
发现能对一个点产生贡献的点只有三种位置:子树内,到根节点链上,到根节点链上的点的其它子树内
于是令 \(val_i\)\(i\) 子树内权值比点 \(i\) 小的点的数量
发现第三类点的贡献可以先按val算,再在到根节点的链上维护树上前缀和减去多余的
第二类点的贡献可以直接dfs一遍树状数组算所有祖先的
然后对于计算子树内权值比点 \(i\) 小的点的数量的具体实现:
比较容易想到主席树或线段树合并,但OJ过慢常数过大
于是依然一遍dfs,记录dfs子树前和dfs子树后树状数组内小于 \(w_i\) 的点的数量的差值即为答案
亲测能快一倍

  • 对于每个点子树内值域上的一些信息的统计方法:
    常见的有利用dfs序连续挂到主席树上或者用线段树合并
    但还有一种常数极小的方法是对值域维护树状数组
    在进入子树之前记录原版本信息,dfs完子树后用新的信息减去旧的信息即为子树内信息

于是整体复杂度 \(O(nlogn)\)

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, q;
int head[N], size;
int w[N], uni[N], usize;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}

namespace force{
	int ans, s;
	void dfs(int u, int fa, int lim, int sum) {
		if (u==s) {ans+=sum; return ;}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs(v, u, lim, sum+(w[v]>lim));
		}
	}
	void solve() {
		for (int i=1,u; i<=q; ++i) {
			s=u=read(); ans=0;
			for (int j=1; j<=n; ++j) if (j!=u) dfs(j, 0, w[j], 0);
			printf("%d\n", ans);
		}
		exit(0);
	}
}

namespace task1{
	ll sum[N], val[N], dlt[N], chain[N], tot;
	int bit[N];
	inline void upd(int i, int dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	int dfs1(int u, int fa, int lim) {
		// cout<<"dfs1: "<<u<<' '<<w[u]<<' '<<lim<<endl;
		int ans=w[u]<lim;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) ans+=dfs1(v, u, lim);
		}
		return ans;
	}
	void dfs2(int u, int fa) {
		tot+=(val[u]=dfs1(u, fa, w[u]));
		if (u!=1) dlt[u]=dfs1(u, fa, w[fa]);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs2(v, u);
		}
	}
	void dfs3(int u, int fa, ll pre) {
		sum[u]=pre+dlt[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs3(v, u, sum[u]);
		}
	}
	void dfs4(int u, int fa, ll pre) {
		chain[u]=pre+query(w[u]-1)-val[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs4(v, u, chain[u]);
		}
	}
	void solve() {
		for (int i=1; i<=n; ++i) upd(w[i], 1);
		dfs2(1, 0); dfs3(1, 0, 0); dfs4(1, 0, 0);
		#if 0
		cout<<"tot: "<<tot<<endl;
		cout<<"sum: "<<sum[3]<<endl;
		cout<<"chain: "; for (int i=1; i<=n; ++i) cout<<chain[i]<<' '; cout<<endl;
		cout<<"val: "; for (int i=1; i<=n; ++i) cout<<val[i]<<' '; cout<<endl;
		cout<<dfs1(2, 0, w[2])<<endl;
		#endif
		for (int i=1,u; i<=q; ++i) {
			u=read();
			printf("%lld\n", tot-sum[u]+chain[u]);
		}
		exit(0);
	}
}

namespace task2{
	ll sum[N], val[N], dlt[N], chain[N], tot;
	int bit[N], rot[N], num;
	int dat[N*35], lson[N*35], rson[N*35];
	#define dat(p) dat[p]
	#define l(p) lson[p]
	#define r(p) rson[p]
	#define pushup(p) dat(p)=dat(l(p))+dat(r(p))
	inline void upd(int i, int dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	void upd(int& p, int tl, int tr, int pos, int val) {
		if (!p) p=++num;
		if (tl==tr) {dat(p)+=val; return ;}
		int mid=(tl+tr)>>1;
		if (pos<=mid) upd(l(p), tl, mid, pos, val);
		else upd(r(p), mid+1, tr, pos, val);
		pushup(p);
	}
	int query(int p, int tl, int tr, int ql, int qr) {
		if (!p) return 0;
		if (ql<=tl&&qr>=tr) return dat(p);
		int mid=(tl+tr)>>1, ans=0;
		if (ql<=mid && l(p)) ans+=query(l(p), tl, mid, ql, qr);
		if (qr>mid && r(p)) ans+=query(r(p), mid+1, tr, ql, qr);
		return ans;
	}
	void merge(int& p1, int p2, int tl, int tr) {
		if (!(p1&&p2)) {p1|=p2; return ;}
		if (tl==tr) {dat(p1)+=dat(p2); return ;}
		int mid=(tl+tr)>>1;
		merge(l(p1), l(p2), tl, mid);
		merge(r(p1), r(p2), mid+1, tr);
		pushup(p1);
	}
	#if 0
	int dfs1(int u, int fa, int lim) {
		// cout<<"dfs1: "<<u<<' '<<w[u]<<' '<<lim<<endl;
		int ans=w[u]<lim;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) ans+=dfs1(v, u, lim);
		}
		return ans;
	}
	#endif
	void dfs2(int u, int fa) {
		upd(rot[u], 1, n, w[u], 1);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {dfs2(v, u); merge(rot[u], rot[v], 1, n);}
		}
		tot+=(val[u]=query(rot[u], 1, n, 1, w[u]-1));
		if (u!=1) dlt[u]=query(rot[u], 1, n, 1, w[fa]-1);
	}
	void dfs3(int u, int fa, ll pre1, ll pre2) {
		sum[u]=pre1+dlt[u];
		chain[u]=pre2+query(w[u]-1)-val[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs3(v, u, sum[u], chain[u]);
		}
	}
	void solve() {
		// cout<<double(sizeof(head)*5+sizeof(e)+sizeof(sum)*4+sizeof(lson)*3)/1000/1000<<endl;
		for (int i=1; i<=n; ++i) upd(w[i], 1);
		dfs2(1, 0); dfs3(1, 0, 0, 0);
		#if 0
		cout<<"tot: "<<tot<<endl;
		cout<<"sum: "<<sum[3]<<endl;
		cout<<"chain: "; for (int i=1; i<=n; ++i) cout<<chain[i]<<' '; cout<<endl;
		cout<<"val: "; for (int i=1; i<=n; ++i) cout<<val[i]<<' '; cout<<endl;
		cout<<dfs1(2, 0, w[2])<<endl;
		#endif
		for (int i=1,u; i<=q; ++i) {
			u=read();
			printf("%lld\n", tot-sum[u]+chain[u]);
		}
		exit(0);
	}
}

signed main()
{
	freopen("tears.in", "r", stdin);
	freopen("tears.out", "w", stdout);

	n=read(); q=read();
	memset(head, -1, sizeof(head));
	for (int i=1; i<=n; ++i) w[i]=uni[i]=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	sort(uni+1, uni+n+1);
	usize=unique(uni+1, uni+n+1)-uni-1;
	for (int i=1; i<=n; ++i) w[i]=lower_bound(uni+1, uni+usize+1, w[i])-uni;
	// force::solve();
	task2::solve();
	
	return 0;
}
posted @ 2021-10-26 15:52  Administrator-09  阅读(0)  评论(0编辑  收藏  举报