题解 young

传送门

挺神仙的题,边界极其恶心

  • 关于求所有情况中最小值的和:考虑转化为统计最小值 \(\geqslant i\) 的方案数,并对 \(i\leqslant n\) 求和
  • 关于图上对权值和的期望一类问题的处理雷点:
    一定要想好要不要带标号统计!除非不带标号的处理思路已经非常清晰否则还是尽量带上标号吧
  • 关于一类点的权值为 \([0, 2^m-1]\) 间的随机值的期望问题的切入点:
    考虑能否逐位确定,不止序列,甚至点集也是可以逐位确定的
    具体地,对点集进行逐位确定时考虑对每一位枚举点集中的哪些点在这一位上是1,依此划分为两个点集递归处理

image

用的是逐位确定的方法
先用期望的线性性将每条边的贡献拆出来考虑
于是枚举这条边两侧的点集
\(2^{(n-i)(m-1)}f(i, m-1)+2^{i(m-1)}f(n-i, m-1)\) 是在统计两边的点集的贡献
\(2^{m-1}\times 2^{n(m-1)}\) 是在统计第 \(m\) 位的贡献,前一部分是贡献的值,后一部分是方案数

然后重点是这个 \(p\) 函数
首先考虑一个边界:当 \(m<0\) 时,代表 \(k\) 已被完全满足,应返回1
然后注意这个 \(p\) 实际上是在满足 \(k\) 限制的前提下给点赋值的方案数
所以当某一个点集为空的时候相当于无限制(其实是个递归边界,贡献给向上到有限制的那一层),返回的是 \(2^{|T|}\),当然还要乘上这一位的贡献
然后考虑 \(k\) 这一位上是0还是1
若为1,显然 \(S\)\(T\) 这一位上完全不同,直接向下递归即可
若为0,那么枚举 \(S_0\)\(T_0\) 的大小,先钦定这一位上的最小值为0
注意到 \(k\) 的限制是作用于 \((S_0, T_0)\)\((S_1, T_1)\) 之间的
\(p\) 函数又是给点赋值的方案数
所以 \(p(i, j, m-1, k)\)\(p(s-i, t-j, m-1, k)\) 之间完全独立,是相乘关系
注意这里枚举 \(S_0,T_0\) 时不能有 \(S_0=S,T_1=T\)\(S_1=S,T_0=T\),否则就是上面那种情况了
再考虑这种情况下这一位的贡献
刚才钦定这一位上的最小值为0是不产生贡献的
现在钦定为1,因为是统计最小值 \(\geqslant k\) 的方案数所以后面 \(m-1\) 位都可以随便选了
于是产生 \(2^{m(|S|+|T|)}\) 的贡献
然后钦定了这一位是1那 \(S\)\(T\) 中有一个这一位都是1
选哪一个都行,所以再乘个2
复杂度大概是 \(O(n^4m2^m)\) 但记搜一下跑的挺快的

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
const ll mod=258280327;
inline ll qpow(ll a, int b) {ll ans=1; if (b<0) return 0; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int val[N], dsu[N], siz;
	ll pre, ans;
	struct edge{int s, t, val;}e[N];
	inline bool operator < (edge a, edge b) {return a.val<b.val;}
	inline int find(int p) {return dsu[p]==p?p:find(dsu[p]);}
	void dfs(int u) {
		if (u>n) {
			siz=0;
			for (int i=1; i<=n; ++i)
				for (int j=i+1; j<=n; ++j)
					e[++siz]={i, j, val[i]^val[j]};
			sort(e+1, e+siz+1);
			for (int i=1; i<=n; ++i) dsu[i]=i;
			ll cnt=0, sum=0;
			for (int i=1,f1,f2; i<=siz&&cnt<n; ++i) {
				f1=find(e[i].s), f2=find(e[i].t);
				if (f1==f2) continue;
				dsu[f1]=f2;
				++cnt;
				sum+=e[i].val;
			}
			ans=(ans+sum*pre)%mod;
			return ;
		}
		int lim=1<<m;
		for (int i=0; i<lim; ++i) {
			val[u]=i;
			dfs(u+1);
		}
	}
	void solve() {
		pre=qpow(qpow(2, m*n), mod-2);
		// cout<<"pre: "<<pre<<endl;
		dfs(1);
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	ll fac[N], inv[N], F[51][9], G[51][51][9], P[51][51][9][1<<8], ans;
	inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
	ll p(int s, int t, int m, int k) {
		// cout<<"p: "<<s<<' '<<t<<' '<<m<<' '<<k<<endl;
		if (s>t) swap(s, t);
		if (m<0) {assert(!k); return 1;}
		if (!s) return qpow(2, t*(m+1));
		if (~P[s][t][m][k]) return P[s][t][m][k];
		ll ans=0;
		if (k&(1<<m)) ans=2*p(s, t, m-1, k^(1<<m))%mod;
		else {
			// cout<<"live: "<<endl;
			for (int i=0; i<=s; ++i)
				for (int j=0; j<=t; ++j) if ((i+t-j)*(j+s-i))
					ans=(ans+C(s, i)*C(t, j)%mod*p(i, j, m-1, k)%mod*p(s-i, t-j, m-1, k))%mod;
			ans=(ans+qpow(2, (s+t)*m+1))%mod;
		}
		return P[s][t][m][k]=ans;
	}
	ll g(int s, int t, int m) {
		// cout<<"g: "<<s<<' '<<t<<' '<<m<<endl;
		if (s>t) swap(s, t);
		if (!s) return 0;
		if (~G[s][t][m]) return G[s][t][m];
		ll ans=0;
		for (int i=1; i<(1<<m); ++i) ans=(ans+p(s, t, m-1, i))%mod;
		// cout<<"return: "<<ans<<endl;
		return G[s][t][m]=ans;
	}
	ll f(int n, int m) {
		// cout<<"f: "<<n<<' '<<m<<endl;
		if (!n||!m) return 0;
		if (~F[n][m]) return F[n][m];
		ll ans=0;
		for (int i=0; i<=n; ++i)
			ans=(ans+C(n, i)*((qpow(2, (n-i)*(m-1))*f(i, m-1)%mod+qpow(2, i*(m-1))*f(n-i, m-1)%mod+g(i, n-i, m-1)+(i&&n!=i?qpow(2, (m-1)*(n+1)):0))%mod))%mod;
		return F[n][m]=ans;
	}
	void solve() {
		memset(F, -1, sizeof(F));
		memset(G, -1, sizeof(G));
		memset(P, -1, sizeof(P));
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		printf("%lld\n", f(n, m)*qpow(qpow(2, n*m), mod-2)%mod);
		// cout<<f(n, m)<<endl;
		exit(0);
	}
}

signed main()
{
	n=read(); m=read();
	// force::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-01-03 21:37  Administrator-09  阅读(0)  评论(0编辑  收藏  举报