Endless Pallet(min-max容斥)
地址:传送门
分析:
设$x_i$表示第i个点被染成黑色的时间,所求即为$E(max \left \{x_i \right \})$
因为$E(X)=\sum_{k=1}^{\infty}i \times P(X=k)=\sum_{k=1}^{\infty}P(X\geqslant k)$,所以即求$\sum_{k=1}^{\infty}P(max\left \{ x_i \right \}\geqslant k)$
我们来考虑$P(max\left \{ x_i \right \}\geqslant k)$如何求
$P(max\left \{ x_i \right \}\geqslant k)=1-P(max\left \{ x_i \right \}< k)$
$P(max\left \{ x_i \right \}< k)=P(x_1<k,x_2<k,...,x_n<k)=\sum_{T \subseteq X}(-1)^{|T|}P(t_1\geqslant k,t_2 \geqslant k,...,t_{ |T|} \geqslant k )$
其中后面一部分是容斥原理
那么我们可以暴力枚举子集,算后面的那一坨的概率
假设固定了T,那么我们计算出从$\frac{n(n+1)}{2}$个染色方案中可以选取A个,使得没有包含T中任何一个元素
那么后面那一坨的概率就是$(\frac{2 \times A}{n(n+1)})^{k-1}$
我们发现,当我们枚举$k=1,2,...,\infty $的时候,这是一个等比数列,所以对答案的贡献是$\frac{1}{1-(\frac{2 \times A}{n(n+1)})}$
注意,当$A=\frac{n(n+1)}{2}$的时候,是不成立的,因为不是等比数列,这种情况下对于每一个k,这一项的贡献都是1,正好与$P(max\left \{ x_i \right \}\geqslant k)=1-P(max\left \{ x_i \right \}< k)$里面的1相抵消
于是我们得到了一个最终的式子:$\sum_{k=1}^{\infty}P(max\left \{ x_i \right \}\geqslant k)=\sum_{T \subseteq X,T\neq \varnothing }(-1)^{|T|+1}\frac{1}{1-(\frac{2 \times A_T}{n(n+1)})}$
其中$A_T$是子集T对应的A
我们需要想办法优化这个指数级别的算法,注意到对答案有影响的只有$A$和$|T|$,我们可以考虑用树形dp计算出有同样的A和|T|的有多少个集合
考虑一个子问题:给你一个树,有一些点是关键点,你需要统计出有多少条链不经过其中任何一个关键点
这个问题是可以树形dp解决的,状态需要保存以u为根的子树中,与u连通的非关键点形成的连通块的大小
于是对于我们这个问题,我们设$dp[i][j][A][0/1]$表示以i为根的子树,与i相连的没被选入点集的连通块大小为j,已经有了A条不经过点集中点的链,选取的点的数量的奇偶性是0/1情况下的点集有多少个
然后就可以对于dp出的值算贡献了
这个dp看似是$O(n^7)$的,但是它实际的复杂度是$O(n^5)$,原因与广外人知的树形背包的复杂度分析相同。(tip:不妨考虑我现在有n个数字1,每次需要把两个数字相加合成新的数字,代价是两个数字的乘积。问我最终合并成1个n,最少要花费多少代价。这个问题的答案是固定的,答案是$O(n^2)$级别的,这个问题等价于基于子树大小的树背包,所以树背包的复杂度是$O(n^2)$而不是$O(n^3)$)。
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int maxn=50; 4 const int mod=998244353; 5 int dp[maxn+1][maxn+1][1300][2]; 6 int tmp[maxn+1][1300][2]; 7 int inv[maxn*maxn+5]; 8 int sz[maxn+5]; 9 vector<int> g[maxn+5]; 10 int n; 11 int mul(int a,int b) 12 { 13 return 1LL*a*b%mod; 14 } 15 int add(int a,int b) 16 { 17 return (a+b)%mod; 18 } 19 void merge(int a[maxn+1][1300][2],int b[maxn+1][1300][2],int n,int m) 20 { 21 for(int i=0;i<=n+m;++i) 22 for(int j=0;j<=(n+m)*(n+m+1)/2;++j) 23 tmp[i][j][0]=tmp[i][j][1]=0; 24 for(int i=0;i<=n;++i) 25 for(int j=0;j<=n*(n+1)/2;++j) 26 for(int k=0;k<2;++k) 27 if(a[i][j][k]) 28 for(int x=0;x<=m;++x) 29 for(int y=0;y<=m*(m+1)/2;++y) 30 for(int z=0;z<2;++z) 31 if(b[x][y][z]) 32 tmp[i?i+x:0][j+y+i*x][k^z]=add(tmp[i?i+x:0][j+y+i*x][k^z],mul(a[i][j][k],b[x][y][z])); 33 for(int i=0;i<=n+m;++i) 34 for(int j=0;j<=(n+m)*(n+m+1)/2;++j) 35 for(int k=0;k<2;++k) 36 a[i][j][k]=tmp[i][j][k]; 37 } 38 void dfs(int k,int fa) 39 { 40 sz[k]=1; 41 dp[k][1][1][0]=1; 42 dp[k][0][0][1]=1; 43 for(auto u:g[k]) 44 { 45 if(u==fa) continue; 46 dfs(u,k); 47 merge(dp[k],dp[u],sz[k],sz[u]); 48 sz[k]+=sz[u]; 49 } 50 } 51 int main() 52 { 53 inv[1]=1; 54 for(int i=2;i<=50*(50+1)/2;++i) inv[i]=mul(mod-mod/i,inv[mod%i]); 55 int T; 56 scanf("%d",&T); 57 for(int cas=1;cas<=T;++cas) 58 { 59 printf("Case #%d: ",cas); 60 scanf("%d",&n); 61 for(int i=0;i<=n;++i) g[i].clear(); 62 for(int i=0;i<=n;++i) 63 for(int j=0;j<=n;++j) 64 for(int k=0;k<=n*(n+1)/2;++k) 65 for(int l=0;l<2;++l) 66 dp[i][j][k][l]=0; 67 for(int i=1;i<n;++i) 68 { 69 int u,v; 70 scanf("%d%d",&u,&v); 71 g[u].push_back(v); 72 g[v].push_back(u); 73 } 74 dfs(1,0); 75 int ans=0; 76 int s=n*(n+1)/2; 77 for(int i=0;i<=n;++i) 78 for(int j=0;j<s;++j) 79 for(int k=0;k<2;++k) 80 { 81 int tmp=mul(s,inv[s-j]); 82 tmp=mul(tmp,dp[1][i][j][k]); 83 if(k==0) tmp=-tmp; 84 ans=add(ans,tmp); 85 } 86 if(ans<0) ans+=mod; 87 printf("%d\n",ans); 88 } 89 return 0; 90 }