【牛客7872 J】树上启发式合并
【牛客7872 J】树上启发式合并
题意
树上启发式合并,求有多少点对满足,这两个点x和y相互之间不是祖先和后代的关系
同时满足\(val[x]+val[y]=2 * val[ lca(x,y) ]\)
题解
根据两个点不能互为祖先的要求可知:
比较可行的方式是枚举这个作为lca的结点,对于一个作为lca的结点
什么样的结点会以它为lca呢,当然是以它的不同的儿子为根结点的子树中的结点
因此,统计答案的方式也比较巧妙,对于一个作为lca的结点u
- 首先遍历它的第一个儿子v1的那棵子树,用一个mp数组记录当前已经遍历过的结点中每个数出现的次数
遍历第1个儿子那棵子树时把mp维护好。 - 然后从第2个儿子开始,先对每一个结点v,获取到当前mp[2*val[u]-val[v]]的大小
这表示能和结点v一起组成符合条件的点对有多少。 - 这样查询完第2个儿子上所有节点后,再把第2个儿子子树上的所有结点的mp值维护好,依次循环这样一个过程
由于在做这个过程的时候必须保证mp值的准确,所以每次一个lca判断完后要清空该棵子树对mp值造成的影响。
那么考虑什么样的结点不用清空呢,那就是该结点作为父亲结点的最后一个儿子维护答案时不用清空。
那么我们怎样能使时间复杂度尽可能降低呢?那就是把所有儿子中最重的(子树大小最大的儿子)放在最后一个访问,这样就可以节省下清空它的时间复杂度,这就是启发式合并,运用最后一个儿子不需要清空的性质来降低时间复杂度。
Code
/****************************
* Author : W.A.R *
* Date : 2020-10-31-20:44 *
****************************/
/*
*/
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<map>
#include<unordered_map>
#include<stack>
#include<string>
#include<set>
#define mem(a,x) memset(a,x,sizeof(a))
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
const ll mod=1e9+7;
namespace Fast_IO{
const int MAXL((1 << 18) + 1);int iof, iotp;
char ioif[MAXL], *ioiS, *ioiT, ioof[MAXL],*iooS=ioof,*iooT=ioof+MAXL-1,ioc,iost[55];
char Getchar(){
if (ioiS == ioiT){
ioiS=ioif;ioiT=ioiS+fread(ioif,1,MAXL,stdin);return (ioiS == ioiT ? EOF : *ioiS++);
}else return (*ioiS++);
}
void Write(){fwrite(ioof,1,iooS-ioof,stdout);iooS=ioof;}
void Putchar(char x){*iooS++ = x;if (iooS == iooT)Write();}
inline int read(){
int x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
if(ioc==EOF)exit(0);
for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
}
inline long long read_ll(){
long long x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
if(ioc==EOF)exit(0);
for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
}
template <class Int>void Print(Int x, char ch = '\0'){
if(!x)Putchar('0');if(x<0)Putchar('-'),x=-x;while(x)iost[++iotp]=x%10+'0',x/=10;
while(iotp)Putchar(iost[iotp--]);if (ch)Putchar(ch);
}
void Getstr(char *s, int &l){
for(ioc=Getchar();ioc==' '||ioc=='\n'||ioc=='\t';)ioc=Getchar();
if(ioc==EOF)exit(0);
for(l=0;!(ioc==' '||ioc=='\n'||ioc=='\t'||ioc==EOF);ioc=Getchar())s[l++]=ioc;s[l] = 0;
}
void Putstr(const char *s){for(int i=0,n=strlen(s);i<n;++i)Putchar(s[i]);}
}
using namespace Fast_IO;
struct node{int to,nxt;}e[maxn];
int son[maxn],siz[maxn],cnt[maxn],head[maxn],val[maxn],ct;
ll ans;
unordered_map<int,int>mp;
void addE(int u,int v){e[++ct].to=v;e[ct].nxt=head[u];head[u]=ct;}
void dfs(int u,int fa){
siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(v==fa)continue;
dfs(v,u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void add(int u,int fa,int value){
mp[val[u]]+=value;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa)continue;
add(v,u,value);
}
}
void calc(int u,int fa,int lca){
ans+=mp[2*val[lca]-val[u]];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa)continue;
calc(v,u,lca);
}
}
void getAns(int u,int fa,bool heavy){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa||v==son[u])continue;
getAns(v,u,0);
}
if(son[u])getAns(son[u],u,1);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa||v==son[u])continue;
calc(v,u,u);
add(v,u,1);
}
mp[val[u]]++;
if(!heavy)add(u,fa,-1);
}
int main(){
int n=read();
for(int i=1;i<=n;i++)val[i]=read();
for(int i=1;i<n;i++){int u=read(),v=read();addE(u,v);addE(v,u);}
dfs(1,0);
getAns(1,0,0);
printf("%lld\n",ans<<1);
return 0;
}