听风

题面:不存在的

我们先dfs一次,只保留每个分叉最上面的颜色,并求到根的前缀和,这就是每个点到根的和,记为s1,然后我们求子树里的,这里只保留了最上面的颜色,那么我们求树链的并,做树上前缀和,这就是子树颜色的和,记为s2,s1+s2就是子树和到根颜色的并,因为只保留了最上层的颜色,也就是说对于一个点,某种颜色只会存在在子树中或到根的路径上,不会重复统计,然后是风,我们先dfs一遍计算出访问序列,然后树上差分,如果这两个点是a和b,且a的访问时间早于b,那么我们让a=fa[a][0],因为size[u]-1,所以自然是第一个访问的不算,然后就是按访问顺序加入每个点的贡献,计算答案就行了。

这道题集合了很多树上方法,值得一做

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
struct edge {
    int nxt, to;
} e[N << 1];
int n, m, top, dfs_clock, cnt = 1, C;
ll ans, sum;
int mir[N], head[N], vis[N], in[N], out[N], a[N], fa[N][21], tim[N], val[N], st[N], dep[N];
vector<int> G[N], c[N];
ll s1[N], s2[N], s3[N];
bool cmp1(const int &i, const int &j) {
    return val[i] < val[j];
}
bool cmp2(const int &i, const int &j) {
    return in[i] < in[j];
}
void link(int u, int v)
{
    e[++cnt].nxt = head[u];
    head[u] = cnt;
    e[cnt].to = v;
}
void dfs(int u, int last)
{
    in[u] = ++dfs_clock;
    mir[in[u]] = u;
    val[u] = u;
    if(!vis[a[u]]) vis[a[u]] = 1, ++s1[u];
    else a[u] = 0;
    for(int i = head[u]; i; i = e[i].nxt) if(e[i].to != last) 
    {
        G[u].push_back(e[i].to);
        fa[e[i].to][0] = u;
        dep[e[i].to] = dep[u] + 1;
        s1[e[i].to] += s1[u];
        dfs(e[i].to, u);
        val[u] = min(val[u], val[e[i].to]);
    }
    sort(G[u].begin(), G[u].end(), cmp1);
    vis[a[u]] = 0;
    out[u] = dfs_clock;
}
int lca(int u, int v)
{
    if(dep[u] < dep[v]) swap(u, v);
    int d = dep[u] - dep[v];
    for(int i = 20; i >= 0; --i) if(d & (1 << i)) u = fa[u][i];
    if(u == v) return u;
    for(int i = 20; i >= 0; --i) if(fa[u][i] != fa[v][i]) 
    {
        u = fa[u][i];
        v = fa[v][i];
    }
    return fa[u][0];
}
void dfs(int u)
{
    if(a[u]) c[a[u]].push_back(u);
    for(int i = 0; i < G[u].size(); ++i) 
    {
        int v = G[u][i];
        dfs(v);
    }
    st[++top] = u;
    tim[u] = top;
}
int main()
{
    scanf("%d%d%d", &n, &m, &C);
    for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    for(int i = 1; i < n; ++i)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        link(u, v);
        link(v, u);
    }
    dfs(1, 0);
    for(int j = 1; j <= 20; ++j)
        for(int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    dfs(1);
    for(int i = 1; i <= C; ++i) if(c[i].size()) 
    {
        sort(c[i].begin(), c[i].end(), cmp2);
        ++s2[in[c[i][0]]];
        for(int j = 1; j < c[i].size(); ++j) 
        {
            int x = lca(c[i][j - 1], c[i][j]);
            --s2[in[x]];
            ++s2[in[c[i][j]]];
        }
    }
    for(int i = 1; i <= m; ++i)
    {
        int x, y, z, t;
        scanf("%d%d%d", &x, &y, &z);
        t = lca(x, y);
        if(tim[x] > tim[y]) swap(x, y);
        x = fa[x][0];
        s3[in[x]] += z;
        s3[in[y]] += z;
        s3[in[t]] -= z;
        s3[in[fa[t][0]]] -= z; 
    }
    for(int i = 1; i <= n; ++i) s2[i] += s2[i - 1], s3[i] += s3[i - 1];
    for(int i = 1; i <= top; ++i) 
    {
//        printf("st[%d] = %d\n", i, st[i]);
        ll d1 = s2[out[st[i]]] - s2[in[st[i]] - 1], d2 = s3[out[st[i]]] - s3[in[st[i]] - 1];
        sum += s1[fa[st[i]][0]] + s2[out[st[i]]] - s2[in[st[i]] - 1] + s3[out[st[i]]] - s3[in[st[i]] - 1] - dep[st[i]];
        ans = max(ans, sum);        
    }
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2017-11-02 16:57  19992147  阅读(154)  评论(0编辑  收藏  举报