【bzoj3522/4543】[POI2014]Hotel加强版(长链剖分+dp)

传送门

神仙题。。简单版本很好做,做法也很多。
加强版\(n\leq 10^5\),显然之前的\(O(n^2)\)的做法时间、空间复杂度都不能承受。
考虑维护以深度有关的\(dp\):

  • \(f[i][j]\)表示以\(i\)为根节点的子树中,深度为\(j\)的点有多少个。

显然这个很好维护,转移\(\displaystyle f[i][j]=\sum_{k}f[k][j-1]\),我们可以用长链剖分加速。
因为我们要枚举\(3\)个点,现在还需要一个\(dp\)维护另外两个点的信息。

  • \(g[i][j]\)表示以\(i\)为根节点的子树中,点对\((x,y)\)的个数有多少个,点对要满足\(x,y,x\not ={y}\)\(lca\)的距离相等,并且从\(lca\)\(i\)这段距离为\(d-j\)。也就是说还需要一条长度为\(j\)的链进行匹配。

考虑如何转移:

  • 显然可以直接从儿子进行转移,即\(g[i][j]=g[k][j+1]\)
  • 从不同儿子子树中选取两个:\(g[i][j]=f[k_1][j-1]*f[k_2][j-1]\)。此时两个结点的\(lca\)一定为\(i\)

注意第一种转移跟深度有关系,但是和之前有点区别,此处我们还是可以通过长链剖分来进行优化;第二种转移可以直接进行枚举,这里枚举的深度会受到限制,总的枚举次数为\(O(长链长度)\)

之后考虑如何维护答案。
显然最终的答案有两种情况:

  • 中心点为某一个结点,此时答案为\(g[i][0]\)
  • 中心点不为某一个结点,此时答案为\(f[i][j]*g[k][j+1]+f[k][j-1]*g[i][j]\)。主要就考虑了\((2,1),(1,2)\)这两种情况,其实\((1,1,1)\)这种也考虑了的,但已经被包含入\((2,1)\)了。

以上过程我们在一边枚举轻儿子时一边进行转移&统计答案。
注意\(g[i][0]\)要先加上,否则可能会重复统计。
细节见代码:

/*
 * Author:  heyuhhh
 * Created Time:  2020/6/10 23:23:27
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#include <assert.h>
#include <functional>
#include <numeric>
#define MP make_pair
#define fi first
#define se second
#define pb push_back
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << std::endl; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
  template <template<typename...> class T, typename t, typename... A> 
  void err(const T <t> &arg, const A&... args) {
  for (auto &v : arg) std::cout << v << ' '; err(args...); }
#else
  #define dbg(...)
#endif
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 1e5 + 5;

int n;
vector <int> G[N];

ll *f[N], *g[N], ans;
ll tmp[N << 2], *id = tmp;

int len[N], bson[N];
void dfs(int u, int fa) {
    int Max = 0;
    for (auto v : G[u]) if (v != fa) {
        dfs(v, u);
        if (len[v] > Max) {
            Max = len[v];
            bson[u] = v;
        }
    }
    len[u] = len[bson[u]] + 1;
}
void dfs2(int u, int fa) {
    f[u][0] = 1;
    if (bson[u]) {
        //处理重链
        int v = bson[u];
        f[v] = f[u] + 1;
        g[v] = g[u] - 1;
        dfs2(v, u);
    }
    ans += g[u][0];
    for (auto v : G[u]) {
        if (v == fa || v == bson[u]) continue;
        //分配空间
        f[v] = id, id += (len[v] << 1);
        g[v] = id, id += (len[v] << 1);
        dfs2(v, u);
        //从轻链转移
        for (int i = 0; i < len[v]; i++) {
            ans += f[v][i] * g[u][i + 1];
            if (i) {
                ans += f[u][i - 1] * g[v][i];
            }
        }
        for (int i = 1; i <= len[v]; i++) {
            if (i < len[v]) {
                g[u][i - 1] += g[v][i];
            }
            g[u][i] += f[u][i] * f[v][i - 1];
            f[u][i] += f[v][i - 1];
        }
    }
}
void run() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    f[1] = id, id += (len[1] << 1);
    g[1] = id, id += (len[1] << 1);
    dfs2(1, 0);
    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    run();
    return 0;
}
posted @ 2020-06-13 11:57  heyuhhh  阅读(254)  评论(0编辑  收藏  举报