BZOJ 5287: [Hnoi2018]毒瘤 动态dp(LCT+矩阵乘法)

自己 yy 了一个动态 dp 做法,应该是全网唯一用 LCT 写的.    

code: 

#include <bits/stdc++.h>            
#define ll long long
#define lson tr[x].ch[0] 
#define rson tr[x].ch[1]  
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;      
const int N=200005;   
const ll mod=998244353;          
vector<int>EDGE;  
int edges=1,n,m,F[N][2];       
int hd[N],to[N<<1],nex[N<<1],from[N<<1],mark[N<<1],vis[N],sta[N];  
int qpow(int x,int y) 
{
    int tmp=1; 
    while(y) 
    {
        if(y&1)  tmp=1ll*tmp*x%mod;   
        x=1ll*x*x%mod, y>>=1; 
    }
    return tmp;  
}
int INV(int x) { return qpow(x,mod-2); }   
void add(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,from[edges]=u; 
}
void dfs(int u,int ff) 
{  
    vis[u]=1;    
    for(int i=hd[u];i;i=nex[i])  
    {
        int v=to[i];   
        if(v==ff||mark[i])   continue;    
        if(vis[v]) 
        {
            mark[i]=mark[i^1]=1, EDGE.push_back(i);                     
            continue; 
        } 
        dfs(v,u);     
    }
}    
struct Num
{
    int val,cnt;     
    Num() { val=cnt=0; }  
    void ins(int a)          
    {   
        val=a,cnt=0; 
        if(!val)    val=1,cnt=1;    
    } 
    Num operator*(Num b) const 
    {
        Num re;   
        re.val=1ll*val*b.val%mod;  
        re.cnt=cnt+b.cnt;      
        return re;   
    }
    Num operator/(Num b) const 
    {       
        Num re;               
        re.cnt=cnt-b.cnt;           
        re.val=1ll*val*INV(b.val)%mod;                 
        return re;  
    }   
    int get() { return cnt?0:val; }           
}tmp[N][2][2];          
struct matrix 
{
    int a[2][2];      
    matrix() { memset(a,0,sizeof(a)); }       
    void I() 
    {
        a[0][0]=a[1][1]=1; 
        a[0][1]=a[1][0]=0; 
    }    
    int *operator[](int x) { return a[x]; }       
}t[N],po[N];               
matrix operator*(matrix a,matrix b) 
{
    matrix c;     
    for(int i=0;i<2;++i)
    {
        for(int j=0;j<2;++j)
            for(int k=0;k<2;++k)               
                c[i][j]=(c[i][j]+1ll*a[i][k]*b[k][j]%mod)%mod; 
    }
    return c; 
}            
void cop(int x) 
{
    for(int i=0;i<2;++i)  
    { 
        for(int j=0;j<2;++j) 
            t[x][i][j]=tmp[x][i][j].get();  
    }
} 
struct node 
{  
    int ch[2],f,rev;        
}tr[N];  
int get(int x) 
{ 
    return tr[tr[x].f].ch[1]==x; 
}    
int isrt(int x) 
{
    return !(tr[tr[x].f].ch[0]==x||tr[tr[x].f].ch[1]==x);      
}   
void pushup(int x) 
{       
    cop(x);                     
    t[x]=po[x]*t[x];   
    if(lson)   t[x]=t[lson]*t[x];   
    if(rson)   t[x]=t[x]*t[rson];                                          
}
void rotate(int x) 
{    
    int old=tr[x].f,fold=tr[old].f,which=get(x);             
    if(!isrt(old))     tr[fold].ch[tr[fold].ch[1]==old]=x;   
    tr[old].ch[which]=tr[x].ch[which^1],tr[tr[old].ch[which]].f=old;   
    tr[x].ch[which^1]=old,tr[old].f=x,tr[x].f=fold; 
    pushup(old),pushup(x); 
}  
void splay(int x) 
{
    int u=x,v=0,fa;   
    for(sta[++v]=u;!isrt(u);u=tr[u].f)   sta[++v]=tr[u].f;    
    for(u=tr[u].f;(fa=tr[x].f)!=u;rotate(x))                             
    {         
        if(tr[fa].f!=u) 
        {           
            rotate(get(fa)==get(x)?fa:x);  
        }
    }
}
void Access(int x) 
{
    for(int y=0;x;y=x,x=tr[x].f) 
    {   
        splay(x);    
        if(rson) 
        {
            Num a,b;   
            a.ins(t[rson][0][0]);                 
            b.ins(t[rson][0][0]+t[rson][1][0]);           
            tmp[x][1][0]=tmp[x][1][0]*a;           
            tmp[x][0][0]=tmp[x][0][0]*b;   
            tmp[x][0][1]=tmp[x][0][1]*b;             
        }
        if(y) 
        {
            Num a,b;  
            a.ins(t[y][0][0]); 
            b.ins(t[y][0][0]+t[y][1][0]);     
            tmp[x][1][0]=tmp[x][1][0]/a;   
            tmp[x][0][0]=tmp[x][0][0]/b;    
            tmp[x][0][1]=tmp[x][0][1]/b;   
        }    
        rson=y;   
        pushup(x);   
    }   
}       
void prepare(int u,int ff) 
{ 
    po[u].I();    
    tr[u].f=ff;            
    tmp[u][1][1].ins(0);   
    tmp[u][0][0].ins(1); 
    tmp[u][0][1].ins(1); 
    tmp[u][1][0].ins(1);    
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        if(v==ff||mark[i])     continue;       
        prepare(v,u);                 
        Num a,b;   
        a.ins(tmp[v][0][0].get());                
        b.ins(tmp[v][0][0].get()+tmp[v][1][0].get());                              
        tmp[u][0][0]=tmp[u][0][0]*b;   
        tmp[u][0][1]=tmp[u][0][1]*b;            
        tmp[u][1][0]=tmp[u][1][0]*a;         
    }       
    pushup(u);   
}
int main() 
{    
    int i,j;   
    scanf("%d%d",&n,&m);                  
    for(i=1;i<=m;++i) 
    {
        int u,v; 
        scanf("%d%d",&u,&v);         
        add(u,v),add(v,u);   
    }   
    dfs(1,0); 
    prepare(1,0);                
    ll ans=0ll;   
    int sta=EDGE.size();        
    if(sta==0) 
    {
        Access(3),splay(3);     
        ans=(t[3][0][0]+t[3][1][0])%mod;    
    }
    else
    {    
        for(i=0;i<(1<<sta);++i) 
        {
            for(j=0;j<sta;++j) 
            {
                if(i&(1<<j)) 
                {    
                    int u=from[EDGE[j]];   
                    int v=to[EDGE[j]];    
                    Access(u),splay(u);       
                    po[u][0][0]=0;   
                    pushup(u); 

                    Access(v),splay(v);   
                    po[v][1][1]=0;   
                    pushup(v);     
                } 
                else 
                {   
                    int u=from[EDGE[j]];    
                    Access(u),splay(u);     
                    po[u][1][1]=0;   
                    pushup(u);   
                }
            } 
            Access(1),splay(1);    
            (ans+=(t[1][0][0]+t[1][1][0])%mod)%=mod;                          
            for(j=0;j<sta;++j) 
            {
                if(i&(1<<j)) 
                {    
                    int u=from[EDGE[j]];   
                    int v=to[EDGE[j]];    
                    Access(u),splay(u);       
                    po[u][0][0]=1;            
                    pushup(u);      
                    Access(v),splay(v); 
                    po[v][1][1]=1;       
                    pushup(v);    
                } 
                else 
                {   
                    int u=from[EDGE[j]];    
                    Access(u),splay(u);     
                    po[u][1][1]=1;   
                    pushup(u);   
                }
            }     
        }        
    }
    printf("%lld\n",ans);   
    return 0; 
}

  

posted @ 2019-12-05 20:25  EM-LGH  阅读(170)  评论(0编辑  收藏  举报