[2016北京集训测试赛1]奇怪的树-[树链剖分]
Description
Solution
对于操作1,不论选了哪个点为a,最后反转颜色的点集都只有两种类型(显然啦)。
暴力解法:对每个操作3,从a向上直到根节点,每到一个节点记录(它父亲的黑点数减去自己的黑点数)*父亲节点的编号。另外,还要记录a子树内的黑点。。
这种O(n2)的做法肯定会爆,考虑优化。由于这是一棵静态树,考虑树链剖分。
需要记录一个数组re[x][0/1][0/1]。第2维表示深度的奇偶,第3维表示点的颜色。例如:re[x][0][0]记录的是初始情况下以x为根的子树中深度为偶数的点有多少个为白色。
为了能够顺利剖分,需要记录一个num[0/1][0/1],它是树状数组,同样一维为深度,一维为颜色。num[0][0].tree[dfn[x]](这里的tree[dfn[x]]是单纯这个节点的值而不是该点所表示区间的值)表示的是(除了x的重儿子外其他孩子子树中深度为偶数的点为白色的个数)*x。(在计算途中,假如有操作1,则num[0][0]或者num[1][0]的定义可能会改变,即num[0/1][0]最后一维的定义可能由白色变为黑色,需要开一个数组flag[2]记录)
最后,我们还需要一个树状数组sum[0/1][0/1],两维所表示意义同上。它记录单独某个点的颜色。那知道了某个点x在dfs2中的dfn(也可以把它理解为in)和out后,就可以用sum查询x点子树内有多少个深度为奇(偶)的点颜色为白(黑)。(PS:当有操作1时sum的定义也可能改变,num和sum的定义是一起变的,所以只开一个数组flag记录就好)
Code
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> using namespace std; const long long MAXN=200100; struct dd{long long y,next1; }g[2*MAXN]; long long m; long long re[MAXN][2][2],h[MAXN],size[MAXN],tot,jsq,n,val[MAXN]; long long fa[MAXN],son[MAXN],dep[MAXN],top[MAXN]; long long out[MAXN],dfn[MAXN];bool flag[2]; struct tree { long long c[MAXN]; inline long long lowbit(long long x){ return x&-x; } inline void add(long long x,long long k) { for (;x<=n;x+=lowbit(x)) c[x]+=k; } inline long long find(long long x) {long long ans1=0; for (;x>=1;x-=lowbit(x)) ans1+=c[x];return ans1; } inline long long query(long long l,long long r) { return find(r)-find(l-1); } }sum[2][2],num[2][2]; void dfs1(long long x,long long fa1) { fa[x]=fa1; long long maxx=0,id=0; dep[x]=dep[fa1]+1; size[x]=1; re[x][dep[x]&1][val[x]]=1; for (long long i=h[x];i!=-1;i=g[i].next1) { if (g[i].y==fa[x]) continue; dfs1(g[i].y,x); size[x]+=size[g[i].y]; for (long long j=0;j<=1;j++) for (long long k=0;k<=1;k++) re[x][j][k]+=re[g[i].y][j][k]; if (size[g[i].y]>maxx) maxx=size[g[i].y],id=g[i].y; } son[x]=id; } void dfs2(long long x,long long u) { top[x]=u; jsq++;dfn[x]=jsq; sum[dep[x]&1][val[x]].add(dfn[x],1); if (son[x]!=0) dfs2(son[x],u); for (long long i=h[x];i!=-1;i=g[i].next1) { if (g[i].y==fa[x]||g[i].y==son[x]) continue; dfs2(g[i].y,g[i].y); } out[x]=jsq; for (long long i=0;i<=1;i++) for (long long j=0;j<=1;j++) num[i][j].add(dfn[x],(re[x][i][j]-re[son[x]][i][j])*x); } int main() { scanf("%lld%lld",&n,&m); memset(sum,0,sizeof(sum)); memset(num,0,sizeof(num)); memset(re,0,sizeof(re)); memset(h,-1,sizeof(h)); long long a,b,t; tot=jsq=0; for (long long i=1;i<=n;i++) scanf("%lld",&val[i]); for (long long i=1;i<n;i++) { scanf("%lld%lld",&a,&b); tot++;g[tot].y=b;g[tot].next1=h[a];h[a]=tot; tot++;g[tot].y=a;g[tot].next1=h[b]; h[b]=tot; } dfs1(1,0);dfs2(1,1); flag[0]=flag[1]=false; for (long long i=1;i<=m;i++) { scanf("%lld%lld",&t,&a); long long le;le=a; if (t==1) flag[(dep[a]&1)^1]^=1; if (t==2) { sum[dep[a]&1][val[a]].add(dfn[a],-1); while (a>0) num[dep[le]&1][val[le]].add(dfn[a],-a),a=fa[top[a]]; a=le; val[a]^=1; sum[dep[a]&1][val[a]].add(dfn[a],1); while (a>0) num[dep[le]&1][val[le]].add(dfn[a],a),a=fa[top[a]]; } if (t==3) { long long res,ans; res=sum[0][flag[0]^1].query(dfn[a],out[a])+sum[1][flag[1]^1].query(dfn[a],out[a]); ans=res*a; while (a!=0) { if (top[a]!=a) ans+=num[0][flag[0]^1].query(dfn[top[a]],dfn[a]-1)+num[1][flag[1]^1].query(dfn[top[a]],dfn[a]-1),a=top[a]; if (fa[a]==0) break; res=sum[0][flag[0]^1].query(dfn[fa[a]],out[fa[a]])+sum[1][flag[1]^1].query(dfn[fa[a]],out[fa[a]]); res-=sum[0][flag[0]^1].query(dfn[a],out[a])+sum[1][flag[1]^1].query(dfn[a],out[a]); ans+=res*fa[a]; a=fa[a]; } printf("%lld\n",ans); } } }