树上莫队 SPOJ COT2
SOL:我们来讲一下树上莫队。
前置技能:莫队,括号序列
我们如果要维护子树的信息的话,只要把树展开成DFS序就好了。
那么如果是路径呢?
1
/ | \
2 3 4
|
5
我们在进入和退出一个节点时都把这个点扔到序列里,那么我们发现上面这个数的序列为:
1 2 2 3 5 5 3 4 4 1
我们记一个点的第一次出现的位置为st[i],第二次为ed[i]
那么比如我们要找1 到 5 的路径, 我们取 st[1] 到 st[5] 这一段区间,我们发现在 1-5上的点出现了一次,而不在这条路径的点出现了偶数次。
-> 一条 从 a 连向其子孙 b 的路径,我们可以取 st[a] 到 st[b] 的区间。(st[a]<=st[b])
若我们要从2 到 5 呢?
我们取ed[2]到st[5],我们发现在 2-5上的点出现了一次,而不在这条路径的点出现了偶数次。但是LCA 1却没有出现,那么我们在统计答案的时候把这玩意加上就好了。
-> 一般的, 一条 由 a 连向 非子孙 非祖宗点 b的路径,我们可以取 ed[a] 到 st[b] 的区间再加上lca 对答案的贡献。(st[a]<=st[b])
代码还是很simple的。
//#pragma GCC optimize("-O2") #include<bits/stdc++.h> #define eho(x) for (int i=head[x];i;i=net[i]) #define v fall[i] #define SIZ 19 #define N 300007 #define sight(c) ('0'<=c&&c<='9') using namespace std; int head[N],fall[N<<1],net[N<<1],tot,q,a[N],b[N],bol[N],anw,col; int pe[N],dep[N],f[N][SIZ],st[N],ed[N],n,m,x,y,l,r,L,R,la,ans[N]; inline void read(int &x){ static char c;static int b; for (b=1,c=getchar();!sight(c);c=getchar())if (c=='-') b=-1; for (x=0;sight(c);c=getchar())x=x*10+c-48; x*=b; } void write(int x){if (x<10) {putchar('0'+x); return;} write(x/10); putchar('0'+x%10);} inline void writeln(int x){ if (x<0) putchar('-'),x*=-1; write(x); putchar('\n'); } inline void writel(int x){ if (x<0) putchar('-'),x*=-1; write(x); putchar(' '); } inline void add(int x,int y){ fall[++tot]=y; net[tot]=head[x]; head[x]=tot; } bool usd[N]; inline void adds(int x) {q=a[x]; if (!bol[q]) ++anw; ++bol[q];} inline void dels(int x) {q=a[x]; --bol[q]; if (!bol[q]) --anw;} inline void Add(int x) {usd[x]?dels(x):adds(x);usd[x]^=1;} struct Node{ int x,y,bol,id,add; inline bool operator <(const Node& X) const { return bol==X.bol?y<X.y:bol<X.bol; } }qu[N]; int Tot; void dfs(int x,int fa) { pe[++Tot]=x; st[x]=Tot; f[x][0]=fa; dep[x]=dep[fa]+1; for (int i=1;i<SIZ;i++) f[x][i]=f[f[x][i-1]][i-1]; eho(x) if(v^fa) dfs(v,x); pe[++Tot]=x; ed[x]=Tot; } int get_lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); for (int i=SIZ-1;~i;--i) if (dep[f[x][i]]>=dep[y]) x=f[x][i]; if (x==y) return x; for (int i=SIZ-1;~i;--i) if (f[x][i]^f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } signed main () { freopen("a.in","r",stdin); freopen("a.out","w",stdout); read(n); read(m); for (int i=1;i<=n;i++) read(a[i]),b[i]=a[i]; sort(b+1,b+n+1); for (int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+n+1,a[i])-b; for (int i=1;i< n;i++) { read(x); read(y); add(x,y); add(y,x); } dfs(1,0); col=1.5*sqrt(2*n)+1; for (int i=1;i<=m;i++) { read(l); read(r); la=get_lca(l,r); if (st[r]<st[l]) swap(l,r); if (l==la) qu[i].x=st[l],qu[i].y=st[r],qu[i].id=i; else qu[i].x=ed[l],qu[i].y=st[r],qu[i].id=i,qu[i].add=la; qu[i].bol=qu[i].x/col; } sort(qu+1,qu+m+1); L=1; for (int i=1;i<=m;i++) { while (R<qu[i].y) Add(pe[++R]); while (L>qu[i].x) Add(pe[--L]); while (L<qu[i].x) Add(pe[L++]); while (R>qu[i].y) Add(pe[R--]); if (qu[i].add) Add(qu[i].add); ans[qu[i].id]=anw; if (qu[i].add) Add(qu[i].add); } for (int i=1;i<=m;i++) writeln(ans[i]); return 0; }