P8339-[AHOI2022]钥匙【虚树,扫描线】
正题
题目连接:https://www.luogu.com.cn/problem/P8339
题目大意
给出\(n\)个点的一棵树,每个点有钥匙或者宝箱,有不同的颜色。
\(m\)次询问,从\(x\)走到\(y\),走到钥匙时会拾取钥匙,走到宝箱时如果有同色的钥匙那么就会消耗一把钥匙打开宝箱,询问能打开多少个宝箱。
保证每一种颜色的钥匙不超过\(5\)把。
\(1\leq n\leq 5\times 10^5,1\leq m\leq 10^6\)
解题思路
先考虑同色的宝箱和钥匙都只有一个的情况,这是一个经典问题,假设分别为\(x,y\),那么删去\(x\leftrightarrow y\)的路径,\(x\)的联通块记为\(S\),\(y\)的联通块记为\(T\)。
如果询问节点起点在\(S\),终点在\(T\)就会产生贡献。
那么\(S\)和\(T\)要么两个都是子树,要么一个是子树,另一个是整棵树删去一个子树,也就是说它们都可以表示成\(dfs\)序上的一个或两个连续区间。
那么我们把两个区间视为一个二维平面上的正方形\(+1\),然后询问的视为查询一个点的值,实现方法就是把这些都离线下来用扫描线。
好现在考虑这一题,我们会发现一条路径上我们把单种颜色的拿出来,钥匙视为\((\),宝箱视为\()\),那么就是一个类似括号匹配的东西,每一对产生贡献的点都会满足中间是一个合法的括号序。
那么我们从这个性质入手,我们枚举所有颜色,把同色的点建一棵虚树,对于每个钥匙我们暴力扫全图,能找到很多个合法的贡献对\(x,y\),像上面的方法扫描线就好了。
实际上我们会发现这样枚举出来的贡献对其实是\(n\)个而不是\(5n\)个的。
时间复杂度:\(O((n+m)\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<stack>
#define mp(x,y) make_pair(x,y)
#define lowbit(x) (x&-x)
using namespace std;
const int N=5e5+10;
struct node{
int to,next;
}a[N<<1];
int n,m,tot,Top,cnt,ls[N],t[N],c[N],s[N],ans[N];
int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],rfn[N],ed[N];
vector<int> G[N],p[N];stack<int> cl;
vector<pair<int,int> >I[N],O[N],q[N];
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
bool cmp(int x,int y)
{return rfn[x]<rfn[y];}
void dfs(int x){
siz[x]=1;dep[x]=dep[fa[x]]+1;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa[x])continue;
fa[y]=x;dfs(y);siz[x]+=siz[y];
if(siz[y]>siz[son[x]])son[x]=y;
}
return;
}
void dfs2(int x){
dfn[++cnt]=x;rfn[x]=cnt;
if(son[x]){
top[son[x]]=top[x];
dfs2(son[x]);
}
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa[x]||y==son[x])continue;
top[y]=y;dfs2(y);
}
ed[x]=cnt;return;
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])
swap(x,y);
x=fa[top[x]];
}
return (dep[x]<dep[y])?x:y;
}
int getTop(int x,int y){
while(top[y]!=top[x])
if(fa[top[y]]==x)
return top[y];
else y=fa[top[y]];
return dfn[rfn[x]+1];
}
void addG(int x,int y){
G[x].push_back(y);
G[y].push_back(x);
cl.push(x);cl.push(y);
return;
}
void Clear(){
Top=0;
while(!cl.empty())
{G[cl.top()].clear();cl.pop();}
}
void Ins(int x){
if(!Top){s[++Top]=x;return;}
int lca=LCA(s[Top],x);
while(Top>1&&dep[s[Top-1]]>=dep[lca])
addG(s[Top-1],s[Top]),Top--;
if(dep[s[Top]]>dep[lca])
addG(lca,s[Top]),Top--;
if(s[Top]!=lca)s[++Top]=lca;
s[++Top]=x;return;
}
void Build(vector<int> &p){
sort(p.begin(),p.end(),cmp);
if(p[0]!=1)Ins(1);
for(int i=0;i<p.size();i++)Ins(p[i]);
while(Top>1)addG(s[Top-1],s[Top]),Top--;
}
void Sets(int x,int y){
int lca=LCA(x,y);
if(lca==x){
x=getTop(x,y);
I[1].push_back(mp(rfn[y],ed[y]));
O[rfn[x]].push_back(mp(rfn[y],ed[y]));
I[ed[x]+1].push_back(mp(rfn[y],ed[y]));
}
else if(lca==y){
y=getTop(y,x);
if(rfn[y]>1)I[rfn[x]].push_back(mp(1,rfn[y]-1));
if(ed[y]<n)I[rfn[x]].push_back(mp(ed[y]+1,n));
if(rfn[y]>1)O[ed[x]+1].push_back(mp(1,rfn[y]-1));
if(ed[y]<n)O[ed[x]+1].push_back(mp(ed[y]+1,n));
}
else{
I[rfn[x]].push_back(mp(rfn[y],ed[y]));
O[ed[x]+1].push_back(mp(rfn[y],ed[y]));
}
return;
}
void calc(int x,int fa,int k,int &from,int &_){
if(c[x]==-_){k++;}
if(c[x]==_){
k--;
if(!k){
Sets(from,x);
return;
}
}
for(int i=0;i<G[x].size();i++)
if(G[x][i]!=fa)calc(G[x][i],x,k,from,_);
}
void Change(int x,int val){
while(x<=n){
t[x]+=val;
x+=lowbit(x);
}
return;
}
int Ask(int x){
int ans=0;
while(x){
ans+=t[x];
x-=lowbit(x);
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1,t;i<=n;i++){
scanf("%d%d",&t,&c[i]);
p[c[i]].push_back(i);
if(t==1)c[i]=-c[i];
}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
addl(x,y);addl(y,x);
}
dfs(1);dfs2(1);
for(int _=1;_<=n;_++){
if(p[_].empty())continue;
Build(p[_]);
for(int i=0;i<p[_].size();i++)
if(c[p[_][i]]==-_)
calc(p[_][i],0,0,p[_][i],_);
Clear();
}
for(int i=1,x,y;i<=m;i++)
scanf("%d%d",&x,&y),q[rfn[x]].push_back(mp(rfn[y],i));
for(int i=1;i<=n;i++){
for(int j=0;j<I[i].size();j++)
Change(I[i][j].first,1),Change(I[i][j].second+1,-1);
for(int j=0;j<O[i].size();j++)
Change(O[i][j].first,-1),Change(O[i][j].second+1,1);
for(int j=0;j<q[i].size();j++)
ans[q[i][j].second]=Ask(q[i][j].first);
}
for(int i=1;i<=m;i++)
printf("%lld\n",ans[i]);
return 0;
}