【福建集训】果树
题意
有一棵 \(n\) 个点的树,每个点有一种 \([1,n]\) 内的整数颜色 \(c_i\),求有多少条不含重复颜色的路径(即路径上不能有两点颜色一样。当然,一个点也是一条合法路径)。
\(n\le 10^5\)
题解
我 只 会 点 分 治
考试时先睡了一小时觉,然后再看这题,想了半天发现是个 sb 题,又感觉其他两题都做不来,于是自信码农起来。
结果 zjr 不到 1h 就写完了这题 200 多行代码还 A 了,我 1h 写完后调了 3h,心态崩了……
后来才发现这题标解根本不是点分治,大概是扫描线+线段树……写得比我还简单……我菜得没救了,,,
顺便吐槽一下这种题暴力分好多啊,以后谁还自信当码农
扫描线+线段树
对于颜色的限制 等价于 \(O(Tn)\)(本题中 \(T\le 20\))组形如 “\(u,v\) 这两点不能同时出现在路径上” 的限制。
把一条路径 \((u,v)\) 对应到 \(n\times n\) 的二维平面上的点 \((dfn_u,dfn_v)\),则一条限制等价于二维平面上的 \(1\) 或 \(2\) 个障碍矩形。
最后就是求二维平面上有多少个点不被任何一个障碍矩形覆盖。扫描线+线段树做即可,复杂度 \(O(Tn\log n)\)。
我又忘了这种基础操作了,我在想什么啊,zbl
我的傻逼点分治
太恶心了,代码都是我糊出来的,不想细讲……
大概就是开两棵线段树,一棵记录以某个点为根的子树中有多少个点不能作为另一端点,另一棵记录某个点是否能作为另一端点。
实际上也可能不要线段树?我说了代码是糊的,我自己都不知道咋蒙过的
复杂度 \(O(Tn\log^2 n)\),理论上被标解吊打,实际上可能跑得比标解快……
#include<bits/stdc++.h>
#define ll long long
#define N 200005
#define inf 2147483647
using namespace std;
inline int read(){
int x=0; bool f=1; char c=getchar();
for(;!isdigit(c); c=getchar()) if(c=='-') f=0;
for(; isdigit(c); c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
if(f) return x;
return 0-x;
}
int n,c[N]; ll ans;
struct edge{int v,nxt;}e[N<<1];
int hd[N],cnt;
inline void add(int u, int v){e[++cnt]=(edge){v,hd[u]}, hd[u]=cnt;}
int Siz,rt,rt_mxson,siz[N];
void getRoot(int u, int fa){
siz[u]=1; int mxson=0;
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa)
getRoot(e[i].v,u), siz[u]+=siz[e[i].v], mxson=max(mxson,siz[e[i].v]);
mxson=max(mxson,Siz-siz[u]);
if(mxson<rt_mxson) rt_mxson=mxson, rt=u;
}
struct SegTree{
#define ls o<<1
#define rs o<<1|1
int sum[N<<2],tag[N<<2]; bool clr[N<<2];
inline void pushup(int o){sum[o]=sum[ls]+sum[rs];}
inline void pushdown(int o, int l, int r, int mid){
if(clr[o]){
sum[ls]=sum[rs]=tag[ls]=tag[rs]=0,
clr[ls]=clr[rs]=1,
clr[o]=0;
}
if(tag[o]){
sum[ls]+=(mid-l+1)*tag[o],
sum[rs]+=(r-mid)*tag[o],
tag[ls]+=tag[o],
tag[rs]+=tag[o],
tag[o]=0;
}
}
void mdf(int o, int l, int r, int L, int R, int v){
if(L<=l && r<=R){sum[o]+=(r-l+1)*v, tag[o]+=v; return;}
int mid=l+r>>1;
pushdown(o,l,r,mid);
if(L<=mid) mdf(ls,l,mid,L,R,v);
if(R>mid) mdf(rs,mid+1,r,L,R,v);
pushup(o);
}
int query(int o, int l, int r, int L, int R){
if(L<=l && r<=R) return sum[o];
int mid=l+r>>1, ret=0;
pushdown(o,l,r,mid);
if(L<=mid) ret+=query(ls,l,mid,L,R);
if(R>mid) ret+=query(rs,mid+1,r,L,R);
return ret;
}
#undef ls
#undef rs
}sgt1,sgt2;
int suf_sum;
int dfn[N],idx,siz2[N]; vector<int> p[N];
void dfs1(int u, int fa){
dfn[u]=++idx, siz2[u]=1;
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs1(e[i].v,u), siz2[u]+=siz2[e[i].v];
}
bool flag[N][21]; int val[N][21],vis[N];
int V;
void dfs2(int u, int fa){
int x;
for(int i=0; i<p[c[u]].size(); ++i) if(!sgt2.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]])){
//cout<<u<<' '<<p[c[u]][i]<<endl;
x = min(siz2[p[c[u]][i]],suf_sum+1) - sgt1.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1);
//cout<<"dfs2:"<<u<<' '<<p[c[u]][i]<<' '<<x<<' '<<dfn[p[c[u]][i]]<<' '<<siz2[p[c[u]][i]]<<' '<<sgt1.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1)<<endl;
sgt1.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]],x);
sgt2.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1,1); //cout<<i<<' '<<p[c[u]][i]<<endl;
flag[u][i]=1, val[u][i]=x;
}
ans-=sgt1.query(1,1,n,1,Siz)-sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1);
//cout<<u<<' '<<sgt1.query(1,1,n,1,Siz)<<' '<<sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1)<<endl;
//cout<<"dfs2:"<<u<<' '<<suf_sum+1<<' '<<sgt1.query(1,1,n,1,Siz)<<' '<<sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1)<<' '<<sgt1.query(1,1,n,1,1)<<endl;
vis[c[u]]=1, ans+=suf_sum+1;
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa && !vis[c[e[i].v]]) dfs2(e[i].v,u);
vis[c[u]]=0;
for(int i=0; i<p[c[u]].size(); ++i) if(flag[u][i]){
sgt1.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]],-val[u][i]);
sgt2.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1,-1);
flag[u][i]=0;
}
}
void dfs3(int u, int fa){
p[c[u]].push_back(u);
if(vis[c[u]]==1 && !sgt2.query(1,1,n,dfn[u],dfn[u])) sgt1.mdf(1,1,n,dfn[u],dfn[u],siz2[u]), sgt2.mdf(1,1,n,dfn[u],dfn[u]+siz2[u]-1,1);
++vis[c[u]];
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs3(e[i].v,u);
--vis[c[u]];
}
void dfs4(int u, int fa){
p[c[u]].pop_back();
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs4(e[i].v,u);
}
int getAns(int u){
idx=0, dfs1(u,0); ++ans, suf_sum=0;
p[c[u]].push_back(u), vis[c[u]]=1;
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v){
V=e[i].v;
//cout<<V<<endl;
if(!vis[c[e[i].v]]) dfs2(e[i].v,u);
dfs3(e[i].v,u);
//cout<<suf_sum<<' '<<siz2[e[i].v]<<endl;
suf_sum+=siz2[e[i].v];
//cout<<"sgt1:"<<sgt1.query(1,1,n,1,Siz)<<endl;
}
//cout<<sum<<endl;
vis[c[u]]=0, dfs4(u,0);
//cout<<u<<' '<<sum<<' '<<sub_ans<<endl;
sgt1.sum[1]=sgt1.tag[1]=0, sgt1.clr[1]=1;
sgt2.sum[1]=sgt2.tag[1]=0, sgt2.clr[1]=1;
}
void solve(int u, int s){
if(s==1){++ans; return;}
Siz=s, rt=0, rt_mxson=inf, getRoot(u,u), u=rt;
//printf("solve:%lld %lld\n",u,s);
getAns(u);
for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v){
int v=e[i].v;
e[i].v=e[i&1?i+1:i-1].v=-1;
solve(v, siz[v]<siz[u]?siz[v]:s-siz[u]);
}
}
signed main(){
//freopen("b.in","r",stdin);
//freopen("b2.out","w",stdout);
n=read();
for(int i=1; i<=n; ++i) c[i]=read();
int u,v;
for(int i=1; i<n; ++i) u=read(), v=read(), add(u,v), add(v,u);
solve(1,n);
cout<<ans<<endl;
return 0;
}
天天就会点分治,分到亲马不认 srO