bzoj3611:[Heoi2014]大工程
传送门
显然还是虚树,虚树之后树形dp
然后我没想到正确的树形dp,只想到一个错误的二次换根
写到一半发现错了,只能写了颗线段树+二次换根,线段树记的就是所有关键点到当前根的距离
写到一半又想到了正确的树形dp,然而真的不想再改了,感觉二次换根也能过:
就是每次对于当前走到的节点记一下它子树里深度最深的和深度最浅的关键点,以及深度和,然后考虑经过当前点的路径长度,更新就好了
然后还需要统计答案,所以还要记下当前点子树里的距离和,子树内的最短距离,子树内的最长距离
方程就不想写了
感谢出题人不卡我常数巨大的\(O(nlogn)\)
代码(二次换根):
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
void read(int &x) {
char ch; bool ok;
for(ok=0,ch=getchar(); !isdigit(ch); ch=getchar()) if(ch=='-') ok=1;
for(x=0; isdigit(ch); x=x*10+ch-'0',ch=getchar()); if(ok) x=-x;
}
#define rg register
const int maxn=1e6+10;
struct oo
{
int cnt,pre[maxn*2],nxt[maxn*2],h[maxn],v[maxn*2];
void add(int x,int y,int z)
{
pre[++cnt]=y,nxt[cnt]=h[x],h[x]=cnt,v[cnt]=z;
pre[++cnt]=x,nxt[cnt]=h[y],h[y]=cnt,v[cnt]=z;
}
}a,b;
struct segment{int l,r,mn,mx,size;long long sum,lazy;}s[maxn*4];
bool vis[maxn];long long ans;
int n,m,size[maxn],dep[maxn],dis[maxn],tot;
int st[maxn],k,dfn[maxn],id[maxn],nid[maxn],tmp,top,ans1,ans2,fa[maxn][20],w[maxn];
void dfs(int x)
{
dfn[x]=++tmp;
for(rg int i=1;i<20;i++)
{
if((1<<i)>dep[x])break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for(rg int i=a.h[x];i;i=a.nxt[i])
if(a.pre[i]!=fa[x][0])
{
dep[a.pre[i]]=dep[x]+1,
fa[a.pre[i]][0]=x;
dfs(a.pre[i]);
}
}
void update(int x)
{
s[x].sum=s[x<<1].sum+s[x<<1|1].sum;
s[x].mn=min(s[x<<1].mn,s[x<<1|1].mn);
s[x].mx=max(s[x<<1].mx,s[x<<1|1].mx);
s[x].size=s[x<<1].size+s[x<<1|1].size;
}
int lca(int x,int y)
{
if(dep[x]>dep[y])swap(x,y);
int poor=dep[y]-dep[x];
for(rg int i=19;i>=0;i--)if(poor&(1<<i))y=fa[y][i];
if(x==y)return x;
for(rg int i=19;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void pushdown(int x)
{
int ls=x<<1,rs=x<<1|1;
s[ls].sum+=s[ls].size*s[x].lazy;
s[rs].sum+=s[rs].size*s[x].lazy;
s[ls].mn+=s[x].lazy,s[ls].mx+=s[x].lazy;
s[rs].mn+=s[x].lazy,s[rs].mx+=s[x].lazy;
s[ls].lazy+=s[x].lazy,s[rs].lazy+=s[x].lazy;
s[x].lazy=0;
}
void change(int x,int l,int r,int v,int c)
{
if(l>r)return ;
if(l<=s[x].l&&r>=s[x].r)
{
if(c)s[x].size=1,s[x].sum=s[x].mn=s[x].mx=v;
else s[x].sum+=v*s[x].size,s[x].mn+=v,s[x].mx+=v,s[x].lazy+=v;
return ;
}
if(s[x].lazy)pushdown(x);
int mid=(s[x].l+s[x].r)>>1;
if(l<=mid)change(x<<1,l,r,v,c);
if(r>mid)change(x<<1|1,l,r,v,c);
update(x);
}
void dp(int x,int fa)
{
if(vis[x])change(1,id[x],id[x],dis[x],1);
for(rg int i=b.h[x];i;i=b.nxt[i])
if(b.pre[i]!=fa)dis[b.pre[i]]=dis[x]+b.v[i],dp(b.pre[i],x);
}
void get(int x,int l,int r)
{
if(l>r)return ;
if(l<=s[x].l&&r>=s[x].r)
{
ans+=s[x].sum;
ans1=min(ans1,s[x].mn);
ans2=max(ans2,s[x].mx);
return ;
}
int mid=(s[x].l+s[x].r)>>1;
if(s[x].lazy)pushdown(x);
if(l<=mid)get(x<<1,l,r);
if(r>mid)get(x<<1|1,l,r);
}
int len(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
void dp1(int x,int fa)
{
if(vis[x])get(1,1,id[x]-1),get(1,id[x]+1,tot);
for(rg int i=b.h[x];i;i=b.nxt[i])
if(b.pre[i]!=fa)
{
change(1,id[b.pre[i]],nid[b.pre[i]],-b.v[i],0);
change(1,1,id[b.pre[i]]-1,b.v[i],0),change(1,nid[b.pre[i]]+1,tot,b.v[i],0);
dp1(b.pre[i],x);
change(1,id[b.pre[i]],nid[b.pre[i]],b.v[i],0);
change(1,1,id[b.pre[i]]-1,-b.v[i],0),change(1,nid[b.pre[i]]+1,tot,-b.v[i],0);
}
}
void prepare(int x,int fa)
{
id[x]=++tot;//printf("%d %d\n",x,fa);
for(rg int i=b.h[x];i;i=b.nxt[i])
if(b.pre[i]!=fa)prepare(b.pre[i],x);
nid[x]=tot;
}
void build(int x,int l,int r)
{
s[x].l=l,s[x].r=r;int mid=(l+r)>>1;s[x].lazy=s[x].size=0,s[x].mx=-1e9;
if(l==r){s[x].mn=1e9,s[x].sum=0;return ;}
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
update(x);
}
void clear(int x,int fa)
{
for(rg int i=b.h[x];i;i=b.nxt[i])
if(b.pre[i]!=fa)clear(b.pre[i],x);
b.h[x]=0;
}
int main()
{
read(n);
for(rg int i=1,x,y;i<n;i++)read(x),read(y),a.add(x,y,0);
read(m),dfs(1);
for(rg int i=1;i<=m;i++)
{
read(k);ans=0,ans1=1e9,ans2=0;
for(rg int j=1;j<=k;j++)read(w[j]),vis[w[j]]=1;
sort(w+1,w+k+1,cmp);top=1;st[top]=1;
for(rg int j=1;j<=k;j++)
{
int e=0,o=0;
while(top&&lca(st[top],w[j])!=st[top])
{
if(e)b.add(e,st[top],len(e,st[top]));
e=st[top];top--;
}
if(e)o=lca(e,w[j]),b.add(e,o,len(e,o));
if(o&&st[top]!=o)st[++top]=o;
st[++top]=w[j];
}
while(top>1)
{
if(st[top]!=st[top-1])b.add(st[top],st[top-1],len(st[top],st[top-1]));
top--;
}
tot=0,prepare(1,0),build(1,1,tot);
dp(1,0);
dp1(1,0),printf("%lld %d %d\n",ans/2,ans1,ans2);
clear(1,0);
for(rg int j=1;j<=k;j++)vis[w[j]]=0;b.cnt=0;
}
}