【XSY2190】Alice and Bob VI 树形DP 树剖
题目描述
Alice和Bob正在一棵树上玩游戏。这棵树有\(n\)个结点,编号由\(1\)到\(n\)。他们一共玩\(q\)盘游戏。
在第\(i\)局游戏中,Alice从结点\(a_i\)出发,Bob从结点\(b_i\)出发。开始时,除了\(a_i\)和\(b_i\)这两个结点外,所有结点都没有染色。结点\(a_i\)被Alice染色,结点\(b_i\)被Bob染色。
接下来,两位玩家轮流移动,两位玩家移动步数之和为\(k_i\)步。Alice走第一步,Bob走第二步,Alice走第三步\(\cdots\)在每一步中,玩家可以移动到相邻的结点并把该结点染色。注意一个结点可以被多次染色:在任意时刻,每个被染过色的结点的颜色为最后到达过该结点的玩家染的颜色。
记游戏结束时Alice染色的结点数为\(A\) , Bob染色的结点数为\(B\) 。Alice 想要\((A - B)\)尽量大,Bob 想要\((A - B)\)尽量小。如果两个玩家都以最优策略玩的话,我们想知道最后的\((A - B)\)值是多少。
\(n,q\leq 20000\)
题解
设两人之间距离为\(d\)。
两个人可以相遇或不相遇。
相遇:
\(k\)奇\(d\)偶:\(1\)
\(k\)奇\(d\)奇:\(2\)
\(k\)偶\(d\)偶:\(-1\)
\(k\)偶\(d\)奇:\(0\)
不相遇:
\(k\)奇:1
\(k\)偶:0
如果\(k\)奇,那么对\(Bob\)来说不相遇更优,他就需要逃跑。
如果\(k\)偶,那么\(Alice\)就要逃跑。
判断逃跑是否成功就是看这个人在不经过另一个人能到达的点的情况下能走多少步。
可以发现逃跑路径的终点一定是直径的一个端点。
可以通过树形DP求出子树内直径和子树外直径。
时间复杂度:\(O(n+q\log n)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
vector<int> g[200010];
int n,q;
struct node
{
int f,t,w,d,s,ms;
};
node a[200010];
int ti;
int w[200010];
void dfs(int x,int fa,int dep)
{
a[x].f=fa;
a[x].d=dep;
a[x].s=1;
int s=0;
for(vector<int>::iterator p=g[x].begin();p!=g[x].end();p++)
{
int v=*p;
if(v!=fa)
{
dfs(v,x,dep+1);
a[x].s+=a[v].s;
if(a[v].s>s)
{
s=a[v].s;
a[x].ms=v;
}
}
}
}
void dfs2(int x,int top)
{
a[x].t=top;
a[x].w=++ti;
w[ti]=x;
if(a[x].ms)
dfs2(a[x].ms,top);
for(vector<int>::iterator p=g[x].begin();p!=g[x].end();p++)
{
int v=*p;
if(v!=a[x].f&&v!=a[x].ms)
dfs2(v,v);
}
}
int getlca(int x,int y)
{
while(a[x].t!=a[y].t)
if(a[a[x].t].d>a[a[y].t].d)
x=a[a[x].t].f;
else
y=a[a[y].t].f;
return a[x].d<a[y].d?x:y;
}
int jump(int x,int d)
{
while(a[x].w-a[a[x].t].w<d)
{
d-=a[x].w-a[a[x].t].w+1;
x=a[a[x].t].f;
}
return w[a[x].w-d];
}
struct p1{int d,x;p1(int a=0,int b=0):x(a),d(b){}};
int operator <(p1 a,p1 b){return a.d<b.d;}
int operator >(p1 a,p1 b){return a.d>b.d;}
p1 operator +(p1 a,int b){a.d+=b;return a;}
struct p2{int d,x,y;p2(int a=0,int b=0,int c=0):x(a),y(b),d(c){}};
int operator <(p2 a,p2 b){return a.d<b.d;}
int operator >(p2 a,p2 b){return a.d>b.d;}
p2 operator +(p1 a,p1 b){p2 c;c.d=a.d+b.d;c.x=a.x;c.y=b.x;return c;}
p2 operator +(p2 a,int b){a.d+=b;return a;}
p1 f1[200010];
p2 f2[200010];
p1 fir[200010];
p1 sec[200010];
p1 thi[200010];
p2 fir1[200010];
p2 sec1[200010];
p1 g1[200010];
p2 g2[200010];
void dfs3(int x)
{
f1[x]=p1(x,1);
f2[x]=p2(x,x,1);
fir[x].d=sec[x].d=thi[x].d=fir1[x].d=sec1[x].d=-1;
for(vector<int>::iterator p=g[x].begin();p!=g[x].end();p++)
{
int v=*p;
if(v!=a[x].f)
{
dfs3(v);
f2[x]=max(f2[x],f1[x]+f1[v]);
f2[x]=max(f2[x],f2[v]);
f1[x]=max(f1[x],f1[v]+1);
if(f1[v]>fir[x])
{
thi[x]=sec[x];
sec[x]=fir[x];
fir[x]=f1[v];
}
else if(f1[v]>sec[x])
{
thi[x]=sec[x];
sec[x]=f1[v];
}
else if(f1[v]>thi[x])
thi[x]=f1[v];
if(f2[v]>fir1[x])
{
sec1[x]=fir1[x];;
fir1[x]=f2[v];
}
else if(f2[v]>sec1[x])
sec1[x]=f2[v];
}
}
}
void dfs4(int x)
{
for(vector<int>::iterator p=g[x].begin();p!=g[x].end();p++)
{
int v=*p;
if(v!=a[x].f)
{
g1[v]=p1(x,1);
g2[v]=p2(x,x,1);
if(f1[v].x==fir[x].x)
{
g2[v]=max(g2[v],max(max(sec[x]+thi[x],sec[x]+g1[x]),thi[x]+g1[x])+1);
g2[v]=max(g2[v],max(sec[x],g1[x])+p1(x,1));
g1[v]=max(g1[v],max(sec[x],g1[x])+1);
}
else if(f1[v].x==sec[x].x)
{
g2[v]=max(g2[v],max(max(fir[x]+thi[x],fir[x]+g1[x]),thi[x]+g1[x])+1);
g2[v]=max(g2[v],max(fir[x],g1[x])+p1(x,1));
g1[v]=max(g1[v],max(fir[x],g1[x])+1);
}
else
{
g2[v]=max(g2[v],max(max(fir[x]+sec[x],fir[x]+g1[x]),sec[x]+g1[x])+1);
g2[v]=max(g2[v],max(fir[x],g1[x])+p1(x,1));
g1[v]=max(g1[v],max(fir[x],g1[x])+1);
}
if(f2[v].x==fir1[x].x)
g2[v]=max(g2[v],max(sec1[x],g2[x]));
else
g2[v]=max(g2[v],max(fir1[x],g2[x]));
dfs4(v);
}
}
}
int getdist(int x,int y)
{
return a[x].d+a[y].d-2*a[getlca(x,y)].d;
}
int gao(int x,int y,int lca,int d)
{
int dist=a[x].d+a[y].d-2*a[lca].d;
if(dist<=d)
return -1;
if(d>=a[x].d-a[lca].d)
{
int z=jump(y,dist-d-1);
int x1=f2[z].x;
int x2=f2[z].y;
return max(getdist(y,x1),getdist(y,x2));
}
else
{
int z=jump(x,d);
int x1=g2[z].x;
int x2=g2[z].y;
return max(getdist(y,x1),getdist(y,x2));
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
#endif
scanf("%d%d",&n,&q);
int i,x,y;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1,0,1);
dfs2(1,1);
dfs3(1);
g1[1]=p1(1,-1);
g2[1]=p2(1,1,-1);
dfs4(1);
int k;
for(i=1;i<=q;i++)
{
scanf("%d%d%d",&x,&y,&k);
int lca=getlca(x,y);
int d=a[x].d+a[y].d-2*a[lca].d;
if(k&1)
{
if(gao(x,y,lca,(k+1)/2)>=k/2)
printf("1\n");
else if(d&1)
printf("2\n");
else
printf("1\n");
}
else
{
if(gao(y,x,lca,k/2)>=(k+1)/2)
printf("0\n");
else if(d&1)
printf("0\n");
else
printf("-1\n");
}
}
return 0;
}