【UOJ388】配对树(dsu on tree+线段树)

传送门

题意:
给出一颗含有\(n\)个结点的无根树,之后给出一个长度为\(m\)的序列,每个元素在\([1,n]\)之间。
现在序列中每个长度为偶数的区间的完成时间定义为树上最小配对方法中每对匹配点间距离的总和。
现在要求所有长度为偶数的区间的完成时间的和。

思路:

  • 首先不妨将这颗树转化为有根树,最终不会影响答案。
  • 注意到性质:偶数个点的两两匹配方式是唯一的,都是最深的两个点相互匹配,这样才能保证没有重复计算的边。
  • 在子树内部直接计算不好算,要考虑很多东西(一开始就想偏了QAQ)。因为匹配方式唯一,所以子树中若全部匹配完成,那么至多只会剩下一个点。显然,若剩下一个点,此时子树及其父亲这条边肯定会算上。
  • 我们将问题转化为单独考虑一条边的贡献,一条边的贡献次数即为序列元素在该子树中出现奇数次的偶数区间个数。
  • 考虑暴力计算:对于当前的子树,暴力给序列上打上标记,然后做一个前缀和,若存在\(i,j,i<j\),满足\(i\equiv j(mod\ 2)\)\(s_j\equiv s_i+1(mod\ 2)\)即可。
  • 如果用线段树维护,那么只需要维护奇数位置、偶数位置上面前缀模\(2\)意义下\(1\)的个数即可快速求得当前子树内部的答案。
  • 时间复杂度\(O(n^2logn)\)
  • 因为这涉及到子树问题,且不带修改,所以可以直接施展\(dsu\ on\ tree\),最终复杂度为\(O(nlog^2n)\)

感觉这种将子树内部的问题转化为子树+一条边的问题似乎是一种套路?(上次atcoder有个题也是这样)。
以后这种计算贡献的题可以尝试转换一下思路,单独考虑每个元素的贡献。
这个题还可以直接线段树合并来写,复杂度是\(O(nlogn)\),不过(不会写)懒得写了。。
代码如下:

/*
 * Author:  heyuhhh
 * Created Time:  2019/11/16 8:54:36
 */
#include <bits/stdc++.h>
#define MP make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
//#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << '\n'; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
  #define dbg(...)
#endif
void pt() {std::cout << '\n'; }
template<typename T, typename...Args>
void pt(T a, Args...args) {std::cout << a << ' '; pt(args...); }
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 1e5 + 5, MOD = 998244353;

template <class T>
inline void read(T& x) {
    static char c;
    x = 0;
    bool sign = 0;
    while (!isdigit(c = getchar()))
        if (c == '-')
            sign = 1;
    for (; isdigit(c); x = x * 10 + c - '0', c = getchar())
        ;
    if (sign)
        x = -x;
}

int n, m;
int a[N];
struct Edge{
    int v, w, next;   
}e[N << 1];
int head[N], tot;
void adde(int u, int v, int w) {
    e[tot].v = v; e[tot].w = w; e[tot].next = head[u]; head[u] = tot++;
}

int pre[N], last[N];
int sz[N], bson[N], son;

void dfs(int u, int fa) {
    int mx = 0; sz[u] = 1;
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v != fa) {
            a[v] = e[i].w;
            dfs(v, u);
            sz[u] += sz[v];
            if(sz[v] > mx) mx = sz[v], bson[u] = v;
        }
    }  
}

int odd[N << 2], even[N << 2], rev[N << 2];
int ans;

void Reverse(int o, int l, int r) {
    if(l == r) {
        if(l & 1) odd[o] ^= 1;
        else even[o] ^= 1;
    } else {
        odd[o] = (r + 1) / 2 - l / 2 - odd[o];
        even[o] = r / 2 - (l - 1) / 2 - even[o];
    }
    rev[o] ^= 1;
}

void push_up(int o) {
    odd[o] = odd[o << 1] + odd[o << 1|1];   
    even[o] = even[o << 1] + even[o << 1|1];
}

void push_down(int o, int l, int r) {
    if(rev[o]) {
        int mid = (l + r) >> 1;
        Reverse(o << 1, l, mid);
        Reverse(o << 1|1, mid + 1, r);
        rev[o] = 0;
    }   
}

void upd(int o, int l, int r, int L, int R) {
    if(L <= l && r <= R) {
        Reverse(o, l, r);
        return;
    }
    push_down(o, l, r);
    int mid = (l + r) >> 1;
    if(L <= mid) upd(o << 1, l, mid, L, R);
    if(R > mid) upd(o << 1|1, mid + 1, r, L, R);
    push_up(o);
}

int Get() {
    int res = (1ll * (m / 2 + 1 - even[1]) * even[1] % MOD + 1ll * ((m + 1) / 2 - odd[1]) * odd[1] % MOD) % MOD;
    return res;
}

void calc(int u, int fa) {
    for(int i = last[u]; i; i = pre[i]) upd(1, 1, m, i, m);
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v == fa || v == son) continue;
        calc(v, u);
    }   
}

void dfs2(int u, int fa, int op) {
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v != fa && v != bson[u]) dfs2(v, u, 0);    
    }
    if(bson[u]) dfs2(bson[u], u, 1);
    son = bson[u];
    calc(u, fa);
    son = 0;
    dbg(u, get());
    ans = (ans + 1ll * a[u] * Get() % MOD) % MOD;
    if(!op) calc(u, fa);
}

void run(){
    memset(head, -1, sizeof(head));
    read(n), read(m);
    for(int i = 1; i < n; i++) {
        int u, v, w;
        read(u), read(v), read(w);
        adde(u, v, w); adde(v, u, w);
    }
    for(int i = 1; i <= m; i++) {
        int x; read(x);
        pre[i] = last[x];
        last[x] = i;
    }
    dfs(1, 0);
    dfs2(1, 0, 1);
    cout << ans << '\n';
}

int main() {
    run();
	return 0;
}
posted @ 2019-11-16 19:46  heyuhhh  阅读(283)  评论(0编辑  收藏  举报