集合划分状压dp

 

给一个 $n$ 个点 $m$ 条边的无向图,每条边有 $p_i$ 的概率消失,求图连通的概率

$n \leq 9$

sol:

我们考虑一个 $dp$ 

$f_{(i,S)}$ 表示只考虑前 $i$ 条边,当前图连通的状态为 $S$ 的概率

设这条边没有消失,图的新连通状态为 $T$

那转移到 $T$ 的概率就是 $(1 - p_i)$

不变的概率是 $p_i$

然后一个滚动数组就做完了

 

然后我们考虑,怎么把“图的连通状态”这个东西状压出来

一个 idea 是,我们可以在状态里记录每个点所处的连通块里最小的点的编号,比如 123 是一个连通块,45 是一个连通块,我们这个状态就是 12344 

这样的话,因为每个点所处连通块里最小的点的编号不超过它的编号,状态数是 $O(n!)$ 的,但如果我们直接存一个 9 位数,数组显然开不下

1.我们可以哈希一下,把代码写成这样

int h;
map<vector<int>,int> hsh;
map<int,vector<int> > reh;
inline int gethsh(vector<int> v)
{
    if(!hsh[v])
    {
        hsh[v] = ++h;
        reh[h] = v;
    }
    return hsh[v];
}
inline vector<int> getreh(int h){return reh[h];}

2.或者我们可以动动脑子

显然,这个状态的最后一位数只有 $n$ 种情况

然后,倒数第二位只有 $n - 1$ 种情况(不能比最后一位大)

然后,倒数第三位只有 $n - 2$ 种情况

...

然后,最高位只有 $1$ 种情况($1$ 只能属于 $1$)

于是我们把最后一位数乘以 $n$ ,倒数第二位乘以 $(n-1) \times n$ ,倒数第三位乘以 $(n-2) \times (n-1) \times n$ ... 最高位乘以 $n!$

就把状态压到了 $O(n!)$ 个数里

如果想知道一个值对应的状态是什么,从高到低除再取余一下就可以了

于是可以 dp 了

 

学习了 yyc 同学的写法,预处理出了前 n 位所有状态

$state_(i,j)$ 表示第 $i$ 个方案的第 $j$ 位是什么,也就是 $j$ 点在第 $i$ 种状态里所属的连通块的编号最小的点

frustrated好强呀

%%%

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-')f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
int n,m;
double grid[15][15];
namespace p30
{
    int fa[15];inline int find(int x){return x == fa[x] ? x : fa[x] = find(fa[x]);}
    struct EDG{int u,v;double p;}es[100];
    int _main()
    {
        for(int i=1;i<=m;i++)
        {
            int a = read(),b = read();double p;
            scanf("%lf",&p);
            //grid[a][b] = grid[b][a] = (1.00 - p);
            es[i] = (EDG){a,b,1.00 - p};
        }
        int MAXSTATE = (1 << m) - 1;
        double zgl = 0.0;
        for(int S=0;S<=MAXSTATE;S++)
        {
            double curgl = 1.00;
            for(int i=1;i<=n;i++)fa[i] = i;
            for(int i=1;i<=m;i++)
            {
                if(S & (1 << (i - 1)))
                {
                    int fu = find(es[i].u),fv = find(es[i].v);
                    curgl *= es[i].p;if(fu == fv)continue;
                    fa[fu] = fv;
                }
                else curgl *= (1.00 - es[i].p);
            }
            int flg = 1;
            for(int i=1;i<=n;i++)if(find(i) != find(1))flg = 0;
            if(flg)zgl += curgl;
        }printf("%.3f\n",zgl);
        return 0;
    }
}
namespace prng
{
    int _main()
    {
        srand(time(0));
        int RNG = rand() % 1000 + 1;
        double rng = RNG / 1000.0;
        printf("%.3f\n",rng);
        return 0;
    }
}
namespace p100
{
    const int maxn = 382880;
    int fac[15],u[150],v[150],state[maxn][12],t[12];
    double w[150],f[maxn],g[maxn];
    inline int calc()
    {
        int res = 0;
        for(int i=1;i<=n;i++) res += (t[i] - 1) * (fac[n] / fac[i]);
        return res + 1;
    }
    int _main()
    {
        for(int i=1;i<=m;i++)
        {
            u[i] = read(),v[i] = read();
            scanf("%lf",&w[i]);
        }fac[0] = 1;
        for(int i=1;i<=10;i++)fac[i] = fac[i - 1] * i;
        for(int i=1;i<=n;i++)state[1][i] = 1;
        for(int i=2;i<=fac[n];i++)
        {
            int x = n;
            while(1)
            {
                state[i][x] = state[i - 1][x] + 1;
                if(state[i][x] > x)state[i][x] = 1,state[i][x - 1]++;
                else break;
                x--;
            }
            for(int j=1;j<x;j++)state[i][j] = state[i - 1][j];
        }
        f[fac[n]] = 1.0;
        for(int i=1;i<=m;i++)
        {
            for(int j=1;j<=fac[n];j++)
            {
                if(f[j])
                {
                    for(int kk=1;kk<=n;kk++)t[kk] = state[j][kk];
                    for(int kk=1;kk<=n;kk++)
                        if(state[j][v[i]] == t[kk] || t[kk] == state[j][u[i]]) t[kk] = min(state[j][u[i]],state[j][v[i]]);
                    g[calc()] += f[j] * (1 - w[i]);
                    g[j] += f[j] * w[i];
                }
            }
            for(int j=1;j<=fac[n];j++)f[j] = g[j],g[j] = 0;
        }
        printf("%.3f",f[1]);
    }
}
int main()
{
    //freopen("10.in","r",stdin);
    //freopen("10.out","w",stdout);
    n = read(),m = read();
    //if(n <= 8 || m <= 23)p30::_main();
    //else prng::_main();
    p100::_main();
}
View Code

 

posted @ 2018-11-06 19:23  探险家Mr.H  阅读(285)  评论(0编辑  收藏  举报