洛谷2664树上游戏-点分治

link:https://www.luogu.com.cn/problem/P2664
lrb 有一棵树,树的每个节点有个颜色。给一个长度为 \(n\) 的颜色序列,定义 \(s(i,j)\)\(i\)\(j\) 的颜色数量。以及

\[sum_i=\sum_{j=1}^n s(i, j) \]

现在他想让你求出所有的 \(sum_i\)


一个暴力的想法:因为是求和,所以可以拆开算贡献。枚举每个颜色 \(c\),将颜色 \(c\) 的点拿出来,会把原树划分成若干个连通块,对每个点的贡献即为 \(n-sz_i\) ,其中 \(sz_i\) 表示 \(i\) 所属连通块大小。连通块用并查集维护,这样是 \(O(n^2)\)的。

考虑点分治,对于当前的分治中心 \(x\),需要考虑:

  • \(x\) 为某个端点,延伸到某个子节点的产生的答案。
  • \(x\) 为LCA,从某个子树中的 \(y\) 出发(或直接从 \(x\) 出发)对另一子树中 \(z\) 的贡献。
    第一种情况直接一次dfs统计:
void dfs1(int x,int fa,int from){
    cnt_col[c[x]]++;
    if(cnt_col[c[x]]==1)ans[from]+=sz[x];
    for(auto to:G[x])if(to!=fa&&!vis[to])dfs1(to,x,from);
    cnt_col[c[x]]--;
}

对第二种情况,假设某种颜色 \(c\)\(x\to y\) 的路径上已经出现过了,统计 \(y\) 的答案时,贡献直接是 \(sz[x]-sz[p]\),否则,应该对 \(p\) 以外的 \(x\) 子树进行统计,如果某个颜色在一个结点里第一次出现,则会产生其子树大小的贡献,记一个 \(path[c]\) 表示颜色 \(c\) 的贡献

void dfs2(int x,int fa,int def_val,bool tag){
    cnt_col[c[x]]++;
    if(cnt_col[c[x]]==1){
        sum_col+=def_val;
        sum_path-=path[c[x]];
    }
    ans[x]+=sum_path;
    if(tag)ans[x]+=sum_col;
    for(auto to:G[x])if(to!=fa&&!vis[to])
        dfs2(to,x,def_val,tag);

    if(cnt_col[c[x]]==1){
        sum_col-=def_val;
        sum_path+=path[c[x]];
    }
    cnt_col[c[x]]--;
}

因为需要扣除掉 \(p\) 子树内的 \(path\) 数组,直接给数组做差不方便,这部分答案可以考虑对 \(x\) 的孩子正反做两次dfs


void calc_path(int x,int fa){
    cnt_col[c[x]]++;
    Q[++tot]=c[x];
    if(cnt_col[c[x]]==1){
        path[c[x]]+=sz[x];
        sum_path+=sz[x];
    }
    for(auto to:G[x])if(to!=fa&&!vis[to])calc_path(to,x);
    cnt_col[c[x]]--;
}
//...
auto work=[&](bool tag){
	clear();
	sum_col=sum_path=0;
	assert(cnt_col[c[x]]==0);
	for(auto to:G[x])if(!vis[to]){
		cnt_col[c[x]]++;
		sum_col=sz[x]-sz[to];
		dfs2(to,x,sz[x]-sz[to],tag);
		calc_path(to,x);
		sum_col=0;
		cnt_col[c[x]]--;
	}
};
work(0);
reverse(G[x].begin(),G[x].end());
work(1);

代码:

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define endl '\n'
#define fastio ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int INF=0x3f3f3f3f;
const int N=1e5+5;
int n,c[N],mx_c;
bool vis[N];
vector<vector<int>> G;

int cnt_col[N],path[N];
int tot,Q[N];

ll sum_path,sum_col;
ll ans[N];
int rt,sz[N],mx_rt;
void get_rt(int x,int fa,int sum){
    sz[x]=1;
    int mx=0;
    for(auto to:G[x])if(to!=fa&&!vis[to]){
        get_rt(to,x,sum);
        mx=max(mx,sz[to]);
        sz[x]+=sz[to];
    }
    mx=max(mx,sum-sz[x]);
    if(mx<mx_rt){
        mx_rt=mx;
        rt=x;
    }
}

void dfs1(int x,int fa,int from){
    cnt_col[c[x]]++;
    if(cnt_col[c[x]]==1)ans[from]+=sz[x];
    for(auto to:G[x])if(to!=fa&&!vis[to])dfs1(to,x,from);
    cnt_col[c[x]]--;
}
void dfs2(int x,int fa,int def_val,bool tag){
    cnt_col[c[x]]++;
    if(cnt_col[c[x]]==1){
        sum_col+=def_val;
        sum_path-=path[c[x]];
    }
    ans[x]+=sum_path;
    if(tag)ans[x]+=sum_col;
    for(auto to:G[x])if(to!=fa&&!vis[to])
        dfs2(to,x,def_val,tag);

    if(cnt_col[c[x]]==1){
        sum_col-=def_val;
        sum_path+=path[c[x]];
    }
    cnt_col[c[x]]--;
}
void calc_path(int x,int fa){
    cnt_col[c[x]]++;
    Q[++tot]=c[x];
    if(cnt_col[c[x]]==1){
        path[c[x]]+=sz[x];
        sum_path+=sz[x];
    }
    for(auto to:G[x])if(to!=fa&&!vis[to])calc_path(to,x);
    cnt_col[c[x]]--;
}
void dfz(int x,int sum=n){
    mx_rt=INF;
    get_rt(x,-1,sum);
    x=rt;
    get_rt(x,-1,sum);
    
    auto clear=[&](){
        rep(i,1,tot)path[Q[i]]=0;
        tot=0;
    };
    dfs1(x,-1,x);
    auto work=[&](bool tag){
        clear();
        sum_col=sum_path=0;
        assert(cnt_col[c[x]]==0);
        for(auto to:G[x])if(!vis[to]){
            cnt_col[c[x]]++;
            sum_col=sz[x]-sz[to];
            dfs2(to,x,sz[x]-sz[to],tag);
            calc_path(to,x);
            sum_col=0;
            cnt_col[c[x]]--;
        }
    };
    work(0);
    reverse(G[x].begin(),G[x].end());
    work(1);

    vis[x]=true;
    for(auto to:G[x])if(!vis[to])dfz(to,sz[to]);
}
int main(){
    fastio;
    cin>>n;
    rep(i,1,n){
        cin>>c[i];
        mx_c=max(mx_c,c[i]);
    }
    G=vector<vector<int>>(n+1);
    rep(i,1,n-1){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfz(1);
    rep(i,1,n)cout<<ans[i]<<endl;
    return 0;
}
posted @ 2024-04-29 22:09  yoshinow2001  阅读(17)  评论(0编辑  收藏  举报