洛谷P2664 树上游戏 [点分治]

传送门


思路

首先,我们要脑洞大开想到点分治。

注意到每个点都要输出对应的\(ans\),所以想到在点分树上跳,得到答案。

一个点答案的贡献要分为好几个部分。

第一个

这是最容易统计的部分。你需要得到点分树上自己的子树对自己的贡献。

显然一个dfs即可搞定。

第二个

假设当前要求的点是\(x\),枚举它的祖先为\(y\)

我们需要得到\(y\rightarrow x\)这条路径上的贡献乘上\(size_y-size_x\),其中\(size_x\)表示x所在子树的大小,\(size_y\)表示\(y\)在点分树上子树大小。

显然这个可以在建点分树时用一个dfs搞定。

第三个

还是上面的\(x,y\),我们需要统计\(y\)子树中除去\(x\)子树的点对\(x\)的贡献。

考虑每种颜色分开考虑。

对于每个颜色,统计会出现它的路径,可以一个dfs搞定。

然后把\(x\)所在的子树的贡献删去之后就可以了。

第四个

发现第二部分和第三部分会算重,要把重复的去掉。

对于\(y\rightarrow x\)路径上出现过的颜色,要把它出现过的路径数删去。

也是建树时一个dfs可以搞定的。

最后

你发现上面这些全是建树时就可以算出来的,于是你根本不需要建树,直接搞就好了。

细节较多,祝你好运。

不过把我的缺省源删去后其实还挺短的,只有大概110行。


代码

#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
    using namespace std;
    #define pii pair<int,int>
    #define fir first
    #define sec second
    #define MP make_pair
    #define rep(i,x,y) for (int i=(x);i<=(y);i++)
    #define drep(i,x,y) for (int i=(x);i>=(y);i--)
    #define go(x) for (int i=head[x];i;i=edge[i].nxt)
    #define templ template<typename T>
    #define sz 101010
    typedef long long ll;
    typedef double db;
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
    templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
    templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
    templ inline void read(T& t)
    {
        t=0;char f=0,ch=getchar();double d=0.1;
        while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
        while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
        if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
        t=(f?-t:t);
    }
    template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
    char sr[1<<21],z[20];int C=-1,Z=0;
    inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
    inline void print(register int x)
    {
    	if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    	while(z[++Z]=x%10+48,x/=10);
    	while(sr[++C]=z[Z],--Z);sr[++C]='\n';
    }
    void file()
    {
        #ifndef ONLINE_JUDGE
        freopen("a.in","r",stdin);
        #endif
    }
    inline void chktime()
    {
        #ifndef ONLINE_JUDGE
        cout<<(clock()-t)/1000.0<<'\n';
        #endif
    }
    #ifdef mod
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
    ll inv(ll x){return ksm(x,mod-2);}
    #else
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
    #endif
//	inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;

int n;
int col[sz];
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
	edge[++ecnt]=(hh){t,head[f]};
	head[f]=ecnt;
	edge[++ecnt]=(hh){f,head[t]};
	head[t]=ecnt;
}

ll ans[sz];
bool vis[sz];
int size[sz],mn,root,sum;
#define v edge[i].t
void calcsize(int x,int fa)
{
    size[x]=1;
    go(x) 
		if (v!=fa&&!vis[v])
        	calcsize(v,x),size[x]+=size[v];
}
void findroot(int x,int fa)
{
    int S=-1;
    go(x) 
		if (v!=fa&&!vis[v])
        	findroot(v,x),chkmax(S,size[v]);
    chkmax(S,sum-size[x]);
    if (chkmin(mn,S)) root=x;
}
void findroot(int x){calcsize(x,0);mn=1e9;sum=size[x];findroot(x,0);}
int cnt[sz]; // 每种颜色出现次数
int S1; // 出现过的颜色数量 
void dfs1(int x,int u,int fa) // 统计u往下的路径 
{
	++cnt[col[x]];if (cnt[col[x]]==1) ++S1;
	ans[u]+=S1;
	go(x) if (!vis[v]&&v!=fa) dfs1(v,u,x);
	--cnt[col[x]];if (cnt[col[x]]==0) --S1;
}
ll S[sz],SS; // 每种颜色出现过的路径数,\sum_i S[i] 
void dfs2(int x,int fa) // 统计每种颜色出现过的路径数
{
	++cnt[col[x]];if (cnt[col[x]]==1) S[col[x]]+=size[x],SS+=size[x];
	go(x) if (!vis[v]&&v!=fa) dfs2(v,x);
	--cnt[col[x]];
}
void clr(int x,int fa)
{
	++cnt[col[x]];if (cnt[col[x]]==1) S[col[x]]-=size[x],SS-=size[x];
	go(x) if (!vis[v]&&v!=fa) clr(v,x);
	--cnt[col[x]];
}
ll S2; // 已经重复需要被减去的路径数 
void dfs3(int x,int fa,int Size)
{
	++cnt[col[x]];if (cnt[col[x]]==1) S2+=S[col[x]],++S1;
	ans[x]+=SS+S1*Size-S2;
	go(x) if (!vis[v]&&v!=fa) dfs3(v,x,Size);
	--cnt[col[x]];if (cnt[col[x]]==0) S2-=S[col[x]],--S1;
}
void CLR(int x,int fa)
{
	S[col[x]]=0;
	go(x) if (v!=fa&&!vis[v]) CLR(v,x);
}
void calc(int x)
{
	calcsize(x,0);
	dfs1(x,x,0);
	dfs2(x,0);
	go(x) if (!vis[v])
	{
		cnt[col[x]]=1;SS-=size[v];S[col[x]]-=size[v];clr(v,x);cnt[col[x]]=0;
		dfs3(v,x,size[x]-size[v]);
		cnt[col[x]]=1;SS+=size[v];S[col[x]]+=size[v];dfs2(v,x);cnt[col[x]]=0;
	}
	CLR(x,0);SS=0;
}
void solve(int x)
{
	vis[x]=1;
	calc(x);
	go(x) 
		if (!vis[v])
			findroot(v),solve(root);
}
#undef v

int main()
{
	file();
	read(n);
	int x,y;
	rep(i,1,n) read(col[i]);
	rep(i,1,n-1) read(x,y),make_edge(x,y);
	findroot(1);solve(root);
	rep(i,1,n) printf("%lld\n",ans[i]);
	return 0;
}
posted @ 2019-03-06 13:52  p_b_p_b  阅读(230)  评论(0编辑  收藏  举报