牛客网字节跳动冬令营网络赛——点分治(简化条件)
题目:https://ac.nowcoder.com/acm/contest/296/J
可以点分治,每次处理经过重心的路径。
合法的形态有这几种:,其中 [ ] 的第一个表示小于号的个数,第二个表示大于号的个数。“2”表示有多个。如果左边是 1 、右边是 k 的话,3的合法条件是 w[1]<=w[k] , 4的合法条件是 w[1]>=w[k] 。
弄一个 f [0/1/2][0/1/2][N] 的桶,存当前重心的其他孩子里各种情况的个数; dfs 当前孩子的时候对于“重心到当前节点的路径”在桶里找一些东西匹配上更新答案,然后再 dfs 一遍当前孩子来更新桶;继续分治之前把所有孩子都 dfs 一遍清空桶(就是正常的点分治流程)。所以用树状数组实现桶。
然后开始各种转移。调了一晚上+一下午还是没调出来。严格递增与非严格递增好麻烦呀。那个第5种情况感觉有好多变种,比如 --_ 再配上一个 / 或者 -- 之类的。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5; int T,n,hd[N],xnt,to[N<<1],nxt[N<<1],w[N],lm; int mn,rt,siz[N],f[3][3][N]; ll ans; bool vis[N]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void init() { xnt=0;memset(hd,0,sizeof hd);ans=0;lm=0; memset(vis,0,sizeof vis); } void init_dfs(int cr,int fa) { siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v]; } void getrt(int cr,int fa,int s) { int mx=0,sm=0; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s); mx=Mx(mx,siz[v]);sm+=siz[v]; } mx=Mx(mx,s-sm); if(mx<mn)mn=mx,rt=cr; } void add(int x,int k,int f[]){for(;x<=lm;x+=(x&-x))f[x]+=k;}//lm!!! int qry(int x,int f[]){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[x];return ret;} int qry_s(int x,int f[]){return qry(lm,f)-qry(x-1,f);}//lm!!!not n void solve(int s0,int s1,int tw) { ans++;//with rt if(!s0&&!s1)// { for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=qry(lm,f[i][j]); return; } ans+=qry(lm,f[0][0]);// if(!s0&&s1) { ans+=qry(lm,f[0][1]); ans+=qry(lm,f[0][2]);//1,1 ans+=qry_s(tw,f[1][2]); ans+=qry_s(tw,f[1][1]); ans+=qry_s(tw+(s1==1),f[1][0]);//1,3 //not only qry_s(tw,f[1][2]);!!!//not creat 5///if for !s0&&s1==1 ans+=qry_s(tw,f[1][1]);//1,5[2] } if(s0&&!s1) { ans+=qry(lm,f[1][0]); ans+=qry(lm,f[2][0]);//2,2 ans+=qry(tw,f[2][1]); ans+=qry(tw,f[1][1]);if(s0>1)ans+=qry(tw-(s0==1),f[0][1]);//2,4// ans+=qry(tw,f[1][1]);//2,5[1] } if(s0==1&&s1>1) { ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]);//3,1 } if(s0>1&&s1==1) { ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]);//4,2 } if(s0==1&&s1==1)//no w[rt]==tw is ok { if(w[rt]<tw){ ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]); }//5[1],2//back so w[rt]<tw else if(w[rt]>tw){ ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]); }//5[2],1 } if(!s0&&s1==1)///////////// creat 5 and others! { ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]);// } if(s0==1&&!s1) { ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]);// } } void dfs(int cr,int fa,int lst,int s0,int s1,int op) { if(op==1) { if(w[cr]>lst)s1++; else if(w[cr]<lst)s0++; }//cr_lst//back else { if(w[cr]>lst)s0++; else if(w[cr]<lst)s1++; }//lst_cr//go if(s0>1&&s1>1)return; if(s0==1&&s1>1){ if(op>1&&w[cr]<w[rt])return; if(op==1&&w[cr]>w[rt])return; } if(s0>1&&s1==1){ if(op>1&&w[cr]>w[rt])return; if(op==1&&w[cr]<w[rt])return; } if(op==1)solve(s0,s1,w[cr]),printf("cr=%d[%d,%d] ans=%lld\n",cr,s0,s1,ans); if(op==2)add(w[cr],1,f[s0>1?2:s0][s1>1?2:s1]);///can't >2 if(op==3)add(w[cr],-1,f[s0>1?2:s0][s1>1?2:s1]); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op); } void solve(int cr,int s) { vis[cr]=1; printf("cr=%d s=%d\n",cr,s); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2); } printf(" ans=%lld\n",ans); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { dfs(v,cr,w[cr],0,0,3); } for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { int ts;if(siz[v]<siz[cr])ts=siz[v];else ts=s-siz[cr]; mn=N;getrt(v,cr,ts);solve(rt,ts); } } int main() { T=rdn(); while(T--) { init();n=rdn();for(int i=1;i<=n;i++)w[i]=rdn(),lm=Mx(lm,w[i]); for(int i=1,u,v;i<n;i++)u=rdn(),v=rdn(),add(u,v),add(v,u); init_dfs(1,0);mn=N;getrt(1,0,n);solve(rt,n); printf("%lld\n",ans+n); } return 0; }
然后想起讲课学长说的。之所以记录小于号个数之类的,其实是为了不根据形态转移,而是根据小于号和大于号的个数来转移。
记 s0 表示小于号个数, s1 表示大于号个数。对于重心到当前节点的路径的 s0 和 s1 ,枚举 i 和 j 表示在哪个桶里,然后只要看看 s0' = s0+i 和 s1' = s1+j 属于上面情况中的哪一种,如果都不属于(当且仅当 s0' > 1 && s1' > 1)就跳过;如果是第3种或第4种就查 w[cr] 为止的前缀/后缀和加到答案里;不然就把整个桶的值加到答案里。
真是简明的思想!应该注意到路径形态对于转移的影响可以归约为不同的大于号、小于号个数对于转移的影响,而大于号、小于号的个数比路径形态好维护得多!
然后发现自己一直写的错误的点分治(当然当初学习的时候写的是正确的)。改一改就好啦。
注意每组数据开头的 init( ) 里别写 memset ,不然有可能 n2 ; 把点值离散化一下可以让树状数组复杂度更正确。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5; int T,n,w[N],tp[N],hd[N],xnt,to[N<<1],nxt[N<<1],siz[N],rt,mn,lm; int f[3][3][N]; ll ans; bool vis[N]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} void init() { xnt=0;for(int i=1;i<=n;i++)hd[i]=0;//memset(hd,0,sizeof hd); ans=0;for(int i=1;i<=n;i++)vis[i]=0;//memset(vis,0,sizeof vis); sort(tp+1,tp+n+1);lm=unique(tp+1,tp+n+1)-tp-1;/// for(int i=1;i<=n;i++)w[i]=lower_bound(tp+1,tp+lm+1,w[i])-tp; } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} /* void init_dfs(int cr,int fa) { siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v]; } */ void getrt(int cr,int fa,int s) { int mx=0;siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s);siz[cr]+=siz[v]; mx=Mx(mx,siz[v]); } mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr; } void add(int x,int k,int s0,int s1){for(;x<=lm;x+=(x&-x))f[s0][s1][x]+=k;} int qry(int x,int s0,int s1){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[s0][s1][x];return ret;} int qry_s(int x,int s0,int s1){return qry(lm,s0,s1)-qry(x-1,s0,s1);} int cal(int s0,int s1,int i,int j,int tw) { if(s0>1&&s1>1)return 0;// if(s0==1&&s1>1)return qry_s(tw,i,j); if(s0>1&&s1==1)return qry(tw,i,j); return qry(lm,i,j); } void calc(int tw,int s0,int s1) { ans++;//with rt for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=cal(s0+i,s1+j,i,j,tw); } void dfs(int cr,int fa,int lst,int s0,int s1,int op) { if(op==1){ if(w[cr]>lst)s1++; if(w[cr]<lst)s0++; } else{ if(lst>w[cr])s1++; if(lst<w[cr])s0++; } if(s0>1&&s1>1)return; if(s0==1&&s1>1){ if(op==1&&w[cr]>w[rt])return; if(op>1&&w[cr]<w[rt])return; } if(s1==1&&s0>1){ if(op==1&&w[cr]<w[rt])return; if(op>1&&w[cr]>w[rt])return; } if(op==1)calc(w[cr],s0,s1);//,printf("cr=%d[%d,%d]ans=%lld\n",cr,s0,s1,ans); if(op==2)add(w[cr],1,s0>1?2:s0,s1>1?2:s1); if(op==3)add(w[cr],-1,s0>1?2:s0,s1>1?2:s1); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op); } void solve(int cr,int s) { vis[cr]=1;// printf("cr=%d s=%d\n",cr,s); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2); } // printf("ans=%lld\n",ans); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]])dfs(v,cr,w[cr],0,0,3); for(int i=hd[cr],v,ts;i;i=nxt[i]) if(!vis[v=to[i]]) { mn=N;ts=(siz[v]<siz[cr]?siz[v]:s-siz[cr]); getrt(v,cr,ts);solve(rt,ts); } } int main() { T=rdn(); while(T--) { n=rdn();lm=0;for(int i=1;i<=n;i++)w[i]=rdn(),tp[i]=w[i]; init(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); /*init_dfs(1,0);*/mn=N;getrt(1,0,n);solve(rt,n); printf("%lld\n",ans+n); } return 0; }