【bzoj3589】动态树

Portal --> bzoj3589

Description

  给你一棵\(n\)个节点的树,总共有\(q\)次操作,每次操作是以下两种中的一种:

操作\((0,x,delta)\):给以\(x\)为根的子树中每个节点的点权\(+delta\)

操作\((1,k,(x_1,y_1),(x_2,y_2)...(x_k,y_k))\):求\(k\)条链的并的权值之和,一个节点集合的权值之和定义为该集合中所有节点的点权之和

  点权初始为\(0\),链保证是某个节点到根的路径上的某一段

​  数据范围:\(1<=n,q<=2*10^5,1<=k<=5\)

Solution

  这题的话。。虽然说网上好像有根本不需要用到容斥和链的特性直接一个线段树就可以搞定的做法。。但是我不会qwq所以还是容斥吧qwq

​  因为这题中的链有十分优秀的性质,都是到根路径上的某一段,所以我们考虑维护每一个节点到根的路径上的点权和,这样一旦求得并问题就很好解决了

​  子树加操作的话直接用树剖+线段树或者树状数组来搞就好了

​  线段树直接搞就不讲了,如果用树状数组的话就是开两个支持区间修改单点查询的树状数组

​  我们考虑对于一次修改操作\((x,delta)\)\(y\)\(x\)子树中的一个节点,考虑这次修改对\(y\)这个位置的值的影响,应该是加上\((dep[y]-(dep[x]-1))*delta\)

​  那么我们可以考虑前面的\(dep[y]*delta\)\(-(dep[x]-1)*delta\)分开维护

​  我们在一个BIT(记为BIT1)中的\(x\)子树范围内的每个节点\(+(dep[x]-1)*delta\),然后在另一个BIT(记为BIT2)中的\(x\)子树范围内每个节点\(+delta\)

​  最后查询就直接BIT2::query(x)*dep[x]-BIT1::query(x),就能够得到\(x\)到根路径上所有节点的权值和了(当然你也可以第二个BIT修改的时候是\(+dep[x]*delta\),查询的时候就不用在外面乘\(dep[x]\),一样的)

​​  
​  接下来就是求并了
  
​  首先注意到每次询问的这个链的数量十分少,并且因为这个链必定是某个节点到根的路径上的一段,所以求两条这样的链的交集其实是很容易的(求个lca然后判断一下什么的就好了)

​  但是求并是一个。。很困难的过程。。

  所以我们考虑用容斥将求并转化为求交,具体一点的话就是集合容斥的这条式子:

\[|\bigcup\limits_{i=1}^{n}A_i|=\sum\limits_{k=1}^{n}(-1)^{k-1}\sum\limits_{1<=j_1<j_2<...<j_k<=n}|A_{j_1}\cap A_{j_2}\cap ...\cap A_{j_k}| \]

  只不过我们这里用来容斥的东西不是集合的大小而是集合中元素的和:

\[sum(\bigcup\limits_{i=1}^{n}A_i)=\sum\limits_{k=1}^{n}(-1)^{k-1}\sum\limits_{1<=j_1<j_2<...<j_k<=n}sum(A_{j_1}\cap A_{j_2}\cap ...\cap A_{j_k}) \]

  然后链的数量又特别小所以直接\(2^{num}\)枚举一下就好了(其中\(num\)表示的是链的数量)

  (这里提醒一下自己。。枚举的时候可以用状压的方式来操作,这样比dfs不知道高到哪里去了qwq)

  最后还有一点就是,对\(2^{31}\)取模可以int自然溢出最后再和maxlongint取个与就好了

  

