点分治学习笔记
点分治
采用分治思想。对树上路径问题进行查询时,把路径分成两部分,一部分是经过根节点的路径,一部分是不经过根节点的路径。
而在处理不经过根节点的路径时,可以才有分治思想,递归到左右子树进行求解。
这样复杂度是 \(O(n^2)\) 的,但是若我们每次选取的根节点都是要求解的子树的重心,则复杂度可以优化到 \(O(nlogn)\)。
实现思路
需要实现以下函数:
- solve:分治过程,不断取重心分治。
- getzx:求重心。
- calc:计算经过当前点的路径对答案的贡献。
例题
洛谷 P3806 【模板】点分治1
传送门
开一个数组记录当前的路径长度有哪些。
先读入所有询问,然后到达一个点就更新一下答案。
总复杂度 \(O(mnlogn)\)。
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
const int maxn=10010,bign=10001000;
int n,m,tmp[bign],judge[bign];
int sz[maxn],vis[maxn];
int head[maxn],q[maxn];
int size,maxp[maxn];
int tot,rt,dis[maxn];
int qqq[maxn],ynn[maxn],cnt,p[maxn];
struct node{
int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
cnt++;
e[cnt].v=v;
e[cnt].w=w;
e[cnt].next=p[u];
p[u]=cnt;
}
void getzx(int u,int fa){
maxp[u]=0;
sz[u]=1;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]||v==fa) continue;
getzx(v,u);
sz[u]+=sz[v];
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],tot-sz[u]);
if(maxp[u]<maxp[rt]) rt=u;
}
inline void getdis(int u,int fa){
tmp[++tmp[0]]=dis[u];
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]||v==fa) continue;
dis[v]=dis[u]+e[i].w;
getdis(v,u);
}
}
inline void calc(int u){
int ppp=0;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
tmp[0]=0;
dis[v]=e[i].w;
getdis(v,u);
for(int j=1;j<=tmp[0];j++){
for(int k=1;k<=m;k++){
if(q[k]>=tmp[j]) ynn[k]|=judge[q[k]-tmp[j]];
}
}
for(int j=1;j<=tmp[0];j++){
if(tmp[j]>=bign) continue;
qqq[++ppp]=tmp[j];
judge[tmp[j]]=1;
}
}
for(int i=1;i<=ppp;i++) judge[qqq[i]]=0;
}
inline void solve(int u){
vis[u]=judge[0]=1; calc(u);
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
tot=sz[v];
maxp[rt=0]=sz[v];
getzx(v,0);
solve(rt);
}
}
int main(){
ios::sync_with_stdio(false);
memset(p,-1,sizeof(p));
cin>>n>>m;
for(int i=1;i<n;i++){
int u,v,w;
cin>>u>>v>>w;
insert(u,v,w);
insert(v,u,w);
}
for(int i=1;i<=m;i++) cin>>q[i];
maxp[rt=0]=n;
tot=n;
getzx(1,0);
solve(rt);
for(int i=1;i<=m;i++){
if(ynn[i]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
return 0;
}
CF161D Distance in Tree
传送门
数组记录的内容变成当前长度为x的路径的数量。
其他和板子基本相同。
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
x=0;register char c=getchar();register bool f=0;
while(!isdigit(c))f^=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(f)x=-x;
}
template<class T>inline void print(T x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10);
putchar('0'+x%10);
}
const int maxn=50005;
long long ans;
int n,k,judge[maxn],tmp[maxn],siz[maxn],vis[maxn],p[maxn],cnt,tot,maxp[maxn],rt,dis[maxn],q[maxn];
struct node{
int v,next;
}e[maxn*2];
void insert(int u,int v){
cnt++;
e[cnt].v=v;
e[cnt].next=p[u];
p[u]=cnt;
}
void getzx(int u,int fa){
maxp[u]=0;
siz[u]=1;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]||v==fa) continue;
getzx(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],tot-siz[u]);
if(maxp[u]<=maxp[rt]) rt=u;
}
void getdis(int u,int fa){
tmp[++tmp[0]]=dis[u];
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+1;
getdis(v,u);
}
}
void calc(int u){
int cntq=0;
dis[u]=0;
judge[0]=1;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
tmp[0]=0;
dis[v]=1;
getdis(v,u);
for(int j=1;j<=tmp[0];j++) if(k>=tmp[j]) ans+=judge[k-tmp[j]];
for(int j=1;j<=tmp[0];j++) judge[tmp[j]]++,q[++cntq]=tmp[j];
}
for(int i=1;i<=cntq;i++) judge[q[i]]--;
}
void solve(int u){
vis[u]=1;calc(u);
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
maxp[rt=0]=tot=siz[v];
getzx(v,-1);
solve(rt);
}
}
int main(){
memset(p,-1,sizeof(p));
read(n);read(k);
for(int i=1;i<n;i++){
int u,v;
read(u);read(v);
insert(u,v);
insert(v,u);
}
maxp[rt=0]=tot=n;
getzx(1,-1);
solve(rt);
print(ans);
return 0;
}
洛谷 P4149 [IOI2011]Race
开一个数组记录当前路径权值;
开一个数组记录当前权值和为x的路径的最少的边数。
两个数组同时求、清空。
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
x=0;register char c=getchar();register bool f=0;
while(!isdigit(c))f^=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(f)x=-x;
}
template<class T>inline void print(T x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10);
putchar('0'+x%10);
}
const int maxn=2e5+5;
const int maxm=1e6+5;
int n,k,judge[maxm],tmp[maxn],siz[maxn],vis[maxn],p[maxn],cnt,tot,maxp[maxn],rt,dis[maxn],q[maxn],dep[maxn],tmp2[maxn],anss[maxm],ans=0x3f3f3f3f;
struct node{
int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
cnt++;
e[cnt].v=v;
e[cnt].w=w;
e[cnt].next=p[u];
p[u]=cnt;
}
void getzx(int u,int fa){
maxp[u]=0;
siz[u]=1;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]||v==fa) continue;
getzx(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],tot-siz[u]);
if(maxp[u]<=maxp[rt]) rt=u;
}
void getdis(int u,int fa){
tmp[++tmp[0]]=dis[u];tmp2[tmp[0]]=dep[u];
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+e[i].w;
dep[v]=dep[u]+1;
getdis(v,u);
}
}
void calc(int u){
int cntq=0;
dis[u]=0;
anss[0]=0;
judge[0]=1;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
tmp[0]=0;
dep[v]=1;
dis[v]=e[i].w;
getdis(v,u);
for(int j=1;j<=tmp[0];j++) if(k>=tmp[j]&&judge[k-tmp[j]]) ans=min(ans,anss[k-tmp[j]]+tmp2[j]);
for(int j=1;j<=tmp[0];j++){
if(tmp[j]>k) continue;
judge[tmp[j]]=1;
anss[tmp[j]]=min(anss[tmp[j]],tmp2[j]);
q[++cntq]=tmp[j];
}
}
for(int i=1;i<=cntq;i++) judge[q[i]]=0,anss[q[i]]=0x3f3f3f3f;
}
void solve(int u){
vis[u]=1;calc(u);
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
maxp[rt=0]=tot=siz[v];
getzx(v,-1);
solve(rt);
}
}
int main(){
memset(p,-1,sizeof(p));
memset(anss,0x3f,sizeof(anss));
read(n);read(k);
for(int i=1;i<n;i++){
int u,v,w;
read(u);read(v);read(w);
u++;v++;
insert(u,v,w);
insert(v,u,w);
}
maxp[rt=0]=tot=n;
getzx(1,-1);
solve(rt);
print((ans==0x3f3f3f3f?-1:ans));
return 0;
}
洛谷 P4178 Tree
传送门
一种做法是充斥一下,但是感觉好麻烦而且常数很大,所以我采用树状数组。
加路径相当于单点修改,更新答案时查询前缀和。
注意先更新答案,后更新存路径数量的桶。
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
x=0;register char c=getchar();register bool f=0;
while(!isdigit(c))f^=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(f)x=-x;
}
template<class T>inline void print(T x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10);
putchar('0'+x%10);
}
const int maxn=4e4+5;
const int maxx=2e4+5;
int n,cnt,num,p[maxn],d[maxn],siz[maxn],maxp[maxn],tot,k,dis[maxn],tmp[maxn],ans,vis[maxn],rt;
struct node{
int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
cnt++;
e[cnt].v=v;
e[cnt].w=w;
e[cnt].next=p[u];
p[u]=cnt;
}
inline int lowbit(int x){
return x&-x;
}
void update(int x,int v){
for(int i=x;i<maxx;i+=lowbit(i)) d[i]+=v;
}
int query(int x){
int res=0;
for(int i=x;i>=1;i-=lowbit(i)) res+=d[i];
return res;
}
void getzx(int u,int fa){
siz[u]=1;
maxp[u]=0;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v]) continue;
getzx(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],tot-siz[u]);
if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int fa){
tmp[++tmp[0]]=dis[u];
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+e[i].w;
getdis(v,u);
}
}
void cal(int u){
int num=0,q[maxn];
dis[u]=0;
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
dis[v]=e[i].w;
tmp[0]=0;
getdis(v,u);
for(int j=1;j<=tmp[0];j++){
if(k>tmp[j]) ans+=query(k-tmp[j]);
if(k>=tmp[j]) ans++;
}
for(int j=1;j<=tmp[0];j++){
if(k>=tmp[j]){
q[++num]=tmp[j];
update(tmp[j],1);
}
}
}
for(int i=1;i<=num;i++) update(q[i],-1);
}
void solve(int u){
vis[u]=1;
cal(u);
for(int i=p[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(vis[v]) continue;
tot=maxp[rt=0]=siz[v];
getzx(v,u);
solve(v);
}
}
int main(){
memset(p,-1,sizeof(p));
read(n);
for(int i=1;i<n;i++){
int u,v,w;
read(u);read(v);read(w);
insert(u,v,w);
insert(v,u,w);
}
read(k);
maxp[rt=0]=tot=n;
getzx(1,-1);
solve(1);
print(ans);
return 0;
}