树分治
点分治
适合处理大规模的树上路径信息。
实现
取一个中心点计算跨过中心点的贡献。(lsy说得精辟但抽象)
先随便指定一个根 ,我们能将树上的路径分为经过 的路径和包含于 的某棵子树内的路径(不经过 )。
对于经过当前根节点的路径,又可以分为两种,一种是以根节点为一个端点的路径,另一种是两个端点都不为根节点的路径。而后者又可以由两条属于前者的链合并得到。
接着我们枚举 ,先计算在其子树中且经过该节点的路径对答案的贡献,再递归其子树对不经过该节点的路径进行求解。
OI-wiki说着有点抽象,其实就是不断地在上一个中心点的子树中去取一个新中心点,然后一直去求解某一类路径,就是一个分治的思想。
然后随便找一个中心点的话如果树是一条链这个分治就会退化为 ,所以我们要让递归层数尽量少,管的什么证明反正就是取树的重心。重心定义不再赘述,写下求法:记录 表示删掉 之后产生的最大子树大小,重心就是 最小的点。下面是求法:
点击查看代码
void find(ll u,ll fa){
siz[u]=1,maxp[u]=0;
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to;
if(v==fa||vis[v]) continue;
find(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],sums-siz[u]); //sums是会变的子树总大小,会在分治过程中变为siz[v]
if(maxp[u]<maxp[rt]) rt=u;
}
结合具体题目来分治。
题
P3806 【模板】点分治 1
首先离线询问省得每次都搞一遍淀粉质,然后先找到重心然后开始分治。每次求答案就把重心子树内的点到重心的距离求出来,然后枚举询问看有没有点的距离满足询问,这里具体是这样判的:开一个数组 表示有没有点到重心的距离为 ,if(q[k]>=dist[j]&&!ans[k]) ans[k]=fl[q[k]-dist[j]];
。
然后我们要记得清空这个 数组,于是考虑开一个桶记录哪些 修改过。但是这里有些细节问题, 可能太大了就会爆掉桶,所以要特判一下。
还有一些细节:我们每一次重新找重心都是先设为 ,然后树总的大小就变成了 ,因为我们是进入 的子树来分治的;一开始 就要设成 ,毕竟也可能一个点和中心的距离直接就是询问值;然后数组开大,还有点细节看代码吧。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll N=114514,M=10000005; //注意数据范围
struct xx{
ll next,to,val;
}e[2*N];
ll head[2*N],cnt;
void add(ll x,ll y,ll z){
e[++cnt].next=head[x];
e[cnt].to=y;
e[cnt].val=z;
head[x]=cnt;
}
ll n,m,q[N],rt,sums;
ll siz[N],dis[N],maxp[N];
bool vis[N],fl[M],ans[M];
ll bu[M],tot;
void find(ll u,ll fa){
siz[u]=1,maxp[u]=0;
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to;
if(v==fa||vis[v]) continue;
find(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],sums-siz[u]);
if(maxp[u]<maxp[rt]) rt=u;
}
ll res,dist[N];
void dfs_dis(ll u,ll fa){
dist[++res]=dis[u];
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to,w=e[i].val;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+w;
dfs_dis(v,u);
}
}
void calc(ll u){
tot=0;
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to,w=e[i].val;
if(vis[v]) continue;
res=0,dis[v]=w;
dfs_dis(v,u);
for(int j=1;j<=res;++j) //枚举点
for(int k=1;k<=m;++k)//枚举询问
if(q[k]>=dist[j]&&!ans[k]) ans[k]=fl[q[k]-dist[j]];
for(int j=1;j<=res;++j)
if(dist[j]<=1e7) bu[++tot]=dist[j],fl[dist[j]]=1;
}
for(int i=1;i<=tot;++i) fl[bu[i]]=0;
}
void solve(ll u){
vis[u]=1;
calc(u);
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to;
if(vis[v]) continue;
rt=0,maxp[rt]=N,sums=siz[v];
find(v,0),solve(rt);
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m; sums=n;
for(int i=1;i<n;++i){
ll a,b,c;
cin>>a>>b>>c;
add(a,b,c),add(b,a,c);
}
for(int i=1;i<=m;++i) cin>>q[i];
rt=0,sums=maxp[rt]=n,fl[0]=1;
find(1,0),solve(rt);
for(int i=1;i<=m;++i)
if(ans[i]) cout<<"AYE\n";
else cout<<"NAY\n";
return 0;
}
P4178 Tree
既然是小于等于 ,那就珂以用一个树状数组来维护小于等于 的点对距离数量。
这个题对于上个题就 calc 函数改了一下。这里面就先求出子树内的距离,然后先查询 ,然后再对每个 修改加一,也要用桶记录然后清零,注意应该先判 是否小于等于 ,然后就没了,比双指针更好理解。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll N=114514,M=10000005;
struct xx{
ll next,to,val;
}e[2*N];
ll head[2*N],cnt;
void add(ll x,ll y,ll z){
e[++cnt].next=head[x];
e[cnt].to=y;
e[cnt].val=z;
head[x]=cnt;
}
ll n,K,rt,sums,ans;
ll siz[N],dis[N],maxp[N];
bool vis[N];
ll bu[M],tot,c[N];
ll lowbit(ll x){return x&-x;}
void update(ll x,ll k){
while(x<=K){ //注意这里是加到K
c[x]+=k;
x+=lowbit(x);
}
}
ll query(ll x){
ll ans=0;
while(x){
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
void find(ll u,ll fa){
siz[u]=1,maxp[u]=0;
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to;
if(v==fa||vis[v]) continue;
find(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],sums-siz[u]);
if(maxp[u]<maxp[rt]) rt=u;
}
ll res,dist[N];
void dfs_dis(ll u,ll fa){
dist[++res]=dis[u];
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to,w=e[i].val;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+w;
dfs_dis(v,u);
}
}
void calc(ll u){
tot=0;
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to,w=e[i].val;
if(vis[v]) continue;
dis[v]=w,res=0;
dfs_dis(v,u);
for(int j=1;j<=res;++j)
if(dist[j]<=K) ans+=query(K-dist[j]);
for(int j=1;j<=res;++j)
if(dist[j]<=K) update(dist[j],1),bu[++tot]=dist[j],++ans;
}
for(int i=1;i<=tot;++i) update(bu[i],-1);
}
void solve(ll u){
vis[u]=1;
calc(u);
for(int i=head[u];i;i=e[i].next){
ll v=e[i].to;
if(vis[v]) continue;
rt=0,maxp[rt]=N,sums=siz[v];
find(v,0),solve(rt);
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n; sums=n;
for(int i=1;i<n;++i){
ll a,b,c;
cin>>a>>b>>c;
add(a,b,c),add(b,a,c);
}
cin>>K;
rt=0,sums=maxp[rt]=n;
find(1,0),solve(rt);
cout<<ans;
return 0;
}
我发现你为什么对淀粉质半懂不懂呢?因为你之前连分治都不懂/jk/jk
不过现在算是懂了,就相当于是把分治过程从区间搬到树上,区间中点变成了树的重心,然后计算跨过中点的区间造成的贡献。
不就这样,有啥好说的?证明以前真不知道是啥东西了。初一初二落下的东西影响太大了,你像这样的题本身就没有打好基础,你能保证说他能在CSP-S2023或者NOIP2024中拿个一等?
P6626 [省选联考 2020 B 卷] 消息传递
多次询问与 距离为 的点的个数。
考虑点分治,把每个询问离线下来挂到点上。令当前的分治中心为 ,我们记录下 子树中每种深度的点有多少个,对于 上的询问来说贡献就是 。当我们进入儿子 时, 子树中的点的深度相对于 子树中的要小一,所以进入之前要先更新 ,出来之后也是。具体来说更新方法就是遍历一遍子树:
void dfs_upd(ll u,ll fa,ll k){
cnt[dept[u]]+=k; //k=1或-1
for(int v:g[u]){
if(v==fa||vis[v]) continue;
dfs_upd(v,u,k);
}
}
然后按正常点分治框架来就行了,注意多测清空。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll int
#define eb emplace_back
const ll N=114514,M=1919810;
ll T;
vector <ll> g[N];
struct que{ll k,id;};
vector <que> q[N];
ll n,m,rt,sums,ans[N];
ll siz[N],maxd[N];
bool vis[N];
void find(ll u,ll fa){
siz[u]=1,maxd[u]=0;
for(int v:g[u]){
if(v==fa||vis[v]) continue;
find(v,u);
siz[u]+=siz[v];
maxd[u]=max(maxd[u],siz[v]);
}
maxd[u]=max(maxd[u],sums-siz[u]);
if(maxd[u]<maxd[rt]) rt=u;
}
ll dept[N],cnt[N]; //每个深度的个数
void dfs_dis(ll u,ll fa){
++cnt[dept[u]];
for(int v:g[u]){
if(v==fa||vis[v]) continue;
dept[v]=dept[u]+1;
dfs_dis(v,u);
}
}
void dfs_upd(ll u,ll fa,ll k){
cnt[dept[u]]+=k;
for(int v:g[u]){
if(v==fa||vis[v]) continue;
dfs_upd(v,u,k);
}
}
void dfs_calc(ll u,ll fa){
for(auto x:q[u]){
ll k=x.k-dept[u];
if(k<0) continue;
ans[x.id]+=cnt[k];
}
for(int v:g[u]){
if(v==fa||vis[v]) continue;
dfs_calc(v,u);
}
}
void calc(ll u){
dept[u]=0,dfs_dis(u,0);
for(auto x:q[u]){
ll k=x.k-dept[u];
if(k<0) continue;
ans[x.id]+=cnt[k];
}
for(int v:g[u]){
if(vis[v]) continue;
dfs_upd(v,u,-1);
dfs_calc(v,u);
dfs_upd(v,u,1);
}
dfs_upd(u,0,-1);
//注意每换一个点深度就要更改
}
void solve(ll u){
vis[u]=1;
calc(u);
for(int v:g[u]){
if(vis[v]) continue;
rt=0,maxd[rt]=N,sums=siz[v];
find(v,0),solve(rt);
}
}
void solve_main(){
cin>>n>>m;
for(int i=1;i<=n;++i) g[i].clear(),q[i].clear(),vis[i]=0; //询问没清空/jk
for(int i=1;i<=m;++i) ans[i]=0;
for(int i=1;i<n;++i){
ll a,b;
cin>>a>>b;
g[a].eb(b),g[b].eb(a);
}
for(int i=1,x,k;i<=m;++i){
cin>>x>>k;
q[x].eb(que{k,i});
}
sums=maxd[rt]=n;
find(1,0),solve(rt);
for(int i=1;i<=m;++i) cout<<ans[i]<<'\n';
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>T;
while(T--) solve_main();
return 0;
}