13th东北四省赛D.Master of Data Structure——暴力+虚树
题目链接:
题目大意:
给一棵$n$个点的树,每个点有一个点权,初始为$0$,有$m$次操作。操作分为$7$种:
1、路径加
2、路径异或
3、路径减
4、求路径和
5、求路径异或和
6、求路径最大点权-最小点权
7、求路径上点权与$k$差的绝对值的最小值
所有操作都是路径操作,且操作只有$2000$次,那么有许多路径每次都会被一起修改或询问。将所有操作路径端点建虚树,记录虚树上相邻两点间的节点数,暴力完成操作即可。
#include<set> #include<stack> #include<queue> #include<cmath> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int n,m,x,y,T; int st[500010]; int top,tot; int s[500010]; vector<int>v[500010]; int head[500010]; int to[500010]; int nex[500010]; int val[500010]; int sum[500010]; int w[500010]; int d[500010]; int dfn,cnt; int a[4010]; int f[500010][20]; int num[500010]; struct query { int opt,u,v,k; }q[2010]; int init() { dfn=top=cnt=tot=0; for(int i=1;i<=n;i++)v[i].clear(); memset(s,0,sizeof(s)); memset(w,0,sizeof(w)); memset(d,0,sizeof(d)); memset(f,0,sizeof(f)); memset(st,0,sizeof(st)); memset(head,0,sizeof(head)); memset(to,0,sizeof(to)); memset(nex,0,sizeof(nex)); memset(val,0,sizeof(val)); memset(num,0,sizeof(num)); memset(sum,0,sizeof(sum)); } bool cmp(int x,int y) { return s[x]<s[y]; } void add(int x,int y) { nex[++tot]=head[x]; head[x]=tot; to[tot]=y; sum[tot]=d[y]-d[x]-1; } void dfs(int x,int fa) { s[x]=++dfn; d[x]=d[fa]+1; f[x][0]=fa; for(int i=1;i<=19;i++)f[x][i]=f[f[x][i-1]][i-1]; int size=v[x].size(); for(int i=0;i<size;i++) { if(v[x][i]!=fa)dfs(v[x][i],x); } } int lca(int x,int y) { if(d[x]<d[y])swap(x,y); int dep=d[x]-d[y]; for(int i=0;i<=19;i++) { if((dep&(1<<i)))x=f[x][i]; } if(x==y)return x; for(int i=19;i>=0;i--) { if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i]; } return f[x][0]; } void insert(int x) { int anc=lca(st[top],x); while(top>1&&d[st[top-1]]>=d[anc]) { add(st[top-1],st[top]); top--; } if(anc!=st[top]) { add(anc,st[top]); st[top]=anc; } st[++top]=x; } void dfs2(int x,int fa) { d[x]=d[fa]+1; f[x][0]=fa; for(int i=head[x];i;i=nex[i]) { if(to[i]!=fa) { num[to[i]]=i; dfs2(to[i],x); } } } void do1(int u,int v,int k) { if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { w[u]+=k; val[num[u]]+=k; u=f[u][0]; } if(u==v) { w[u]+=k; return ; } while(u!=v) { w[u]+=k,w[v]+=k; val[num[u]]+=k,val[num[v]]+=k; u=f[u][0],v=f[v][0]; } w[u]+=k; } void do2(int u,int v,int k) { if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { w[u]^=k; val[num[u]]^=k; u=f[u][0]; } if(u==v) { w[u]^=k; return ; } while(u!=v) { w[u]^=k,w[v]^=k; val[num[u]]^=k,val[num[v]]^=k; u=f[u][0],v=f[v][0]; } w[u]^=k; } void do3(int u,int v,int k) { if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { w[u]>=k?w[u]-=k:1; val[num[u]]>=k?val[num[u]]-=k:1; u=f[u][0]; } if(u==v) { w[u]>=k?w[u]-=k:1; return ; } while(u!=v) { w[u]>=k?w[u]-=k:1; w[v]>=k?w[v]-=k:1; val[num[u]]>=k?val[num[u]]-=k:1; val[num[v]]>=k?val[num[v]]-=k:1; u=f[u][0],v=f[v][0]; } w[u]>=k?w[u]-=k:1; } void do4(int u,int v) { ll ans=0; if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { ans+=w[u]; ans+=1ll*val[num[u]]*sum[num[u]]; u=f[u][0]; } if(u==v) { ans+=w[u]; printf("%lld\n",ans); return ; } while(u!=v) { ans+=w[u],ans+=w[v]; ans+=1ll*val[num[u]]*sum[num[u]]; ans+=1ll*val[num[v]]*sum[num[v]]; u=f[u][0],v=f[v][0]; } ans+=w[u]; printf("%lld\n",ans); return ; } void do5(int u,int v) { int ans=0; if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { ans^=w[u]; ans^=(sum[num[u]]%2==0?0:val[num[u]]); u=f[u][0]; } if(u==v) { ans^=w[u]; printf("%d\n",ans); return ; } while(u!=v) { ans^=w[u],ans^=w[v]; ans^=(sum[num[u]]%2==0?0:val[num[u]]); ans^=(sum[num[v]]%2==0?0:val[num[v]]); u=f[u][0],v=f[v][0]; } ans^=w[u]; printf("%d\n",ans); return ; } void do6(int u,int v) { int mx,mn; mx=0,mn=1<<30; if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { mx=max(mx,w[u]),mn=min(mn,w[u]); if(sum[num[u]])mx=max(mx,val[num[u]]),mn=min(mn,val[num[u]]); u=f[u][0]; } if(u==v) { mx=max(mx,w[u]),mn=min(mn,w[u]); printf("%d\n",mx-mn); return ; } while(u!=v) { mx=max(mx,w[u]),mn=min(mn,w[u]); mx=max(mx,w[v]),mn=min(mn,w[v]); if(sum[num[u]])mx=max(mx,val[num[u]]),mn=min(mn,val[num[u]]); if(sum[num[v]])mx=max(mx,val[num[v]]),mn=min(mn,val[num[v]]); u=f[u][0],v=f[v][0]; } mx=max(mx,w[u]),mn=min(mn,w[u]); printf("%d\n",mx-mn); return ; } void do7(int u,int v,int k) { int ans=1<<30; if(d[u]<d[v])swap(u,v); while(d[u]>d[v]) { ans=min(ans,abs(w[u]-k)); if(sum[num[u]])ans=min(ans,abs(val[num[u]]-k)); u=f[u][0]; } if(u==v) { ans=min(ans,abs(w[u]-k)); printf("%d\n",ans); return ; } while(u!=v) { ans=min(ans,abs(w[u]-k)); ans=min(ans,abs(w[v]-k)); if(sum[num[u]])ans=min(ans,abs(val[num[u]]-k)); if(sum[num[v]])ans=min(ans,abs(val[num[v]]-k)); u=f[u][0],v=f[v][0]; } ans=min(ans,abs(w[u]-k)); printf("%d\n",ans); return ; } void work() { for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); v[x].push_back(y); v[y].push_back(x); } dfs(1,0); for(int i=1;i<=m;i++) { scanf("%d%d%d",&q[i].opt,&q[i].u,&q[i].v); if(q[i].opt<4||q[i].opt>6)scanf("%d",&q[i].k); a[++cnt]=q[i].u,a[++cnt]=q[i].v; } sort(a+1,a+1+cnt,cmp); if(a[1]!=1) { st[++top]=1; } for(int i=1;i<=cnt;i++) { if(a[i]==a[i-1])continue; insert(a[i]); } while(top>1) { add(st[top-1],st[top]); top--; } dfs2(1,0); for(int i=1;i<=m;i++) { if(q[i].opt==1)do1(q[i].u,q[i].v,q[i].k); else if(q[i].opt==2)do2(q[i].u,q[i].v,q[i].k); else if(q[i].opt==3)do3(q[i].u,q[i].v,q[i].k); else if(q[i].opt==4)do4(q[i].u,q[i].v); else if(q[i].opt==5)do5(q[i].u,q[i].v); else if(q[i].opt==6)do6(q[i].u,q[i].v); else do7(q[i].u,q[i].v,q[i].k); } } int main() { scanf("%d",&T); while(T--) { scanf("%d%d",&n,&m); init(); work(); } }