[Luogu2664]树上游戏

题面戳我

sol

点分。我们面临的最主要一个问题,就是如何在\(O(n)\)的时间内算出所有LCA为根的点对的贡献,还要分别累加到它们自己的答案中去。
\(num_i\):每一种颜色的数量。你可以认为这就是一个桶。从根到叶子遍历,相当于每次都只维护一条链上的颜色情况。以便于得到\(tot_i\)
\(fst_i\)\(i\)号点上的颜色是不是从根往下第一次出现。如果是,就会加到\(col_i\)里面取算贡献
\(col_i\):每一种颜色的贡献
\(tot_i\):每个点到根的路径上有多少种颜色

鉴于点对之间计算答案不太现实,我们考虑计算每种颜色对答案的贡献。

如果一个节点\(i\),它的颜色是从根往下第一次出现的(即\(fst_i=1\)),那么这种颜色就一定会给其他子树中的每个节点贡献\(sz_i\)的答案。这个答案就累加在\(col_i\)中。然后在对这个\(col_i\)求和,就是总贡献。

一个节点\(i\)的答案的初始值应该是\(tot_i*(sz_u-sz_v)\)(就是总\(sz\)除去自己所在的子树外的部分),然后还要加上一些\(col_i\)的值,但是要保证加上的\(col_i\)不能是自己到根已经有过的颜色(不然就重复计算了)。

多做几遍dfs,维护以上提到的东西就行了。
复杂度是\(O(nlog_2n)\)的,带点小常数

code

我之前干了件非常傻逼的事情
我没写\(clear\)\(solve\)函数里面写了个memset
然后复杂度变成了严格\(O(n^2)\)
然后就只有暴力分。。。

#include<cstdio>
#include<algorithm>
using namespace std;
#define ll long long
const int N = 100005;
int gi()
{
    int x=0,w=1;char ch=getchar();
    while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
    if (ch=='-') w=0,ch=getchar();
    while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return w?x:-x;
}
struct edge{int to,next;}a[N<<1];
int n,c[N],head[N],cnt,sz[N],w[N],vis[N],sum,root;
int num[N],fst[N],col[N],tot[N];
ll sigma,ans[N];
void getroot(int u,int f)
{
    sz[u]=1;w[u]=0;
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (v==f||vis[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
    }
    w[u]=max(w[u],sum-sz[u]);
    if (w[u]<w[root]) root=u;
}
void dfs(int u,int f,ll &Ans)
{
    sz[u]=1;num[c[u]]++;
    if (num[c[u]]==1) fst[u]=1,cnt++;else fst[u]=0;
    tot[u]=cnt;Ans+=tot[u];
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (v==f||vis[v]) continue;
        dfs(v,u,Ans);sz[u]+=sz[v];
    }
    if (fst[u]) col[c[u]]+=sz[u],sigma+=sz[u],cnt--;
    num[c[u]]--;
}
void change(int u,int f,int b)
{
    if (fst[u]) col[c[u]]+=b*sz[u],sigma+=b*sz[u];
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (v==f||vis[v]) continue;
        change(v,u,b);
    }
}
void calc(int u,int f,int k)
{
    if (fst[u]) sigma-=col[c[u]];
    ans[u]+=1ll*tot[u]*k+sigma;
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (v==f||vis[v]) continue;
        calc(v,u,k);
    }
    if (fst[u]) sigma+=col[c[u]];
}
void clear(int u,int f)
{
	col[c[u]]=0;
	for (int e=head[u];e;e=a[e].next)
	{
		int v=a[e].to;if (v==f||vis[v]) continue;
		clear(v,u);
	}
}
void solve(int u)
{
    vis[u]=1;
    dfs(u,0,ans[u]);
    col[c[u]]-=sz[u];sigma-=sz[u];
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (vis[v]) continue;
        change(v,0,-1);
        calc(v,0,sz[u]-sz[v]);
        change(v,0,1);
    }
    clear(u,0);sigma=0;
    for (int e=head[u];e;e=a[e].next)
    {
        int v=a[e].to;if (vis[v]) continue;
        sum=sz[v];
        root=0;
        getroot(v,0);
        solve(root);
    }
}
int main()
{
    n=gi();
    for (int i=1;i<=n;i++) c[i]=gi();
    for (int i=1;i<n;i++)
    {
        int u=gi(),v=gi();
        a[++cnt]=(edge){v,head[u]};head[u]=cnt;
        a[++cnt]=(edge){u,head[v]};head[v]=cnt;
    }
    sum=w[0]=n;cnt=0;
    getroot(1,0);
    solve(root);
    for (int i=1;i<=n;i++) printf("%lld\n",ans[i]);
    return 0;
}
posted @ 2018-01-18 17:10  租酥雨  阅读(581)  评论(1编辑  收藏  举报