[bzoj4543][POI2014]Hotel加强版——长链剖分

题目大意:

传送门

思路:

每一个三元组必定是一个三叉。
考虑在三个点的lca处计算贡献。
考虑记\(f_{u,j}\)表示距离u深度为j的点一共有多少个,\(g_{u,j}\)表示在u的子树中,点对a,b距离lca 的距离为d,lca距离u的距离为d-j,也就是这两个点对还差一段长度为j的路径才成凑成一个合法的三元组。
于是每一个点都可以从子树转移过来,考虑记录前缀,然后在新子树添加进来的时候计算贡献。
于是有以下的DP方程:

\[\begin{aligned} ans&+=g_{u,j}\times f_{son,j-1}+f_{u,j}\times g_{son,j+1}\\ g_{u,j}&+=g_{son,j+1}+f_{u,j}\times f_{son,j-1}\\ f_{u,j}&+=f_{son,j-1} \end{aligned} \]

发现f,gd都之和深度有关,并且在第一颗子树的转移都只涉及位移,于是直接长链剖分+动态数组维护即可。

#include<bits/stdc++.h>
 
#define REP(i,a,b) for(int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;
 
using namespace std;
 
void File(){
    freopen("bzoj4543.in","r",stdin);
    freopen("bzoj4543.out","w",stdout);
}
 
template<typename T>void read(T &_){
    _=0; T f=1; char c=getchar();
    for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
    for(;isdigit(c);c=getchar())_=(_<<1)+(_<<3)+(c^'0');
    _*=f;
}
 
const int maxn=1e5+10;
int n;
int beg[maxn],to[maxn<<1],las[maxn<<1],cnte=1;
int len[maxn],fa[maxn],son[maxn];
ll ans,*f[maxn],*g[maxn],ft[maxn],gt[maxn<<1],*fp=ft,*gp=gt;
 
void add(int u,int v){
    las[++cnte]=beg[u]; beg[u]=cnte; to[cnte]=v;
    las[++cnte]=beg[v]; beg[v]=cnte; to[cnte]=u;
}
 
void dfs(int u,int fh){
    fa[u]=fh;
    for(int i=beg[u];i;i=las[i]){
        int v=to[i];
        if(v==fh)continue;
        dfs(v,u);
        if(len[v]>len[u]){
            len[u]=len[v];
            son[u]=v;
        }
    }
    ++len[u];
}
 
void solve(int u){
    f[u][0]=1;
    if(son[u]){
        int v=son[u];
        g[v]=g[u]-1;
        f[v]=f[u]+1;
        solve(v);
        ans+=g[u][0];
    }
    for(int i=beg[u];i;i=las[i]){
        int v=to[i];
        if(v==fa[u] || v==son[u])continue;
        g[v]=gp+len[v]+1,gp+=len[v]<<1;
        f[v]=fp+1,fp+=len[v];
        solve(v);
        REP(j,0,len[v]){
            if(j)ans+=g[u][j]*f[v][j-1];
            ans+=f[u][j]*g[v][j+1];
            if(j)g[u][j]+=1ll*f[u][j]*f[v][j-1];
            g[u][j]+=g[v][j+1];
            if(j)f[u][j]+=f[v][j-1];
        }
    }
}
 
int main(){
    //File();
    read(n);
    int u,v;
    REP(i,1,n-1)read(u),read(v),add(u,v);
 
    dfs(1,0);
 
    g[1]=gp+len[1]+1,gp+=len[1]<<1;
    f[1]=fp+1,fp+=len[1];
 
    solve(1);
 
    printf("%lld\n",ans);
 
    return 0;
}

posted @ 2019-01-07 20:49  ylsoi  阅读(125)  评论(0编辑  收藏  举报