BZOJ2286:[SDOI2011]消耗战
虚树模板题
建树的时候记录最小边就可以了
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define ll long long
const int maxn=2.5e5+5;
const ll inf=2e15;
ll w[maxn],f[maxn];
int n,tot,m,top,cnt;
int que[maxn],dfn[maxn],sta[maxn];
int dep[maxn],bin[maxn],Log[maxn*2],st[maxn*2][20];
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];
int read() {
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
return x*f;
}
void add(int a,int b,int v) {
pre[++tot]=now[a];
now[a]=tot,son[tot]=b;
val[tot]=v;
}
void dfs(int fa,int u) {
dep[u]=dep[fa]+1;
dfn[u]=++tot;st[tot][0]=u;
for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
if(v!=fa) {
w[v]=min(w[u],1ll*val[p]);
dfs(u,v);
st[++tot][0]=u;
}
}
int cala(int u,int v) {
if(dep[u]<dep[v])return u;
else return v;
}
void prepare_st() {
bin[0]=1,Log[1]=0;
for(int i=1;i<=20;i++)
bin[i]=bin[i-1]*2;
for(int i=1;i<=tot/2;i++)
Log[i<<1]=Log[i<<1|1]=Log[i]+1;
for(int j=1;j<=20;j++)
for(int i=1;i<=tot-bin[j]+1;i++)
st[i][j]=cala(st[i][j-1],st[i+bin[j-1]][j-1]);
}
int lca(int u,int v) {
if(dfn[u]>dfn[v])swap(u,v);
int len=dfn[v]-dfn[u]+1,lg=Log[len];
return cala(st[dfn[u]][lg],st[dfn[v]-bin[lg]+1][lg]);
}
void dp(int u) {
ll res=0;
for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) {
dp(v);
res+=f[v];
}
if(!res)f[u]=w[u];
else f[u]=min(res,w[u]);
now[u]=0;
}
bool cmp(int a,int b) {
return dfn[a]<dfn[b];
}
void print(int u) {
printf("%d\n",u);
for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
printf("son%d:",u),print(v);
}
void solve() {
int k=read();
for(int i=1;i<=k;i++)
que[i]=read();
sort(que+1,que+k+1,cmp);
top=cnt=0;que[++cnt]=que[1];
for(int i=2;i<=k;i++)
if(lca(que[i],que[cnt])!=que[cnt])que[++cnt]=que[i];
sta[++top]=1;tot=0;
for(int i=1;i<=cnt;i++) {
int grand=lca(sta[top],que[i]);
while(1) {
if(dep[sta[top-1]]<=dep[grand]) {
if(grand!=sta[top])add(grand,sta[top],0);
if(sta[--top]!=grand)sta[++top]=grand;
break;
}
if(top>1)add(sta[top-1],sta[top],0);top--;
}
if(sta[top]!=que[i])sta[++top]=que[i];
}
top--;
while(top)add(sta[top],sta[top+1],0),top--;
/*print(1);puts("OK");*/dp(1);
//printf("%lld %lld\n",f[2],f[4]);
printf("%lld\n",f[1]);
}
int main() {
n=read();
for(int i=1;i<n;i++) {
int x=read(),y=read(),v=read();
add(x,y,v);add(y,x,v);
}tot=0;w[1]=inf;dfs(0,1);
prepare_st();m=read();
memset(now,0,sizeof(now));
for(int i=1;i<=m;i++)solve();
return 0;
}