bzoj3162 独钓寒江雪
为了找一个合适的hash函数爆OJ23333.
网上其他题解已经很详细了,我记录一个自己感觉有点妙的细节:我们需要用方案数求组合数,但是方案数是对mod=10^9+7取过模的,这样真的对吗?
注意C(n,m)%mod=C(n%mod,m%mod)*C(n/mod,m/mod)%mod,而本题中的m都小于mod,所以C(n/mod,m/mod)一定是C(n/mod,0)=1,所以把方案数取个模不影响组合数的计算.
如果取模的数字小一些,比如对2333取模,这样直接做可能就会WA了.
#include<cstdio>
#include<algorithm>
#include<vector>
#include<map>
using namespace std;
typedef unsigned long long ll;
map<ll,int> f[2];
const int mod=1000000007;
const int maxn=500005;
int inv[maxn];
void init(){
inv[0]=1;
inv[1]=1;
for(int i=2;i<maxn;++i)inv[i]=inv[mod%i]*1ll*(mod-mod/i)%mod;
for(int i=1;i<maxn;++i)inv[i]=inv[i-1]*1ll*inv[i]%mod;
}
int C(int n,int m){
int ans=1;
for(int i=0;i<m;++i)ans=ans*1ll*(n-i)%mod;
ans=ans*1ll*inv[m]%mod;
return ans;
}
struct edge{
int to,next;
}lst[maxn<<1];int len=1,first[maxn];
void addedge(int a,int b){
lst[len].to=b;lst[len].next=first[a];first[a]=len++;
}
int n;
int cnt=0,ct[3];
int sz[maxn],g[maxn];
void get_root(int x,int p){
sz[x]=1;
for(int pt=first[x];pt;pt=lst[pt].next){
if(lst[pt].to==p)continue;
get_root(lst[pt].to,x);
sz[x]+=sz[lst[pt].to];
if(sz[lst[pt].to]>g[x])g[x]=sz[lst[pt].to];
}
if(n-sz[x]>g[x])g[x]=n-sz[x];
if(g[x]<=n/2)ct[cnt++]=x;
}
ll HA[maxn];
vector<ll> tmp[maxn];
void getHa(int x,int p){
for(int pt=first[x];pt;pt=lst[pt].next){
if(lst[pt].to==p)continue;
else{
getHa(lst[pt].to,x);
tmp[x].push_back(HA[lst[pt].to]);
}
}
HA[x]=19;
if(tmp[x].empty())return;
sort(tmp[x].begin(),tmp[x].end());
for(vector<ll>::iterator pt=tmp[x].begin();pt!=tmp[x].end();++pt){
HA[x]=((HA[x]*12312123123+(*pt))^(*pt))+(*pt)+2333;
}
}
void dp(int x,int p){
if(f[0][HA[x]])return;
for(int pt=first[x];pt;pt=lst[pt].next){
if(lst[pt].to==p)continue;
dp(lst[pt].to,x);
}
int ans1=1,ans0=1;
ll old=-1;int cnt=0;
int sz=tmp[x].size();
for(int i=0;i<sz;++i){
cnt++;
if(i==sz-1||tmp[x][i]!=tmp[x][i+1]){
ans1=ans1*1ll*C((f[0][tmp[x][i]]+cnt-1)%mod,cnt)%mod;
ans0=ans0*1ll*C((f[0][tmp[x][i]]+f[1][tmp[x][i]]+cnt-1)%mod,cnt)%mod;
cnt=0;
}
}
f[0][HA[x]]=ans0;f[1][HA[x]]=ans1;
}
int main(){
init();
scanf("%d",&n);
for(int i=1,a,b;i<n;++i){
scanf("%d%d",&a,&b);addedge(a,b);addedge(b,a);
}
get_root(1,0);
if(cnt==2){
getHa(ct[0],ct[1]);getHa(ct[1],ct[0]);
dp(ct[0],ct[1]);dp(ct[1],ct[0]);
int ans=(f[0][HA[ct[0]]]*1ll*f[1][HA[ct[1]]]%mod)%mod;
if(HA[ct[0]]!=HA[ct[1]]){
ans=(ans+f[1][HA[ct[0]]]*1ll*f[0][HA[ct[1]]]%mod)%mod;
ans=(ans+f[0][HA[ct[0]]]*1ll*f[0][HA[ct[1]]]%mod)%mod;
}else{
ans=(ans+(f[0][HA[ct[0]]]*1ll*f[0][HA[ct[1]]]%mod+f[0][HA[ct[0]]]+mod)%mod*1ll*inv[2]%mod)%mod;
}
printf("%d\n",ans);
}else{
getHa(ct[0],0);
dp(ct[0],0);
printf("%d\n",(f[0][HA[ct[0]]]+f[1][HA[ct[0]]])%mod);
}
return 0;
}