洛谷 5291 [十二省联考2019]希望(52分)——思路+树形DP

题目:https://www.luogu.org/problemnew/show/P5291

考场上写了 16 分的。不过只得了 4 分。

对于一个救援范围,其中合法的点集也是一个连通块。 2n 枚举一个救援范围,然后换根 DP 一下范围内的每个点开始的最长链,那些最长链 <=L 的点就是该范围的合法点集。

这样得到每个合法点集出现的方案, 与卷积 k 次即可。卷积的时候先 FWT 成点值,然后快速幂一样乘 k 次,再 FWT 回来即可。

但只有 4 分。过不了大样例。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1e6+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}

int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
namespace S1{
  const int K=20,M=(1<<16)+5;
  int bin[K],dp[K],pr[K],sc[K],nd[K],tot;
  int ts,len,f[M],g[M]; bool vis[K],col[K];
  void chk_dfs(int cr,int fa)
  {
    vis[cr]=1;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr);
  }
  void dfs(int cr,int fa)
  {
    dp[cr]=0;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa)
    dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1);
  }
  void dfsx(int cr,int fa,int tmp)
  {
    if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1];
    int l=tot;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa) nd[++tot]=v;
    int r=tot; if(l==r)return;
    pr[l+1]=dp[nd[l+1]]+1;
    for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1);
    sc[r]=dp[nd[r]]+1;
    for(int i=r-1;i>l;i--)sc[i]=Mn(sc[i+1],dp[nd[i]]+1);
    for(int i=l+1;i<=r;i++)
      {
    int tp=tmp;//=tmp
    if(i>l+1)tp=pr[i-1];if(i<r)tp=Mx(tp,sc[i+1]);
    dfsx(nd[i],cr,tp+1);
      }
  }
  void fwt(int *a,bool fx)
  {
    for(int R=2;R<=len;R<<=1)
      for(int i=0,m=R>>1;i<len;i+=R)
    for(int j=0;j<m;j++)
      {
        if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]);
        else a[i+j]=upt(a[i+j]-a[i+m+j]);
      }
  }
  void solve()
  {
    bin[0]=1;
    for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
    for(int s=1;s<bin[n];s++)
      {
    for(int i=1;i<=n;i++)
      {
        vis[i]=0;
        if(s&bin[i-1])col[i]=1; else col[i]=0;
      }
    int cr=0;
    for(int i=1;i<=n;i++)
      if(col[i]){chk_dfs(i,0);cr=i;break;}
    bool fg=0;
    for(int i=1;i<=n;i++)
      if(col[i]&&!vis[i]){fg=1;break;}
    if(fg)continue;
    ts=tot=0; dfs(cr,0); dfsx(cr,0,0);
    if(ts){f[ts]++; g[ts]++;}
      }
    k--; len=bin[n]; fwt(g,0); fwt(f,0);
    while(k)
      {
    if(k&1)
      {
        for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod;
      }
    for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod;
    k>>=1;
      }
    int ans=0; fwt(f,1);
    for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]);
    printf("%d\n",ans);
  }
}
int main()
{
  freopen("hope.in","r",stdin);
  freopen("hope.out","w",stdout);
  n=rdn();L=rdn();k=rdn();
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  if(n<=16){S1::solve();return 0;}
  return 0;
}

后来发现两个地方写错了:

1.换根的时候做了前缀 max 和后缀 max ,其中后缀取 max 写成取 min 了;

2.往孩子换根的时候用了一个 tp 对父亲来的 tmp 、前缀 max 、后缀 max 取 max ,结果 tp=tmp 之后写成 tp = pr[ ] 而非 tp = Mx( tp , pr[ ] ) 。

改了这两个地方就有 16 分了。

希望以后写代码的时候更仔细。别走神或不集中之类的。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1e6+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}

