[CTSC2018]暴力写挂

题面

给出两颗树\(A,B\),求\(\max(A_{dis}(x,y) + dep_{A_{lca}\ }(x,y) + dep_{B_{\ lca}\ }\ (x,y))\)

解法

考虑改写柿子答案为\(\frac{1}{2}\max(A_{dis}(x,y) + dep_x + dep_y - 2 dep_{B_{\ lca}\ }(x,y))\)

考虑淀粉质,把分治中点统领范围的点拉出来在\(B\)上建出虚树,关键点点权赋值为\(dis(mid,x) + dep(x)\),同一个子树赋值同一个颜色,在虚树上\(dp\)不同颜色的\(\max val(x) + val(y) - 2dep(lca)\)即可。

点击查看代码
//晦暗的宇宙,我们找不到光,看不见尽头,但我们永远都不会被黑色打倒。——Quinn葵因
#include <bits/stdc++.h>
#define ll long long
#define N 1000000
using std::vector;
using std::pair;

int n;

#define pil pair<int,ll>
#define mp std::make_pair

vector<pil>A[N], B[N]; //tree

ll depa[N], depb[N];
int dep[N];

inline void dfsa(int u, int fa) {
    //  std::cout<<u<<" "<<fa<<" "<<depa[u]<<"\n";
    for (auto it : A[u]) {
        int x = it.first;
        ll v = it.second;

        if (x == fa)
            continue;

        depa[x] = depa[u] + v;
        dfsa(x, u);
    }
}

int F[N][20];

int dfn[N];
int cnt;

inline void dfsb(int u, int fa) {
    F[u][0] = fa;
    dep[u] = dep[fa] + 1;

    for (int i = 1; i <= 19; ++i)
        F[u][i] = F[F[u][i - 1]][i - 1];

    dfn[u] = ++cnt;

    for (auto it : B[u]) {
        int x = it.first;
        ll v = it.second;

        if (x == fa)
            continue;

        depb[x] = depb[u] + v;
        dfsb(x, u);
    }
}

int sum, siz[N], root;
int maxn[N];

ll val[N];
int c[N];
int vis[N];

inline void find(int u, int fa) {
    siz[u] = 1;
    maxn[u] = 1;

    for (auto it : A[u]) {
        int v = it.first;

        if (v == fa || vis[v])
            continue;

        find(v, u);
        siz[u] += siz[v];
        maxn[u] = std::max(maxn[u], siz[v]);
    }

    maxn[u] = std::max(maxn[u], sum - siz[u]);

    //  std::cout<<"FIND "<<u<<" "<<maxn[u]<<"\n";
    if (maxn[u] < maxn[root])
        root = u;
}

vector<int>P;

inline void dis(int u, int fa) {
    //  std::cout<<"DIS "<<u<<" "<<fa<<" "<<val[u]<<"\n";
    c[u] = c[fa];

    for (auto it : A[u]) {
        int v = it.first;
        ll vi = it.second;

        if (vis[v] || v == fa)
            continue;

        val[v] = val[u] + vi;
        dis(v, u);
    }

    val[u] = val[u] + depa[u];
    P.push_back(u);
}

inline int lca(int x, int y) {
    //  std::cout<<"LCA "<<x<<" "<<y<<"\n";
    if (dep[x] < dep[y])
        std::swap(x, y);

    for (int i = 19; i >= 0; --i) {
        if (dep[F[x][i]] >= dep[y])
            x = F[x][i];

        //      std::cout<<x<<"\n";
    }

    if (x == y)
        return x;

    for (int i = 19; i >= 0; --i) {
        if (F[x][i] != F[y][i])
            x = F[x][i], y = F[y][i];

        //      std::cout<<"UP "<<x<<" "<<y<<"\n";
    }

    return F[x][0];
}

inline bool cmp(int x, int y) {
    return dfn[x] < dfn[y];
}

ll f[N][2];

bool key[N];

int Fi[N];

ll ans = -1e18;

