题解 心理阴影

传送门

确实写出心理阴影了

发现没有祖孙关系的两个点的子树之间互不干扰,于是尝试对这个东西进行 DP
转移的话发现形成的序列中第一个点一定是两个子树的根节点之一
枚举两种情况,发现变成了一个子树和一堆子树,就变成子问题了
先令 \(g_{u, v}\) 表示点 \(v\) 的子树内比 \(u\) 小的点的个数
来康康转移细节
发现如果我定义的 \(f_{u, v}\) 考虑了 \(u, v\) 子树内部的顺序关系的话是没办法转化子问题的(会算重)
(我因为这样定义调了一下午)
于是定义 \(f_{u, v}\) 为在不考虑 \(u, v\) 子树内部顺序的前提下合并两子树的逆序对数平均数
不考虑内部顺序的意思是我不需要知道现在 \(u, v\) 子树内到底是怎么排的
然后考虑转移系数:
考虑合并 \(u, v\) 子树,现在钦定 \(u\) 是拓扑序中的第一个,那么

\[逆序对平均数=\frac{合并 v 和 u 的儿子的逆序对平均数*合并 v 和 u 的儿子的合法方案数}{合并 u 和 v 子树的总方案数} \]

所以系数就是

\[\begin{aligned}\frac{令u在拓扑序第一个的合法方案数}{合法方案总数}&=\dfrac{\dbinom{siz_u-1+siz_v}{siz_u-1}}{\dbinom{siz_u+siz_v}{siz_u}}\end{aligned} \]

系数能先化简一下:\(\frac{(siz_u-1+sizv)!}{(siz_u-1)!siz_v!}*\frac{siz_u!siz_v!}{(siz_u+siz_v)!}=\frac{siz_u}{siz_u+siz_v}\)
于是有image
枚举兄弟合并,加上每个根节点对其子树的贡献即可

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5010
#define ll long long
#define pb push_back
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, r;
int head[N], size;
vector<int> to[N];
const ll mod=1e9+7;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qinv(ll a) {return qpow(a, mod-2);}

namespace force{
	ll ans;
	int p[N], pos[N];
	bool dfs(int u, int fa) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			if (pos[u]>pos[v]) return 0;
			if (!dfs(v, u)) return 0;
		}
		return 1;
	}
	void solve() {
		for (int i=1; i<=n; ++i) p[i]=i;
		int cnt=0;
		do {
			for (int i=1; i<=n; ++i) pos[p[i]]=i;
			if (dfs(r, 0)) {
				// cout<<"p: "; for (int i=1; i<=n; ++i) cout<<p[i]<<' '; cout<<endl;
				int tem=0;
				for (int i=1; i<=n; ++i) 
					for (int j=i+1; j<=n; ++j)
						if (p[i]>p[j]) ++tem;
				ans=(ans+tem)%mod; ++cnt;
			}
		} while (next_permutation(p+1, p+n+1));
		printf("%lld\n", ans*qpow(cnt, mod-2)%mod);
	}
}

