题解 T2

传送门

一个暴力是枚举前三个点,bitset 确定最后一个

image

  • 树上连通性/链相交一类问题记得试试 边-点=1 的容斥
    实现时的一个技巧是边化点

那就枚举钦定 A 中的一个边/点删去,将形成的几个连通块分别染色
在第二棵树上同样用 边-点 容斥算出合法四元划分数
直接计算就可以了
用 vector 会获得 TLE40 的好成绩哦
WA40?怀疑自己模数写错了?这题 tm 不取模我调了一年才发现
复杂度 \(O(n^2)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 4010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int col[1010];
ll fac[N], inv[N], C[1010][1010], ans;
// const ll mod=998244353;
// inline ll C(int n, int k) {return n<k?0:fac[n]*inv[k]%mod*inv[n-k]%mod;}
#define C(n, k) C[n][k]

namespace tr2{
	int head[N], ecnt;
	struct vec{
		int a[4];
		vec() {memset(a, 0, sizeof(a));}
		inline void clear() {memset(a, 0, sizeof(a));}
		inline int& operator [] (int t) {return a[t];}
	}bkp[N];
	// vector<int> bkp[N];
	struct edge{int to, next;}e[N<<1];
	inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
	void dfs1(int u, int fa) {
		// cout<<"u: "<<u<<endl;
		// vector<int> s(4);
		vec s;
		if (u<=n) {s[col[u]]=1; bkp[u]=s; return ;}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs1(v, u);
			for (int j=1; j<=3; ++j) s[j]+=bkp[v][j];
		}
		bkp[u]=s;
	}
	// void solve(int u) {
	// 	cout<<"solve: "<<u<<endl;
	// 	int tot=0;
	// 	ll f[4], sum=0;
	// 	vector<int> col[4];
	// 	memset(f, 0, sizeof(f));
	// 	for (int i=head[u]; ~i; i=e[i].next)
	// 		col[++tot]=dfs1(e[i].to, u);
	// 	// cout<<"tot: "<<tot<<endl;
	// 	for (int i=1; i<=3; ++i) col[i].resize(4);
	// 	cout<<"---col---"<<endl;
	// 	for (int i=1; i<=3; ++i) {for (int j=1; j<=3; ++j) cout<<col[i][j]<<' '; cout<<endl;}
	// 	for (int i=1; i<=3; ++i)
	// 		for (int j=1; j<=3; ++j)
	// 			f[i]=(f[i]+C(col[i][j], 2))%mod; //, cout<<col[i][j]<<endl;
	// 	cout<<"f: "; for (int i=1; i<=3; ++i) cout<<f[i]<<' '; cout<<endl;
	// 	for (int i=1; i<=3; ++i) {
	// 		for (int j=i+1; j<=3; ++j) {
	// 			sum=(sum+f[i]*f[j])%mod;
	// 			for (int k=1; k<=3; ++k)
	// 				sum=(sum-C(col[i][k], 2)*C(col[j][k], 2))%mod;
	// 		}
	// 	}
	// 	cout<<"sum: "<<sum<<endl;
	// 	ans=(ans+(u<=n?-1:1)*sum)%mod;
	// }
	ll dfs2(int u, int fa, vec g) {
		int tot=0;
		ll f[4], sum=0;
		vec col[4];
		memset(f, 0, sizeof(f));
		if (fa) ++tot, col[1]=g;
		for (int i=head[u]; ~i; i=e[i].next)
			if (e[i].to!=fa) col[++tot]=bkp[e[i].to];
		// cout<<"tot: "<<tot<<endl;
		// for (int i=1; i<=3; ++i) col[i].resize(4);
		// cout<<"---col---"<<endl;
		// for (int i=1; i<=3; ++i) {for (int j=1; j<=3; ++j) cout<<col[i][j]<<' '; cout<<endl;}
		for (int i=1; i<=3; ++i)
			for (int j=1; j<=3; ++j)
				f[i]=(f[i]+C(col[i][j], 2)); //, cout<<col[i][j]<<endl;
		// cout<<"f: "; for (int i=1; i<=3; ++i) cout<<f[i]<<' '; cout<<endl;
		for (int i=1; i<=3; ++i) if (f[i]) {
			for (int j=i+1; j<=3; ++j) if (f[j]) {
				sum=(sum+f[i]*f[j]);
				for (int k=1; k<=3; ++k)
					sum=(sum-C(col[i][k], 2)*C(col[j][k], 2));
			}
		}
		if (u<=2*n-2) sum*=-1;
		// cout<<"u: "<<u<<' '<<sum<<endl;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			vec t=g;
			// assert(t.size()==4);
			for (int j=head[u]; ~j; j=e[j].next) if (e[j].to!=fa && e[j].to!=v) {
				// assert(bkp[e[j].to].size()==4);
				for (int k=1; k<=3; ++k)
					t[k]+=bkp[e[j].to][k];
			}
			sum=(sum+dfs2(v, u, t));
		}
		return sum;
	}
	void solve(ll k) {
		dfs1(n+1, 0);
		ll sum=dfs2(n+1, 0, vec());
		// cout<<"sum: "<<sum<<endl;
		ans=(ans+sum*k);
	}
}

namespace tr1{
	int head[N], ecnt;
	struct edge{int to, next;}e[N<<1];
	inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
	void dfs1(int u, int fa, int c) {
		if (u<=n) {col[u]=c; return ;}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs1(v, u, c);
		}
	}
	void solve(int u) {
		// cout<<"solve: "<<u<<endl;
		memset(col, 0, sizeof(col));
		int tot=0;
		for (int i=head[u]; ~i; i=e[i].next)
			dfs1(e[i].to, u, ++tot);
		// cout<<"col: "; for (int i=1; i<=n; ++i) cout<<col[i]<<' '; cout<<endl;
		tr2::solve(u<=2*n-2 ? -1 : 1);
	}
}

signed main()
{
	freopen("b.in", "r", stdin);
	freopen("b.out", "w", stdout);

	n=read();
	memset(tr1::head, -1, sizeof(tr1::head));
	memset(tr2::head, -1, sizeof(tr2::head));
	// fac[0]=fac[1]=1; inv[0]=inv[1]=1;
	// for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
	// for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	// for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
	for (int i=0; C[i][0]=1,i<=n; ++i) for (int j=1; j<=i; ++j) C[i][j]=C[i-1][j-1]+C[i-1][j];
	for (int i=1,u,v,id; i<=2*n-3; ++i) {
		u=read(); v=read(); id=2*n-2+i;
		tr1::add(u, id); tr1::add(id, u);
		tr1::add(id, v); tr1::add(v, id);
	}
	for (int i=1,u,v,id; i<=2*n-3; ++i) {
		u=read(); v=read(); id=2*n-2+i;
		tr2::add(u, id); tr2::add(id, u);
		tr2::add(id, v); tr2::add(v, id);
	}
	for (int i=n+1; i<=2*n-2; ++i) tr1::solve(i);
	for (int i=1; i<=2*n-3; ++i) tr1::solve(2*n-2+i);
	// cout<<(ans%mod+mod)%mod<<endl;
	cout<<2ll*n*(n-1)*(n-2)*(n-3)/24-2*ans<<endl;
	
	return 0;
}
posted @ 2022-03-12 21:18  Administrator-09  阅读(3)  评论(0编辑  收藏  举报