P9555 「CROI · R1」浣熊的阴阳鱼
思路
暴力
比赛的时候想过树链剖分,然后想不出来怎么处理区间合并,再加上树链剖分代码量比较大,我又比较懒,就随手写了个暴力拿了40pts。
思路就是暴力求得 \(u\) 到 \(v\) 的简单路径,然后暴力枚举模拟一遍。
40pts 代码
#include<bits/stdc++.h>
using namespace std;
int n,q,a[100005],e[200005],ne[200005],h[100005],idx=1,k,x,y;
int dep[100005],fa[100005][25],z[100005],cnt;
int f[2];
void dfs1(int u,int f)
{
for(int i=h[u];i;i=ne[i]) if(e[i]!=f) fa[e[i]][0]=u,dep[e[i]]=dep[u]+1,dfs1(e[i],u);
}
inline void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;--i) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=20;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i) scanf("%d",&a[i]);
for(int i=1;i<n;++i) scanf("%d%d",&x,&y),add(x,y),add(y,x);
fa[1][0]=1,dep[1]=1,dfs1(1,0);
for(int i=1;i<=20;++i) for(int u=1;u<=n;++u) fa[u][i]=fa[fa[u][i-1]][i-1];
while(q--)
{
scanf("%d%d",&k,&x);
if(k==1) a[x]^=1;
else
{
int f[2],lca,ans,cnt,sum;
scanf("%d",&y),f[0]=f[1]=cnt=ans=0,lca=LCA(x,y);
while(x!=lca) z[++cnt]=x,x=fa[x][0];
z[++cnt]=x,cnt+=dep[y]-dep[lca],sum=cnt;
while(y!=lca) z[sum--]=y,y=fa[y][0];
sum=0;
for(int i=1;i<=cnt;++i)
{
if(f[a[z[i]]^1]) ans++,f[a[z[i]]^1]--,sum--;
else if(sum<2) f[a[z[i]]]++,sum++;
}
printf("%d\n",ans);
}
}
return 0;
}
正解:树链剖分
显然,暴力时间复杂度非常高,最后一个 subtask 会 TLE。
看了 MaxBlazeResFire 巨佬的题解后恍然大悟,是类似状压的思路,因为篮子里不可能放一阴一阳,因为遇到就会被吃掉,所以篮子里的状态只有 \(5\) 个。分别是:啥都没有,有一条阴,有两条阴,有一条阳,有两条阳,分别用数字 \(0,1,2,3,4\) 对应。
因为前面区间过了后,篮子里的鱼会对后续区间产生影响,所以我们自然而然地想到,需要用一个变量储存某个状态经过该区间后的状态,用 \(suf_i\) 表示状态 \(i\) 经过该区间后状态为 \(suf_i\),然后再随便用一个 \(sum_i\) 代表该区间能吃几条阴阳鱼,也就是答案贡献。
区间合并就很容易推出来了,假设我们用区间 \(a\) 和区间 \(b\) 合并为区间 \(c\)。
用代码表示就是,
for(long long i=0;i<=4;i++) c.sum[i]=a.sum[i]+b.sum[a.suf[i]],c.suf[i]=b.suf[a.suf[i]];
需要注意的是,因为对于同一个区间,方向不同,答案也不一样,所以需要一个对应的反转的数组。
题目要求里有两个操作,一个是单点修改,一个是区间查询。
单点修改很容易,这里就不展开了。
重点是查询,其实查询和上面的暴力思路差别不大,只不过上面是老实的挨个找,这里可以用树链剖分分成一段一段,这就是这道题用树链剖分的原因。
还不会树链剖分的,可以去这道模板题。
对于点 \(u\), \(v\)。我们设它们的 LCA 为 \(l\)。
整个路径就分成了两部分:\(u\to l\) 和 \(l\to v\)。
由于树的性质,点往上找很容易,往下找就很麻烦。
所以我们就都往上找,用一个 vector 存下每段。
需要注意的是,两边方向不一样,记得要颠转。
因为树链剖分是从上往下,所以到时候需要注意那些部分需要用反转的数组,哪些部分用没反转的数组。
AC 代码
#include<bits/stdc++.h>
using namespace std;
long long n,q,a,b,k,siz[100005],son[100005],dep[100005],top[100005],f[100005][17],c[100005];
long long dfn[100005],wc[100005],cnt;
vector<long long>v[100005];
struct node{long long sum[5],suf[5];}t[400005],ft[400005],s[2];
inline void init()//s是单位数组,用来直接赋值,这个可以自己按照定义推出来
{
s[0].sum[0]=0,s[0].sum[1]=1,s[0].sum[2]=1,s[0].sum[3]=0,s[0].sum[4]=0;
s[0].suf[0]=3,s[0].suf[1]=0,s[0].suf[2]=1,s[0].suf[3]=4,s[0].suf[4]=4;
s[1].sum[0]=0,s[1].sum[1]=0,s[1].sum[2]=0,s[1].sum[3]=1,s[1].sum[4]=1;
s[1].suf[0]=1,s[1].suf[1]=2,s[1].suf[2]=2,s[1].suf[3]=0,s[1].suf[4]=3;
}
/*树链剖分部分*/
void dfs1(long long u,long long fa)
{
siz[u]=1,dep[u]=dep[fa]+1;
for(long long j:v[u])
if(j!=fa)
{
dfs1(j,u),siz[u]+=siz[j],f[j][0]=u;
if(siz[j]>siz[son[u]]) son[u]=j;
}
}
void dfs2(long long u)
{
dfn[u]=++cnt,wc[cnt]=c[u];
if(!top[u]) top[u]=u;
if(son[u]) top[son[u]]=top[u],dfs2(son[u]);
for(long long j:v[u]) if(j!=son[u]&&j!=f[u][0]) dfs2(j);
}
/*线段树部分*/
inline node merge(node a,node b){node c;for(long long i=0;i<=4;i++) c.sum[i]=a.sum[i]+b.sum[a.suf[i]],c.suf[i]=b.suf[a.suf[i]];return c;}//区间合并
inline void pushup(long long p){t[p]=merge(t[p<<1],t[p<<1|1]),ft[p]=merge(ft[p<<1|1],ft[p<<1]);}//注意:反向的合并是右+左
void build(long long p,long long l,long long r)
{
if(l==r)
{
t[p]=ft[p]=s[wc[l]];
return;
}
long long mid=l+r>>1;
build(p<<1,l,mid),build(p<<1|1,mid+1,r);
pushup(p);
}
void update(long long p,long long l,long long r,long long x,long long k)//单点修改
{
if(l==r)
{
t[p]=ft[p]=s[k];
return;
}
long long mid=l+r>>1;
if(x<=mid) update(p<<1,l,mid,x,k);
else update(p<<1|1,mid+1,r,x,k);
pushup(p);
}
node ask(long long p,long long l,long long r,long long L,long long R)//正向区间查询
{
if(L<=l&&r<=R) return t[p];
long long mid=l+r>>1;
if(R<=mid) return ask(p<<1,l,mid,L,R);
if(L>mid) return ask(p<<1|1,mid+1,r,L,R);
return merge(ask(p<<1,l,mid,L,R),ask(p<<1|1,mid+1,r,L,R));
}
node revask(long long p,long long l,long long r,long long L,long long R)//反向区间查询
{
if(L<=l&&r<=R) return ft[p];
long long mid=l+r>>1;
if(R<=mid) return revask(p<<1,l,mid,L,R);
if(L>mid) return revask(p<<1|1,mid+1,r,L,R);
return merge(revask(p<<1|1,mid+1,r,L,R),revask(p<<1,l,mid,L,R));//区间右+左
}
inline long long LCA(long long x,long long y)//求LCA
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=f[top[x]][0];
}
return (dep[x]<dep[y])?x:y;
}
inline long long nson(long long rt,long long u)//求rt的那个儿子是u的祖宗
{
for(long long i=16;i>=0;--i) if(dep[f[u][i]]>dep[rt]) u=f[u][i];
return (u==rt)?0:u;//记得特判特殊情况
}
inline long long gask(long long u,long long v)
{
long long lca=LCA(u,v),j=nson(lca,v),len;node ans;
vector<pair<long long,long long>>l,r;//用两个vector存两边的路径
while(top[u]!=top[lca]) l.push_back(make_pair(-dfn[top[u]],-dfn[u])),u=f[top[u]][0];//左侧路径是深度高到低,所以需要反转,赋称负值标记
l.push_back(make_pair(-dfn[lca],-dfn[u]));
if(j)
{
while(top[v]!=top[j]) r.push_back(make_pair(dfn[top[v]],dfn[v])),v=f[top[v]][0];//右侧路径是深度低到高,不需要反转
r.push_back(make_pair(dfn[j],dfn[v]));
}
reverse(r.begin(),r.end());//颠转右侧路径
for(pair<long long,long long>i:r) l.push_back(i);//把左右侧路径合并
/*先算一端区间,让后面好写一点*/
if(l[0].first<0) ans=revask(1,1,n,-l[0].first,-l[0].second);//需要反转
else ans=ask(1,1,n,l[0].first,l[0].second);
len=l.size();
for(long long i=1;i<len;++i)
if(l[i].first<0) ans=merge(ans,revask(1,1,n,-l[i].first,-l[i].second));//需要反转
else ans=merge(ans,ask(1,1,n,l[i].first,l[i].second));
return ans.sum[0];//求得是以状态0(什么鱼也没有)开始的答案
}
int main()
{
init();
scanf("%lld%lld",&n,&q);
for(long long i=1;i<=n;++i) scanf("%lld",&c[i]);
for(long long i=1;i<n;++i) scanf("%lld%lld",&a,&b),v[a].push_back(b),v[b].push_back(a);
f[1][0]=1,dfs1(1,0),dfs2(1),build(1,1,n);
for(long long j=1;j<=16;++j) for(long long i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
for(long long i=1;i<=q;++i)
{
scanf("%lld%lld",&k,&a);
if(k==1) c[a]^=1,update(1,1,n,dfn[a],c[a]);
else scanf("%lld",&b),printf("%lld\n",gask(a,b));
}
return 0;
}