牛客网字节跳动冬令营网络赛——点分治(简化条件)

题目:https://ac.nowcoder.com/acm/contest/296/J

可以点分治,每次处理经过重心的路径。

合法的形态有这几种:,其中 [ ] 的第一个表示小于号的个数,第二个表示大于号的个数。“2”表示有多个。如果左边是 1 、右边是 k 的话,3的合法条件是 w[1]<=w[k] , 4的合法条件是 w[1]>=w[k] 。

弄一个  f [0/1/2][0/1/2][N] 的桶,存当前重心的其他孩子里各种情况的个数; dfs 当前孩子的时候对于“重心到当前节点的路径”在桶里找一些东西匹配上更新答案,然后再 dfs 一遍当前孩子来更新桶;继续分治之前把所有孩子都 dfs 一遍清空桶(就是正常的点分治流程)。所以用树状数组实现桶。

然后开始各种转移。调了一晚上+一下午还是没调出来。严格递增与非严格递增好麻烦呀。那个第5种情况感觉有好多变种,比如 --_ 再配上一个 / 或者 -- 之类的。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
int T,n,hd[N],xnt,to[N<<1],nxt[N<<1],w[N],lm;
int mn,rt,siz[N],f[3][3][N]; ll ans;
bool vis[N];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void init()
{
  xnt=0;memset(hd,0,sizeof hd);ans=0;lm=0;
  memset(vis,0,sizeof vis);
}
void init_dfs(int cr,int fa)
{
  siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v];
}
void getrt(int cr,int fa,int s)
{
  int mx=0,sm=0;
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)
      {
    getrt(v,cr,s);
    mx=Mx(mx,siz[v]);sm+=siz[v];
      }
  mx=Mx(mx,s-sm);
  if(mx<mn)mn=mx,rt=cr;
}
void add(int x,int k,int f[]){for(;x<=lm;x+=(x&-x))f[x]+=k;}//lm!!!
int qry(int x,int f[]){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[x];return ret;}
int qry_s(int x,int f[]){return qry(lm,f)-qry(x-1,f);}//lm!!!not n
void solve(int s0,int s1,int tw)
{
  ans++;//with rt
  if(!s0&&!s1)//
    {
      for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=qry(lm,f[i][j]);
      return;
    }
  ans+=qry(lm,f[0][0]);//
  if(!s0&&s1)
    {
      ans+=qry(lm,f[0][1]); ans+=qry(lm,f[0][2]);//1,1
      ans+=qry_s(tw,f[1][2]); ans+=qry_s(tw,f[1][1]); ans+=qry_s(tw+(s1==1),f[1][0]);//1,3
      //not only qry_s(tw,f[1][2]);!!!//not creat 5///if for !s0&&s1==1
      ans+=qry_s(tw,f[1][1]);//1,5[2]
    }
  if(s0&&!s1)
    {
      ans+=qry(lm,f[1][0]); ans+=qry(lm,f[2][0]);//2,2
      ans+=qry(tw,f[2][1]); ans+=qry(tw,f[1][1]);if(s0>1)ans+=qry(tw-(s0==1),f[0][1]);//2,4//
      ans+=qry(tw,f[1][1]);//2,5[1]
    }
  if(s0==1&&s1>1)
    {
      ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]);//3,1
    }
  if(s0>1&&s1==1)
    {
      ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]);//4,2
    }
  if(s0==1&&s1==1)//no w[rt]==tw is ok
    {
      if(w[rt]<tw){ ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]); }//5[1],2//back so w[rt]<tw
      else if(w[rt]>tw){ ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]); }//5[2],1
    }
  if(!s0&&s1==1)///////////// creat 5 and others!
    {
      ans+=qry(tw,f[1][0]); ans+=qry(tw,f[2][0]);//
    }
  if(s0==1&&!s1)
    {
      ans+=qry_s(tw,f[0][1]); ans+=qry_s(tw,f[0][2]);//
    }
}
void dfs(int cr,int fa,int lst,int s0,int s1,int op)
{
  if(op==1) { if(w[cr]>lst)s1++; else if(w[cr]<lst)s0++; }//cr_lst//back
  else { if(w[cr]>lst)s0++; else if(w[cr]<lst)s1++; }//lst_cr//go
  if(s0>1&&s1>1)return;
  if(s0==1&&s1>1){ if(op>1&&w[cr]<w[rt])return; if(op==1&&w[cr]>w[rt])return; }
  if(s0>1&&s1==1){ if(op>1&&w[cr]>w[rt])return; if(op==1&&w[cr]<w[rt])return; }
  if(op==1)solve(s0,s1,w[cr]),printf("cr=%d[%d,%d] ans=%lld\n",cr,s0,s1,ans);
  if(op==2)add(w[cr],1,f[s0>1?2:s0][s1>1?2:s1]);///can't >2
  if(op==3)add(w[cr],-1,f[s0>1?2:s0][s1>1?2:s1]);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op);
}
void solve(int cr,int s)
{
  vis[cr]=1; printf("cr=%d s=%d\n",cr,s);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
    dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2);
      }
  printf(" ans=%lld\n",ans);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
    dfs(v,cr,w[cr],0,0,3);
      }
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
    int ts;if(siz[v]<siz[cr])ts=siz[v];else ts=s-siz[cr];
    mn=N;getrt(v,cr,ts);solve(rt,ts);
      }
}
int main()
{
  T=rdn();
  while(T--)
    {
      init();n=rdn();for(int i=1;i<=n;i++)w[i]=rdn(),lm=Mx(lm,w[i]);
      for(int i=1,u,v;i<n;i++)u=rdn(),v=rdn(),add(u,v),add(v,u);
      init_dfs(1,0);mn=N;getrt(1,0,n);solve(rt,n);
      printf("%lld\n",ans+n);
    }
  return 0;
}
有时间再调调?

