51nod1868

思路

随便即可推得答案是 \((n-1)!\sum_i\sum_j[i\neq j]g(i,j)\)

后面这坨东西明显就是淀粉质的形式,直接上淀粉质即可。

复杂度 \(O(n\log n)\)\(O(n\log^2n)\)。为了方便,我写了一个俩 \(\log\) 的东西。

// Problem: 彩色树
// Contest: Virtual Judge - 51Nod
// URL: https://vjudge.net/problem/51Nod-1868
// Memory Limit: 131 MB
// Time Limit: 1000 ms

#include <algorithm>
#include <map>
#include <stdio.h>
#include <vector>
typedef long long llt;
typedef unsigned uint;typedef unsigned long long ullt;
typedef bool bol;typedef char chr;typedef void voi;
typedef double dbl;
template<typename T>bol _max(T&a,T b){return(a<b)?a=b,true:false;}
template<typename T>bol _min(T&a,T b){return(b<a)?a=b,true:false;}
template<typename T>T power(T base,T index,T mod){return((index<=1)?(index?base:1):(power(base*base%mod,index>>1,mod)*power(base,index&1,mod)))%mod;}
template<typename T>T lowbit(T n){return n&-n;}
template<typename T>T gcd(T a,T b){return b?gcd(b,a%b):a;}
template<typename T>T lcm(T a,T b){return(a!=0||b!=0)?a/gcd(a,b)*b:(T)0;}
template<typename T>T exgcd(T a,T b,T&x,T&y){if(!b)return y=0,x=1,a;T ans=exgcd(b,a%b,y,x);y-=a/b*x;return ans;}
template<const ullt p=998244353>
class mod_ullt//会自然溢出,模数不可过大
{
	private:
		ullt v;
		ullt chg(ullt w){return(w<p)?w:w-p;}
		mod_ullt _chg(ullt w){mod_ullt ans;ans.v=(w<p)?w:w-p;return ans;}
		voi _print(ullt v){if(v>=10)_print(v/10);putchar('0'+v%10);}
	public:
		mod_ullt():v(0){}
		mod_ullt(ullt v):v(v%p){}
		bol empty(){return!v;}
		ullt val(){return v;}
		bol friend operator<(mod_ullt a,mod_ullt b){return a.v<b.v;}
		bol friend operator>(mod_ullt a,mod_ullt b){return a.v>b.v;}
		bol friend operator<=(mod_ullt a,mod_ullt b){return a.v<=b.v;}
		bol friend operator>=(mod_ullt a,mod_ullt b){return a.v>=b.v;}
		bol friend operator==(mod_ullt a,mod_ullt b){return a.v==b.v;}
		bol friend operator!=(mod_ullt a,mod_ullt b){return a.v!=b.v;}
		mod_ullt friend operator+(mod_ullt a,mod_ullt b){return a._chg(a.v+b.v);}
		mod_ullt friend operator-(mod_ullt a,mod_ullt b){return a._chg(a.v+a.chg(p-b.v));}
		mod_ullt friend operator*(mod_ullt a,mod_ullt b){return a.v*b.v;}
		mod_ullt friend operator/(mod_ullt a,mod_ullt b){return b._power(p-2)*a.v;}
		mod_ullt friend operator-(mod_ullt a){return a._chg(p-a.v);}
		mod_ullt sqrt()
		{
            if(power(v,(p-1)>>1,p)!=1)return 0;
            mod_ullt b=1;do b++;while(b._power((p-1)>>1)==1);
            ullt t=p-1,s=0,k=1;while(!(t&1))s++,t>>=1;
            mod_ullt x=_power((t+1)>>1),e=_power(t);
            while(k<s)
            {
            	if(e._power(1llu<<(s-k-1))!=1)x*=b._power((1llu<<(k-1))*t);
            	e=_power(p-2)*x*x,k++;
            }
            return _min(x,-x),x;
        }
		mod_ullt inv(){return _power(p-2);}
		mod_ullt _power(ullt index){mod_ullt ans(1),w(v);while(index){if(index&1)ans*=w;w*=w,index>>=1;}return ans;}
		voi read(){v=0;chr c;do c=getchar();while(c>'9'||c<'0');do v=c-'0'+v*10,c=getchar();while(c>='0'&&c<='9');v%=p;}
		voi print(){_print(v);}
		voi println(){_print(v),putchar('\n');}
		mod_ullt operator++(int){mod_ullt ans=*this;return v=chg(v+1),ans;}
	public:
		mod_ullt&operator+=(mod_ullt b){return*this=_chg(v+b.v);}
		mod_ullt&operator-=(mod_ullt b){return*this=_chg(v+chg(p-b.v));}
		mod_ullt&operator*=(mod_ullt b){return*this=v*b.v;}
		mod_ullt&operator/=(mod_ullt b){return*this=b._power(p-2)*v;}
		mod_ullt&operator++(){return v=chg(v+1),*this;}
};
const ullt Mod=1000000007;
typedef mod_ullt<Mod>modint;
std::vector<uint>Way[100005];bol Del[100005];
uint gotsiz(uint p,uint f)
{
	uint siz=1;
	for(auto s:Way[p])if(s!=f&&!Del[s])siz+=gotsiz(s,p);
	return siz;
}
uint med,most;
uint gotmed(uint p,uint f,uint n)
{
	uint siz=1,w=0,q;
	for(auto s:Way[p])if(s!=f&&!Del[s])_max(w,q=gotmed(s,p,n)),siz+=q;
	_max(w,n-siz);if(_min(most,w))med=p;
	return siz;
}
uint A[100005];
std::map<uint,modint>M,S1,S2;
bol Init[100005];
uint gotd(uint p,uint f)
{
	uint siz=1;
	if(Init[A[p]])for(auto s:Way[p]){if(s!=f&&!Del[s])siz+=gotd(s,p);}
	else
	{
		Init[A[p]]=true;
		for(auto s:Way[p])if(s!=f&&!Del[s])siz+=gotd(s,p);
		M[A[p]]+=siz;
		Init[A[p]]=false;
	}
	return siz;
}
modint dfs(uint siz)//树上点分治
{
	uint p=med,t;modint ans(0),user(0),w1(0),w2(0);
	Init[A[p]]=true;
	S1.clear(),S2.clear();
	for(auto s:Way[p])if(!Del[s])
	{
		M.clear(),w1+=t=gotd(s,p),w2+=modint(t)*t;
		for(auto w:M)
			S1[w.first]+=w.second,S2[w.first]+=w.second*w.second,
			user+=w.second*(siz-t)*2;
	}
	//printf("%u:",p+1);
	for(auto w:S1)ans+=w.second*w.second-S2[w.first]/*,printf("(%u,%u) ",w.first,w.second)*/;
	Init[A[p]]=false,Del[p]=true;
	/*printf("%llu\n",(*/ans=user-ans+w1*w1-w2+(siz-1)*2/*).val())*/;
	for(auto s:Way[p])if(!Del[s])gotmed(s,p,most=siz=gotsiz(s,p))/*,med=s*/,ans+=dfs(siz);
	return ans;
}
int main()
{
	uint n,u,v;scanf("%u",&n);
	for(uint i=0;i<n;i++)scanf("%u",A+i);
	for(uint i=1;i<n;i++)scanf("%u%u",&u,&v),Way[--u].push_back(--v),Way[v].push_back(u);
	gotmed(0,0,most=n);
	modint ans=dfs(n);
	// ans.println();
	for(uint i=2;i<n;i++)ans*=i;
	ans.println();
	return 0;
}
posted @ 2022-02-14 15:12  myee  阅读(45)  评论(0编辑  收藏  举报