NowCode-Gene Tree/牛客练习赛81D - 小 Q 与树 (树上点分治)

 

又是向杨大佬学习的一天

Gene Tree题目大意:

求树上所有叶子节点距离的平方和

 

 就是求上面这个式子

小 Q 与树 题目大意:

 

 求上面这个式子

----------------------------------------------------------------------------------------------------------

这两个题挺相似的,如果都考虑朴素做法的话,都是两层for枚举两个节点,复杂度已经到n^2了

所以考虑树上点分治,一种logn的复杂度解决树上静态询问的问题的数据结构(我瞎说的

动态的话就要点分树了(学长跟我说的)

 

Gene Tree题目思路:

这个题我们可以先处理子树间的贡献,这样的话就不需要常规分治算法的合并的这个步骤了

点分治是要每次找一个新的重心

每次以新重心为根节点处理一部分子树

每次找子树的重心最多找logN次,那么需要一个O(N) 的计算贡献的方法

如果每次新来一个叶子节点,我需要和每个已经处理好的一部分叶子节点算贡献的话,如果枚举处理好的那部分

复杂度又是O(N^2)的

假设当前重心是root

因为我们选择先处理子树间的贡献,这样可以不用最后的合并操作

设已经处理好的一部分子树有叶子节点pre个,距离的平方和为sum1,距离和为sum2

我们展开一下(a[i]+a[j])^2这个式子(新来的是a[j]这个叶子节点)

a[i]^2 + 2*a[i]*a[j] + a[j]^2(下称第一项第二项第三项)

因为a[i]_1 a[i]_2....a[i]_cnt都要算一遍

第一项的贡献就是Sum1,第二项的贡献就是2*sum2*a[j] (这里O(n)枚举就行),第三项的贡献是a[j]*a[j]*pre(同第二项枚举计算就好)

CODE:

// Problem: Gene Tree
// Contest: NowCoder
// URL: https://ac.nowcoder.com/acm/contest/15644/B
// Memory Limit: 524288 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef unsigned long long ull;
const int inf = 0x3f3f3f3f;
const long long INF = 1e18;
const int maxn = 2e5 + 7;
const ll mod = 1e9 + 7;

#define pb push_back
#define debug(x) cout << #x << ":" << x << endl;
#define mst(x, a) memset(x, a, sizeof(x))
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define dep(i, a, b) for (int i = (a); i >= (b); --i)

inline ll read() {
    ll x = 0;
    bool f = 0;
    char ch = getchar();
    while (ch < '0' || '9' < ch)
        f |= ch == '-', ch = getchar();
    while ('0' <= ch && ch <= '9')
        x = x * 10 + ch - '0', ch = getchar();
    return f ? -x : x;
}

void out(ll x) {
    int stackk[20];
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (!x) {
        putchar('0');
        return;
    }
    int top = 0;
    while (x)
        stackk[++top] = x % 10, x /= 10;
    while (top)
        putchar(stackk[top--] + '0');
}
ll qpow(ll a, ll b) {
    ll ans = 1;
    while (b) {
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
#define int ll
int n,cnt,head[maxn];
struct node{
    int u,v,w,next;
}e[maxn];
void add(int u,int v,int w)
{
    e[cnt].u=u,e[cnt].v=v,e[cnt].w=w;
    e[cnt].next = head[u],head[u]=cnt++;
}
int d[maxn];
int vis[maxn];// the root is visit
int maxson[maxn];
int siz[maxn],Smer,Mx,root;
int ans;
void  getroot(int u,int p)
{
    siz[u] = 1,maxson[u]=0;
    for(int i = head[u];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(v==p||vis[v]) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        if(maxson[u]<siz[v]) maxson[u] = siz[v];        
    }
    maxson[u] = max(maxson[u],Smer - siz[u]);
    if(maxson[u]<Mx) Mx = maxson[u],root = u;
}
int temp[maxn];
void  solve(int u,int p,int len)
{
    
    if(d[u]==1)
    {
        temp[++cnt] = len;
        return ;
    }
    
    for(int i = head[u];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(vis[v]||v==p) continue;
        solve(v,u,len+e[i].w);
    }
}
void Divide(int tr)
{
    //solve(tr,0,0);
    vis[tr]=1;
    
    int sum1=0,sum2=0;
    int pre=0;        
    for(int i = head[tr];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(vis[v]) continue;
            
        cnt=0;
        solve(v,0,e[i].w);
        

        for(int j=1;j<=cnt; j++) ans+=(sum1+2ll*temp[j]*sum2+temp[j]*temp[j]*pre);
        for(int j=1;j<=cnt; j++) sum1+=(temp[j]*temp[j]),sum2+=temp[j],pre++;
    
        Smer = siz[v];root=0;
        Mx=inf;getroot(v,0);
        Divide(root);    
    }
}
#define int int
int main() {
    // ios::sync_with_stdio(false);
    mst(head,-1);
    n = read();
    for(int i=1 ;i<n ;i++)
    {
        ll u,v,w;
        u = read(),v=read(),w=read();
        add(u,v,w),add(v,u,w);
        d[u]++,d[v]++;
    }
    Mx = inf,Smer = n;
    getroot(1,0),Divide(root);
    out(ans);
    return 0;
}
/*


*/
View Code

 

 

小 Q 与树 题目思路:

跟上面一样

step1:先处理子树之间的贡献

step2:写一个式子,如果新来了一个点,如何计算和之前所有点的贡献

step3:AC

具体操作:

计算贡献的方法,首先,我们对当前处理的所有子树的点,按照权值大小排序

假设一共有m个节点,新来的是二号节点,那么二号节点最为最小值影响到的是[3,m]的节点

val = (dis[2][root] + dis[k][root])*minn  -->这个式子可以用距离和去维护,然后维护一个前缀值就能算区间的了

But

这里我们处理子树之间的时候,因为不同于上一个题是叶子节点,这个题是所有节点,

所以会处理了一部分相同子树内部的节点,这部分贡献是要减去的

CODE:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef unsigned long long ull;
const int inf = 0x3f3f3f3f;

const int maxn = 2e5 + 7;
const ll mod = 998244353 ;

#define pb push_back
#define debug(x) cout << #x << ":" << x << endl;
#define mst(x, a) memset(x, a, sizeof(x))
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define dep(i, a, b) for (int i = (a); i >= (b); --i)

inline ll read() {
    ll x = 0;
    bool f = 0;
    char ch = getchar();
    while (ch < '0' || '9' < ch)
        f |= ch == '-', ch = getchar();
    while ('0' <= ch && ch <= '9')
        x = x * 10 + ch - '0', ch = getchar();
    return f ? -x : x;
}

void out(ll x) {
    int stackk[20];
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (!x) {
        putchar('0');
        return;
    }
    int top = 0;
    while (x)
        stackk[++top] = x % 10, x /= 10;
    while (top)
        putchar(stackk[top--] + '0');
}
ll qpow(ll a, ll b) {
    ll ans = 1;
    while (b) {
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}



int n,cnt,head[maxn*2];
struct node{
    int u,v,w,next;
}e[maxn*2];
void add(int u,int v,int w)
{
    e[cnt].u=u,e[cnt].v=v,e[cnt].w=w;
    e[cnt].next = head[u],head[u]=cnt++;
}
int vis[maxn];// the root is visit
int maxson[maxn];
int siz[maxn],Smer,Mx,root;
int ans;
void  getroot(int u,int p)
{
    siz[u] = 1,maxson[u]=0;
    for(int i = head[u];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(v==p||vis[v]) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        if(maxson[u]<siz[v]) maxson[u] = siz[v];        
    }
    maxson[u] = max(maxson[u],Smer - siz[u]);
    if(maxson[u]<Mx) Mx = maxson[u],root = u;
}
int val[maxn],sum[maxn];
struct Node{
    int id,dis,val;
}no[maxn];
int tot;
bool cmp(Node x ,Node y ){return x.val<y.val;}
void get_dis(int u,int p,int len)
{
    no[++tot] = {u,len,val[u]};
    for(int i=head[u];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(v==p||vis[v]) continue;
        get_dis(v,u,len+e[i].w);
    }
}
ll solve(int u,int p,int len)
{
    tot=0;
    get_dis(u,p,len);
    sort(no+1,no+1+tot,cmp);
    ll temp=0;
    for(int i=1 ;i<=tot; i++) sum[i] = (sum[i-1] + no[i].dis)%mod;
    for(int i=1 ;i<=tot ; i++)
    {
        temp=((temp+((tot-i+1)%mod*(no[i].val*no[i].dis)%mod)%mod+(no[i].val*(sum[tot] - sum[i])))%mod+mod)%mod;
    }
    return temp*2%mod;
}
void Divide(int tr)
{
    ans+=solve(tr,0,0);
    ans%=mod;
    vis[tr]=1;   
    for(int i = head[tr];~i;i=e[i].next)
    {
        int v = e[i].v;
        if(vis[v]) continue;
        ans=((ans - solve(v,0,e[i].w))%mod + mod )%mod;
        Smer = siz[v];root=0;
        Mx=inf;getroot(v,0);
        Divide(v);    
    }
}

int main() {
    // ios::sync_with_stdio(false);
    mst(head,-1);
    n = read();
    rep(i,1,n) val[i] = read();
    for(int i=1 ;i<n ;i++)
    {
        ll u,v,w;
        u = read(),v=read(),w=1;
        add(u,v,w),add(v,u,w);
    }
    Mx = inf,Smer = n;
    getroot(1,0),Divide(root);
    out(ans);
    return 0;
}
/*


*/
View Code

 

posted @ 2021-05-07 18:12  UpMing  阅读(74)  评论(0编辑  收藏  举报