#define pii pair<ll,int>

vector<pii>M;

inline void merge(int x, int y) { //y -> x
    for (int i = 0; i < 2; ++i)
        for (int j = 0; j < 2; ++j)
            if (c[f[x][i]] != c[f[y][j]])
                ans = std::max(ans, val[f[x][i]] + val[f[y][j]] - 2 * depb[x]);

    //    bool tag = 0;
    //
    //    if (val[f[y][0]] < val[f[y][1]])
    //        std::swap(f[y][0], f[y][1]);
    //
    //    if (val[f[x][0]] < val[f[x][1]])
    //        std::swap(f[x][0], f[x][1]);
    //
    //    if (val[f[y][0]] > val[f[x][0]]) {
    //        if (c[f[y][0]] != c[f[x][0]])
    //            f[x][1] = f[x][0];
    //
    //        f[x][0] = f[y][0];
    //        tag = 1;
    //    } else {
    //        if (val[f[y][0]] > val[f[x][1]] && c[f[y][0]] != c[f[x][0]]) {
    //            f[x][1] = f[y][0];
    //            tag = 1;
    //        }
    //    }
    //
    //    if (!tag) {
    //        if (val[f[y][1]] > val[f[x][0]]) {
    //            if (c[f[y][1]] != c[f[x][0]])
    //                f[x][1] = f[x][0];
    //
    //            f[x][0] = f[y][1];
    //        } else {
    //            if (val[f[y][1]] > val[f[x][1]] && c[f[y][1]] != c[f[x][0]]) {
    //                f[x][1] = f[y][1];
    //            }
    //        }
    //    }

    //  std::cout<<"CAO NI MA DE MERGE"<<"\n";
    //  std::cout<<x<<" "<<y<<"\n";
    //  std::cout<<f[x][0]<<" "<<c[f[x][0]]<<" "<<val[f[x][0]]<<"\n";
    //  std::cout<<f[x][1]<<" "<<c[f[x][1]]<<" "<<val[f[x][1]]<<"\n";
    //  std::cout<<f[y][0]<<" "<<c[f[y][0]]<<" "<<val[f[y][0]]<<"\n";
    //  std::cout<<f[y][1]<<" "<<c[f[y][1]]<<" "<<val[f[y][1]]<<"\n";
    M.clear();
    M.push_back(mp(-val[f[x][0]], f[x][0]));
    M.push_back(mp(-val[f[x][1]], f[x][1]));
    M.push_back(mp(-val[f[y][0]], f[y][0]));
    M.push_back(mp(-val[f[y][1]], f[y][1]));
    std::sort(M.begin(), M.end());
    //  puts("CAO NI MA DE WO sort GET");
    //  for(int i = 0;i <= 3;++i)
    //  std::cout<<-M[i].first<<" "<<M[i].second<<"\n";
    f[x][0] = M[0].second;

    for (int i = 1; i <= 3; ++i) {
        if (c[M[i].second] != c[f[x][0]]) {
            f[x][1] = M[i].second;
            //  std::cout<<f[x][0]<<" "<<c[f[x][0]]<<" "<<val[f[x][0]]<<"\n";
            //  std::cout<<f[x][1]<<" "<<c[f[x][1]]<<" "<<val[f[x][1]]<<"\n";
            return ;
        }
    }
}

inline int iread(){
   int s=0,w=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){ch=getchar();}
   while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
   return s*w;
}

inline ll lread(){
   ll s=0,w=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
   while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
   return s*w;
}

