Endless Pallet(min-max容斥)

地址:传送门

分析:

设$x_i$表示第i个点被染成黑色的时间,所求即为$E(max \left \{x_i  \right \})$

因为$E(X)=\sum_{k=1}^{\infty}i \times P(X=k)=\sum_{k=1}^{\infty}P(X\geqslant k)$,所以即求$\sum_{k=1}^{\infty}P(max\left \{ x_i \right \}\geqslant k)$

我们来考虑$P(max\left \{ x_i \right \}\geqslant k)$如何求

$P(max\left \{ x_i \right \}\geqslant k)=1-P(max\left \{ x_i \right \}<  k)$

$P(max\left \{ x_i \right \}<  k)=P(x_1<k,x_2<k,...,x_n<k)=\sum_{T \subseteq X}(-1)^{|T|}P(t_1\geqslant k,t_2 \geqslant k,...,t_{ |T|} \geqslant k )$

其中后面一部分是容斥原理

那么我们可以暴力枚举子集,算后面的那一坨的概率

假设固定了T,那么我们计算出从$\frac{n(n+1)}{2}$个染色方案中可以选取A个,使得没有包含T中任何一个元素

那么后面那一坨的概率就是$(\frac{2 \times A}{n(n+1)})^{k-1}$

我们发现,当我们枚举$k=1,2,...,\infty $的时候,这是一个等比数列,所以对答案的贡献是$\frac{1}{1-(\frac{2 \times A}{n(n+1)})}$

注意,当$A=\frac{n(n+1)}{2}$的时候,是不成立的,因为不是等比数列,这种情况下对于每一个k,这一项的贡献都是1,正好与$P(max\left \{ x_i \right \}\geqslant k)=1-P(max\left \{ x_i \right \}<  k)$里面的1相抵消

于是我们得到了一个最终的式子:$\sum_{k=1}^{\infty}P(max\left \{ x_i \right \}\geqslant k)=\sum_{T \subseteq X,T\neq \varnothing }(-1)^{|T|+1}\frac{1}{1-(\frac{2 \times A_T}{n(n+1)})}$

其中$A_T$是子集T对应的A

我们需要想办法优化这个指数级别的算法,注意到对答案有影响的只有$A$和$|T|$,我们可以考虑用树形dp计算出有同样的A和|T|的有多少个集合

考虑一个子问题:给你一个树,有一些点是关键点,你需要统计出有多少条链不经过其中任何一个关键点

这个问题是可以树形dp解决的,状态需要保存以u为根的子树中,与u连通的非关键点形成的连通块的大小

于是对于我们这个问题,我们设$dp[i][j][A][0/1]$表示以i为根的子树,与i相连的没被选入点集的连通块大小为j,已经有了A条不经过点集中点的链,选取的点的数量的奇偶性是0/1情况下的点集有多少个

然后就可以对于dp出的值算贡献了

这个dp看似是$O(n^7)$的,但是它实际的复杂度是$O(n^5)$,原因与广外人知的树形背包的复杂度分析相同。(tip:不妨考虑我现在有n个数字1,每次需要把两个数字相加合成新的数字,代价是两个数字的乘积。问我最终合并成1个n,最少要花费多少代价。这个问题的答案是固定的,答案是$O(n^2)$级别的,这个问题等价于基于子树大小的树背包,所以树背包的复杂度是$O(n^2)$而不是$O(n^3)$)。

 

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int maxn=50;
 4 const int mod=998244353;
 5 int dp[maxn+1][maxn+1][1300][2];
 6 int tmp[maxn+1][1300][2];
 7 int inv[maxn*maxn+5];
 8 int sz[maxn+5];
 9 vector<int> g[maxn+5];
10 int n;
11 int mul(int a,int b)
12 {
13     return 1LL*a*b%mod;
14 }
15 int add(int a,int b)
16 {
17     return (a+b)%mod;
18 }
19 void merge(int a[maxn+1][1300][2],int b[maxn+1][1300][2],int n,int m)
20 {
21     for(int i=0;i<=n+m;++i)
22         for(int j=0;j<=(n+m)*(n+m+1)/2;++j)
23             tmp[i][j][0]=tmp[i][j][1]=0;
24     for(int i=0;i<=n;++i)
25         for(int j=0;j<=n*(n+1)/2;++j)
26             for(int k=0;k<2;++k)
27                 if(a[i][j][k])
28                     for(int x=0;x<=m;++x)
29                         for(int y=0;y<=m*(m+1)/2;++y)
30                             for(int z=0;z<2;++z)
31                                 if(b[x][y][z])
32                                     tmp[i?i+x:0][j+y+i*x][k^z]=add(tmp[i?i+x:0][j+y+i*x][k^z],mul(a[i][j][k],b[x][y][z]));
33     for(int i=0;i<=n+m;++i)
34         for(int j=0;j<=(n+m)*(n+m+1)/2;++j)
35             for(int k=0;k<2;++k)
36                 a[i][j][k]=tmp[i][j][k];
37 }
38 void dfs(int k,int fa)
39 {
40     sz[k]=1;
41     dp[k][1][1][0]=1;
42     dp[k][0][0][1]=1;
43     for(auto u:g[k])
44     {
45         if(u==fa) continue;
46         dfs(u,k);
47         merge(dp[k],dp[u],sz[k],sz[u]);
48         sz[k]+=sz[u];
49     }
50 }
51 int main()
52 {
53     inv[1]=1;
54     for(int i=2;i<=50*(50+1)/2;++i) inv[i]=mul(mod-mod/i,inv[mod%i]);
55     int T;
56     scanf("%d",&T);
57     for(int cas=1;cas<=T;++cas)
58     {
59         printf("Case #%d: ",cas);
60         scanf("%d",&n);
61         for(int i=0;i<=n;++i) g[i].clear();
62         for(int i=0;i<=n;++i)
63             for(int j=0;j<=n;++j)
64                 for(int k=0;k<=n*(n+1)/2;++k)
65                     for(int l=0;l<2;++l)
66                         dp[i][j][k][l]=0;
67         for(int i=1;i<n;++i)
68         {
69             int u,v;
70             scanf("%d%d",&u,&v);
71             g[u].push_back(v);
72             g[v].push_back(u);
73         }
74         dfs(1,0);
75         int ans=0;
76         int s=n*(n+1)/2;
77         for(int i=0;i<=n;++i)
78             for(int j=0;j<s;++j)
79                 for(int k=0;k<2;++k)
80                 {
81                     int tmp=mul(s,inv[s-j]);
82                     tmp=mul(tmp,dp[1][i][j][k]);
83                     if(k==0) tmp=-tmp;
84                     ans=add(ans,tmp);
85                 }
86         if(ans<0) ans+=mod;
87         printf("%d\n",ans);
88     }
89     return 0;
90 }
View Code

 

posted @ 2018-08-22 00:39  Chellyutaha  阅读(341)  评论(0编辑  收藏  举报