BZOJ4538 HNOI2016网络(树链剖分+线段树+堆/整体二分+树上差分)
某两个点间的请求只对不在这条路径上的询问有影响。那么容易想到每次修改除该路径上的所有点的答案。对每个点建个两个堆,其中一个用来删除,线段树维护即可。由于一条路径在树剖后的dfs序中是log个区间,所以其补集也是log个区间。
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> #include<queue> using namespace std; int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } #define N 100010 int n,m,p[N],t,cnt; int id[N],tag[N],top[N],fa[N],deep[N],size[N],son[N]; int L[N<<2],R[N<<2]; struct data{int to,nxt; }edge[N<<1]; struct data2{int x,y,z; }q[N<<1]; struct data3 { int l,r; bool operator <(const data3&a) const { return r<a.r; } }a[N]; struct heap { priority_queue<int> a,b; void check(){while (!b.empty()&&!a.empty()&&a.top()==b.top()) a.pop(),b.pop();} void ins(int x){a.push(x);check();} void del(int x){b.push(x);check();} }tree[N<<2]; void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;} void dfs1(int k) { size[k]=1; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k]) { deep[edge[i].to]=deep[k]+1; fa[edge[i].to]=k; dfs1(edge[i].to); size[k]+=size[edge[i].to]; if (size[son[k]]<size[edge[i].to]) son[k]=edge[i].to; } } void dfs2(int k,int from) { top[k]=from; id[k]=++cnt;tag[cnt]=k; if (son[k]) dfs2(son[k],from); for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k]&&edge[i].to!=son[k]) dfs2(edge[i].to,edge[i].to); } void build(int k,int l,int r) { L[k]=l,R[k]=r;tree[k].ins(-1); if (l==r) return; int mid=l+r>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); } void down(int k) { while (!tree[k].a.empty()) { int x=tree[k].a.top();tree[k].a.pop(); tree[k<<1].ins(x);tree[k<<1|1].ins(x); tree[k].check(); } while (!tree[k].b.empty()) { int x=tree[k].b.top();tree[k].b.pop(); tree[k<<1].del(x);tree[k<<1|1].del(x); } } void ins(int k,int l,int r,int x) { if (L[k]==l&&R[k]==r) {tree[k].ins(x);return;} down(k); int mid=L[k]+R[k]>>1; if (r<=mid) ins(k<<1,l,r,x); else if (l>mid) ins(k<<1|1,l,r,x); else ins(k<<1,l,mid,x),ins(k<<1|1,mid+1,r,x); } void del(int k,int l,int r,int x) { if (L[k]==l&&R[k]==r) {tree[k].del(x);return;} down(k); int mid=L[k]+R[k]>>1; if (r<=mid) del(k<<1,l,r,x); else if (l>mid) del(k<<1|1,l,r,x); else del(k<<1,l,mid,x),del(k<<1|1,mid+1,r,x); } int query(int k,int x) { if (L[k]==R[k]) return tree[k].a.top(); down(k); int mid=L[k]+R[k]>>1; if (x<=mid) return query(k<<1,x); else return query(k<<1|1,x); } void modify(int x,int y,int z) { int m=0; while (top[x]!=top[y]) { if (deep[top[x]]<deep[top[y]]) swap(x,y); m++,a[m].l=id[top[x]],a[m].r=id[x]; x=fa[top[x]]; } if (deep[x]<deep[y]) swap(x,y); m++,a[m].l=id[y],a[m].r=id[x]; sort(a+1,a+m+1); a[0].l=a[0].r=0,a[m+1].l=a[m+1].r=n+1; for (int i=0;i<=m;i++) if (a[i].r+1<=a[i+1].l-1) if (z>0) ins(1,a[i].r+1,a[i+1].l-1,z); else del(1,a[i].r+1,a[i+1].l-1,-z); } int main() { #ifndef ONLINE_JUDGE freopen("bzoj4538.in","r",stdin); freopen("bzoj4538.out","w",stdout); const char LL[]="%I64d\n"; #else const char LL[]="%lld\n"; #endif n=read(),m=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); addedge(x,y),addedge(y,x); } dfs1(1); dfs2(1,1); build(1,1,n); for (int i=1;i<=m;i++) { int op=read(); switch(op) { case 0: { q[i].x=read(),q[i].y=read(),q[i].z=read(); modify(q[i].x,q[i].y,q[i].z); break; } case 1: { int x=read(); modify(q[x].x,q[x].y,-q[x].z); break; } case 2: { int x=read(); printf("%d\n",query(1,id[x])); } } } return 0; }
然而由于复杂度是O(nlog3n)的以及蒟蒻自带大常数,在luogu上T了两个点,bzoj甚至直接MLE了。于是考虑有没有更好的做法。
感觉上修改时对于某一条路径之外的都要修改过于暴力,考虑是否能改成修改这条路径。如果只有一次询问,可以二分答案,统计出有多少个修改的权值不小于该答案,并统计上述种类的修改覆盖了查询点多少次,若两者相等则说明答案不可行,否则可行。修改路径可以通过树上差分维护子树和完成。
现在有多组询问,可以想到整体二分。对整体二分一个答案,按时间顺序依次修改、查询,每次做完之后都缩小了每一个询问的答案范围,将询问按照答案是否大于mid、修改按照权值是否大于mid扔在两边即可。于是复杂度O(nlog2n)。
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> using namespace std; int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } #define N 100010 #define inf 1000000000 int n,m,p[N],t,tree[N],id[N],size[N],fa[N][18],deep[N],cnt; struct data{int to,nxt; }edge[N<<1]; struct data2{int op,x,y,z,l,i,ans; }q[N<<1],tmp[N<<1]; void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;} bool cmp(const data2&a,const data2&b) { return a.i<b.i; } void dfs(int k,int from) { id[k]=++cnt;size[k]=1; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=from) { deep[edge[i].to]=deep[k]+1; fa[edge[i].to][0]=k; dfs(edge[i].to,k); size[k]+=size[edge[i].to]; } } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); for (int j=17;~j;j--) if (deep[fa[x][j]]>=deep[y]) x=fa[x][j]; if (x==y) return x; for (int j=17;~j;j--) if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j]; return fa[x][0]; } void add(int k,int x){while (k<=n){tree[k]+=x;k+=k&-k;}} int sum(int k){int s=0;while (k){s+=tree[k];k-=k&-k;}return s;} void solve(int l,int r,int low,int high) { if (l>r) return; if (low==high) { for (int i=l;i<=r;i++) if (q[i].op==2) q[i].ans=low; return; } int mid=low+high>>1,s=0; for (int i=l;i<=r;i++) if (q[i].op!=2) { if (q[i].z>mid) { int t=q[i].op==0?1:-1; s+=t; add(id[q[i].x],t),add(id[q[i].y],t); add(id[q[i].l],-t); if (q[i].l>1) add(id[fa[q[i].l][0]],-t); } } else if (sum(id[q[i].x]+size[q[i].x]-1)-sum(id[q[i].x]-1)==s) q[i].ans=low; else q[i].ans=high; int cut=l; for (int i=l;i<=r;i++) if (q[i].op==2&&q[i].ans<=mid||q[i].op!=2&&q[i].z<=mid) tmp[cut++]=q[i]; int t=cut; for (int i=l;i<=r;i++) if (q[i].op==2&&q[i].ans>mid||q[i].op!=2&&q[i].z>mid) tmp[t++]=q[i]; for (int i=l;i<=r;i++) q[i]=tmp[i]; for (int i=cut;i<=r;i++) if (q[i].op!=2) { int t=q[i].op==0?-1:1; add(id[q[i].x],t),add(id[q[i].y],t); add(id[q[i].l],-t); if (q[i].l>1) add(id[fa[q[i].l][0]],-t); } solve(l,cut-1,low,mid); solve(cut,r,mid+1,high); } int main() { #ifndef ONLINE_JUDGE freopen("bzoj4538.in","r",stdin); freopen("bzoj4538.out","w",stdout); const char LL[]="%I64d\n"; #else const char LL[]="%lld\n"; #endif n=read(),m=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); addedge(x,y),addedge(y,x); } fa[1][0]=1;dfs(1,1); for (int j=1;j<18;j++) for (int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; for (int i=1;i<=m;i++) { q[i].op=read();q[i].i=i; switch(q[i].op) { case 0: { q[i].x=read(),q[i].y=read(),q[i].z=read(),q[i].l=lca(q[i].x,q[i].y); break; } case 1: { int x=read(); q[i].x=q[x].x,q[i].y=q[x].y,q[i].z=q[x].z,q[i].l=q[x].l; break; } case 2:q[i].x=read(),q[i].ans=-1; } } solve(1,m,-1,inf); sort(q+1,q+m+1,cmp); for (int i=1;i<=m;i++) if (q[i].op==2) printf("%d\n",q[i].ans); return 0; }