『ZJOI2019 D2T2』语言

~~ 话说,本题考场想出三只\(log\)的暴力做法,被卡成暴力了。~~

题目分析

首先考虑枚举每一个点,计算这个点可以和多少点进行交易。

将所有经过该点的路径\(s,t\)拿出,那么这些极远的\(s,t\)构成的连通块大小\(sz - 1\)就是答案。

\(Codeforces\)\(异象石\)那题可以想到,若一些点集按照\(dfs\)序排序,那么这些点构成连通块大小就是

\(\frac{1}{2} (dist(a_1 , a_2) + dist(a_2,a_3) + ... + dist(a_{k-1} , a_k) + dist(a_k,a_1))\)

考虑对于每一个节点开一棵线段树,其叶子节点\(i\)表示\(dfs\)序为\(i\)的极远点出现次数。

线段树中存储\(3\)个值\(lp,rp,Sum\)分别表示当前存在的大于\(0\)的最小下标和最大下标,和不算头尾的连通块大小。

由于路径条数为\(m\),显然我们可以用可持久化线段树来维护这\(n\)棵线段树,使得空间复杂度为\(O(m log_2 n)\)

利用树上差分的思想,对于每一条\(s,t\)的路径,我们先在\(s\)\(t\)所在的线段树中将\(dfn[s]\)\(dfn[t]\)两个点单点\(+1\)

然后在\(father(lca(s,t))\)的节点,将\(dfn[s]\)\(dfn[t]\)两个点单点\(-2\)

于是,我们可以自下往上去统计每个节点的答案。

每一次,我们需要对该节点的所有子树进行线段树合并,然后询问这个节点的答案,将其累加进总个数中。

这样,我们就完成了无序数对的统计,那么此时答案除以\(2\)就是最终的答案。

复杂度分析

由于\(n\)次线段树合并节点总数是\(m\)个,所以需要时间复杂度为\(O(m log_2 n)\)

由于\(m\)次线段单点修改,使用\(O(1)\)\(LCA\)实现,所以需要时间复杂度为\(O(m log_2 n)\)

所以,本题的总时间复杂度就是\(O(m log_2 n)\)