​  代码大概长这个样子

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
const int N=2*(1e5)+10,TOP=20;
struct xxx{
	int y,nxt;
}a[N*2];
struct L{
	int x,y;
	L(){}
	L(int _x,int _y){x=_x; y=_y;}
}rec[6],tmp;
int h[N],lis[N*2],st[N],dep[N],pre[N];
int mn[N*2][TOP+1];
int dfn[N],dfned[N];
int n,m,tot,dfn_t,dfn_t1;
namespace BIT{/*{{{*/
	int c[N*2],rt[2];
	int n;
	void init(int _n){n=_n;rt[0]=0; rt[1]=n;}
	void _add(int St,int x,int delta){
		for (;x<=n;x+=x&-x) c[St+x]+=delta;
	}
	void add(int which,int x,int delta){_add(rt[which],dfn[x],delta);_add(rt[which],dfned[x]+1,-delta);}
	int _query(int St,int x){
		int ret=0;
		for (;x;x-=x&-x) ret+=c[St+x];
		return ret;
	}
	int query(int x){
		if (!x) return 0;
		return _query(rt[1],dfn[x])*dep[x]-_query(rt[0],dfn[x]);
	}
}/*}}}*/
void add(int x,int y){a[++tot].y=y; a[tot].nxt=h[x]; h[x]=tot;}
void dfs(int fa,int x,int d){
	int u;
	st[x]=++dfn_t; lis[dfn_t]=x; dep[x]=d; pre[x]=fa;
	dfn[x]=++dfn_t1;
	for (int i=h[x];i!=-1;i=a[i].nxt){
		u=a[i].y;
		if (u==fa) continue;
		dfs(x,u,d+1);
		lis[++dfn_t]=x;
	}
	dfned[x]=dfn_t1;
}
void prework(){
	BIT::init(n);
	lis[0]=dfn_t;
	for (int i=1;i<=lis[0];++i) mn[i][0]=lis[i];
	for (int j=1;j<=TOP;++j)
		for (int i=lis[0]-(1<<j)+1;i>=1;--i)
			if (dep[mn[i][j-1]]<dep[mn[i+(1<<j-1)][j-1]])
				mn[i][j]=mn[i][j-1];
			else 
				mn[i][j]=mn[i+(1<<j-1)][j-1];
}
int get_lca(int x,int y){
	x=st[x]; y=st[y];
	if (x>y) swap(x,y);
	int len=y-x+1,lg=(int)(log(1.0*len)/log(2.0));
	if (dep[mn[x][lg]]<dep[mn[y-(1<<lg)+1][lg]]) return mn[x][lg];
	else return mn[y-(1<<lg)+1][lg];
}
void merge(L &x,L y){
	if (x.x==0&&x.y==0) return;
	int lca;
	if (dep[x.x]<dep[y.x]){
		lca=get_lca(x.y,y.x);
		if (lca!=y.x){x=L(0,0); return;}
		lca=get_lca(x.y,y.y);
		x=L(y.x,lca);
	}
	else{
		lca=get_lca(y.y,x.x);
		if (lca!=x.x){x=L(0,0); return;}
		lca=get_lca(y.y,x.y);
		x=L(x.x,lca);
	}
}
int get_val(L x){
	return BIT::query(x.y)-BIT::query(pre[x.x]);
}
void solve(int num){
	int all=1<<num,mark,ans=0;
	for (int i=1;i<all;++i){
		mark=-1;
		tmp=L(-1,-1);
		for (int j=1;j<=num;++j)
			if (i>>(j-1)&1){
				mark*=-1;
				if (tmp.x==-1&&tmp.y==-1) 
					tmp.x=rec[j].x,tmp.y=rec[j].y;
				else
					merge(tmp,rec[j]);
			}
		ans+=mark*get_val(tmp);
	}
	printf("%d\n",ans&2147483647);
}
void debug(){
	for (int i=1;i<=n;++i) printf("%d ",BIT::query(i));
	printf("\n");
}

int main(){
#ifndef ONLINE_JUDGE
	freopen("a.in","r",stdin);
#endif
	int x,y,delta,num,op;
	scanf("%d",&n);
	memset(h,-1,sizeof(h));
	tot=0;
	for (int i=1;i<n;++i){
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	dfn_t=0; dfn_t1=0;
	dfs(0,1,1);
	prework();
	scanf("%d",&m);
	for (int i=1;i<=m;++i){
		scanf("%d",&op);
		if (op==0){
			scanf("%d%d",&x,&delta);
			BIT::add(0,x,(dep[x]-1)*delta);
			BIT::add(1,x,delta);
		}
		else{
			scanf("%d",&num);
			for (int j=1;j<=num;++j){
				scanf("%d%d",&rec[j].x,&rec[j].y);
				if (dep[rec[j].x]>dep[rec[j].y]) swap(rec[j].x,rec[j].y);
			}
			solve(num);
		}
		//debug();
	}
}
posted @ 2018-07-30 10:24  yoyoball  阅读(332)  评论(0编辑  收藏  举报