P2634 [国家集训队]聪聪可可
点分治入门题
首先可以直接枚举所有两点的lca强行dp
设 $f [ x ] [ 0/1/2 ]$ 表示节点 $x$ 在模3意义下,$x$ 的子树所有节点到 $x$ 的距离为 $0/1/2$ 时的方案数
初始 $f [ x ] [ 0 ] =1$ (本身到自己有一种方案)
转移就枚举所有儿子 $v$ ,设 $x$ 到 $v$ 的距离为 $w$,那么转移显然为:
$ f [ x ] [ (w+0)\%3 ] += f [ v ] [ 0 ] $
$ f [ x ] [ (w+1)\%3 ] += f [ v ] [ 1 ] $
$ f [ x ] [ (w+2)\%3 ] += f [ v ] [ 2 ] $
统计答案也十分显然,对 $w$ 分类讨论一下就好了:
inline void work(int x)//注意函数名不是"dfs",x就是我们枚举的lca { f[x][0]=1; f[x][1]=f[x][2]=0; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(vis[v]) continue; dfs(v,x);//dfs求出儿子的f if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1]; if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0]; if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2];
//注意先统计ans再转移f f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; } }
但是最坏情况会被卡到 $O(n^2)$
所以上点分治,每次找重心作lca,这样每次子树大小至少减半
枚举lca复杂度$O(n)$,搞dp因为子树大小每次减半所以复杂度约为 $O(log_n)$
总复杂度 $O(nlog_n)$
注意long long
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<vector> using namespace std; typedef long long ll; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=5e5+7,INF=1e9+7; int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt; inline void add(int &a,int &b,int &c) { from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; val[cntt]=c; } inline int fk(int x) { return x>=3 ? x-3 : x; } int n,rt,tot; ll ans,f[N][3]; int sz[N],mx[N]; bool vis[N]; void find_rt(int x,int fa)//找重心 { mx[x]=0; sz[x]=1; for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(vis[v]||v==fa) continue; find_rt(v,x); sz[x]+=sz[v]; mx[x]=max(mx[x],sz[v]); } mx[x]=max(mx[x],tot-sz[x]); if(mx[x]<mx[rt]) rt=x; } void dfs(int x,int fa)//dfs先求出子树的f { f[x][0]=1; f[x][1]=f[x][2]=0; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(vis[v]||v==fa) continue; dfs(v,x); f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; } } inline void work(int x)//统计答案 { f[x][0]=1; f[x][1]=f[x][2]=0; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(vis[v]) continue; dfs(v,x); if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1]; if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0]; if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2]; f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; } } void solve(int x)//点分治 { vis[x]=1; work(x); for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(vis[v]) continue; tot=sz[v]; rt=0; find_rt(v,0); solve(rt); } } inline ll gcd(ll a,ll b) { return b ? gcd(b,a%b) : a; } int main() { //freopen("data.in","r",stdin); //freopen("data.out","w",stdout); int a,b,c; n=read(); for(int i=1;i<n;i++) { a=read(),b=read(),c=read()%3; add(a,b,c); add(b,a,c); } tot=n; mx[rt]=INF; find_rt(1,0); solve(rt); ans=ans*2+n; ll d=gcd(ans,1ll*n*n); printf("%lld/%lld",ans/d,1ll*n*n/d); return 0; }
其实此题不用点分治
可以直接树形dp,转移同上...代码又短又好写
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<vector> using namespace std; typedef long long ll; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=5e5+7; inline int fk(int x) { return x>=3 ? x-3 : x; } int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt; inline void add(int &a,int &b,int &c) { from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; val[cntt]=c; } ll ans,f[N][3]; void dfs(int x,int fa) { f[x][0]=1; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(v==fa) continue; dfs(v,x); if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1]; if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0]; if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2]; f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; } } int n; ll gcd(ll a,ll b) { return b ? gcd(b,a%b) : a; } int main() { int a,b,c; n=read(); for(int i=1;i<n;i++) { a=read(),b=read(),c=read()%3; add(a,b,c); add(b,a,c); } dfs(1,1); ans<<=1; ans+=n; ll d=gcd(ans,1ll*n*n); printf("%lld/%lld",ans/d,1ll*n*n/d); return 0; }