[USACO12NOV]Balanced Trees G

tag:点分治


怎么两篇题解都没了,我来补一篇

这种树上路径问题很容易想到是点分治,考虑如何计算两条路径拼起来的答案。

首先一定是(...(...(...(...这样一条路径和...)...)...)这样一条路径拼起来,然后因为是求 \(\max\),所以求出两边的最大深度再取 \(\max\) 就行了。这里以左边为例。


假设当前是一条合法路径,那么当前的答案就是处理括号序列时用的那个栈的历史最大值,所以在dfs的时候拿个mx变量记录一下。

如果当前是一个“(”,而且栈为空,就要把mx变量加一,然后可以忽略掉这个“(”【因为拼起来的路径是合法的,这里就可以默认右边有一个“)”把它匹配掉了】。

比如说当前是“(”,父亲到重心是“()(())”,mx\(2\),由于右边路径一定有一个“)”与当前的“(”匹配,所以两段路径拼起来是“( ()(()) ... )”。

相当于中间那部分括号的深度整体+1,所以让mx+1就行。

对于右边路径的处理也是类似的,可以结合代码理解一下:

/*
up为括号序列栈顶
mxup为历史最大值
cntl为多余出来的,需要用右边的")"去匹配的"("
*/
if(a[x]=='(') val[x].up++;
else val[x].up--;
val[x].mxup = max(val[x].mxup,-val[x].up);
if(val[x].up>0)
    val[x].cntl++,
    val[x].mxup++,
    val[x].up = 0;

然后可以用一个桶去记录最大值,以 cntl/cntr 为下标。


注意一些小细节

  • 在dfs的时候最好令一边包含重心,另一边不包含重心

  • 不要漏了到重心的本身就合法的链

  • 在拼路径的时候要判断是否存在对应的路径


#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=0;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
	for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48)));
	if(flag)n=-n;
}

enum{
    MAXN = 40005
};

int n;

struct _{
    int nxt, to;
    _(int nxt=0, int to=0):nxt(nxt),to(to){}
}edge[MAXN<<1];
int fst[MAXN], tot;

inline void Add_Edge(int f, int t){
    edge[++tot] = _(fst[f], t); fst[f] = tot;
    edge[++tot] = _(fst[t], f); fst[t] = tot;
}

char a[MAXN];

inline void upd(int &x, int y){x = max(x,y);}

int sz[MAXN], Size, Weight, Center;
char vis[MAXN], Vis[MAXN];
void Get_Center(int x){
    vis[x] = true;
    sz[x] = 1;
    int w=0;
    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(vis[v] or Vis[v]) continue;
        Get_Center(v);
        sz[x] += sz[v];
        upd(w,sz[v]);
    }   
    upd(w,Size-sz[x]);
    if(w < Weight)
        Weight = w, Center = x;
    vis[x] = false;
}

struct ele{int mxup, mxdown, up, down, cntl, cntr;}val[MAXN];

int q[MAXN], top;
int mxl[MAXN], mxr[MAXN];
void dfs(int x){
    vis[x] = true; q[++top] = x; sz[x] = 1;

    if(a[x]=='(') val[x].up++; else val[x].up--;
    upd(val[x].mxup,-val[x].up);
    if(val[x].up>0)
        val[x].cntl++,
        val[x].mxup++,
        val[x].up = 0;

    if(a[x]=='(') val[x].down++; else val[x].down--;
    upd(val[x].mxdown,val[x].down);
    if(val[x].down<0)
        val[x].cntr++,
        val[x].mxdown++,
        val[x].down = 0;

    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(vis[v] or Vis[v]) continue;
        val[v] = val[x]; dfs(v);
        sz[x] += sz[v];
    }
    vis[x] = false;
}

int ans, dep;
void solve(int x){
    Weight = Size; Get_Center(x); x = Center; Vis[x] = true;

    ele base = (ele){0,0,0,0,0,0};
    if(a[x]=='(') base.cntl = 1; else base.up = -1; base.mxup = 1;
    if(!base.up) mxl[base.cntl] = base.mxup;
    mxr[0] = 0;

    int prv; top = 0;
    int ml=0, mr=0;
    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(Vis[v]) continue;
        val[v] = base; prv = top+1;
        dfs(v);
        for(register int i=prv; i<=top; i++){
            ele cur = val[q[i]];
            upd(ml,cur.cntl); upd(mr,cur.cntr);
            if(!cur.up and ~mxr[cur.cntl]) upd(ans,max(mxr[cur.cntl],cur.mxup));
            if(!cur.down and ~mxl[cur.cntr]) upd(ans,max(mxl[cur.cntr],cur.mxdown));
        }
        for(register int i=prv; i<=top; i++){
            ele cur = val[q[i]];
            if(!cur.up) upd(mxl[cur.cntl],cur.mxup);
            if(!cur.down) upd(mxr[cur.cntr],cur.mxdown);
        }
    }
    fill(mxl,mxl+ml+1,-1); fill(mxr,mxr+mr+1,-1);

    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(Vis[v]) continue;
        Size = sz[v]; solve(v);
    }
}

char tmp[10];

int main(){
    freopen("1.in","r",stdin);
	freopen("2.out","w",stdout);
    // double tt=clock();
    Read(n);
    for(register int i=2; i<=n; i++){
        int fa; scanf("%d",&fa);
        Add_Edge(i,fa);
    }
    memset(mxl,-1,sizeof mxl);
    memset(mxr,-1,sizeof mxr);
    for(register int i=1; i<=n; i++) scanf("%s",tmp), a[i] = tmp[0];
    Size = n; solve(1);
    cout<<ans<<endl;
    // printf("%.6lf\n",(clock()-tt)/CLOCKS_PER_SEC);
    return 0;
}
posted @ 2021-06-26 13:25  oisdoaiu  阅读(25)  评论(0编辑  收藏  举报