inline void build() {
    val[0] = -1e18;
    c[0] = -1;
    //  puts("FUCK BUILE THE WEAK TREE");
    std::sort(P.begin(), P.end(), cmp);

    for (int i = 0; i < P.size(); ++i)
        key[P[i]] = 1;/*,std::cout<<P[i]<<" ";

    puts("");   */
    int k = P.size();

    for (int i = 1; i < k; ++i)
        P.push_back(lca(P[i], P[i - 1]));

    //  std::cout<<"HELP "<<P[i]<<" "<<P[i - 1]<<" "<<lca(P[i],P[i - 1])<<"\n";
    std::sort(P.begin(), P.end(), cmp);
    P.erase(unique(P.begin(), P.end()), P.end());

    for (int i = 1; i < P.size(); ++i)
        Fi[P[i]] = lca(P[i], P[i - 1]);

    for (int i = 0; i < P.size(); ++i) {
        if (!key[P[i]])
            f[P[i]][0] = f[P[i]][1] = 0;
        else
            f[P[i]][0] = P[i], f[P[i]][1] = 0;
    }

    for (int i = P.size() - 1; i >= 1; --i) {
        int u = P[i];
        merge(Fi[u], u);
        //      std::cout<<"GLASS "<<u<<" "<<Fi[u]<<"\n";
        //      std::cout<<f[u][0]<<" "<<c[f[u][0]]<<" "<<val[f[u][0]]<<"\n";
        //      std::cout<<f[u][1]<<" "<<c[f[u][1]]<<" "<<val[f[u][1]]<<"\n";
        //        ans = std::max(ans, val[f[u][0]] + val[f[u][1]] - 2 * depb[u]);
        //      std::cout<<ans<<"\n";
    }

    //  std::cout<<"GLASS "<<P[1]<<" "<<Fi[P[1]]<<"\n";
    //  std::cout<<f[P[0]][0]<<" "<<c[f[P[0]][0]]<<" "<<val[f[P[0]][0]]<<"\n";
    //  std::cout<<f[P[0]][1]<<" "<<c[f[P[0]][1]]<<" "<<val[f[P[0]][1]]<<"\n";
    //    ans = std::max(ans, val[f[P[0]][0]] + val[f[P[0]][1]] - 2 * depb[P[0]]);

    for (int i = 0; i < P.size(); ++i)
        key[P[i]] = 0;
}

inline void solve(int u) {
    //    if (vis[u])
    //        return ;

    //  std::cout<<"DEL "<<u<<"\n";
    vis[u] = 1;
    find(u, 0);
    val[u] = depa[u];
    c[u] = u;
    P.clear();
    P.push_back(u);

    //  std::cout<<"FUCK DIS"<<"\n";
    for (auto it : A[u]) {
        int v = it.first;
        ll vi = it.second;

        if (vis[v])
            continue;

        val[v] = vi;
        c[v] = v;
        dis(v, v);
    }

    build();

    for (auto it : A[u]) {
        int v = it.first;
        sum = siz[v], root = 0;

        if (vis[v])
            continue;

        find(v, 0);
        //      std::cout<<u<<" FUCK "<<v<<" = "<<root<<"\n";
        solve(root);
    }
}

signed main() {
    //  freopen("q.in","r",stdin);
    //  freopen("q.out","w",stdout);
//    scanf("%d", &n);
	n = iread();
    for (int i = 1; i < n; ++i) {
        int x = iread(), y = iread();
        ll v = lread();
        A[x].push_back(mp(y, v));
        A[y].push_back(mp(x, v));
    }

    for (int i = 1; i < n; ++i) {
        int x = iread(), y = iread();
        ll v = lread();
        B[x].push_back(mp(y, v));
        B[y].push_back(mp(x, v));
    }

    dfsa(1, 0);
    dfsb(1, 0);
    maxn[0] = n * 2;
    root = 0;
    sum = n;
    find(1, 0);
    solve(root);
    ans = ans / 2;

    for (int i = 1; i <= n; ++i)
        ans = std::max(ans, 2ll * depa[i] - depa[i] - depb[i]);

    std::cout << ans << "\n";
}

/*
6
1 2 2
1 3 0
2 4 1
2 5 -7
3 6 0
1 2 -1
2 3 -1
2 5 3
2 6 -2
3 4 8
*/
posted @ 2022-03-09 19:13  fhq_treap  阅读(49)  评论(2编辑  收藏  举报