【知识点】点分治&动态点分治
点分治:
优雅地暴力解决一类不带修改的树上路径问题。
每次找原树的重心,以重心为根暴力枚举当前子树内的所有点算答案,然后继续递归子树。
这个东西最多会递归$\log{n}$层,所以复杂度是$O(n\log{n})$的。
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #define maxn 100005 #define maxm 500005 #define inf 0x7fffffff #define ll long long using namespace std; int N,M,K,hd[maxn],to[maxn<<1],nxt[maxn<<1],cst[maxn<<1]; int rt,ans,num,siz[maxn],d[maxn],q[maxn],tot,mnsiz,cnt; bool vis[maxn]; inline int read(){ int x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline void addedge(int u,int v,int w){ to[++cnt]=v,cst[cnt]=w,nxt[cnt]=hd[u],hd[u]=cnt; to[++cnt]=u,cst[cnt]=w,nxt[cnt]=hd[v],hd[v]=cnt; } inline void getrt(int u,int fa){ siz[u]=1; int mx=0; for(int i=hd[u];i;i=nxt[i]){ int v=to[i]; if(v==fa || vis[v]) continue; getrt(v,u); siz[u]+=siz[v]; mx=max(mx,siz[v]); } mx=max(mx,tot-siz[u]); if(mx<mnsiz) rt=u,mnsiz=mx; return; } inline void getdis(int u,int fa){ q[++num]=d[u]; for(int i=hd[u];i;i=nxt[i]){ int v=to[i],w=cst[i]; if(v==fa || vis[v]) continue; d[v]=d[u]+w,getdis(v,u); } return; } inline int calc(int u,int val){ num=0,d[u]=val,getdis(u,0); sort(q+1,q+1+num); int l=1,r=num,res=0; while(l<=r){ if(q[l]+q[r]==K) res++,l++; else if(q[l]+q[r]>K) r--; else l++; } return res; } inline void dfs(int u){ ans+=calc(u,0),vis[u]=1; for(int i=hd[u];i;i=nxt[i]){ int v=to[i],w=cst[i]; if(vis[v]) continue; ans-=calc(v,w); mnsiz=inf,tot=siz[v]; getrt(v,u),dfs(rt); } return; } int main(){ N=read(),M=read(); for(int i=1;i<=N-1;i++){ int u=read(),v=read(),w=read(); addedge(u,v,w); } while(M--){ memset(vis,0,sizeof(vis)); K=read(),ans=0,mnsiz=inf,tot=N; getrt(1,0),dfs(rt); if(ans) printf("AYE\n"); else printf("NAY\n"); } return 0; }
动态点分治(点分树):
解决带修改的点分治问题。
按dfs的顺序把点分治的所有重心连成一棵树,每次修改在树上跳fa更新贡献。
注意这棵树破坏了原树的结构,所以需要容斥的地方可能要单独维护。
复杂度$O(n\log{n})$。(luogu这么喜欢模板题卡常啊?)
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #define maxn 100005 #define maxm 500005 #define inf 0x7fffffff #define ll long long using namespace std; int N,M,K,hd[maxn],to[maxn<<1],nxt[maxn<<1],cst[maxn<<1]; int rt,ans,num,siz[maxn],d[maxn],q[maxn],tot,mnsiz,cnt; bool vis[maxn]; inline int read(){ int x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline void addedge(int u,int v,int w){ to[++cnt]=v,cst[cnt]=w,nxt[cnt]=hd[u],hd[u]=cnt; to[++cnt]=u,cst[cnt]=w,nxt[cnt]=hd[v],hd[v]=cnt; } inline void getrt(int u,int fa){ siz[u]=1; int mx=0; for(int i=hd[u];i;i=nxt[i]){ int v=to[i]; if(v==fa || vis[v]) continue; getrt(v,u); siz[u]+=siz[v]; mx=max(mx,siz[v]); } mx=max(mx,tot-siz[u]); if(mx<mnsiz) rt=u,mnsiz=mx; return; } inline void getdis(int u,int fa){ q[++num]=d[u]; for(int i=hd[u];i;i=nxt[i]){ int v=to[i],w=cst[i]; if(v==fa || vis[v]) continue; d[v]=d[u]+w,getdis(v,u); } return; } inline int calc(int u,int val){ num=0,d[u]=val,getdis(u,0); sort(q+1,q+1+num); int l=1,r=num,res=0; while(l<=r){ if(q[l]+q[r]==K) res++,l++; else if(q[l]+q[r]>K) r--; else l++; } return res; } inline void dfs(int u){ ans+=calc(u,0),vis[u]=1; for(int i=hd[u];i;i=nxt[i]){ int v=to[i],w=cst[i]; if(vis[v]) continue; ans-=calc(v,w); mnsiz=inf,tot=siz[v]; getrt(v,u),dfs(rt); } return; } int main(){ N=read(),M=read(); for(int i=1;i<=N-1;i++){ int u=read(),v=read(),w=read(); addedge(u,v,w); } while(M--){ memset(vis,0,sizeof(vis)); K=read(),ans=0,mnsiz=inf,tot=N; getrt(1,0),dfs(rt); if(ans) printf("AYE\n"); else printf("NAY\n"); } return 0; }