题解 星图

传送门

赛时折腾一个能优化的 \(O(n^8)\) DP 假了
听说状态数到 \(O(n^{11})\) 才能写?

赛时想法是枚举一个点,再枚举一个区间作为其子树
每个点在其所有祖先统计一次答案
但是这样是三个变量
有一种两个变量的统计方法:枚举两个点,钦定这两个点之间的点权比这两个点都小,再给这两个点钦定一个大小关系
那么可以枚举最高点所在的段,枚举最高点的位置,按顺序枚举另一个点,用背包的方式计算有 \(k\) 个点在当前段内的概率
这样是 \(O(n^4)\)
然后发现只需要统计答案,所以可以整体 DP,同时对每个最高点进行 DP
具体的,若枚举了最高点 \(i\),则另一点 \(j>i\) 在每个 \(i\) 背包时都需要加入一遍
那么在每个 \(i\) 处加入 \(i\) 的初始状态,把所有这 \(n\) 个背包一起做即可
复杂度 \(O(n^3)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 610
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
const ll mod=998244353;
int l[N], r[N], uni[N], usiz;
ll f[N], inv[N], invp[N], ans;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

ll down(int i, int w) {
	if (r[i]<=w) return 1;
	else if (l[i]>=w) return 0;
	else return (uni[w]-uni[l[i]])*invp[i]%mod;
}

ll in(int i, int w) {
	return (down(i, w+1)-down(i, w))%mod;
}

void add(int i, int w) {
	for (int j=n-1; ~j; --j) if (f[j]) {
		f[j+1]=(f[j+1]+f[j]*in(i, w))%mod;
		f[j]=(f[j]*down(i, w))%mod;
	}
}

signed main()
{
	freopen("starmap.in", "r", stdin);
	freopen("starmap.out", "w", stdout);

	n=read();
	for (int i=1; i<=n; ++i) {
		uni[++usiz]=(l[i]=read());
		uni[++usiz]=(r[i]=read());
	}
	inv[0]=inv[1]=1;
	for (int i=1; i<=n; ++i) invp[i]=qpow(r[i]-l[i], mod-2);
	for (int i=2; i<=n+1; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	sort(uni+1, uni+usiz+1);
	usiz=unique(uni+1, uni+usiz+1)-uni-1;
	for (int i=1; i<=n; ++i) {
		l[i]=lower_bound(uni+1, uni+usiz+1, l[i])-uni;
		r[i]=lower_bound(uni+1, uni+usiz+1, r[i])-uni;
	}
	for (int w=1; w<usiz; ++w) {
		// cout<<"w: "<<w<<endl;
		memset(f, 0, sizeof(f));
		for (int i=1; i<n; ++i) {
			f[0]=(f[0]+in(i, w))%mod, add(i+1, w);
			for (int i=0; i<=n; ++i) ans=(ans+f[i]*inv[i+1])%mod;
		}
		memset(f, 0, sizeof(f));
		for (int i=n; i>1; --i) {
			f[0]=(f[0]+in(i, w))%mod, add(i-1, w);
			for (int i=0; i<=n; ++i) ans=(ans+f[i]*inv[i+1])%mod;
		}
	}
	printf("%lld\n", (ans%mod+mod)%mod);

	return 0;
}
posted @ 2022-07-16 19:31  Administrator-09  阅读(3)  评论(0编辑  收藏  举报