2534. 树上计数2
题目链接
2534. 树上计数2
给定一棵 \(N\) 个节点的树,节点编号从 \(1\) 到 \(N\),每个节点都有一个整数权值。
现在,我们要进行 \(M\) 次询问,格式为 u v
,对于每个询问你需要回答从 \(u\) 到 \(v\) 的路径上(包括两端点)共有多少种不同的点权值。
输入格式
第一行包含两个整数 \(N,M\)。
第二行包含 \(N\) 个整数,其中第 \(i\) 个整数表示点 \(i\) 的权值。
接下来 \(N-1\) 行,每行包含两个整数 \(x,y\),表示点 \(x\) 和点 \(y\) 之间存在一条边。
最后 \(M\) 行,每行包含两个整数 \(u,v\),表示一个询问。
输出格式
共 \(M\) 行,每行输出一个询问的答案。
数据范围
\(1 \le N \le 40000\),
\(1 \le M \le 10^5\),
\(1 \le x,y,u,v \le N\),
各点权值均在 \(int\) 范围内。
输入样例:
8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
3 8
输出样例:
4
4
解题思路
树上莫队,lca
先将整棵树的欧拉序求出来,记录每个点第一次出现的位置 \(first[i]\) 和最后一次出现的位置 \(last[i]\),然后观察树中的路径 \([l,r](first[l]<first[r])\) 可以发现两种情况:
- 如果路径是一条从上往下的直链,则其所有点对应欧拉序中 \(first[l]\) 到 \(first[r]\) 中出现一次的点
- 否则其所有点对应欧拉序中 \(first[l]\) 到 \(last[r]\) 中出现一次的点加上 \(lca(l,r)\)
理解一下会发现的确这样,然后问题就转化为普通莫队问题了
- 时间复杂度:\(O(n\sqrt{n})\)
代码
// Problem: 树上计数2
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/2536/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=1e5+5;
int n,m,cnt[N],st[N],len,w[N],ans[N];
int seq[N],first[N],last[N],top;
int f[N][20],d[N],t;
vector<int> nums,adj[N];
struct query
{
int id,l,r,p;
}q[N];
void dfs(int x,int fa)
{
seq[++top]=x;
first[x]=top;
for(int y:adj[x])
{
if(y==fa)continue;
dfs(y,x);
}
seq[++top]=x;
last[x]=top;
}
void bfs()
{
queue<int> q;
q.push(1);
memset(d,0x3f,sizeof d);
d[0]=0,d[1]=1;
while(q.size())
{
int x=q.front();
q.pop();
for(int y:adj[x])
{
if(d[y]>d[x]+1)
{
d[y]=d[x]+1;
f[y][0]=x;
for(int i=1;i<=t;i++)f[y][i]=f[f[y][i-1]][i-1];
q.push(y);
}
}
}
}
int lca(int x,int y)
{
if(d[x]<d[y])swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[x][i]]>=d[y])x=f[x][i];
if(x==y)return x;
for(int i=t;i>=0;i--)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
int get(int x)
{
return x/len;
}
void add(int x,int &res)
{
st[x]^=1;
if(st[x])
{
if(!cnt[w[x]])res++;
cnt[w[x]]++;
}
else
{
cnt[w[x]]--;
if(!cnt[w[x]])res--;
}
}
int main()
{
cin>>n>>m;
t=__lg(n);
for(int i=1;i<=n;i++)cin>>w[i],nums.pb(w[i]);
sort(nums.begin(),nums.end());
nums.erase(unique(nums.begin(),nums.end()),nums.end());
for(int i=1;i<=n;i++)w[i]=lower_bound(nums.begin(),nums.end(),w[i])-nums.begin();
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
adj[x].pb(y),adj[y].pb(x);
}
dfs(1,0);
bfs();
for(int i=1;i<=m;i++)
{
int l,r;
cin>>l>>r;
if(first[l]>first[r])swap(l,r);
int LCA=lca(l,r);
if(LCA==l)q[i]={i,first[l],first[r]};
else
q[i]={i,last[l],first[r],LCA};
}
len=sqrt(top);
sort(q+1,q+1+m,[](const query &a,const query &b){
int al=get(a.l),bl=get(b.l);
if(al!=bl)return al<bl;
return a.r<b.r;
});
for(int i=0,j=1,res=0,k=1;k<=m;k++)
{
int l=q[k].l,r=q[k].r,id=q[k].id,p=q[k].p;
while(i<r)add(seq[++i],res);
while(i>r)add(seq[i--],res);
while(j<l)add(seq[j++],res);
while(j>l)add(seq[--j],res);
if(p)add(p,res);
ans[id]=res;
if(p)add(p,res);
}
for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
return 0;
}