树上启发式合并

首先可以了解一下启发式合并,这个可以看我之前的博客虽然两者关系不大

该算法英文名为\(dsu\ on\ tree\),最先以成型的算法出现是在\(Codeforces\)这篇博客上

树上启发式合并可以在\(O(nlogn)\)的时间复杂度内离线解决很多无修改子树询问。

先由一个例子引入:树上每个点有一种颜色,询问子树颜色个数。

在线算法我们可以用\(dfs\)\(+\)主席树。

离线算法呢?

我们用\(vis_i\)表示子树内\(i\)是否出现,\(cnt_i\)表示颜色个数。这个东西是支持\(O(1)\)修改的。

先考虑暴力,对每一子树\(dfs\)一遍统计答案。时间复杂度\(O(n^2)\)

\(dfs\)\(+\)序列莫队,但复杂度是\(O(n\sqrt{n})\)的。

树上启发式合并怎么做?

我们发现答案可以从儿子节点获取,但不能直接获取,这样空间复杂度是\(O(n^2)\)的。

预处理重儿子(即子树节点最多的儿子)

先递归处理非重儿子的答案,并且不获取非重儿子的答案,即清空\(vis\)数组

然后处理重儿子的答案,并且获取重儿子的答案

最后再次递归计算非重儿子的答案,并且暴力合并得到该点的答案。

该算法的复杂度是什么?前面说了是\(O(nlogn)\)的。

我们需要证明一个引理。

根节点出发的任意路径上轻边(不连向重儿子的边)条数\(\leq logn\)

证明考虑每次到非重儿子子树大小减少一半以上,最多减\(logn\)次。

统计一个点的答案是,重儿子的子树内点的遍历次数是不需计入该点的(那些点自己本身也要遍历一次)。

考虑每个点被遍历的次数,即为到根的轻边数,复杂度为\(O(logn)\)

总复杂度为\(O(nlogn)\)

一道例题:CF600E

和上面那题做法差不多,就当模板题做啦。

#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int N=100005;
int size[N], son[N], cnt[N], col[N], skip[N], Max;
long long ans[N], sum;
vector<int> G[N];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

void dfs(int u, int fa)
{
	size[u]=1;
	for (int v: G[u]) if (v^fa)
	{
		dfs(v, u); size[u]+=size[v];
		if (!son[u] || size[v]>size[son[u]]) son[u]=v;
	}
}

void modify(int u, int fa, int k)
{
	cnt[col[u]]+=k;
	if (~k && cnt[col[u]]>=Max)
	{
		if (cnt[col[u]]>Max) sum=0, Max=cnt[col[u]];
		sum+=col[u];
	}
	for (int v: G[u]) if (v^fa && !skip[v]) modify(v, u, k);
}

void solve(int u, int fa, bool flag)
{
	for (int v: G[u]) if (v^fa && v^son[u]) solve(v, u, 0);
	if (son[u]) solve(son[u], u, 1), skip[son[u]]=1;
	modify(u, fa, 1); ans[u]=sum;
	if (son[u]) skip[son[u]]=0; 
	if (!flag) modify(u, fa, -1), Max=sum=0;
}

int main()
{
	int n=read();
	rep(i, 1, n) col[i]=read();
	rep(i, 1, n-1)
	{
		int u=read(), v=read();
		G[u].push_back(v); G[v].push_back(u);	
	}
	dfs(1, 0); solve(1, 0, 0);
	rep(i, 1, n) printf("%lld ", ans[i]);
	return 0;
}

CF570D

\(cnt_i\)\(i\)点子树内每个字母奇偶性的二进制状态。

只有\(cnt_i=0/2^k\)时合法,这个用\(lowbit\)检验即可。

然后就是树上启发式合并模板啦。

#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int N=500005;
int size[N], son[N], cnt[N], skip[N], dep[N], ans[N];
vector<pair<int, int> >q[N];
vector<int> G[N];
char s[N];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