然后想起讲课学长说的。之所以记录小于号个数之类的,其实是为了不根据形态转移,而是根据小于号和大于号的个数来转移。

记 s0 表示小于号个数, s1 表示大于号个数。对于重心到当前节点的路径的 s0 和 s1 ,枚举 i 和 j 表示在哪个桶里,然后只要看看 s0' = s0+i 和 s1' = s1+j 属于上面情况中的哪一种,如果都不属于(当且仅当 s0' > 1 && s1' > 1)就跳过;如果是第3种或第4种就查 w[cr] 为止的前缀/后缀和加到答案里;不然就把整个桶的值加到答案里。

真是简明的思想!应该注意到路径形态对于转移的影响可以归约为不同的大于号、小于号个数对于转移的影响,而大于号、小于号的个数比路径形态好维护得多!

然后发现自己一直写的错误的点分治(当然当初学习的时候写的是正确的)。改一改就好啦。

注意每组数据开头的 init( ) 里别写 memset ,不然有可能 n2 ; 把点值离散化一下可以让树状数组复杂度更正确。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
int T,n,w[N],tp[N],hd[N],xnt,to[N<<1],nxt[N<<1],siz[N],rt,mn,lm;
int f[3][3][N]; ll ans; bool vis[N];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
void init()
{
  xnt=0;for(int i=1;i<=n;i++)hd[i]=0;//memset(hd,0,sizeof hd);
  ans=0;for(int i=1;i<=n;i++)vis[i]=0;//memset(vis,0,sizeof vis);
  sort(tp+1,tp+n+1);lm=unique(tp+1,tp+n+1)-tp-1;///
  for(int i=1;i<=n;i++)w[i]=lower_bound(tp+1,tp+lm+1,w[i])-tp;
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
/*
void init_dfs(int cr,int fa)
{
  siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v];
}
*/
void getrt(int cr,int fa,int s)
{
  int mx=0;siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)
      {
    getrt(v,cr,s);siz[cr]+=siz[v];
    mx=Mx(mx,siz[v]);
      }
  mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr;
}
void add(int x,int k,int s0,int s1){for(;x<=lm;x+=(x&-x))f[s0][s1][x]+=k;}
int qry(int x,int s0,int s1){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[s0][s1][x];return ret;}
int qry_s(int x,int s0,int s1){return qry(lm,s0,s1)-qry(x-1,s0,s1);}
int cal(int s0,int s1,int i,int j,int tw)
{
  if(s0>1&&s1>1)return 0;//
  if(s0==1&&s1>1)return qry_s(tw,i,j);
  if(s0>1&&s1==1)return qry(tw,i,j);
  return qry(lm,i,j);
}
void calc(int tw,int s0,int s1)
{
  ans++;//with rt
  for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=cal(s0+i,s1+j,i,j,tw);
}
void dfs(int cr,int fa,int lst,int s0,int s1,int op)
{
  if(op==1){ if(w[cr]>lst)s1++; if(w[cr]<lst)s0++; }
  else{ if(lst>w[cr])s1++; if(lst<w[cr])s0++; }
  if(s0>1&&s1>1)return;
  if(s0==1&&s1>1){ if(op==1&&w[cr]>w[rt])return; if(op>1&&w[cr]<w[rt])return; }
  if(s1==1&&s0>1){ if(op==1&&w[cr]<w[rt])return; if(op>1&&w[cr]>w[rt])return; }
  if(op==1)calc(w[cr],s0,s1);//,printf("cr=%d[%d,%d]ans=%lld\n",cr,s0,s1,ans);
  if(op==2)add(w[cr],1,s0>1?2:s0,s1>1?2:s1);
  if(op==3)add(w[cr],-1,s0>1?2:s0,s1>1?2:s1);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op);
}
void solve(int cr,int s)
{
  vis[cr]=1;// printf("cr=%d s=%d\n",cr,s);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
    dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2);
      }
  //  printf("ans=%lld\n",ans);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])dfs(v,cr,w[cr],0,0,3);
  for(int i=hd[cr],v,ts;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
    mn=N;ts=(siz[v]<siz[cr]?siz[v]:s-siz[cr]);
    getrt(v,cr,ts);solve(rt,ts);
      }
}
int main()
{
  T=rdn();
  while(T--)
    {
      n=rdn();lm=0;for(int i=1;i<=n;i++)w[i]=rdn(),tp[i]=w[i];
      init();
      for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
      /*init_dfs(1,0);*/mn=N;getrt(1,0,n);solve(rt,n);
      printf("%lld\n",ans+n);
    }
  return 0;
}

 

posted on 2018-12-26 21:07  Narh  阅读(333)  评论(0编辑  收藏  举报

导航