题解 T2
一个暴力是枚举前三个点,bitset 确定最后一个
- 树上连通性/链相交一类问题记得试试 边-点=1 的容斥
实现时的一个技巧是边化点
那就枚举钦定 A 中的一个边/点删去,将形成的几个连通块分别染色
在第二棵树上同样用 边-点 容斥算出合法四元划分数
直接计算就可以了
用 vector 会获得 TLE40 的好成绩哦
WA40?怀疑自己模数写错了?这题 tm 不取模我调了一年才发现
复杂度 \(O(n^2)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 4010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int col[1010];
ll fac[N], inv[N], C[1010][1010], ans;
// const ll mod=998244353;
// inline ll C(int n, int k) {return n<k?0:fac[n]*inv[k]%mod*inv[n-k]%mod;}
#define C(n, k) C[n][k]
namespace tr2{
int head[N], ecnt;
struct vec{
int a[4];
vec() {memset(a, 0, sizeof(a));}
inline void clear() {memset(a, 0, sizeof(a));}
inline int& operator [] (int t) {return a[t];}
}bkp[N];
// vector<int> bkp[N];
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void dfs1(int u, int fa) {
// cout<<"u: "<<u<<endl;
// vector<int> s(4);
vec s;
if (u<=n) {s[col[u]]=1; bkp[u]=s; return ;}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs1(v, u);
for (int j=1; j<=3; ++j) s[j]+=bkp[v][j];
}
bkp[u]=s;
}
// void solve(int u) {
// cout<<"solve: "<<u<<endl;
// int tot=0;
// ll f[4], sum=0;
// vector<int> col[4];
// memset(f, 0, sizeof(f));
// for (int i=head[u]; ~i; i=e[i].next)
// col[++tot]=dfs1(e[i].to, u);
// // cout<<"tot: "<<tot<<endl;
// for (int i=1; i<=3; ++i) col[i].resize(4);
// cout<<"---col---"<<endl;
// for (int i=1; i<=3; ++i) {for (int j=1; j<=3; ++j) cout<<col[i][j]<<' '; cout<<endl;}
// for (int i=1; i<=3; ++i)
// for (int j=1; j<=3; ++j)
// f[i]=(f[i]+C(col[i][j], 2))%mod; //, cout<<col[i][j]<<endl;
// cout<<"f: "; for (int i=1; i<=3; ++i) cout<<f[i]<<' '; cout<<endl;
// for (int i=1; i<=3; ++i) {
// for (int j=i+1; j<=3; ++j) {
// sum=(sum+f[i]*f[j])%mod;
// for (int k=1; k<=3; ++k)
// sum=(sum-C(col[i][k], 2)*C(col[j][k], 2))%mod;
// }
// }
// cout<<"sum: "<<sum<<endl;
// ans=(ans+(u<=n?-1:1)*sum)%mod;
// }
ll dfs2(int u, int fa, vec g) {
int tot=0;
ll f[4], sum=0;
vec col[4];
memset(f, 0, sizeof(f));
if (fa) ++tot, col[1]=g;
for (int i=head[u]; ~i; i=e[i].next)
if (e[i].to!=fa) col[++tot]=bkp[e[i].to];
// cout<<"tot: "<<tot<<endl;
// for (int i=1; i<=3; ++i) col[i].resize(4);
// cout<<"---col---"<<endl;
// for (int i=1; i<=3; ++i) {for (int j=1; j<=3; ++j) cout<<col[i][j]<<' '; cout<<endl;}
for (int i=1; i<=3; ++i)
for (int j=1; j<=3; ++j)
f[i]=(f[i]+C(col[i][j], 2)); //, cout<<col[i][j]<<endl;
// cout<<"f: "; for (int i=1; i<=3; ++i) cout<<f[i]<<' '; cout<<endl;
for (int i=1; i<=3; ++i) if (f[i]) {
for (int j=i+1; j<=3; ++j) if (f[j]) {
sum=(sum+f[i]*f[j]);
for (int k=1; k<=3; ++k)
sum=(sum-C(col[i][k], 2)*C(col[j][k], 2));
}
}
if (u<=2*n-2) sum*=-1;
// cout<<"u: "<<u<<' '<<sum<<endl;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
vec t=g;
// assert(t.size()==4);
for (int j=head[u]; ~j; j=e[j].next) if (e[j].to!=fa && e[j].to!=v) {
// assert(bkp[e[j].to].size()==4);
for (int k=1; k<=3; ++k)
t[k]+=bkp[e[j].to][k];
}
sum=(sum+dfs2(v, u, t));
}
return sum;
}
void solve(ll k) {
dfs1(n+1, 0);
ll sum=dfs2(n+1, 0, vec());
// cout<<"sum: "<<sum<<endl;
ans=(ans+sum*k);
}
}
namespace tr1{
int head[N], ecnt;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void dfs1(int u, int fa, int c) {
if (u<=n) {col[u]=c; return ;}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs1(v, u, c);
}
}
void solve(int u) {
// cout<<"solve: "<<u<<endl;
memset(col, 0, sizeof(col));
int tot=0;
for (int i=head[u]; ~i; i=e[i].next)
dfs1(e[i].to, u, ++tot);
// cout<<"col: "; for (int i=1; i<=n; ++i) cout<<col[i]<<' '; cout<<endl;
tr2::solve(u<=2*n-2 ? -1 : 1);
}
}
signed main()
{
freopen("b.in", "r", stdin);
freopen("b.out", "w", stdout);
n=read();
memset(tr1::head, -1, sizeof(tr1::head));
memset(tr2::head, -1, sizeof(tr2::head));
// fac[0]=fac[1]=1; inv[0]=inv[1]=1;
// for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
// for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
// for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=0; C[i][0]=1,i<=n; ++i) for (int j=1; j<=i; ++j) C[i][j]=C[i-1][j-1]+C[i-1][j];
for (int i=1,u,v,id; i<=2*n-3; ++i) {
u=read(); v=read(); id=2*n-2+i;
tr1::add(u, id); tr1::add(id, u);
tr1::add(id, v); tr1::add(v, id);
}
for (int i=1,u,v,id; i<=2*n-3; ++i) {
u=read(); v=read(); id=2*n-2+i;
tr2::add(u, id); tr2::add(id, u);
tr2::add(id, v); tr2::add(v, id);
}
for (int i=n+1; i<=2*n-2; ++i) tr1::solve(i);
for (int i=1; i<=2*n-3; ++i) tr1::solve(2*n-2+i);
// cout<<(ans%mod+mod)%mod<<endl;
cout<<2ll*n*(n-1)*(n-2)*(n-3)/24-2*ans<<endl;
return 0;
}