题解 超级加倍

传送门

链上的部分分可以单调栈求出范围 \(l, r\) 后主席树维护
其实也可以求出后转化为三维偏序求解
题解说可以忽略一个条件,再减去算重的

然后正解

  • 与形如 经过点中的最大值/起点为全路径最大值 类似的问题,序列上可以考虑笛卡尔树,树上可以考虑kruskal重构树
  • 对点权建立kruskal重构树(lca为最大值):
    从小到大扫描点,扫描当前点的所有出边 \((u, v)\),仅当 \(v<u\) 时在 \(u\)\(find(v)\) 间连一条边

于是可以依据点权建两棵kruskal重构树
现在问题变为了求 \(x\)\(T1\) 中是 y 的祖先,\(y\)\(T2\) 中是 \(x\) 的祖先的点对 \((x, y)\) 数量
是个二维偏序问题,可以求出 \(T1\) 的dfs序,在 \(T2\) 上dfs
dfs时保证当前dfs到的点的每个祖先在其在 \(T1\) dfs序上的位置产生1的贡献
于是可以树状数组查询当前点子树内1的个数

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 2000010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int head[N], dep[N], mdep[N], deg[N], a[N], tot, size;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}

void dfs1(int u, int fa) {
	mdep[u]=dep[u];
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v!=fa) {
			dep[v]=dep[u]+1;
			dfs1(v, u);
			mdep[u]=max(mdep[u], mdep[v]);
		}
	}
}

void dfs3(int u, int fa) {
	a[++tot]=u;
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v!=fa) dfs3(v, u);
	}
}

namespace force{
	int ans;
	void dfs2(int u, int rot, int fa, int maxn) {
		// cout<<"dfs: "<<u<<' '<<rot<<' '<<maxn<<endl;
		if (maxn<u) ++ans; //, cout<<rot<<' '<<u<<endl;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa && v>rot) dfs2(v, rot, u, max(maxn, u));
		}
	}
	void solve() {
		for (int i=1; i<=n; ++i) dfs2(i, i, 0, i);
		printf("%d\n", ans);
		exit(0);
	}
}

