题解 生成树

传送门

首先可以确定是矩阵树定理
发现我们可以钦定每种颜色对应一个边权
于是矩阵树算出来的东西是关于每种颜色的边数的二维多项式
\(\sum r_{a, b}X^aY^b\)\(X, Y\) 是两种边的数量
于是可以多带几个值进去消元得到系数 \(r_{a, b}\),答案即为 \(\sum r_{a, b}\)
但是消元的复杂度不太对,可以考虑插值
于是二维插值与一维插值的原理是类似的
但我并不知道为什么需要代进去 \(n^2\) 个点虽然看起来好像显然
公式是 \(y=\sum\limits_{i=1}^n y_i\prod\limits_{x_i\neq x_j}\frac{x-x_i}{x_i-x_j}\prod\limits_{y_i\neq y_j}\frac{y-y_i}{y_i-y_j}\)
实际写的话条件要写成 if (p[i].x!=p[j].x && p[i].y==p[j].y)问原因被战神赶回来了
然后就模拟多项式乘法就好了
复杂度 \(O(n^5)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define fir first
#define sec second
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m, g, b;
int s[N], t[N], c[N];
ll inv[N];
const ll mod=1e9+7;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}

namespace force{
	ll ans;
	int dsu[N];
	inline int find(int p) {return dsu[p]==p?p:find(dsu[p]);}
	void dfs(int u, int g, int b, int cnt) {
		if (cnt==n-1) {++ans; return ;}
		if (u>m) return ;
		int f1=find(s[u]), f2=find(t[u]);
		dfs(u+1, g, b, cnt);
		if (f1==f2 || (!g&&c[u]==2) || (!b&&c[u]==3)) return ;
		int tem=dsu[f1];
		dsu[f1]=f2;
		if (c[u]==2) --g;
		else if (c[u]==3) --b;
		dfs(u+1, g, b, cnt+1);
		dsu[f1]=tem;
	}
	void solve() {
		for (int i=1; i<=m; ++i) {s[i]=read(); t[i]=read(); c[i]=read();}
		for (int i=1; i<=n; ++i) dsu[i]=i;
		dfs(1, g, b, 0);
		printf("%lld\n", ans%mod);
	}
}