void dfs(int u, int fa)
{
    dep[u]=dep[fa]+1; size[u]=1;
    for (int v: G[u])
    {
        dfs(v, u); size[u]+=size[v];
        if (!son[u] || size[son[u]]<size[v]) son[u]=v;
    }
}

bool check(int x){return !(x&(x-1));}

void modify(int u)
{
    cnt[dep[u]]^=1<<(s[u]-'a');
    for (int v: G[u]) if (!skip[v]) modify(v);
}

void solve(int u, int flag)
{
    for (int v: G[u]) if (v^son[u]) solve(v, 0);
    if (son[u]) solve(son[u], 1), skip[son[u]]=1;
    modify(u); skip[son[u]]=0;
    for (auto i: q[u]) ans[i.second]=check(cnt[i.first]);
    if (!flag) modify(u);
}

int main()
{
    int n=read(), m=read();
    rep(i, 2, n) G[read()].push_back(i);
    scanf("%s", s+1); 
    rep(i, 1, m) {int v=read(), h=read(); q[v].push_back(make_pair(h, i));}
    dfs(1, 0); solve(1, 0);
    rep(i, 1, m) puts(ans[i]?"Yes":"No");
    return 0;
}

CF741D

算法发明人出的题。据说坑了很多人

还是记一个上题的\(cnt_i\)一样的东西,不过记录的是到根的路径。

然后开一个桶\(f_i\)记录\(cnt\)\(i\)的最大深度,然后按照点分治的思路统计答案。

然后统计答案的时候就需要用到\(dsu\ on\ tree\)了。

#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
inline void chkmax(int &x, int y){x<y?(x=y):0;}
const int N=500005;
vector<pair<int, int> > G[N];
int size[N], in[N], out[N], id[N], dep[N], son[N], tot;
int Xor[N], f[1<<22], ans[N];

inline int read()
{
 	int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

#define v i.first
#define w i.second

void dfs(int u)
{
    size[u]=1; id[in[u]=++tot]=u;
    for (auto i: G[u]) 
    {
        dep[v]=dep[u]+1; Xor[v]=Xor[u]^w;
        dfs(v); size[u]+=size[v]; 
        if (size[v]>size[son[u]]) son[u]=v;
    }
    out[u]=tot;
}

void solve(int u, int flag)
{
    for (auto i: G[u]) if (v^son[u]) solve(v, 0), chkmax(ans[u], ans[v]);
    if (son[u]) solve(son[u], 1), chkmax(ans[u], ans[son[u]]);
    if (f[Xor[u]]) chkmax(ans[u], f[Xor[u]]-dep[u]);
    rep(i, 0, 21) if (f[Xor[u]^(1<<i)]) 
        chkmax(ans[u], f[Xor[u]^(1<<i)]-dep[u]);
    chkmax(f[Xor[u]], dep[u]);
    for (auto i: G[u]) if (v^son[u])
    {
        rep(j, in[v], out[v])
        {
            if (f[Xor[id[j]]]) 
                chkmax(ans[u], f[Xor[id[j]]]+dep[id[j]]-(dep[u]<<1));
            rep(k, 0, 21) if (f[Xor[id[j]]^(1<<k)])
                chkmax(ans[u], f[Xor[id[j]]^(1<<k)]+dep[id[j]]-(dep[u]<<1));
        }
        rep(j, in[v], out[v]) chkmax(f[Xor[id[j]]], dep[id[j]]);
    }
    if (!flag) rep(i, in[u], out[u]) f[Xor[id[i]]]=0;
}

#undef v
#undef w

int main()
{
    int n=read();
    rep(i, 2, n)
    {
        int p=read(); char c=getchar();
        G[p].push_back(make_pair(i, 1<<(c-'a')));
    }
    dep[1]=1; dfs(1); solve(1, 0);
    rep(i, 1, (1<<22)-1) if (f[i]) printf("%d %d\n", i, f[i]);
    rep(i, 1, n) printf("%d ", ans[i]);
    return 0;
}
posted @ 2019-04-25 09:54  OIerC  阅读(451)  评论(0编辑  收藏  举报