[BZOJ3162]独钓寒江雪
description
你要给一个树上的每个点黑白染色,要求白点不相邻。求本质不同的染色方案数。
两种染色方案本质相同当且仅当对树重新标号后对应节点的颜色相同。
\(n\le 5\times10^5\)
sol
首先考虑没有本质相同那个限制怎么做。
直接设\(f_{i,0/1}\)表示\(i\)点染成黑色/白色时子树内的方案数。
转移很简单:\(f_{i,0}=\prod_j (f_{j,0}+f_{j,1}),f_{i,1}=\prod_j f_{j,0}\)。
先在问题在于本质不同。那么如果重新标号之后同构的话方案数就会多算。
考虑重新标号后重心不会变,于是以重心为根处理子树。如果有两个重心就新建一个点连接这两个点,在输出方案的时候讨论一下即可。
在\(dp\)的时候,对于一个点\(i\)的若干个同构的子树,应该要一起计算贡献,设这种子树染色的方案数是\(x\)(就是\(dp\)值),这样的子树一共有\(k\)棵,那么这就是一个可重组合,方案数为\(\binom{x+k-1}{k}\)。
虽然\(x\)可能会很大,但是显然\(k\)是\(O(n)\)的,所以组合数暴力计算即可。
树\(Hash\)要写对啊qwq。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi(){
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
#define ull unsigned long long
const int N = 5e5+5;
const int mod = 1e9+7;
const ull base1 = 20020415;
const ull base2 = 20011118;
int n,inv[N],to[N<<1],nxt[N<<1],head[N],cnt,sz[N],w[N],root,rt1,rt2,fg,f[2][N],tmp[N];
ull hsh[N];
void link(int u,int v){
to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;
}
void getroot(int u,int fa){
sz[u]=1;w[u]=0;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=fa){
getroot(to[e],u);sz[u]+=sz[to[e]];
w[u]=max(w[u],sz[to[e]]);
}
w[u]=max(w[u],n-sz[u]);
if (w[u]<w[root]) root=u;
}
int C(int n,int m){
int res=1;
for (int i=n-m+1;i<=n;++i) res=1ll*res*i%mod;
for (int i=1;i<=m;++i) res=1ll*res*inv[i]%mod;
return res;
}
bool cmp(int i,int j){return hsh[i]<hsh[j];}
void dfs(int u,int fa){
sz[u]=f[0][u]=f[1][u]=1;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=fa) dfs(to[e],u),sz[u]+=sz[to[e]];
int len=0;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=fa) tmp[++len]=to[e];
sort(tmp+1,tmp+len+1,cmp);
for (int i=1,j=1;i<=len;i=j){
while (j<=len&&hsh[tmp[j]]==hsh[tmp[i]]) ++j;
f[0][u]=1ll*f[0][u]*C(f[0][tmp[i]]+f[1][tmp[i]]+j-i-1,j-i)%mod;
f[1][u]=1ll*f[1][u]*C(f[0][tmp[i]]+j-i-1,j-i)%mod;
}
hsh[u]=base2*len+sz[u];
for (int i=1;i<=len;++i)
hsh[u]=(hsh[u]*base1)^hsh[tmp[i]];
}
int main(){
n=gi();inv[0]=inv[1]=1;
for (int i=2;i<=n;++i) inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
for (int i=1;i<n;++i){
int u=gi(),v=gi();
link(u,v),link(v,u);
}
w[0]=n;getroot(1,0);getroot(root,0);
for (int e=head[root],lst=0;e;lst=e,e=nxt[e])
if (sz[to[e]]*2==n){
++n;
if (e==head[root]) head[root]=nxt[e];
else nxt[lst]=nxt[e];
for (int i=head[to[e]],Lst=0;i;Lst=i,i=nxt[i])
if (to[i]==root){
if (i==head[to[e]]) head[to[e]]=nxt[i];
else nxt[Lst]=nxt[i];
break;
}
link(n,root);link(root,n);link(n,to[e]);link(to[e],n);
rt1=root;rt2=to[e];root=n;fg=1;break;
}
dfs(root,0);
if (!fg) printf("%d\n",(f[0][root]+f[1][root])%mod);
else if (hsh[rt1]==hsh[rt2]) printf("%d\n",(f[0][root]-C(f[1][rt1]+1,2)+mod)%mod);
else printf("%d\n",(1ll*f[0][rt1]*f[0][rt2]+1ll*f[0][rt1]*f[1][rt2]+1ll*f[1][rt1]*f[0][rt2])%mod);
return 0;
}