int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
namespace S1{
  const int K=20,M=(1<<16)+5;
  int bin[K],dp[K],pr[K],sc[K],nd[K],tot;
  int ts,len,f[M],g[M]; bool vis[K],col[K];
  void chk_dfs(int cr,int fa)
  {
    vis[cr]=1;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr);
  }
  void dfs(int cr,int fa)
  {
    dp[cr]=0;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa)
    dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1);
  }
  void dfsx(int cr,int fa,int tmp)
  {
    if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1];
    int l=tot;
    for(int i=hd[cr],v;i;i=nxt[i])
      if(col[v=to[i]]&&v!=fa) nd[++tot]=v;
    int r=tot; if(l==r)return;
    pr[l+1]=dp[nd[l+1]]+1;
    for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1);
    sc[r]=dp[nd[r]]+1;
    for(int i=r-1;i>l;i--)sc[i]=Mx(sc[i+1],dp[nd[i]]+1);////mx not mn!!!
    for(int i=l+1;i<=r;i++)
      {
    int tp=tmp;//=tmp
    if(i>l+1)tp=Mx(tp,pr[i-1]);if(i<r)tp=Mx(tp,sc[i+1]);//mx!!!
    dfsx(nd[i],cr,tp+1);
      }
  }
  void fwt(int *a,bool fx)
  {
    for(int R=2;R<=len;R<<=1)
      for(int i=0,m=R>>1;i<len;i+=R)
    for(int j=0;j<m;j++)
      {
        if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]);
        else a[i+j]=upt(a[i+j]-a[i+m+j]);
      }
  }
  void solve()
  {
    bin[0]=1;
    for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
    for(int s=1;s<bin[n];s++)
      {
    for(int i=1;i<=n;i++)
      {
        vis[i]=0;
        if(s&bin[i-1])col[i]=1; else col[i]=0;
      }
    int cr=0;
    for(int i=1;i<=n;i++)
      if(col[i]){chk_dfs(i,0);cr=i;break;}
    bool fg=0;
    for(int i=1;i<=n;i++)
      if(col[i]&&!vis[i]){fg=1;break;}
    if(fg)continue;
    ts=tot=0; dfs(cr,0); dfsx(cr,0,0);
    if(ts){f[ts]++; g[ts]++;}
      }
    k--; len=bin[n]; fwt(g,0); fwt(f,0);
    while(k)
      {
    if(k&1)
      {
        for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod;
      }
    for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod;
    k>>=1;
      }
    int ans=0; fwt(f,1);
    for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]);
    printf("%d\n",ans);
  }
}
int main()
{
  freopen("hope.in","r",stdin);
  freopen("hope.out","w",stdout);
  n=rdn();L=rdn();k=rdn();
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  if(n<=16){S1::solve();return 0;}
  return 0;
}
View Code

然后参照题解写了 52 分的。

很重要的转化是令 \( f[i] \) 表示 i 是合法点的救援范围个数,那么 k 个救援范围包含 i 的方案就是 \( f[i]^k \) ;考虑到一个方案的合法点集是连通块,即点数比边数大一,所以令 \( g[i] \) 表示边 i 的两端点是合法点的救援范围个数,答案就是 \( \sum\limits_{i=1}^{n}f[i]^k - \sum\limits_{i=1}^{n-1}g[i]^k \) 。

然后就可以写 n*L 的 DP 了。再把链和 L=n 的部分做一下就有 52 分。

