题解 LGP8805【[蓝桥杯 2022 国 B] 机房】

出处

应该叫 "静态离线查询树链半群信息的一种并查集做法"

problem

一棵树,点上有一些满足结合律的信息,\(m\) 次询问求出一条链上的点权之“和”,允许离线\(n,m\leq 10^5\)

前置知识 1:Tarjan LCA

离线 dfs,做完一棵子树后将这棵子树的 \(fa\) 设为 \(u\),然后看一下 \(u\) 的所有的询问,如果 \(v\) 被访问过,那么直接 \(find(v)\) 就是 \(lca\)

void dfs(int u,int fa=0){
	vis[u]=1;
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		dfs(v,u),s.fa[v]=u;
	}
	for(int i=que.head[u];i;i=que.nxt[i]){
		int v=que[i].v;
		if(vis[v]) que[i].w=que[i^1].w=s.find(v);
	}
}

前置知识 2:带权并查集

一个并查集,额外维护一个 \(ans_x\) 表示 \(x\to fa_x\) 这条边的答案。路径压缩时更新 \(ans\)

int find(int x){
		if(fa[x]==x) return x;
		int g=find(fa[x]);
		ans[x]+=ans[fa[x]];//现在 ans[fa[x]] 存了 fa[x]->g 的信息
		return fa[x]=g;
	}

solution

将这两个东西拼起来!

具体地,你维护一个带权并查集,边权是深度更深的点的点权。

求一遍 Tarjan LCA,这时候把询问挂在 LCA 上,回溯到 LCA 时,询问的 \(u\to v\) 可拆成 \((u\to lca)+(lca)+(lca\to v)\) 三段,都是带权并查集维护过的,拼起来就是了。

因为有 \(lca\to v\) 这一段,你很有可能需要同时维护向上和向下两个方向的权值。这是简单的。

复杂度:\(O(n\alpha(n))\) 乘上一次“加法”的复杂度。常数是并查集常数,很小(不会有人卡并查集吧?)

应用:有结合律和单位元的东西都可以用,不需要交换律,包括但不限于:\(\sum,\prod,\min,\max,\gcd\),矩阵乘法,最大子段和,等等。通用的。搬到序列上貌似也可以。

example:[蓝桥杯 2022 国 B] 机房

其实是查询一条链的点权和,可以树上前缀和,但是我们要创新!

Code
#include <cstdio>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template<int N,class T> struct dsy{
	int fa[N+10],siz[N+10],cnt; T ans[N+10][2];
	dsy(int n=N):cnt(n){for(int i=1;i<=N;i++) fa[i]=i,siz[i]=1;}
	int find(int x){
		if(fa[x]==x) return x;
		int g=find(fa[x]);
		ans[x][0]=ans[x][0]+ans[fa[x]][0];
		ans[x][1]=ans[fa[x]][1]+ans[x][1];
		return fa[x]=g;
	}
	T query(int x,int k){return find(x),ans[x][k];}
}; 
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head);}
    edge&operator[](int i){return e[i];}
    void add(int u,int v,T w=0){e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
int n,m,ret[100010],a[100010];
bool vis[100010],svd[100010];
graph<100010,100010> g,que,sol;
dsy<100010,int> s;
void dfs(int u,int fa=0){
	vis[u]=1;
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		dfs(v,u),s.fa[v]=u,s.ans[v][0]=s.ans[v][1]=a[v];
	}
	for(int i=que.head[u];i;i=que.nxt[i]){
		int v=que[i].v,id=que[i].w; if(!vis[v]||svd[id]) continue;
		sol.add(s.find(v),u,id),svd[id]=1;
	}
	for(int i=sol.head[u];i;i=sol.nxt[i]){
		int id=sol[i].w,x=que[id*2-1].u,y=que[id*2-1].v;
		ret[id]=s.query(x,0)+a[u]+s.query(y,1);
	}
}
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d%d",&n,&m);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),g.link(u,v),a[u]++,a[v]++;
	for(int i=1,u,v;i<=m;i++) scanf("%d%d",&u,&v),que.link(u,v,i);
	dfs(1);
	for(int i=1;i<=m;i++) printf("%d\n",ret[i]);
	return 0;
}

example:[CSP-S 2022] 数据传输

转移矩阵略。那么变成维护矩阵乘法乘积。\(O(k^3\alpha(n)n)\)

