【UR #17】滑稽树前做游戏
tag:概率期望,多项式,积分
一个经典套路 \(E(f_{mx})=\sum_i P[f\ge i]\)(\(P\) 表示概率)
仔细思考了一下发现不太好求,转化成 \(E(G)=2-\sum_i[f(G)\le i]\)。
假设求出了 \(F(G,i)=P[f(G)\le i]\),那么 \(ans=2-\int_0^2F(G_0,i)di\)。(\(G_0\) 表示原图)
\(F(G,i)\) 还是不太好求,考虑增加第三维 \(F(G,i,j)\),表示 \(f(G)\le i,v_{mx}\le j\),即点权不超过 \(j\)。分两种情况:
-
所有点都不超过 \(\frac i2\),显然合法,概率为 \((\frac i2)^{|G|}\)
-
有一个点超过 \(\frac i2\),枚举超过的那个点 \(x\) 和它的点权 \(w\),然后删掉和它相连的点,递归下去。注意和它相连的点都不能超过 \(i-w\),所以要乘上 \((i-w)^{d_x}\),\(d_x\) 为 \(x\) 的度数。
注意到不可能有多于一个点超过 \(\frac i2\),所以考虑到了所有情况。
设 \(G_i\) 表示 \(G\) 删掉 \(i\) 和与之相连的点后的图,则:
不难发现 \(F(G)\) 是关于 \(i,j\) 的二元多项式。
进一步还能发现是关于 \(i,j\) 的齐次多项式。(可以归纳证明)
所以只需要维护一下这个多项式即可。
最后得到一个 \(F(G)\) 的答案多项式,由于要求的是 \(\int_0^2 F(G,\min(1,t),t)dt\)。
可以分为 \(\int_0^1F(G,t,t)dt+\int_1^2F(G,1,t)dt\)。
复杂度为 \(O(2^nn^3)\)。
关于多项式积分
多元函数对一个元求积分,就把其他的元当作系数看,而且此题是多项式的积分,所以很简单。
\(\int_a^bF(x)dx\) 就是对 \(F(x)\) 做不定积分,然后把 \(a,b\) 带进去的值相减。
inline void sig(poly &a){
int n = a.size(); a.resize(n+1);
for(int i=n; i; i--) a[i] = 1ll*a[i-1]*inv[i]%MOD;
a[0] = 0;
}
一个优化,如果图不连通,答案多项式等于所有连通块的答案多项式的乘积。然后就能过 \(n=25\) 了。
一些代码上的心得:
-
由于递归过程中都是对 \(j\) 积分,所以递归的时候令 \(a_x\) 表示 \(i^{n-x}j^x\) 的系数会比较方便。但是注意最后求答案是对 \(i\) 积分,所以要
reverse
一下。 -
此题暴力
map<pair<vector<int>,vector<pair<int,int>>>,int>
把整个图作为 key 值是能过的。 -
封几个多项式运算函数会很舒服(
当然后来注意到因为 \(n\le25\),而且边集无关,所以可以用 map<int,int>
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=false;
while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
for(n=(ch^48);isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
if(flag)n=-n;
}
typedef vector<int> poly;
enum{
MAXN = 30,
MOD = 998244353
};
inline int inc(int a, int b){
a += b;
if(a>=MOD) a -= MOD;
return a;
}
inline int dec(int a, int b){
a -= b;
if(a<0) a += MOD;
return a;
}
inline int ksm(int base, int k=MOD-2){
int res=1;
while(k){
if(k&1) res = 1ll*res*base%MOD;
base = 1ll*base*base%MOD;
k >>= 1;
}
return res;
}
inline void iinc(int &a, int b){a = inc(a,b);}
inline void ddec(int &a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}
int inv[MAXN], C[MAXN][MAXN], pw2[MAXN];
inline poly pw(int k){
poly res(k+1);
for(int i=0; i<=k; i++) res[i] = (i&1)?dec(0,C[k][i]):C[k][i];
return res;
}
inline poly mul(poly a, poly b){
int n = a.size(), m = b.size();
poly res(n+m-1,0);
for(int i=0; i<n; i++) for(int j=0; j<m; j++) upd(res[i+j],1ll*a[i]*b[j]);
return res;
}
inline poly add(poly a, poly b){
int n = a.size(), m = b.size();
poly res(max(n,m));
for(int i=0; i<n; i++) res[i] = a[i];
for(int i=0; i<m; i++) iinc(res[i],b[i]);
return res;
}
inline void sig(poly &a){
int n = a.size(); a.resize(n+1);
for(int i=n; i; i--) a[i] = 1ll*a[i-1]*inv[i]%MOD; a[0] = 0;
}
inline void print(poly x){for(int i:x) printf("%d ",i);puts("");}
struct UFS{
int fa[MAXN];
int Find(int x){return fa[x]==x?x:fa[x]=Find(fa[x]);}
inline void merge(int u, int v){fa[Find(u)] = Find(v);}
}ufs[MAXN];
typedef pair<int,int> pii;
poly Split(vector<int>,vector<pii>);
int dep;
char vis[MAXN][MAXN];
poly Solve(vector<int>P, vector<pii>E){
static map<int,poly>mp;
int s=0; for(int i:P) s |= 1<<i;
if(mp.count(s)) return mp[s];
poly res(P.size()+1,0);
res[0] = pw2[P.size()];
if(!P.size()) return res;
for(int i:P){
for(int j:P) vis[dep][j] = true;
vis[dep][i] = false;
int cnt=0;
for(pii e:E){
if(e.first==i) vis[dep][e.second] = false, cnt++;
if(e.second==i) vis[dep][e.first] = false, cnt++;
}
vector<int>np; vector<pii>ne;
for(int j:P) if(vis[dep][j]) np.push_back(j);
for(pii e:E) if(vis[dep][e.first] and vis[dep][e.second]) ne.push_back(e);
poly tmp = mul(Split(np,ne),pw(cnt));
sig(tmp);
res = add(res,tmp);
for(int j=0, tp=tmp.size(); j<tp; j++) ddec(res[0],1ll*pw2[j]*tmp[j]%MOD);
}
for(int i:P) vis[dep][i] = false;
return mp[s]=res;
}
poly Split(vector<int>P, vector<pii>E){
static map<int,poly>mp;
int s=0; for(int i:P) s |= 1<<i-1;
if(mp.count(s)) return mp[s];
for(int i:P) ufs[dep].fa[i] = i;
for(pii e:E) ufs[dep].merge(e.first,e.second);
poly res(1,1);
for(int i:P) if(ufs[dep].Find(i)==i){
vector<int>np; vector<pii>ne;
for(int j:P) if(ufs[dep].Find(j)==i) np.push_back(j);
for(pii e:E) if(ufs[dep].Find(e.first)==i) ne.push_back(e);
dep++;
res = mul(res,Solve(np,ne));
dep--;
}
return mp[s]=res;
}
int main(){
// freopen("test.in","r",stdin);
// freopen("test.out","w",stdout);
int n, m;
for(int i=0; i<MAXN; i++){
C[i][0] = 1; inv[i] = ksm(i); pw2[i] = ksm(ksm(2),i);
for(int j=1; j<=i; j++) C[i][j] = inc(C[i-1][j-1],C[i-1][j]);
}
Read(n); Read(m);
vector<int>p; vector<pii>e;
for(int i=1; i<=n; i++) p.push_back(i);
for(int i=1; i<=m; i++){
int f, t;
Read(f); Read(t);
e.push_back(pii(f,t));
}
poly ans = Split(p,e); reverse(ans.begin(),ans.end());
int sum=0; for(int i:ans) iinc(sum,i);
sig(ans);
n = ans.size(); int res = 1ll*ksm(n-1)*sum%MOD;
for(int i=0, tmp=1; i<n; i++, iinc(tmp,tmp)) upd(res,1ll*tmp*ans[i]), ddec(res,ans[i]);
printf("%d\n",dec(2,res));
return 0;
}