题解 生成树
首先可以确定是矩阵树定理
发现我们可以钦定每种颜色对应一个边权
于是矩阵树算出来的东西是关于每种颜色的边数的二维多项式
\(\sum r_{a, b}X^aY^b\),\(X, Y\) 是两种边的数量
于是可以多带几个值进去消元得到系数 \(r_{a, b}\),答案即为 \(\sum r_{a, b}\)
但是消元的复杂度不太对,可以考虑插值
于是二维插值与一维插值的原理是类似的
但我并不知道为什么需要代进去 \(n^2\) 个点虽然看起来好像显然
公式是 \(y=\sum\limits_{i=1}^n y_i\prod\limits_{x_i\neq x_j}\frac{x-x_i}{x_i-x_j}\prod\limits_{y_i\neq y_j}\frac{y-y_i}{y_i-y_j}\)
实际写的话条件要写成 if (p[i].x!=p[j].x && p[i].y==p[j].y)
,问原因被战神赶回来了
然后就模拟多项式乘法就好了
复杂度 \(O(n^5)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define fir first
#define sec second
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, g, b;
int s[N], t[N], c[N];
ll inv[N];
const ll mod=1e9+7;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
namespace force{
ll ans;
int dsu[N];
inline int find(int p) {return dsu[p]==p?p:find(dsu[p]);}
void dfs(int u, int g, int b, int cnt) {
if (cnt==n-1) {++ans; return ;}
if (u>m) return ;
int f1=find(s[u]), f2=find(t[u]);
dfs(u+1, g, b, cnt);
if (f1==f2 || (!g&&c[u]==2) || (!b&&c[u]==3)) return ;
int tem=dsu[f1];
dsu[f1]=f2;
if (c[u]==2) --g;
else if (c[u]==3) --b;
dfs(u+1, g, b, cnt+1);
dsu[f1]=tem;
}
void solve() {
for (int i=1; i<=m; ++i) {s[i]=read(); t[i]=read(); c[i]=read();}
for (int i=1; i<=n; ++i) dsu[i]=i;
dfs(1, g, b, 0);
printf("%lld\n", ans%mod);
}
}
namespace task1{
bool vis[1<<10][101][101];
ll dp[1<<10][101][101], ans;
int head[N], size;
struct edge{int to, next, val;}e[N<<1];
struct sit{int s, g, b; sit(){} sit(int x, int y, int z):s(x),g(y),b(z){}};
queue<sit> q;
inline void add(int s, int t, int w) {e[++size]={t, head[s], w}; head[s]=size;}
void solve() {
memset(head, -1, sizeof(head));
for (int i=1,u,v,w; i<=m; ++i) {
u=read()-1; v=read()-1; w=read();
add(u, v, w); add(v, u, w);
}
inv[0]=inv[1]=1;
for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
int lim=1<<n;
sit u;
for (int i=0; i<n; ++i) q.push(sit(1<<i, g, b)), dp[1<<i][g][b]=1, vis[1<<i][g][b]=1;
while (q.size()) {
u=q.front(); q.pop();
if (u.s==lim-1) {ans=(ans+dp[u.s][u.g][u.b])%mod; continue;}
for (int i=0; i<n; ++i) if (u.s&(1<<i)) {
for (int j=head[i],v; ~j; j=e[j].next) {
v = e[j].to;
if (u.s&(1<<v) || (u.g==0&&e[j].val==2) || (u.b==0&&e[j].val==3)) continue;
md(dp[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)], dp[u.s][u.g][u.b]);
if (!vis[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)]) {
vis[u.s|(1<<v)][u.g-(e[j].val==2)][u.b-(e[j].val==3)]=1;
q.push(sit(u.s|(1<<v), u.g-(e[j].val==2), u.b-(e[j].val==3)));
}
}
}
}
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
int top;
ll ans;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
struct matrix{
int n, m;
ll a[110][110];
matrix(){}
matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
inline void resize(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
inline ll* operator [] (int t) {return a[t];}
ll gauss() {
ll det=1;
for (int i=1; i<=n; ++i) {
for (int j=i+1; j<=n; ++j) {
while (a[j][i]) {
ll t=a[i][i]/a[j][i];
for (int k=i; k<=m; ++k) a[i][k]=((a[i][k]-a[j][k]*t)%mod+mod)%mod;
det=-det;
swap(a[i], a[j]);
}
}
}
for (int i=1; i<=n; ++i) det=det*a[i][i]%mod;
return det;
}
}mat;
struct point{ll x, y, z; point(){} point(ll a, ll b, ll c):x(a),y(b),z(c){}}p[N];
struct poly{
vector<vector<ll>> a;
poly(){}
poly(int x, int y) {a.clear(); a.resize(x, vector<ll>(y));}
inline void set(int x, int y) {a.clear(); a.resize(x, vector<ll>(y));}
inline int lenx() {return a.size();}
inline int leny() {return a.size()?a[0].size():0;}
inline vector<ll>& operator [] (int t) {return a[t];}
inline void put() {for (int i=0; i<lenx(); ++i) {for (int j=0; j<leny(); ++j) cout<<a[i][j]<<' '; cout<<endl;} cout<<endl;}
inline poly operator * (poly b) {
poly ans(lenx()+b.lenx()-1, leny()+b.leny()-1);
// cout<<"size: "<<lenx()+b.lenx()-1<<' '<<leny()+b.leny()-1<<endl;
for (int i=0; i<lenx(); ++i)
for (int j=0; j<leny(); ++j)
for (int k=0; k<b.lenx(); ++k)
for (int l=0; l<b.leny(); ++l)
ans[i+k][j+l]=(ans[i+k][j+l]+a[i][j]*b[k][l])%mod;
return ans;
}
inline poly operator + (poly b) {
poly ans(max(lenx(), b.lenx()), max(leny(), b.leny()));
for (int i=0; i<lenx(); ++i)
for (int j=0; j<leny(); ++j)
ans[i][j]=a[i][j];
for (int i=0; i<b.lenx(); ++i)
for (int j=0; j<b.leny(); ++j)
ans[i][j]=(ans[i][j]+b[i][j])%mod;
return ans;
}
ll qval(ll x, ll y) {
ll ans=0;
for (int i=0; i<lenx(); ++i)
for (int j=0; j<leny(); ++j)
ans=(ans+a[i][j]*qpow(x, i)%mod*qpow(y, j))%mod;
return (ans%mod+mod)%mod;
}
}r, f;
ll calc(ll x, ll y) {
mat.resize(n-1, n-1);
for (int i=1; i<=m; ++i) {
ll val;
switch (c[i]) {
case 1: val=1; break;
case 2: val=x; break;
case 3: val=y; break;
}
mat[s[i]][t[i]]-=val; mat[s[i]][s[i]]+=val;
mat[t[i]][s[i]]-=val; mat[t[i]][t[i]]+=val;
}
return mat.gauss();
}
void lagrange() {
r.set(0, 0);
poly t;
for (int i=1; i<=top; ++i) {
f.set(1, 1); f[0][0]=p[i].z;
// cout<<f.lenx()<<' '<<f.leny()<<endl;
// cout<<"val: "<<f.qval(p[i].x, p[i].y)<<endl;
// assert(f.qval(p[i].x, p[i].y)==p[i].z);
for (int j=1; j<=top; ++j) {
if (p[i].x!=p[j].x&&p[i].y==p[j].y) {
// cout<<"ij: "<<i<<' '<<j<<' '<<p[i].x<<' '<<p[j].x<<endl;
ll inv=qpow(p[i].x-p[j].x, mod-2);
t.set(2, 1);
t[0][0]=-p[j].x*inv%mod;
t[1][0]=inv;
// assert(t.qval(p[j].x, p[j].y)==0);
// cout<<f.qval(p[i].x, p[i].y)<<endl;
// cout<<"f"<<endl; f.put(); cout<<"t"<<endl; t.put();
f=f*t;
// cout<<"f"<<endl;
// f.put(); cout<<endl;
// assert(f.qval(p[i].x, p[i].y)==p[i].z);
}
if (p[i].y!=p[j].y&&p[i].x==p[j].x) {
ll inv=qpow(p[i].y-p[j].y, mod-2);
t.set(1, 2);
t[0][0]=-p[j].y*inv%mod;
t[0][1]=inv;
// assert(f.qval(p[i].x, p[i].y)==p[i].z);
// assert(t.qval(p[i].x, p[i].y)==1);
f=f*t;
}
}
// f.put();
// assert(f.qval(p[i].x, p[i].y)==p[i].z);
r=r+f;
}
}
void solve() {
for (int i=1; i<=m; ++i) {s[i]=read(); t[i]=read(); c[i]=read();}
for (int i=1; i<=n+1; ++i) for (int j=1; j<=n+1; ++j) p[++top]={i, j, calc(i, j)}; //, cout<<"p: "<<i<<' '<<j<<' '<<calc(i, j)<<endl;
lagrange();
for (int i=0; i<=min(r.lenx()-1, g); ++i)
for (int j=0; j<=min(r.leny()-1, b); ++j)
ans=(ans+r[i][j])%mod;
// p[++top]={0, 1, 1}; p[++top]={0, 2, 3}; p[++top]={0, 3, 6};
// lagrange();
// cout<<r.qval(5, 0)<<endl;
printf("%lld\n", (ans%mod+mod)%mod);
exit(0);
}
}
signed main()
{
n=read(); m=read(); g=read(); b=read();
// force::solve();
task::solve();
return 0;
}