# include<bits/stdc++.h>
# define int long long
# define inf (1e9)
using namespace std;
const int N=1e5+10;
struct rec{ int pre,to;}a[N<<1];
int dep[N],head[N],dfn[N],root[N],acr[N],g[N];
int n,m,tot,ans;
namespace fast_IO{
    const int IN_LEN = 10000000, OUT_LEN = 10000000;
    char ibuf[IN_LEN], obuf[OUT_LEN], *ih = ibuf + IN_LEN, *oh = obuf, *lastin = ibuf + IN_LEN, *lastout = obuf + OUT_LEN - 1;
    inline char getchar_(){return (ih == lastin) && (lastin = (ih = ibuf) + fread(ibuf, 1, IN_LEN, stdin), ih == lastin) ? EOF : *ih++;}
    inline void putchar_(const char x){if(oh == lastout) fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; *oh ++= x;}
    inline void flush(){fwrite(obuf, 1, oh - obuf, stdout);}
    int read(){
        int x = 0; int zf = 1; char ch = ' ';
        while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar_();
        if (ch == '-') zf = -1, ch = getchar_();
        while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar_(); return x * zf;
    }
    void write(int x){
        if (x < 0) putchar_('-'), x = -x;
        if (x > 9) write(x / 10);
        putchar_(x % 10 + '0');
    }
}
using namespace fast_IO;
namespace LCA {
    int ST[N << 1][22], value[N << 1], depth[N << 1], first[N], dist[N], cnt;
    inline int calc(int x, int y) {
        return depth[x] < depth[y] ? x : y;
    }
    inline void dfs(int u, int p, int d) {
        value[++cnt] = u; depth[cnt] = d; first[u] = cnt;
        for (int i = head[u]; i; i = a[i].pre) {
            int v = a[i].to;
            if (v == p) continue;
            dist[v] = dist[u] + 1;
            dfs(v, u, d + 1);
            value[++cnt] = u; depth[cnt] = d;
        }
    }
    inline void init(int root, int node_cnt) {
    	cnt = 0; dist[root] = 0;
        dfs(root, 0, 1);
        int n = 2 * node_cnt - 1;
        for (int i = 1; i <= n; i++) ST[i][0] = i;
        for (int j = 1; j < 22; j++) 
            for (int i = 1; i + (1 << j) - 1 <= n; i++) 
                ST[i][j] = calc(ST[i][j - 1], ST[i + (1 << (j - 1))][j - 1]);
    }
    inline int query(int x, int y) {
        int l = first[x], r = first[y];
        if (l > r) std::swap(l, r);
        int k = log2(r - l + 1);
        return value[calc(ST[l][k], ST[r - (1 << k) + 1][k])];
    }
}
void adde(int u,int v) {
	a[++tot].pre=head[u];
	a[tot].to=v;
	head[u]=tot;
}
void dfs1(int u,int fa) {
	dfn[u]=++dfn[0]; acr[dfn[u]]=u;
	dep[u]=dep[fa]+1,g[u]=fa;
	for (int i=head[u];i;i=a[i].pre) {
		int v=a[i].to; if (v==fa) continue;
		dfs1(v,u);
	}
}
int lca(int u,int v) {
	return LCA::query(u,v);	
}
int dist(int u,int v) {
	int l=lca(u,v);
	return dep[u]+dep[v]-2*dep[l];
}
struct Seg {
	int ls,rs,lp,rp,sum,val;
	Seg() { sum=ls=rs=0; lp=inf; rp=-inf;}
}tr[N*70];
# define ls(x) tr[x].ls
# define rs(x) tr[x].rs
# define mid (l+r>>1)
# define lson ls(x),l,mid
# define rson rs(x),mid+1,r
int cnt=0;
void up(int &x) {
	if (ls(x)!=0) tr[x].lp=min(tr[x].lp,tr[ls(x)].lp),tr[x].rp=max(tr[x].rp,tr[ls(x)].rp);
	if (rs(x)!=0) tr[x].lp=min(tr[x].lp,tr[rs(x)].lp),tr[x].rp=max(tr[x].rp,tr[rs(x)].rp);
	int ret=0;
	if (ls(x)) ret+=tr[ls(x)].sum;
	if (rs(x)) ret+=tr[rs(x)].sum;
	if (ls(x) && rs(x) && tr[ls(x)].rp!=-inf && tr[rs(x)].lp!=inf) ret+=dist(acr[tr[ls(x)].rp],acr[tr[rs(x)].lp]);
	tr[x].sum=ret;
}
void update(int &x,int l,int r,int pos,int d) {
	if (!x) x=++cnt;
	if (l==r) {
		tr[x].val+=d;
		if (tr[x].val>0) tr[x].lp=tr[x].rp=l;
		else tr[x].lp=inf,tr[x].rp=-inf; 
		tr[x].sum=0; 
		return;
	}
	if (pos<=mid) update(lson,pos,d);
	else update(rson,pos,d);
	up(x);
}
void merge(int &x,int y,int l,int r) {
	if (!x || !y) {x=x+y; return;}
	if (l==r) {
		tr[x].val+=tr[y].val;
		if (tr[x].val>0) {
			tr[x].lp=min(tr[x].lp,tr[y].lp);
			tr[x].rp=max(tr[x].rp,tr[y].rp);
		} else tr[x].lp=inf,tr[x].rp=-inf; 
		tr[x].sum=0; 
		return;
	}
	merge(ls(x),ls(y),l,mid);
	merge(rs(x),rs(y),mid+1,r);
	up(x);
}
void dfs2(int u,int fa) {
	for (int i=head[u];i;i=a[i].pre) {
		int v=a[i].to; if (v==fa) continue;
		dfs2(v,u);merge(root[u],root[v],1,n);
	}
	ans+=tr[root[u]].sum;
	if (tr[root[u]].lp!=inf && tr[root[u]].rp!=-inf)
		ans+=dist(acr[tr[root[u]].lp],acr[tr[root[u]].rp]);
}
signed main()
{
	n=read();m=read();
	memset(root,0,sizeof(root));
	for (int i=2;i<=n;i++) {
		int u=read(),v=read(); 
		adde(u,v); adde(v,u);
	}
	dfs1(1,0);
	LCA::init(1,n);
	for (int i=1;i<=m;i++) {
		int u=read(),v=read(); 
		update(root[u],1,n,dfn[u],1);
		update(root[u],1,n,dfn[v],1);
		update(root[v],1,n,dfn[u],1);
		update(root[v],1,n,dfn[v],1);
		int l = g[lca(u,v)];
		if (l) {
			update(root[l],1,n,dfn[u],-2);
			update(root[l],1,n,dfn[v],-2);
		}
	}
	dfs2(1,0);
	write(ans/4); putchar_('\n');
	flush(); 
	return 0;
 } 
posted @ 2019-10-04 22:22  ljc20020730  阅读(174)  评论(0编辑  收藏  举报