BZOJ2152聪聪可可
bzoj传送门
luogu传送门
这题算是很sb的点分治了,最近在点分治复习,写了练练手,对于这个题只需要对统计0,1,2出现的次数就好了吧,然后发现答案不对,也就是每个点对需要算两遍嘛,0也算,所以答案再加个n就好啦
(用树状数组的\(O(nlog^2)\))代码:
#include<cstdio>
#define max(a,b) (a>b?a:b)
int ans,rt,sum,w,size[20001],f[4],son[20001],dis[20001],n,pre[40001],nxt[40001],h[20001],v[40001],cnt;
bool vis[20001];
void add(int x,int y,int z)
{
pre[++cnt]=y;nxt[cnt]=h[x];h[x]=cnt;v[cnt]=z;
pre[++cnt]=x;nxt[cnt]=h[y];h[y]=cnt;v[cnt]=z;
}
void getroot(int x,int fa)
{
size[x]=1;son[x]=0;
for(int i=h[x];i;i=nxt[i])
if(pre[i]!=fa&&!vis[pre[i]])
{
getroot(pre[i],x);
size[x]+=size[pre[i]];
son[x]=max(size[pre[i]],son[x]);
}
son[x]=max(son[x],sum-size[x]);
if(son[x]<son[rt])rt=x;
}
void getans(int x,int fa)
{
if(dis[x]!=0)ans+=f[3-dis[x]];
else ans+=f[0]+1;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]]&&pre[i]!=fa)
{
dis[pre[i]]=(dis[x]+v[i])%3;
getans(pre[i],x);
}
}
void getdeep(int x,int fa)
{
f[dis[x]]++;
for(int i=h[x];i;i=nxt[i])if(!vis[pre[i]]&&pre[i]!=fa)getdeep(pre[i],x);
}
void dfs(int x)
{
vis[x]=1;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]])
{
dis[pre[i]]=v[i]%3;
getans(pre[i],0),getdeep(pre[i],0);
}
f[0]=f[1]=f[2]=0;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]])
{
rt=0;sum=size[pre[i]];getroot(pre[i],0);
dfs(rt);
}
}
int gcd(int a,int b)
{
if(!b)return a;
return gcd(b,a%b);
}
int main()
{
scanf("%d",&n);
for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),add(x,y,z);
sum=son[0]=n;getroot(1,0);dfs(rt);
ans=ans*2+n;w=n*n;
printf("%d/%d\n",ans/gcd(ans,w),w/gcd(ans,w));
}
(当年写的two pointer(也称尺取法)):
#include<cstdio>
#include<algorithm>
using namespace std;
int now,st[20001],n,cnt,sum,ans,pre[40001],rt,nxt[40001],v[40001],h[20001],son[20001],size[20001],dis[20001],t[3];
bool vis[20001];
void read(int &x)
{
int f=1;x=0;char s=getchar();
while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
x*=f;
}
void add(int x,int y,int z)
{
pre[++cnt]=y;nxt[cnt]=h[x];h[x]=cnt,v[cnt]=z;
pre[++cnt]=x;nxt[cnt]=h[y];h[y]=cnt,v[cnt]=z;
}
void getroot(int x,int fa)
{
size[x]=1;son[x]=0;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]]&&pre[i]!=fa)
{
getroot(pre[i],x);
size[x]+=size[pre[i]];
son[x]=max(son[x],size[pre[i]]);
}
son[x]=max(son[x],sum-size[x]);
if(son[x]<son[rt])rt=x;
}
void get(int x,int fa)
{
t[dis[x]%3]++;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]]&&pre[i]!=fa)
{
dis[pre[i]]=dis[x]+v[i];
get(pre[i],x);
}
}
int getans(int x)
{
now=0;
get(x,0);
return t[0]*t[0]+t[1]*t[2]*2;
}
void dfs(int x)
{
vis[x]=1;dis[x]=0;
ans+=getans(x);
t[0]=t[1]=t[2]=0;
for(int i=h[x];i;i=nxt[i])
if(!vis[pre[i]])
{
dis[pre[i]]=v[i];
ans-=getans(pre[i]);
t[0]=t[1]=t[2]=0;
rt=0;sum=size[pre[i]];
getroot(pre[i],0);
dfs(rt);
}
}
int gcd(int a,int b){return !b?a:gcd(b,a%b);}
int main()
{
read(n);
for(int i=1,x,y,z;i<n;i++)
read(x),read(y),read(z),add(x,y,z);
sum=n;son[0]=n;
getroot(1,0);
dfs(rt);
int mod=gcd(ans,n*n);
printf("%d/%d",ans/mod,n*n/mod);
}