Code
#include <cstdio>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template<int N,int M,class T=LL> struct matrix{
	T a[N][M];
	matrix(T flag=1e18){memset(a,0x3f,sizeof a);for(int i=0;i<N&&i<M;i++) a[i][i]=flag;}
	T* operator[](int i){return a[i];}
//	void print(const char*s){
//		debug("matrix %s:\n",s);
//		for(int i=0;i<N;i++){
//			for(int j=0;j<M;j++){
//				debug("%lld%c",a[i][j]," \n"[j==M-1]);
//			}
//		}
//	}
};
template<int N,int M,int R,class T=LL> matrix<N,R,T> operator*(matrix<N,M,T> a,matrix<M,R,T> b){
	matrix<N,R,T> c=1e18;
	for(int i=0;i<N;i++){
		for(int j=0;j<M;j++){
			for(int k=0;k<R;k++){
				c[i][k]=min(c[i][k],a[i][j]+b[j][k]);
			}
		}
	}
	return c;
};
template<int N,class T> struct dsy{
	int fa[N+10],siz[N+10],cnt; T ans[N+10][2];
	dsy(int n=N):cnt(n){for(int i=1;i<=N;i++) fa[i]=i,siz[i]=1,ans[i][0]=ans[i][1]=0;}
	int find(int x){
		if(fa[x]==x) return x;
		int g=find(fa[x]);
		ans[x][0]=ans[x][0]*ans[fa[x]][0];
		ans[x][1]=ans[fa[x]][1]*ans[x][1];
		return fa[x]=g;
	}
	T query(int x,int k){return find(x),ans[x][k];}
}; 
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head);}
    edge&operator[](int i){return e[i];}
    void add(int u,int v,T w=0){e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
int n,m,sshwy;
LL ret[200010],a[200010],b[200010];
bool vis[200010],svd[200010];
graph<200010,200010> g,que,sol;
dsy<200010,matrix<3,3>> s;
matrix<1,3> f0;
matrix<3,3> gettrans(int i){
	matrix<3,3> c=1e18; 
	switch(sshwy){
		case 1: c[0][0]=a[i]; break;
		case 2: c[0][1]=0,c[0][0]=c[1][0]=a[i]; break;
		case 3: c[0][0]=c[1][0]=c[2][0]=a[i],c[0][1]=c[1][2]=0,c[1][1]=b[i]; break;
	}
	return c;
}
matrix<3,1> getend(int i){
	matrix<3,1> c;
	c[0][0]=0,c[1][0]=a[i],c[2][0]=a[i];
	return c;
}
void dfs(int u,int fa=0){
	vis[u]=1;
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		dfs(v,u),s.fa[v]=u,s.ans[v][0]=s.ans[v][1]=gettrans(v);
	}
	for(int i=que.head[u];i;i=que.nxt[i]){
		int v=que[i].v,id=que[i].w; if(!vis[v]||svd[id]) continue;
		sol.add(s.find(v),u,id),svd[id]=1;
	}
	for(int i=sol.head[u];i;i=sol.nxt[i]){
		int id=sol[i].w,x=que[id*2-1].u,y=que[id*2-1].v;
//		debug("calculating query (%d,%d), lca=%d...\n",x,y,u);
//		f0.print("f0");
//		s.query(x,0).print("s.query(x,0)");
//		gettrans(u).print("gettrans(u)");
//		s.query(y,1).print("s.query(y,1)");
//		getend(y).print("getend(y)");
//		(f0*s.query(x,0)*gettrans(u)*s.query(y,1)*getend(y)).print("tot");
		ret[id]=(f0*s.query(x,0)*gettrans(u)*s.query(y,1)*getend(y))[0][0];
	}
}
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d%d%d",&n,&m,&sshwy);
	f0[0][sshwy-1]=0,a[0]=1e18;
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	memset(b,0x3f,sizeof b);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),g.link(u,v),b[u]=min(b[u],a[v]),b[v]=min(b[v],a[u]);
	for(int i=1,u,v;i<=m;i++) scanf("%d%d",&u,&v),que.link(u,v,i);
	dfs(1);
	for(int i=1;i<=m;i++) printf("%lld\n",ret[i]);
	return 0;
}
posted @ 2022-11-11 21:09  caijianhong  阅读(77)  评论(0编辑  收藏  举报