namespace task1{
	bool vis[1<<10][101][101];
	ll dp[1<<10][101][101], ans;
	int head[N], size;
	struct edge{int to, next, val;}e[N<<1];
	struct sit{int s, g, b; sit(){} sit(int x, int y, int z):s(x),g(y),b(z){}};
	queue<sit> q;
	inline void add(int s, int t, int w) {e[++size]={t, head[s], w}; head[s]=size;}
	void solve() {
		memset(head, -1, sizeof(head));
		for (int i=1,u,v,w; i<=m; ++i) {
			u=read()-1; v=read()-1; w=read();
			add(u, v, w); add(v, u, w);
		}
		inv[0]=inv[1]=1;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;

		int lim=1<<n;
		sit u;
		for (int i=0; i<n; ++i) q.push(sit(1<<i, g, b)), dp[1<<i][g][b]=1, vis[1<<i][g][b]=1;
		while (q.size()) {
			u=q.front(); q.pop();
			if (u.s==lim-1) {ans=(ans+dp[u.s][u.g][u.b])%mod; continue;}
			for (int i=0; i<n; ++i) if (u.s&(1<<i)) {
				for (int j=head[i],v; ~j; j=e[j].next) {
					v = e[j].to;
					if (u.s&(1<<v) || (u.g==0&&e[j].val==2) || (u.b==0&&e[j].val==3)) continue;
					md(dp[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)], dp[u.s][u.g][u.b]);
					if (!vis[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)]) {
						vis[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)]=1;
						q.push(sit(u.s|(1<<v), u.g-(e[j].val==2), u.b-(e[j].val==3)));
					}
				}
			}
		}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	int top;
	ll ans;
	inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
	struct matrix{
		int n, m;
		ll a[110][110];
		matrix(){}
		matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
		inline void resize(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
		inline ll* operator [] (int t) {return a[t];}
		ll gauss() {
			ll det=1;
			for (int i=1; i<=n; ++i) {
				for (int j=i+1; j<=n; ++j) {
					while (a[j][i]) {
						ll t=a[i][i]/a[j][i];
						for (int k=i; k<=m; ++k) a[i][k]=((a[i][k]-a[j][k]*t)%mod+mod)%mod;
						det=-det;
						swap(a[i], a[j]);
					}
				}
			}
			for (int i=1; i<=n; ++i) det=det*a[i][i]%mod;
			return det;
		}
	}mat;
	struct point{ll x, y, z; point(){} point(ll a, ll b, ll c):x(a),y(b),z(c){}}p[N];
	struct poly{
		vector<vector<ll>> a;
		poly(){}
		poly(int x, int y) {a.clear(); a.resize(x, vector<ll>(y));}
		inline void set(int x, int y) {a.clear(); a.resize(x, vector<ll>(y));}
		inline int lenx() {return a.size();}
		inline int leny() {return a.size()?a[0].size():0;}
		inline vector<ll>& operator [] (int t) {return a[t];}
		inline void put() {for (int i=0; i<lenx(); ++i) {for (int j=0; j<leny(); ++j) cout<<a[i][j]<<' '; cout<<endl;} cout<<endl;}
		inline poly operator * (poly b) {
			poly ans(lenx()+b.lenx()-1, leny()+b.leny()-1);
			// cout<<"size: "<<lenx()+b.lenx()-1<<' '<<leny()+b.leny()-1<<endl;
			for (int i=0; i<lenx(); ++i)
				for (int j=0; j<leny(); ++j)
					for (int k=0; k<b.lenx(); ++k)
						for (int l=0; l<b.leny(); ++l)
							ans[i+k][j+l]=(ans[i+k][j+l]+a[i][j]*b[k][l])%mod;
			return ans;
		}
		inline poly operator + (poly b) {
			poly ans(max(lenx(), b.lenx()), max(leny(), b.leny()));
			for (int i=0; i<lenx(); ++i)
				for (int j=0; j<leny(); ++j)
					ans[i][j]=a[i][j];
			for (int i=0; i<b.lenx(); ++i)
				for (int j=0; j<b.leny(); ++j)
					ans[i][j]=(ans[i][j]+b[i][j])%mod;
			return ans;
		}
		ll qval(ll x, ll y) {
			ll ans=0;
			for (int i=0; i<lenx(); ++i)
				for (int j=0; j<leny(); ++j)
					ans=(ans+a[i][j]*qpow(x, i)%mod*qpow(y, j))%mod;
			return (ans%mod+mod)%mod;
		}
	}r, f;
	ll calc(ll x, ll y) {
		mat.resize(n-1, n-1);
		for (int i=1; i<=m; ++i) {
			ll val;
			switch (c[i]) {
				case 1: val=1; break;
				case 2: val=x; break;
				case 3: val=y; break;
			}
			mat[s[i]][t[i]]-=val; mat[s[i]][s[i]]+=val;
			mat[t[i]][s[i]]-=val; mat[t[i]][t[i]]+=val;
		}
		return mat.gauss();
	}
	void lagrange() {
		r.set(0, 0);
		poly t;
		for (int i=1; i<=top; ++i) {
			f.set(1, 1); f[0][0]=p[i].z;
			// cout<<f.lenx()<<' '<<f.leny()<<endl;
			// cout<<"val: "<<f.qval(p[i].x, p[i].y)<<endl;
			// assert(f.qval(p[i].x, p[i].y)==p[i].z);
			for (int j=1; j<=top; ++j) {
				if (p[i].x!=p[j].x&&p[i].y==p[j].y) {
					// cout<<"ij: "<<i<<' '<<j<<' '<<p[i].x<<' '<<p[j].x<<endl;
					ll inv=qpow(p[i].x-p[j].x, mod-2);
					t.set(2, 1);
					t[0][0]=-p[j].x*inv%mod;
					t[1][0]=inv;
					// assert(t.qval(p[j].x, p[j].y)==0);
					// cout<<f.qval(p[i].x, p[i].y)<<endl;
					// cout<<"f"<<endl; f.put(); cout<<"t"<<endl; t.put();
					f=f*t;
					// cout<<"f"<<endl;
					// f.put(); cout<<endl;
					// assert(f.qval(p[i].x, p[i].y)==p[i].z);
				}
				if (p[i].y!=p[j].y&&p[i].x==p[j].x) {
					ll inv=qpow(p[i].y-p[j].y, mod-2);
					t.set(1, 2);
					t[0][0]=-p[j].y*inv%mod;
					t[0][1]=inv;
					// assert(f.qval(p[i].x, p[i].y)==p[i].z);
					// assert(t.qval(p[i].x, p[i].y)==1);
					f=f*t;
				}
			}
			// f.put();
			// assert(f.qval(p[i].x, p[i].y)==p[i].z);
			r=r+f;
		}
	}
	void solve() {
		for (int i=1; i<=m; ++i) {s[i]=read(); t[i]=read(); c[i]=read();}
		for (int i=1; i<=n+1; ++i) for (int j=1; j<=n+1; ++j) p[++top]={i, j, calc(i, j)}; //, cout<<"p: "<<i<<' '<<j<<' '<<calc(i, j)<<endl;
		lagrange();
		for (int i=0; i<=min(r.lenx()-1, g); ++i)
			for (int j=0; j<=min(r.leny()-1, b); ++j)
				ans=(ans+r[i][j])%mod;
		// p[++top]={0, 1, 1}; p[++top]={0, 2, 3}; p[++top]={0, 3, 6};
		// lagrange();
		// cout<<r.qval(5, 0)<<endl;
		printf("%lld\n", (ans%mod+mod)%mod);
		exit(0);
	}
}

signed main()
{
	n=read(); m=read(); g=read(); b=read();
	// force::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-01-11 20:49  Administrator-09  阅读(0)  评论(0编辑  收藏  举报