2534. 树上计数2

题目链接

2534. 树上计数2

给定一棵 \(N\) 个节点的树,节点编号从 \(1\)\(N\),每个节点都有一个整数权值。

现在,我们要进行 \(M\) 次询问,格式为 u v,对于每个询问你需要回答从 \(u\)\(v\) 的路径上(包括两端点)共有多少种不同的点权值。

输入格式

第一行包含两个整数 \(N,M\)

第二行包含 \(N\) 个整数,其中第 \(i\) 个整数表示点 \(i\) 的权值。

接下来 \(N-1\) 行,每行包含两个整数 \(x,y\),表示点 \(x\) 和点 \(y\) 之间存在一条边。

最后 \(M\) 行,每行包含两个整数 \(u,v\),表示一个询问。

输出格式

\(M\) 行,每行输出一个询问的答案。

数据范围

\(1 \le N \le 40000\),
\(1 \le M \le 10^5\),
\(1 \le x,y,u,v \le N\),
各点权值均在 \(int\) 范围内。

输入样例:

8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
3 8

输出样例:

4
4

解题思路

树上莫队,lca

先将整棵树的欧拉序求出来,记录每个点第一次出现的位置 \(first[i]\) 和最后一次出现的位置 \(last[i]\),然后观察树中的路径 \([l,r](first[l]<first[r])\) 可以发现两种情况:

  1. 如果路径是一条从上往下的直链,则其所有点对应欧拉序中 \(first[l]\)\(first[r]\) 中出现一次的点
  2. 否则其所有点对应欧拉序中 \(first[l]\)\(last[r]\) 中出现一次的点加上 \(lca(l,r)\)

理解一下会发现的确这样,然后问题就转化为普通莫队问题了

  • 时间复杂度:\(O(n\sqrt{n})\)

代码

// Problem: 树上计数2
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/2536/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

const int N=1e5+5;
int n,m,cnt[N],st[N],len,w[N],ans[N];
int seq[N],first[N],last[N],top;
int f[N][20],d[N],t;
vector<int> nums,adj[N];
struct query
{
	int id,l,r,p;
}q[N];
void dfs(int x,int fa)
{
	seq[++top]=x;
	first[x]=top;
	for(int y:adj[x])
	{
		if(y==fa)continue;
		dfs(y,x);
	}
	seq[++top]=x;
	last[x]=top;
}
void bfs()
{
	queue<int> q;
	q.push(1);
	memset(d,0x3f,sizeof d);
	d[0]=0,d[1]=1;
	while(q.size())
	{
		int x=q.front();
		q.pop();
		for(int y:adj[x])
		{
			if(d[y]>d[x]+1)
			{
				d[y]=d[x]+1;
				f[y][0]=x;
				for(int i=1;i<=t;i++)f[y][i]=f[f[y][i-1]][i-1];
				q.push(y);	
			}
		}
	}
}
int lca(int x,int y)
{
	if(d[x]<d[y])swap(x,y);
	for(int i=t;i>=0;i--)
		if(d[f[x][i]]>=d[y])x=f[x][i];
	if(x==y)return x;
	for(int i=t;i>=0;i--)
		if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
	return f[x][0];
}

int get(int x)
{
	return x/len;
}
void add(int x,int &res)
{
	st[x]^=1;
	if(st[x])
	{
		if(!cnt[w[x]])res++;
		cnt[w[x]]++;
	}
	else
	{
		cnt[w[x]]--;
		if(!cnt[w[x]])res--;
	}
}
int main()
{
    cin>>n>>m;
    t=__lg(n);
    for(int i=1;i<=n;i++)cin>>w[i],nums.pb(w[i]);
    sort(nums.begin(),nums.end());
    nums.erase(unique(nums.begin(),nums.end()),nums.end());
    for(int i=1;i<=n;i++)w[i]=lower_bound(nums.begin(),nums.end(),w[i])-nums.begin();
    for(int i=1;i<n;i++)
    {
    	int x,y;
    	cin>>x>>y;
    	adj[x].pb(y),adj[y].pb(x);
    }
    dfs(1,0);
    bfs();
    for(int i=1;i<=m;i++)
    {
    	int l,r;
    	cin>>l>>r;
    	if(first[l]>first[r])swap(l,r);
    	int LCA=lca(l,r);
    	if(LCA==l)q[i]={i,first[l],first[r]};
    	else
    		q[i]={i,last[l],first[r],LCA};
    }
    len=sqrt(top);
    sort(q+1,q+1+m,[](const query &a,const query &b){
    	int al=get(a.l),bl=get(b.l);
    	if(al!=bl)return al<bl;
    	return a.r<b.r;
    });
    for(int i=0,j=1,res=0,k=1;k<=m;k++)
    {
    	int l=q[k].l,r=q[k].r,id=q[k].id,p=q[k].p;
    	while(i<r)add(seq[++i],res);
    	while(i>r)add(seq[i--],res);
    	while(j<l)add(seq[j++],res);
    	while(j>l)add(seq[--j],res);
    	if(p)add(p,res);
    	ans[id]=res;
    	if(p)add(p,res);
    }
    for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
    return 0;
}
posted @ 2022-10-08 20:37  zyy2001  阅读(37)  评论(0编辑  收藏  举报