【洛谷P5311】成都七中
题目
题目链接:https://www.luogu.com.cn/problem/P5311
给你一棵 \(n\) 个节点的树,每个节点有一种颜色,有 \(m\) 次查询操作。
查询操作给定参数 \(l\ r\ x\),需输出:
将树中编号在 \([l,r]\) 内的所有节点保留,\(x\) 所在连通块中颜色种类数。
每次查询操作独立。
\(n,m\leq 10^5\)。
思路
鬼能想到这道题是点分树啊。
点分树有一个性质:对于原树上的一个连通块,这个连通块一定存在一个点,使得点分树上这个点的子树内,包含了连通块内所有的点。
反证法。如果不存在这样的点,设连通块所有点在点分树内深度最小的点为 \(x\),那么必然存在另一个连通块内的点 \(y\) 不在 \(x\) 的子树内,那么原树从 \(x\) 到 \(y\) 的路径上,一定存在一个点,在点分树上的深度小于 \(x\) 的深度。矛盾。
那么可以把每一个询问对应到 \(x\) 所在连通块内点分树上深度最小的点。
然后对于一个点 \(x\),考虑求出所有对应到他的询问。可以遍历点分树子树内所有点,对于一个点 \(y\),求出原树中 \(x\) 到 \(y\) 路径上点的编号的最小值和最大值。分别记为 \(mn_y\) 和 \(mx_y\)。
然后对于一个询问 \(l,r\),满足 \(l\leq mn_y,r\geq mx_y\) 的不同颜色数。这个东西最暴力的做法是把颜色单独看作一维然后三维数点,算上点分树的复杂度是 \(O(n\log^3 n)\),无法接受。
把所有询问和点都扔到一起,按照 \(mn\)(询问是 \(l\))从大到小排序。然后依次枚举所有的点(询问),记录目前每个颜色的 \(mx\) 的最小值,遇到询问的时候就只需要查询每个颜色 \(mx\) 最小值 \(\leq r\) 的数量。树状数组维护即可。
时间复杂度 \(O(n\log^2n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=100010,Inf=1e9;
int n,m,Q,rt,tot,a[N],id[N],ans[N],dep[N],minn[N],head[N],siz[N],maxp[N],fat[N];
bool vis[N];
vector<int> qry[N];
struct edge
{
int next,to;
}e[N*2];
struct node
{
int l,r,id;
}b[N],c[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void findrt(int x,int fa,int sum)
{
siz[x]=1; maxp[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v] && v!=fa)
{
findrt(v,x,sum);
siz[x]+=siz[v]; maxp[x]=max(maxp[x],siz[v]);
}
}
maxp[x]=max(maxp[x],sum-siz[x]);
if (!rt || maxp[x]<maxp[rt]) rt=x;
}
void dfs1(int x,int sum)
{
vis[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
int s=(siz[v]>siz[x])?(sum-siz[x]):siz[v];
rt=0; findrt(v,x,s);
fat[rt]=x; dep[rt]=dep[x]+1;
dfs1(rt,s);
}
}
}
void dfs2(int x,int fa,int d,int mn,int mx)
{
c[++m]=(node){mn,mx,-a[x]};
for (int i=0;i<(int)qry[x].size();i++)
if (qry[x][i] && b[qry[x][i]].l<=mn && b[qry[x][i]].r>=mx)
c[++m]=b[qry[x][i]],qry[x][i]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (dep[v]>d && v!=fa)
dfs2(v,x,d,min(mn,v),max(mx,v));
}
}
bool cmp(node x,node y)
{
if (x.l!=y.l) return x.l>y.l;
return x.id<y.id;
}
bool cmp2(int x,int y)
{
return dep[x]<dep[y];
}
struct BIT
{
int c[N];
void add(int x,int v)
{
for (int i=x;i<=n;i+=i&-i)
c[i]+=v;
}
int query(int x)
{
int ans=0;
for (int i=x;i;i-=i&-i)
ans+=c[i];
return ans;
}
}bit;
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&Q);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
findrt(1,0,n);
dfs1(rt,n);
for (int i=1,x;i<=Q;i++)
{
scanf("%d%d%d",&b[i].l,&b[i].r,&x);
b[i].id=i;
qry[x].push_back(i);
}
for (int i=1;i<=n;i++) id[i]=i;
sort(id+1,id+1+n,cmp2);
memset(minn,0x3f3f3f3f,sizeof(minn));
for (int k=1;k<=n;k++)
{
int i=id[k]; m=0;
dfs2(i,0,dep[i],i,i);
sort(c+1,c+1+m,cmp);
for (int j=1;j<=m;j++)
if (c[j].id>0)
ans[c[j].id]=bit.query(c[j].r);
else if (c[j].r<minn[-c[j].id])
{
if (minn[-c[j].id]<Inf) bit.add(minn[-c[j].id],-1);
minn[-c[j].id]=c[j].r;
bit.add(c[j].r,1);
}
for (int j=0;j<=m;j++)
if (c[j].id<0 && minn[-c[j].id]<Inf)
{
bit.add(minn[-c[j].id],-1);
minn[-c[j].id]=Inf;
}
}
for (int i=1;i<=Q;i++)
cout<<ans[i]<<"\n";
return 0;
}