洛谷 P5669 [SDOI2018] 原题识别-改 题解--zhengjun
鉴于这题目前还没题解,提供一种时间 \(\Theta(n\sqrt{m})\),空间 \(\Theta(n+m)\) 的做法。
询问 1
可以直接上树分块或者树上莫队,见 P6177 Count on a tree II/【模板】树分块。
但是因为本题询问 2 的做法,所以我采用了树上莫队的做法。
询问 2
方便起见:
- 记 \(\operatorname{path}(u,v)\) 表示 \(u,v\) 路径之间的所有点构成的集合;
- 记 \(f(u,v)\) 表示 \(u,v\) 路径上本质不同的颜色数。
这里直接考虑 \(u,v\) 不为祖先关系的情况(\(u,v\) 为祖先关系的情况显然严格弱于这个,特判一下即可)。
所以答案即为:
因为我们发现,答案的形式非常像对于一个区间的所有子区间求和,那么我们引入新的函数:
首先观察这个 \(F\),将 \(\operatorname{path}(u,v)\) 理解为一个区间 \([1,m]\)。
它的实际意义就是 \([1,m]\) 的所有子区间的本质不同的颜色数之和。
但是这样并不好计算,我们考虑另外一种实际意义:对于每种颜色,计算 \([1,m]\) 的所有子区间中包含该颜色的个数和。
如果把颜色 \(c\) 删去,序列剩下来长度为 \(l_1,l_2,\cdots ,l_{k_c}\) 的 \(k_c\) 段连续区间,那么该颜色的贡献就是 \(\binom{n+1}{2}-\sum \binom{l_i+1}{2}\)。
那么,如果在序列的后面加入一个元素,那么答案的增量就是 \(\sum\limits_{c}n'-l'_{c,k'_c}\)。
所以,我们如果我们维护出了 \(suf_c=l_{c,k_c}\) 以及它的和,那么我们就可以 \(\Theta(1)\) 向右边扩展了。
注意,我们同时可以 \(\Theta(1)\) 删除最后一个位置。
使用链表维护相同颜色的位置,并实时记录每个颜色的起始位置,维护出 \(\sum suf_c\) 和 \(\sum pre_c\),这样左右端点都能够 \(\Theta(1)\) 左右移动了。
现在,我们就可以使用树上莫队来计算 \(F(u,v)\) 了。
接下来考虑怎么计算答案。
设询问的两个节点分别为 \(u,v\)。
设 \(t\) 为 \(u,v\) 的最近公共祖先 \((t\ne u , t\ne v)\)。
设 \(p,q\) 为 \(t\) 的两个不同的儿子且 \(p\in \operatorname{path}(u,t),q\in \operatorname{path}(v,t)\)。
考虑对答案进行转化,这里直接给出结果:
其中后面一大坨的尾巴是 $ \operatorname{path}(u,p)$ 和 $ \operatorname{path}(v,q)$ 之间的贡献,即:
最后加一是因为 \(f(t,t)\) 被两边都减了一遍,类似于容斥。
而剩下的贡献就是 \(F(1,u)+F(1,v)-|\operatorname{path}(1,t)|-F(u,p)-F(v,q)\)。
- 减去 \(|\operatorname{path}(1,t)|\) 的就是 \(\forall i\in \operatorname{path}(1,t),f(i,i)\) 都算了两遍;
- 而其余的 \(\forall i,j\in \operatorname{path}(1,t) \land i\ne j,f(i,j)\) 本身就应该算两遍。
做到这里似乎已经做完了……
细节处理:
- 维护一条链的情况时,需要使用循环队列,因为左右端点都可能移动很远,但是任意时刻序列长度都不超过 \(n\)。
本人直接写完后不卡常最大点用时 4.05s,经过调整块长、对莫队的排序进行奇偶优化过后,最大点用时 1.96s,效率还行,毕竟询问 2 有个 \(5\) 倍常数。
代码
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
#ifdef DEBUG
template<class T>
ostream& operator << (ostream &out,vector<T> a){
out<<'[';
for(T x:a)out<<x<<',';
return out<<']';
}
template<class T>
vector<T> ary(T *a,int l,int r){
return vector<T>{a+l,a+1+r};
}
template<class T>
void debug(T x){
cerr<<x<<endl;
}
template<class T,class...S>
void debug(T x,S...y){
cerr<<x<<' ',debug(y...);
}
#else
#define debug(...) void()
#endif
const int N=1e5+10,V=N*2,M=2e5+10;
int n,q,a[N];
int dft,B,id[V],pos[V],dfn[N];
vector<int>to[N];
struct ques{
int l,r,id,w;
bool operator < (const ques &a)const{
return ::id[l]^::id[a.l]?::id[l]<::id[a.l]:(::id[l]&1?r<a.r:r>a.r);
}
}o1[M],o2[M*5];
int m1,m2;
void make(int u,int fa=0){
pos[dfn[u]=++dft]=u;
for(int v:to[u])if(v^fa){
make(v,u);
pos[++dft]=u;
}
}
namespace Path{
int top[N],fa[N],dep[N],siz[N],son[N];
void dfs1(int u){
siz[u]=1,dep[u]=dep[fa[u]]+1;
for(int v:to[u])if(v^fa[u]){
fa[v]=u,dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
int dft,dfn[N],pos[N];
void dfs2(int u,int t){
top[u]=t,pos[dfn[u]=++dft]=u;
if(son[u])dfs2(son[u],t);
for(int v:to[u])if(v^fa[u]&&v^son[u])dfs2(v,v);
}
void init(){
dfs1(1),dfs2(1,1);
}
int LCA(int u,int v){
for(;top[u]^top[v];u=fa[top[u]]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
}
return dep[u]<dep[v]?u:v;
}
int jump(int u,int k){
for(;k>dep[u]-dep[top[u]];u=fa[top[u]])k-=dep[u]-dep[top[u]]+1;
return pos[dfn[u]-k];
}
}
using Path::dep;
namespace DS1{
int now,cnt[N];
void insert(int x){
now+=!cnt[x]++;
}
void erase(int x){
now-=!--cnt[x];
}
int query(){
return now;
}
}
namespace DS2{
struct Queue{
int a[N];
const int& operator [] (const int &x)const{
return a[(x%N+N)%N];
}
int& operator [] (const int &x){
return a[(x%N+N)%N];
}
}col,pre,nex;
int s,t;
int now,cnt[N],bg[N],ed[N];
ll s1,s2,ans;
void init(){
s=1e9,t=s-1,s1=s2=now=ans=0;
memset(bg,0,sizeof bg);
memset(ed,0,sizeof ed);
memset(cnt,0,sizeof cnt);
}
void push_back(int x){
// debug("push_back",x);
col[++t]=x,now+=!cnt[x]++;
s1+=n-now;
s2+=n-(ed[x]?t-ed[x]:t-s+1);
ans+=(t-s+1ll)*n-s2;
pre[t]=ed[x],nex[t]=0;
if(ed[x])nex[ed[x]]=t;
ed[x]=t;
if(!bg[x])bg[x]=t;
}
void pop_back(){
// debug("pop_back");
int x=col[t];
ed[x]=pre[t];
if(!ed[x])bg[x]=0;
else nex[ed[x]]=0;
ans-=(t-s+1ll)*n-s2;
s2-=n-(ed[x]?t-ed[x]:t-s+1);
s1-=n-now;
now-=!--cnt[col[t--]];
}
void push_front(int x){
// debug("push_front",x);
col[--s]=x,now+=!cnt[x]++;
s1+=n-(bg[x]?bg[x]-s:t-s+1);
s2+=n-now;
ans+=(t-s+1ll)*n-s1;
nex[s]=bg[x],pre[s]=0;
if(bg[x])pre[bg[x]]=s;
bg[x]=s;
if(!ed[x])ed[x]=s;
}
void pop_front(){
// debug("pop_front");
int x=col[s];
bg[x]=nex[s];
if(!bg[x])ed[x]=0;
else pre[bg[x]]=0;
ans-=(t-s+1ll)*n-s1;
s2-=n-now;
s1-=n-(bg[x]?bg[x]-s:t-s+1);
now-=!--cnt[col[s++]];
}
ll query(){
return ans;
}
}
ll f[N],ans[M];
void dfs(int u,int fa=0){
DS2::push_back(a[u]);
f[u]=DS2::query();
for(int v:to[u])if(v^fa){
dfs(v,u);
}
DS2::pop_back();
}
int vis[N];
void solve1(){
for(int i=1;i<=m1;i++){
if(o1[i].l>o1[i].r)swap(o1[i].l,o1[i].r);
}
B=max(1.0,dft/max(1.0,sqrt(m1))*3);
for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
sort(o1+1,o1+1+m1);
auto go=[&](int u,int v){
if(!vis[v])DS1::insert(a[v]),vis[v]=1;
else DS1::erase(a[u]),vis[u]=0;
};
int l=1,r=0;
for(int i=1;i<=m1;i++){
for(;r<o1[i].r;r++)go(pos[r],pos[r+1]);
for(;r>o1[i].r;r--)go(pos[r],pos[r-1]);
for(;l<o1[i].l;l++)go(pos[l],pos[l+1]);
for(;l>o1[i].l;l--)go(pos[l],pos[l-1]);
ans[o1[i].id]+=DS1::query()*o1[i].w;
}
}
void solve2(){
memset(vis,0,sizeof vis);
for(int i=1;i<=m2;i++){
if(o2[i].l>o2[i].r)swap(o2[i].l,o2[i].r);
}
B=max(1.0,dft/max(1.0,sqrt(m2))*3);
for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
sort(o2+1,o2+1+m2);
auto go_t=[&](int u,int v){
if(!vis[v])DS2::push_back(a[v]),vis[v]=1;
else DS2::pop_back(),vis[u]=0;
};
auto go_s=[&](int u,int v){
if(!vis[v])DS2::push_front(a[v]),vis[v]=1;
else DS2::pop_front(),vis[u]=0;
};
int l=1,r=0;
DS2::init();
for(int i=1;i<=m2;i++){
for(;r<o2[i].r;r++)go_t(pos[r],pos[r+1]);
for(;r>o2[i].r;r--)go_t(pos[r],pos[r-1]);
for(;l<o2[i].l;l++)go_s(pos[l],pos[l+1]);
for(;l>o2[i].l;l--)go_s(pos[l],pos[l-1]);
ans[o2[i].id]+=DS2::query()*o2[i].w;
}
}
int main(){
freopen(".in","r",stdin);
// freopen(".out","w",stdout);
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
to[u].push_back(v),to[v].push_back(u);
}
make(1),Path::init(),DS2::init(),dfs(1);
debug("f",ary(f,1,n));
for(int i=1,op,u,v;i<=q;i++){
scanf("%d%d%d",&op,&u,&v);
if(op==1){
o1[++m1]={dfn[u],dfn[v],i,1};
}else{
int t=Path::LCA(u,v);
ans[i]=f[u]+f[v]-dep[t];
if(u^t)o2[++m2]={dfn[u],dfn[Path::jump(u,dep[u]-dep[t]-1)],i,-1};
if(v^t)o2[++m2]={dfn[v],dfn[Path::jump(v,dep[v]-dep[t]-1)],i,-1};
if(u^t&&v^t){
ans[i]++;
o2[++m2]={dfn[u],dfn[v],i,1};
o2[++m2]={dfn[u],dfn[t],i,-1};
o2[++m2]={dfn[v],dfn[t],i,-1};
}
}
}
solve1(),solve2();
for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
return 0;
}