题解 星图
赛时折腾一个能优化的 \(O(n^8)\) DP 假了
听说状态数到 \(O(n^{11})\) 才能写?
赛时想法是枚举一个点,再枚举一个区间作为其子树
每个点在其所有祖先统计一次答案
但是这样是三个变量
有一种两个变量的统计方法:枚举两个点,钦定这两个点之间的点权比这两个点都小,再给这两个点钦定一个大小关系
那么可以枚举最高点所在的段,枚举最高点的位置,按顺序枚举另一个点,用背包的方式计算有 \(k\) 个点在当前段内的概率
这样是 \(O(n^4)\) 的
然后发现只需要统计答案,所以可以整体 DP,同时对每个最高点进行 DP
具体的,若枚举了最高点 \(i\),则另一点 \(j>i\) 在每个 \(i\) 背包时都需要加入一遍
那么在每个 \(i\) 处加入 \(i\) 的初始状态,把所有这 \(n\) 个背包一起做即可
复杂度 \(O(n^3)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 610
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
const ll mod=998244353;
int l[N], r[N], uni[N], usiz;
ll f[N], inv[N], invp[N], ans;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
ll down(int i, int w) {
if (r[i]<=w) return 1;
else if (l[i]>=w) return 0;
else return (uni[w]-uni[l[i]])*invp[i]%mod;
}
ll in(int i, int w) {
return (down(i, w+1)-down(i, w))%mod;
}
void add(int i, int w) {
for (int j=n-1; ~j; --j) if (f[j]) {
f[j+1]=(f[j+1]+f[j]*in(i, w))%mod;
f[j]=(f[j]*down(i, w))%mod;
}
}
signed main()
{
freopen("starmap.in", "r", stdin);
freopen("starmap.out", "w", stdout);
n=read();
for (int i=1; i<=n; ++i) {
uni[++usiz]=(l[i]=read());
uni[++usiz]=(r[i]=read());
}
inv[0]=inv[1]=1;
for (int i=1; i<=n; ++i) invp[i]=qpow(r[i]-l[i], mod-2);
for (int i=2; i<=n+1; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
sort(uni+1, uni+usiz+1);
usiz=unique(uni+1, uni+usiz+1)-uni-1;
for (int i=1; i<=n; ++i) {
l[i]=lower_bound(uni+1, uni+usiz+1, l[i])-uni;
r[i]=lower_bound(uni+1, uni+usiz+1, r[i])-uni;
}
for (int w=1; w<usiz; ++w) {
// cout<<"w: "<<w<<endl;
memset(f, 0, sizeof(f));
for (int i=1; i<n; ++i) {
f[0]=(f[0]+in(i, w))%mod, add(i+1, w);
for (int i=0; i<=n; ++i) ans=(ans+f[i]*inv[i+1])%mod;
}
memset(f, 0, sizeof(f));
for (int i=n; i>1; --i) {
f[0]=(f[0]+in(i, w))%mod, add(i-1, w);
for (int i=0; i<=n; ++i) ans=(ans+f[i]*inv[i+1])%mod;
}
}
printf("%lld\n", (ans%mod+mod)%mod);
return 0;
}