题解 树拓扑序

传送门

MD 我是 NT 做了 tm 一场这个题只 tm 会 \(n^5\) 结果出来一看 \(n^4\) 直接写还 tm 是能 \(n^2\) 爆标的原题(

好了,原题
注意算出逆序对数平均数再乘方案数会方便很多
否则在合并两个子树的时候并不容易计算其中一个根节点对总方案的贡献

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 510
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
pair<int, int> e[N];
const ll mod=1e9+7, inv2=(mod+1)>>1;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll ans;
	int p[N], q[N];
	void solve() {
		for (int i=1; i<=n; ++i) p[i]=i;
		int cnt;
		do {
			cnt=0;
			for (int i=1; i<=n; ++i) q[p[i]]=i;
			for (int i=1; i<n; ++i) if (q[e[i].fir]>q[e[i].sec]) goto jump;
			for (int i=1; i<=n; ++i) for (int j=i+1; j<=n; ++j) if (p[i]>p[j]) ++cnt;
			ans=(ans+cnt)%mod;
			jump: ;
		} while (next_permutation(p+1, p+n+1));
		printf("%lld\n", ans);
	}
}

namespace task1{
	bool del[N];
	queue<int> q;
	bitset<N> s[N];
	int siz[N], cnt[N];
	vector<int> to[N], son[N];
	vector<pair<int, int>> sta;
	ll fac[N], inv[N], dp[N], ans;
	inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
	void dfs(int u) {
		siz[u]=0; dp[u]=1;
		for (auto v:son[u]) dfs(v), dp[u]=dp[u]*C(siz[u]+=siz[v], siz[v])%mod*dp[v]%mod;
		++siz[u];
	}
	ll calc() {
		for (int i=1; i<=n; ++i) s[i].reset(), to[i].clear(), son[i].clear(), cnt[i]=0;
		for (auto it:sta) to[it.fir].pb(it.sec), ++cnt[it.sec];
		for (int i=1; i<=n; ++i) if (!cnt[i]) q.push(i);
		while (q.size()) {
			int u=q.front(); q.pop();
			s[u][u]=1;
			for (auto v:to[u]) {
				s[v]|=s[u];
				if (--cnt[v]==0) q.push(v);
			}
		}
		for (int i=1; i<=n; ++i) if (cnt[i]) return 0;
		int cnt=0;
		for (int u=1; u<=n; ++u) {
			for (int i=0; i<to[u].size(); ++i) del[i]=0;
			for (int i=0; i<to[u].size(); ++i)
				for (int j=0; j<to[u].size(); ++j)
					if (!del[i]&&!del[j]&&i!=j&&s[to[u][j]][to[u][i]])
						del[j]=1;
			for (int i=0; i<to[u].size(); ++i) if (!del[i]) son[to[u][i]].pb(u), ++cnt;
		}
		assert(cnt==n-1);
		// cout<<"cnt: "<<cnt<<endl;
		dfs(1);
		// cout<<"dp: "; for (int i=1; i<=n; ++i) cout<<dp[i]<<' '; cout<<endl;
		return dp[1];
	}
	void solve() {
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (int i=1; i<n; ++i) sta.pb(e[i]);
		for (int i=1; i<=n; ++i)
			for (int j=1; j<i; ++j) {
				sta.pb({i, j});
				// cout<<"ij: "<<i<<' '<<j<<endl;
				// cout<<"ij: "<<i<<' '<<j<<' '<<calc()<<endl;
				ans=(ans+calc())%mod;
				sta.pop_back();
			}
		printf("%lld\n", ans);
	}
}

namespace task2{
	int cnt[N];
	ll f[1<<20], g[1<<20], ans;
	vector<int> to[N];
	vector<pair<int, int>> sta;
	void solve() {
		for (int i=1; i<n; ++i) to[e[i].fir-1].pb(e[i].sec-1);
		f[0]=0; g[0]=1;
		int lim=1<<n;
		for (int s=0; s<lim; ++s) {
			for (int i=0; i<n; ++i) cnt[i]=0;
			for (int i=0; i<n; ++i) if (!(s&(1<<i)))
				for (auto& it:to[i]) ++cnt[it];
			for (int i=0; i<n; ++i) if (!(s&(1<<i))&&!cnt[i]) {
				f[s|(1<<i)]=(f[s|(1<<i)]+f[s]+__builtin_popcount(s>>i)*g[s])%mod;
				g[s|(1<<i)]=(g[s|(1<<i)]+g[s])%mod;
			}
		}
		printf("%lld\n", f[lim-1]);
	}
}

namespace task{
	vector<int> to[N];
	int siz[N], fa[N];
	ll fac[N], inv[N], inv2[N], rec[N][N], g[N][N], cnt[N], ans;
	inline ll C(int n, int k) {return fac[n]*inv2[k]%mod*inv2[n-k]%mod;}
	void dfs(int u, int pa) {
		fa[u]=pa; siz[u]=cnt[u]=1;
		for (auto v:to[u]) {
			dfs(v, u);
			siz[u]+=siz[v];
			cnt[u]=cnt[u]*inv2[siz[v]]%mod*cnt[v]%mod;
			for (int i=1; i<=n; ++i) g[i][u]+=g[i][v];
		}
		cnt[u]=cnt[u]*fac[siz[u]-1]%mod;
		for (int i=1; i<=n; ++i) g[i][u]+=(i<u);
	}
	ll f(int u, int v) {
		if (u>v) swap(u, v);
		if (~rec[u][v]) return rec[u][v];
		ll *t=&rec[u][v]; *t=0;
		ll t1=g[v][u], t2=g[u][v];
		for (auto w:to[v]) t1=(t1+f(u, w))%mod;
		for (auto w:to[u]) t2=(t2+f(v, w))%mod;
		// cout<<"f: "<<u<<' '<<v<<' '<<*t<<endl;
		t1=t1*siz[v]%mod; t2=t2*siz[u]%mod;
		return *t=(t1+t2)*inv[siz[u]+siz[v]]%mod;
	}
	void solve() {
		fac[0]=fac[1]=1; inv[0]=inv[1]=1; inv2[0]=inv2[1]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv2[i]=inv2[i-1]*inv[i]%mod;
		for (int i=1; i<n; ++i) to[e[i].fir].pb(e[i].sec);
		dfs(1, 0);
		// cout<<"siz: "; for (int i=1; i<=n; ++i) cout<<siz[i]<<' '; cout<<endl;
		memset(rec, -1, sizeof(rec));
		for (int i=1; i<=n; ++i)
			for (int j=i+1; j<=n; ++j) if (fa[i]==fa[j])
				ans=(ans+f(i, j))%mod;
		// cout<<"cnt: "<<cnt[1]<<endl;
		for (int i=1; i<=n; ++i) ans=(ans+g[i][i])%mod;
		printf("%lld\n", ans*cnt[1]%mod);
	}
}

signed main()
{
	n=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		e[i]={u, v};
	}
	// force::solve();
	// task1::solve();
	// task2::solve();
	task::solve();

	return 0;
}
posted @ 2022-05-25 08:13  Administrator-09  阅读(1)  评论(0编辑  收藏  举报