【CF833D】Red-Black Cobweb(点分治)
【CF833D】Red-Black Cobweb(点分治)
题面
有一棵树,每条边有一个颜色(黑白)和一个权值,定义一条路径是好的,当且仅当这条路径上所有边的黑白颜色个数a,b满足2min(a,b)>=max(a,b),一条路径的权值为路径上所有边的权值的乘积,求所有好的路径的权值乘积.
\(n<=10^5\)
题解
首先看到求所有路径相关的内容,不难想到点分治。
两个限制可以转化为需要同时满足:\(2a\ge b,2b\ge a\)。
对于两条路径\(a1,b1/a2,b2\)考虑如何合并。
需要满足的两个条件就变成了\(2(a1+a2)\ge b1+b2\)以及\(2(b1+b2)\ge a1+a2\)
再稍微拆开看看就变成了\(2a1-b1\ge b2-2a2\),另一个类似。
这里怎么计算总的方案数,那么就用总数减去不合法的,如果不合法显然只会有一个不等式不合法(因为另外一个不等式是由最大值的两倍大于较小值得到的,它无论如何都会是对的),那么只需要统计有一个不合法的所有链就好了。
#include<iostream>
#include<cstdio>
using namespace std;
#define MAX 100100
#define MOD 1000000007
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int fpow(int a,int b){int s=1;while(b){if(b&1)s=1ll*s*a%MOD;a=1ll*a*a%MOD;b>>=1;}return s;}
struct Line{int v,next,w,c;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w,int c){e[cnt]=(Line){v,h[u],w,c};h[u]=cnt++;}
int Size,mx,rt,size[MAX];bool vis[MAX];
int n,N,ans,ans1=1,ans2=1;
int lb(int x){return x&(-x);}
struct BIT
{
int c1[MAX<<3],c2[MAX<<3];
void pre(){for(int i=1;i<=N;++i)c1[i]=1,c2[i]=0;}
void Modify(int x,int w){while(x<=N)c1[x]=1ll*c1[x]*w%MOD,c2[x]+=1,x+=lb(x);}
void Clear(int x){while(x<=N)c1[x]=1,c2[x]=0,x+=lb(x);}
int Querys(int x){int s=1;while(x)s=1ll*s*c1[x]%MOD,x-=lb(x);return s;}
int Queryt(int x){int s=0;while(x)s+=c2[x],x-=lb(x);return s;}
int Querys(int l,int r){return 1ll*Querys(r)*fpow(Querys(l-1),MOD-2)%MOD;}
int Queryt(int l,int r){return Queryt(r)-Queryt(l-1);}
}c1,c2;
void Getroot(int u,int ff)
{
int ret=0;size[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(v==ff||vis[v])continue;
Getroot(v,u);size[u]+=size[v];ret=max(ret,size[v]);
}
ret=max(ret,Size-size[u]);
if(ret<mx)mx=ret,rt=u;
}
struct Pair{int a,b,w;}S[MAX],T[MAX];
int top,sum,W,SW,py;
void dfs(int u,int ff,int a,int b,int w)
{
T[++top]=(Pair){a,b,w};W=1ll*w*W%MOD;
for(int i=h[u];i;i=e[i].next)
if(e[i].v!=ff&&!vis[e[i].v])
dfs(e[i].v,u,a+(e[i].c^1),b+e[i].c,1ll*w*e[i].w%MOD);
}
void Divide(int u)
{
vis[u]=true;sum=0;SW=1;S[++sum]=(Pair){0,0,1};
c1.Modify(py,1);c2.Modify(py,1);
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v])continue;
top=0;W=1;dfs(e[i].v,u,e[i].c^1,e[i].c,e[i].w);
ans1=1ll*ans1*fpow(W,sum)%MOD*fpow(SW,top)%MOD;
SW=1ll*SW*W%MOD;
for(int j=1;j<=top;++j)
{
int A=T[j].b-2*T[j].a-1+py,B=T[j].a-2*T[j].b-1+py;
ans2=1ll*ans2*c1.Querys(2,A)%MOD*fpow(T[j].w,c1.Queryt(2,A))%MOD;
ans2=1ll*ans2*c2.Querys(2,B)%MOD*fpow(T[j].w,c2.Queryt(2,B))%MOD;
}
for(int j=1;j<=top;++j)
{
int A=2*T[j].a-T[j].b+py;c1.Modify(A,T[j].w);
int B=2*T[j].b-T[j].a+py;c2.Modify(B,T[j].w);
S[++sum]=T[j];
}
}
for(int j=1;j<=sum;++j)
{
int A=2*S[j].a-S[j].b+py;c1.Clear(A);
int B=2*S[j].b-S[j].a+py;c2.Clear(B);
}
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(vis[v])continue;
Size=mx=size[v];Getroot(v,u);
Divide(rt);
}
}
int main()
{
n=read();py=n+n+2;N=5*n;c1.pre();c2.pre();
for(int i=1,u,v,w,c;i<n;++i)
u=read(),v=read(),w=read(),c=read(),Add(u,v,w,c),Add(v,u,w,c);
Size=mx=n;Getroot(1,0);Divide(rt);
ans=1ll*ans1*fpow(ans2,MOD-2)%MOD;
printf("%d\n",ans);
return 0;
}