不太会 k=1 时候的长链剖分。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1e6+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,L,k,hd[N],xnt=1,to[N<<1],nxt[N<<1],rd[N],f[N],g[N];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;rd[y]++;}
namespace S1{
  const int N=1005;
  int dfs(int cr,int fa,int lm)
  {
    int ret=1; if(!lm)return ret;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    ret=(ll)ret*(dfs(v,cr,lm-1)+1)%mod;
    return ret;
  }
  void dfsx(int cr,int fa)
  {
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      int ret=dfs(cr,v,L-1);
      ret=(ll)ret*dfs(v,cr,L-1)%mod;
      g[i>>1]=ret; dfsx(v,cr);
    }
  }
  void solve()
  {
    for(int i=1;i<=n;i++) f[i]=dfs(i,0,L);
    dfsx(1,0); int ans=0;
    for(int i=1;i<=n;i++)ans=upt(ans+pw(f[i],k));
    for(int i=1;i<n;i++)ans=upt(ans-pw(g[i],k));
    printf("%d\n",ans);
  }
}
namespace S2{
  const int N=1e5+5,M=105;
  int nd[N],tot;
  struct Node{
    int v[M],s[M],cd;
    void init(){v[0]=s[0]=1;}
    void frs()
    {
      for(int i=1;i<=cd;i++)
    s[i]=upt(s[i-1]+v[i]);
    }
    void cz()
    {
      cd=Mn(cd+1,L);
      for(int i=cd;i;i--)v[i]=v[i-1];
      frs();
    }
  }dp[N],pr[N],sc[N],up[N];
  void mrg(Node &d0,Node d1)
  {
    int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm+1));
    d0.cd=tc;
    for(int j=yc+1;j<=tc;j++)
      d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1
    for(int j=1;j<=tc;j++)
      {
    int tp;
    if(j-1<=lm)tp=d1.s[j-1]; else tp=d1.s[lm];
    tp++;///for choosen't
    d0.v[j]=(ll)d0.v[j]*tp%mod;
    if(j-1<=lm)
      d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j-1])%mod;
      }
    d0.frs();
  }
  void mg2(Node &d0,Node d1)
  {
    int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm));
    d0.cd=tc;
    for(int j=yc+1;j<=tc;j++)
      d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1
    for(int j=0;j<=tc;j++)
      {
    int tp;
    if(j<=lm)tp=d1.s[j]; else tp=d1.s[lm];
    tp++;///for choosen't
    d0.v[j]=(ll)d0.v[j]*tp%mod;
    if(j&&j<=lm)
      d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j])%mod;
      }
    d0.frs();
  }
  void dfs(int cr,int fa)
  {
    dp[cr].init();
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      dfs(v,cr);
      mrg(dp[cr],dp[v]);
    }
  }
  void dfsx(int cr,int fa)
  {
    int tp=up[cr].cd;
    f[cr]=(tp>=L?up[cr].s[L]:up[cr].s[tp]);
    tp=dp[cr].cd;
    f[cr]=(ll)f[cr]*(tp>=L?dp[cr].s[L]:dp[cr].s[tp])%mod;
    int l=tot;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      nd[++tot]=i;
      if(tot==l+1)pr[tot].init();
      else pr[tot]=pr[tot-1];
      mrg(pr[tot],dp[v]);
    }
    int r=tot;
    for(int i=r;i>l;i--)
      {
    if(i==r)sc[i].init();
    else sc[i]=sc[i+1];
    mrg(sc[i],dp[to[nd[i]]]);
      }
    for(int i=l+1;i<=r;i++)
      {
    pr[i].v[0]=pr[i].s[0]=0;pr[i].frs();
    sc[i].v[0]=sc[i].s[0]=0;sc[i].frs();
      }
    for(int i=l+1;i<=r;i++)
      {
    int v=to[nd[i]],bh=nd[i]>>1;
    up[v]=up[cr];
    if(i>l+1) mg2(up[v],pr[i-1]);
    if(i<r) mg2(up[v],sc[i+1]);
    int tp=up[v].cd;
    g[bh]=(tp>=L-1?up[v].s[L-1]:up[v].s[tp]);
    tp=dp[v].cd;
    g[bh]=(ll)g[bh]*(tp>=L-1?dp[v].s[L-1]:dp[v].s[tp])%mod;
    up[v].cz();
    dfsx(v,cr);
      }
  }
  void solve()
  {
    dfs(1,0); up[1].init(); dfsx(1,0);
    int ans=0;
    for(int i=1;i<=n;i++)
      ans=upt(ans+pw(f[i],k));
    for(int i=1;i<n;i++)
      ans=upt(ans-pw(g[i],k));
    printf("%d\n",ans);
  }
}
namespace S3{
  const int N=2e5+5;
  int dp[N],nd[N],pr[N],sc[N],tot;
  void dfs(int cr,int fa)
  {
    dp[cr]=1;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      dfs(v,cr); dp[cr]=(ll)dp[cr]*(dp[v]+1)%mod;
    }
  }
  void dfsx(int cr,int fa,int tmp)
  {
    f[cr]=(ll)dp[cr]*(tmp+1)%mod;
    int l=tot;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      nd[++tot]=i;
      if(tot==l+1)pr[tot]=1;
      else pr[tot]=pr[tot-1];
      pr[tot]=(ll)pr[tot]*(dp[v]+1)%mod;
    }
    int r=tot;
    for(int i=r;i>l;i--)
      {
    if(i==r)sc[i]=1;
    else sc[i]=sc[i+1];
    sc[i]=(ll)sc[i]*(dp[to[nd[i]]]+1)%mod;
      }
    for(int i=l+1;i<=r;i++)
      {
    int v=to[nd[i]], tp=tmp+1, bh=nd[i]>>1;
    if(i>l+1)tp=(ll)tp*pr[i-1]%mod;
    if(i<r)tp=(ll)tp*sc[i+1]%mod;
    g[bh]=(ll)tp*dp[v]%mod;
    dfsx(v,cr,tp);
      }
  }
  void solve()
  {
    dfs(1,0); dfsx(1,0,0); int ans=0;
    for(int i=1;i<=n;i++)
      ans=upt(ans+pw(f[i],k));
    for(int i=1;i<n;i++)
      ans=upt(ans-pw(g[i],k));
    printf("%d\n",ans);
  }
}
namespace S4{
  void solve()
  {
    int ans=0;
    for(int i=1;i<=n;i++)
      {
    int t0=Mn(L+1,i), t1=Mn(L+1,n-i+1);
    ans=upt(ans+pw((ll)t0*t1%mod,k));
      }
    for(int i=1;i<n;i++)
      {
    int t0=Mn(L,i), t1=Mn(L,n-i);
    ans=upt(ans-pw((ll)t0*t1%mod,k));
      }
    printf("%d\n",ans);
  }
}
int main()
{
  n=rdn();L=rdn();k=rdn();
  for(int i=1,u,v;i<n;i++)
    { u=rdn();v=rdn();add(u,v);add(v,u);}
  if(n<=1000){S1::solve();return 0;}
  if((ll)n*L<=1e7){S2::solve();return 0;}
  if(L==n){S3::solve();return 0;}
  bool fg=0;
  for(int i=1;i<=n;i++)if(rd[i]>2){fg=1;break;}
  if(!fg){S4::solve();return 0;}
  return 0;
}
View Code

 

posted on 2019-04-10 21:56  Narh  阅读(242)  评论(0编辑  收藏  举报

导航