namespace task1{
	int bit[N], siz[N], pa[N];
	ll fac[N], inv[N], F[N][N], g[N][N], q[N], k[N][N], ans;
	inline void upd(int i) {for (; i<=n; i+=i&-i) ++bit[i];}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	void dfs(int u, int fa) {
		for (int i=1; i<=n; ++i) g[i][u]-=query(i-1);
		upd(u); siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			pa[v]=u;
			dfs(v, u);
			siz[u]+=siz[v];
			for (int j=1; j<=n; ++j) k[u][j]=(k[u][j]+k[v][j])%mod;
		}
		for (int i=1; i<=n; ++i) g[i][u]+=query(i-1);
		for (int i=1; i<=n; ++i) k[u][i]=(k[u][i]+g[u][i])%mod;
	}
	void dfs2(int u, int fa) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs2(v, u);
			for (int j=1; j<=n; ++j) k[u][j]=(k[u][j]+k[v][j])%mod;
		}
		for (int i=1; i<=n; ++i) k[u][i]=(k[u][i]+g[u][i])%mod;
	}
	ll f(int , int ) ;
	ll calc(int r, int fa) {
		ll ans=g[r][r];
		for (int i=0; i<to[r].size(); ++i)
			for (int j=0; j<to[r].size(); ++j) if (j!=i) {
				int u=to[r][i], v=to[r][j];
				if (u!=fa&&v!=fa) ans=(ans+k[u][v])%mod;
			}
		return ans;
	}
	ll f(int u, int v) {
		cout<<"f: "<<u<<' '<<v<<endl;
		if (u>v) swap(u, v);
		if (~F[u][v]) return F[u][v];
		ll* t=&F[u][v]; *t=0;
		ll t1=g[u][v], t2=g[v][u];
		for (int i=head[u],w; ~i; i=e[i].next) {
			w = e[i].to;
			if (w!=pa[u]) t1=(t1+f(w, v))%mod;
		}
		for (int i=head[v],w; ~i; i=e[i].next) {
			w = e[i].to;
			if (w!=pa[v]) t2=(t2+f(w, u))%mod;
		}
		if (u==3&&v==5) cout<<t1<<' '<<t2<<endl;
		t1=t1*siz[u]%mod; t2=t2*siz[v]%mod;
		if (u==3&&v==5) cout<<t1<<' '<<t2<<endl;
		cout<<"f: "<<u<<' '<<v<<' '<<(t1+t2)*qinv(siz[u]+siz[v])%mod<<endl;
		// ll tem=fac[siz[u]+siz[v]]*inv[siz[u]]%mod*inv[siz[v]]%mod;
		return *t=((t1+t2)*qinv(siz[u]+siz[v])+(calc(u, pa[u])+calc(v, pa[v])))%mod;
	}
	void solve() {
		dfs(r, 0); dfs2(r, 0);
		memset(F, -1, sizeof(F));
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		// cout<<"g"<<endl; for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<g[i][j]<<' '; cout<<endl;}
		while (to[r].size()-(pa[r]>0)==1) {
			ans=(ans+g[r][r]);
			r=to[r][0]==pa[r]?to[r][1]:to[r][0];
		}
		// cout<<"r: "<<r<<endl;
		// for (int i=0; i<to[r].size(); ++i)
		// 	for (int j=i+1; j<to[r].size(); ++j) {
		// 		int u=to[r][i], v=to[r][j];
		// 		if (u==pa[r]||v==pa[r]) continue;
		// 		ans=(ans+f(u, v))%mod;
		// 	}
		printf("%lld\n", (ans+calc(r, pa[r]))%mod);
		// cout<<f(2, 5)<<endl;
		// cout<<calc(5, pa[5])<<endl;
		// cout<<calc(2, pa[2])<<endl;
		cout<<f(3, 5)<<endl;
		cout<<f(2, 5)<<endl;
		// cout<<k[5][3]<<endl;
		cout<<calc(3, pa[3])<<endl;
		cout<<9*qinv(4)%mod<<endl;
		cout<<3*qinv(2)%mod<<endl;
	}
}

namespace task{
	int bit[N], siz[N], pa[N];
	ll fac[N], inv[N], F[N][N], g[N][N], q[N], k[N][N], ans;
	inline void upd(int i) {for (; i<=n; i+=i&-i) ++bit[i];}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	void dfs(int u, int fa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			pa[v]=u;
			dfs(v, u);
			siz[u]+=siz[v];
			for (int j=1; j<=n; ++j) g[u][j]+=g[v][j];
		}
		for (int i=1; i<=n; ++i) g[u][i]+=(i>u);
	}
	ll f(int u, int v) {
		if (u>v) swap(u, v);
		if (~F[u][v]) return F[u][v];
		ll* t=&F[u][v]; *t=0;
		ll t1=g[v][u], t2=g[u][v];
		for (int i=head[u],w; ~i; i=e[i].next) {
			w = e[i].to;
			if (w!=pa[u]) t1=(t1+f(w, v))%mod;
		}
		for (int i=head[v],w; ~i; i=e[i].next) {
			w = e[i].to;
			if (w!=pa[v]) t2=(t2+f(w, u))%mod;
		}
		t1=t1*siz[u]%mod; t2=t2*siz[v]%mod;
		return *t=(t1+t2)*inv[siz[u]+siz[v]]%mod;
	}
	void solve() {
		dfs(r, 0);
		memset(F, -1, sizeof(F));
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		// cout<<"g"<<endl; for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<g[i][j]<<' '; cout<<endl;}
		for (int i=1; i<=n; ++i)
			for (int j=i+1; j<=n; ++j)
				if (pa[i]==pa[j])
					ans=(ans+f(i, j))%mod;
		for (int i=1; i<=n; ++i) ans=(ans+g[i][i])%mod;
		// cout<<f(2, 5)<<endl;
		printf("%lld\n", ans);
	}
}

signed main()
{
	freopen("nightmare.in", "r", stdin);
	freopen("nightmare.out", "w", stdout);

	n=read(); r=read();
	memset(head, -1, sizeof(head));
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
		to[u].pb(v); to[v].pb(u);
	}
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2022-01-20 10:49  Administrator-09  阅读(3)  评论(0编辑  收藏  举报