namespace task1{
	ll ans;
	void solve() {
		int rot;
		for (int i=1; i<=n; ++i) if (deg[i]==n-1) {rot=i; break;}
		for (int i=1; i<=n; ++i) if (i!=rot) {
			++ans;
			if (i<rot) ans+=n-rot;
		}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task2{
	ll ans;
	int l[N], r[N];
	int q[N], ql, qr;
	int ls[N*10], rs[N*10], sum[N*10], rot[N], now;
	#define pushup(p) sum[p]=sum[ls[p]]+sum[rs[p]]
	void upd(int& p1, int p2, int tl, int tr, int pos, int val) {
		if (!p1) {p1=++now;}
		if (tl==tr) {sum[p1]=sum[p2]+val; return ;}
		int mid=(tl+tr)>>1;
		if (pos<=mid) upd(ls[p1], ls[p2], tl, mid, pos, val), rs[p1]=rs[p2];
		else upd(rs[p1], rs[p2], mid+1, tr, pos, val), ls[p1]=ls[p2];
		pushup(p1);
	}
	int query(int p1, int p2, int tl, int tr, int ql, int qr) {
		if (!p1) return 0;
		if (ql<=tl && qr>=tr) {return sum[p2]-sum[p1];}
		int mid=(tl+tr)>>1, ans=0;
		if (ql<=mid) ans+=query(ls[p1], ls[p2], tl, mid, ql, qr);
		if (qr>mid) ans+=query(rs[p1], rs[p2], mid+1, tr, ql, qr);
		return ans;
	}
	void solve() {
		ql=1; qr=0; a[n+1]=0;
		for (int i=1; i<=n+1; ++i) {
			while (ql<=qr && a[q[qr]]>a[i]) r[q[qr--]]=i-1;
			q[++qr]=i;
		}
		ql=1; qr=0; a[0]=n+1;
		for (int i=n; ~i; --i) {
			while (ql<=qr && a[q[qr]]<a[i]) l[q[qr--]]=i+1;
			q[++qr]=i;
		}
		// cout<<"l: "; for (int i=1; i<=n; ++i) cout<<l[i]<<' '; cout<<endl;
		// cout<<"r: "; for (int i=1; i<=n; ++i) cout<<r[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) upd(rot[i], rot[i-1], 1, n, l[i], 1);
		for (int i=1; i<=n; ++i) if (r[i]>i) ans+=query(rot[i], rot[r[i]], 1, n, 1, i);

		now=0;
		memset(l, 0, sizeof(l));
		memset(r, 0, sizeof(r));
		memset(ls, 0, sizeof(ls));
		memset(rs, 0, sizeof(rs));
		memset(sum, 0, sizeof(sum));
		memset(rot, 0, sizeof(rot));

		reverse(a+1, a+n+1);
		ql=1; qr=0; a[n+1]=0;
		for (int i=1; i<=n+1; ++i) {
			while (ql<=qr && a[q[qr]]>a[i]) r[q[qr--]]=i-1;
			q[++qr]=i;
		}
		ql=1; qr=0; a[0]=n+1;
		for (int i=n; ~i; --i) {
			while (ql<=qr && a[q[qr]]<a[i]) l[q[qr--]]=i+1;
			q[++qr]=i;
		}
		// cout<<"l: "; for (int i=1; i<=n; ++i) cout<<l[i]<<' '; cout<<endl;
		// cout<<"r: "; for (int i=1; i<=n; ++i) cout<<r[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) upd(rot[i], rot[i-1], 1, n, l[i], 1);
		for (int i=1; i<=n; ++i) if (r[i]>i) ans+=query(rot[i], rot[r[i]], 1, n, 1, i);

		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	ll ans;
	int dsu[N], id[N], siz[N], bit[N], now;
	inline void upd(int i, int val) {for (; i<=n; i+=i&-i) bit[i]+=val;}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	inline int find(int p) {return dsu[p]==p?p:dsu[p]=find(dsu[p]);}
	namespace tr1{
		int head[N], size;
		struct edge{int to, next;}e[N<<1];
		inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}
	}
	namespace tr2{
		int head[N], size;
		struct edge{int to, next;}e[N<<1];
		inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}
	}
	void dfs1(int u) {
		using namespace tr1;
		siz[u]=1;
		id[u]=++now;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dfs1(v);
			siz[u]+=siz[v];
		}
	}
	void dfs2(int u) {
		using namespace tr2;
		ans+=query(id[u]+siz[u]-1)-query(id[u]);
		upd(id[u], 1);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dfs2(v);
		}
		upd(id[u], -1);
	}
	void solve() {
		memset(tr1::head, -1, sizeof(tr1::head));
		memset(tr2::head, -1, sizeof(tr2::head));
		for (int i=1; i<=n; ++i) dsu[i]=i;
		for (int i=1; i<=n; ++i) {
			for (int j=head[i],v; ~j; j=e[j].next) {
				v = e[j].to;
				if (v<i && find(v)!=find(i)) {
					tr1::add(i, find(v));
					dsu[find(v)]=i;
				}
			}
		}
		dfs1(n);
		for (int i=1; i<=n; ++i) dsu[i]=i;
		for (int i=n; i; --i) {
			for (int j=head[i],v; ~j; j=e[j].next) {
				v = e[j].to;
				if (v>i && find(v)!=find(i)) {
					tr2::add(i, find(v));
					dsu[find(v)]=i;
				}
			}
		}
		dfs2(1);
		printf("%lld\n", ans);
		exit(0);
	}
}

signed main()
{
	freopen("charity.in", "r", stdin);
	freopen("charity.out", "w", stdout);

	n=read(); read();
	memset(head, -1, sizeof(head));
	for (int i=2,u; i<=n; ++i) {
		u=read();
		add(u, i); add(i, u);
		++deg[u]; ++deg[i];
	}
	#if 0
	if (n<=1000) force::solve();
	dep[1]=1; dfs1(1, 0);
	int odd=0, beg;
	for (int i=1; i<=n; ++i) {
		if (deg[i]==1) ++odd, beg=i;
		else if (deg[i]>2) odd=INF;
	}
	if (odd==2) {dfs3(beg, 0); task2::solve();}
	int mdeg=0;
	for (int i=1; i<=n; ++i) mdeg=max(mdeg, deg[i]);
	if (mdeg==n-1) task1::solve();
	force::solve();
	#else
	task::solve();
	#endif

	return 0;
}
posted @ 2021-11-11 06:37  Administrator-09  阅读(0)  评论(0编辑  收藏  举报