洛谷 5289 [十二省联考2019]皮配——分开决策的动态规划

题目:https://www.luogu.org/problemnew/show/P5289

考场上只写了 50 分的 DP 。并且没意识到只记录两个导师的人数就行了,所以记了 3 个。不过写的记搜,还是得了 50 分。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=1005,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}

int n,c,c0,c1,d0,d1,ct[N];
struct Node{
  int bh,s,p;
  bool operator< (const Node &b)const
  {return bh<b.bh;}
}a[N];
struct Dt{
  int cr,s0,s1,s2;bool fx;
  Dt(int c=0,bool f=0,int s0=0,int s1=0,int s2=0):
    cr(c),fx(f),s0(s0),s1(s1),s2(s2) {}
  bool operator< (const Dt &b)const
  {
    if(cr!=b.cr)return cr<b.cr; if(fx!=b.fx)return fx<b.fx;
    if(s0!=b.s0)return s0<b.s0; if(s1!=b.s1)return s1<b.s1;
    return s2<b.s2;
  }
};
map<Dt,int> mp;
int dfs(int cr,bool fx,int s0,int s1,int s2,int s3)
{
  if(cr>n)return 1;
  Dt tp=Dt(cr,fx,s0,s1,s2);
  if(mp.count(tp))return mp[tp];
  int ret=0, p=a[cr].p, s=a[cr].s;
  if(a[cr-1].bh==a[cr].bh)
    {
      if(!fx)
    {
      if(s2+s<=d0&&p!=1)
        ret=upt(ret+dfs(cr+1,fx,s0+s,s1,s2+s,s3));
      if(s3+s<=d1&&p!=2)
        ret=upt(ret+dfs(cr+1,fx,s0+s,s1,s2,s3+s));
    }
      else
    {
      if(s2+s<=d0&&p!=3)
        ret=upt(ret+dfs(cr+1,fx,s0,s1+s,s2+s,s3));
      if(s3+s<=d1&&p!=4)
        ret=upt(ret+dfs(cr+1,fx,s0,s1+s,s2,s3+s));
    }
    }
  else
    {
      if(s0+ct[a[cr].bh]<=c0)
    {
      if(s2+s<=d0&&p!=1)
        ret=upt(ret+dfs(cr+1,0,s0+s,s1,s2+s,s3));
      if(s3+s<=d1&&p!=2)
        ret=upt(ret+dfs(cr+1,0,s0+s,s1,s2,s3+s));
    }
      if(s1+ct[a[cr].bh]<=c1)
    {
      if(s2+s<=d0&&p!=3)
        ret=upt(ret+dfs(cr+1,1,s0,s1+s,s2+s,s3));
      if(s3+s<=d1&&p!=4)
        ret=upt(ret+dfs(cr+1,1,s0,s1+s,s2,s3+s));
    }
    }
  mp[tp]=ret; return ret;
}
int main()
{
  freopen("mentor.in","r",stdin);
  freopen("mentor.out","w",stdout);
  int T=rdn();
  while(T--)
    {
      n=rdn();c=rdn();c0=rdn();c1=rdn();d0=rdn();d1=rdn();
      mp.clear();
      for(int i=1;i<=n;i++)a[i].p=0;
      for(int i=1;i<=c;i++)ct[i]=0;
      int sm=0;
      for(int i=1;i<=n;i++)
    {
      a[i].bh=rdn(); a[i].s=rdn();
      ct[a[i].bh]+=a[i].s; sm+=a[i].s;
    }
      if(sm>min(c0+c1,d0+d1)){puts("0");continue;}
      int k=rdn();
      for(int i=1,d;i<=k;i++)
    d=rdn(), a[d].p=rdn()+1;
      sort(a+1,a+n+1);
      printf("%d\n",dfs(1,0,0,0,0,0));
    }
  return 0;
}

下面按官方题解的说法,把阵营看成红蓝两种颜色,把派系看成 0 / 1两种编号。

一个很好的思路是,当 k = 0 的时候,每个城市决策颜色与每个学校决策编号是独立的。即单独算出 c 个城市的染色方案,再单独算出 n 个学校的编号方案,把它们乘起来,一个学校就既有了颜色又有了编号,就是答案。

注意到 k 很小。所以考虑把没有限制的学校的编号按这种方法决策,把没有限制学校的城市的颜色按这种方法决策,剩下有限制的部分就用 50 分的那个做法,颜色和编号同时决策即可。

但是如果一个学校自己没有限制,可是它所在的城市里有有限制的学校,那么该学校的颜色受到限制。所以自己的想法是 dp[ i ][ j ] 表示没有限制的城市给蓝包 i 人、没有限制的学校给 0 包 j 人的方案,然后在把没有限制的城市的染色方案和没有限制的学校的编号方案合上去。

但是这种 “ 只有颜色有限制 ” 的学校可能很多?复杂度?感觉刷表会很艰难,所以按之前暴力的方法写了记搜。发现搜的次数还是合理的 ( 107 级别 ),但就是很慢。把 map 换成哈希表之后快了很多,但还是要跑数十秒。这应该是常数吧?……

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1005,M=2505,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}

int n,c,k,c0,c1,d0,d1,sc,sd,sm,ans;
int f[M],g[M],ct[N]; bool vis[N];
struct Node{
  int bh,s,p;
  bool operator< (const Node &b)const
  {return bh<b.bh;}
}a[N];
struct Dt{
  int cr,s0,s1;bool fx;
  Dt(int c=0,int s0=0,int s1=0,bool f=0):
    cr(c),s0(s0),s1(s1),fx(f) {}
  bool operator< (const Dt &b)const
  {
    if(cr!=b.cr)return cr<b.cr; if(s0!=b.s0)return s0<b.s0;
    if(s1!=b.s1)return s1<b.s1; return fx<b.fx;
  }
  bool operator== (const Dt &b)const
  {return cr==b.cr&&s0==b.s0&&s1==b.s1&&fx==b.fx;}
};
namespace H{
  const int bs=10007,md=1e7+3,M=24e6;
  int hd[md+5],xnt,to[M],nxt[M]; Dt c[M];
  void init()
  {
    xnt=0; memset(hd,0,sizeof hd);
  }
  int Hs(Dt t)
  {
    int ret=t.cr;
    ret=((ll)ret*bs+t.s0)%md;
    ret=((ll)ret*bs+t.s1)%md;
    ret=((ll)ret*bs+t.fx)%md;
    return ret;
  }
  void ins(Dt t,int k)
  {
    int ret=Hs(t);
    to[++xnt]=k;nxt[xnt]=hd[ret];hd[ret]=xnt;
    c[xnt]=t;
  }
  int qry(Dt t)
  {
    int ret=Hs(t);
    for(int i=hd[ret];i;i=nxt[i])
      if(c[i]==t)return to[i];
    return -1;
  }
}
void init()
{
  for(int i=1;i<=n;i++)a[i].p=0;
  for(int i=1;i<=c;i++)vis[i]=0;
  for(int i=1;i<=c;i++)ct[i]=0;//
  sc=sd=sm=ans=0;
  H::init();
}
void solve1()
{
  for(int i=1;i<=c0;i++)f[i]=0;//c0 not sm for too large
  for(int i=1;i<=d0;i++)g[i]=0;
  f[0]=1;//
  for(int i=1;i<=c;i++)
    {
      if(vis[i]||!ct[i])continue;//!ct[i]!!!
      sc+=ct[i]; sc=Mn(sc,c0);//
      for(int j=sc;j>=ct[i];j--)
    f[j]=upt(f[j]+f[j-ct[i]]);
    }
  for(int j=1;j<=sc;j++)f[j]=upt(f[j]+f[j-1]);
  g[0]=1;//
  for(int i=1;i<=n;i++)
    {
      if(a[i].p)continue;
      sd+=a[i].s; sd=Mn(sd,d0);
      for(int j=sd;j>=a[i].s;j--)
    g[j]=upt(g[j]+g[j-a[i].s]);
    }
  for(int j=1;j<=sd;j++)g[j]=upt(g[j]+g[j-1]);
}
int dfs(int cr,int s0,int s1,int ct0,int ct1,bool fx)
{
  Dt nw=Dt(cr,s0,s1,fx);
  int k=H::qry(nw); if(k!=-1)return k;
  if(!vis[a[cr].bh])
    {
      int i;
      for(i=cr+1;i<=n&&a[i].bh==a[cr].bh;i++);
      if(i<=n)
    {
      k=dfs(i,s0,s1,ct0,ct1,fx);
      H::ins(nw,k); return k;
    }
      else cr=i;
    }
  if(cr>n)
    {
      int l0=Mx(0,sm-s0-c1), r0=Mn(sc,c0-s0);//Mn
      int l1=Mx(0,sm-s1-d1), r1=Mn(sd,d0-s1);
      if(l0>r0||l1>r1){ H::ins(nw,0); return 0;}
      k=(ll)upt(f[r0]-(l0?f[l0-1]:0))*upt(g[r1]-(l1?g[l1-1]:0))%mod;
      H::ins(nw,k); return k;
    }
  int ret=0, s=a[cr].s, tc=ct[a[cr].bh];
  if(!a[cr].p)
    {
      if(a[cr].bh!=a[cr-1].bh)//O(c*m)
    {
      ct0+=tc;
      if(s0+tc<=c0)ret=dfs(cr+1,s0+tc,s1,ct0,ct1,0);
      if(ct0-s0<=c1)ret=upt(ret+dfs(cr+1,s0,s1,ct0,ct1,1));
      H::ins(nw,ret); return ret;
    }
      else
    {
      k=dfs(cr+1,s0,s1,ct0,ct1,fx);
      H::ins(nw,k); return k;
    }
    }
  int p=a[cr].p; ct1+=s;
  if(a[cr].bh!=a[cr-1].bh)//O(k*m*m)
    {
      ct0+=tc;
      if(p!=1&&s0+tc<=c0&&s1+s<=d0)
    ret=dfs(cr+1,s0+tc,s1+s,ct0,ct1,0);
      if(p!=2&&s0+tc<=c0&&ct1-s1<=d1)
    ret=upt(ret+dfs(cr+1,s0+tc,s1,ct0,ct1,0));
      if(p!=3&&ct0-s0<=c1&&s1+s<=d0)
    ret=upt(ret+dfs(cr+1,s0,s1+s,ct0,ct1,1));
      if(p!=4&&ct0-s0<=c1&&ct1-s1<=d1)
    ret=upt(ret+dfs(cr+1,s0,s1,ct0,ct1,1));
      H::ins(nw,ret); return ret;
    }
  else
    {
      if(!fx)
    {
      if(p!=1&&s1+s<=d0)
        ret=dfs(cr+1,s0,s1+s,ct0,ct1,fx);
      if(p!=2&&ct1-s1<=d1)
        ret=upt(ret+dfs(cr+1,s0,s1,ct0,ct1,fx));
    }
      else
    {
      if(p!=3&&s1+s<=d0)
        ret=dfs(cr+1,s0,s1+s,ct0,ct1,fx);
      if(p!=4&&ct1-s1<=d1)
        ret=upt(ret+dfs(cr+1,s0,s1,ct0,ct1,fx));
    }
      H::ins(nw,ret); return ret;
    }
}
int main()
{
  int T=rdn();
  while(T--)
  {
    n=rdn();c=rdn(); init();
    c0=rdn();c1=rdn();d0=rdn();d1=rdn();
    for(int i=1;i<=n;i++)
      {
    a[i].bh=rdn(); a[i].s=rdn();
    ct[a[i].bh]+=a[i].s; sm+=a[i].s;
      }
    k=rdn();
    for(int i=1,d;i<=k;i++)
      d=rdn(), a[d].p=rdn()+1, vis[a[d].bh]=1;
    sort(a+1,a+n+1);
    solve1();
    printf("%d\n",dfs(1,0,0,0,0,0));
  }
  return 0;
}
View Code

然后观察一番 AC 代码。原来那种 “ 只有颜色有限制 ” 的学校就不用决策它的颜色了!因为它颜色的决策已经在 “有限制学校” 决策颜色的时候决策过了!

所以把没有限制的城市的颜色 DP 一番,把没有限制的学校的编号 DP 一番,再把有限制的 k 个学校的颜色和编号 DP 一番就行啦!这样就可以小常数刷表啦!复杂的是确实的 O( c*m2 + n*m ) 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1005,K=30,M=2505,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
void cz(int &x,int y){x=upt(x+y);}

int n,c,c0,c1,d0,d1,sc,sd,sm;
int ct[N],f[M],g[M],dp[M][M][2];
bool vis[N];
struct Node{
  int t,s,p;
  bool operator< (const Node &b)const
  {
    if(p&&b.p)return t<b.t; else return p;
  }
}a[N];
void solve()
{
  sc=sd=0; f[0]=g[0]=1;
  for(int i=1;i<=c0;i++)f[i]=0;//c0 not sm for RE
  for(int i=1;i<=d0;i++)g[i]=0;
  for(int i=1;i<=c;i++)
    {
      if(vis[i]||!ct[i])continue;
      int tp=ct[i]; sc+=tp; sc=Mn(sc,c0);
      for(int j=sc;j>=tp;j--)
    cz(f[j],f[j-tp]);
    }
  for(int i=1;i<=sc;i++)cz(f[i],f[i-1]);
  for(int i=1;i<=n;i++)
    {
      if(a[i].p)continue;
      int tp=a[i].s; sd+=tp; sd=Mn(sd,d0);
      for(int j=sd;j>=tp;j--)
    cz(g[j],g[j-tp]);
    }
  for(int i=1;i<=sd;i++)cz(g[i],g[i-1]);
}
int cal(int j,int k)
{
  int l0=Mx(0,sm-j-c1), r0=Mn(sc,c0-j);
  int l1=Mx(0,sm-k-d1), r1=Mn(sd,d0-k);
  return (ll)upt(f[r0]-(l0?f[l0-1]:0))*upt(g[r1]-(l1?g[l1-1]:0))%mod;
}
int main()
{
  int T=rdn();
  while(T--)
    {
      n=rdn();c=rdn();
      c0=rdn();c1=rdn();d0=rdn();d1=rdn();
      sm=0; for(int i=1;i<=c;i++)ct[i]=0;
      for(int i=1;i<=n;i++)
    {
      a[i].t=rdn(); a[i].s=rdn();
      ct[a[i].t]+=a[i].s; sm+=a[i].s;
    }
      for(int i=1;i<=c;i++)vis[i]=0;
      for(int i=1;i<=n;i++)a[i].p=0;
      int k=rdn();
      for(int i=1,d;i<=k;i++)
    {
      d=rdn(); a[d].p=rdn()+1; vis[a[d].t]=1;
    }
      solve();
      sort(a+1,a+n+1); int s0=0, s1=0;
      memset(dp,0,sizeof dp); dp[0][0][0]=1;
      for(int i=1;i<=n;i++)
    {
      if(!a[i].p)break; int p=a[i].p;
      if(a[i].t!=a[i-1].t)
        {
          int tc=ct[a[i].t], ts=a[i].s;
          s0+=tc; s1+=ts;
          for(int j=Mn(s0,c0);j>=0;j--)
        for(int k=Mn(s1,d0);k>=0;k--)
          {
            int y0=dp[j][k][0], y1=dp[j][k][1];
            dp[j][k][0]=dp[j][k][1]=0;
            if(s0-j>c1||s1-k>d1)continue;
            if(p!=1&&j>=tc&&k>=ts)
              {
            cz(dp[j][k][0],dp[j-tc][k-ts][0]);
            cz(dp[j][k][0],dp[j-tc][k-ts][1]);
              }
            if(p!=2&&j>=tc)
              {
            cz(dp[j][k][0],dp[j-tc][k][0]);
            cz(dp[j][k][0],dp[j-tc][k][1]);
              }
            if(p!=3&&k>=ts)
              {
            cz(dp[j][k][1],dp[j][k-ts][0]);
            cz(dp[j][k][1],dp[j][k-ts][1]);
              }
            if(p!=4)
              {
            cz(dp[j][k][1],y0);
            cz(dp[j][k][1],y1);
              }
          }
        }
      else
        {
          int ts=a[i].s; s1+=ts;
          for(int j=Mn(s0,c0);j>=0;j--)
        for(int k=Mn(s1,d0);k>=0;k--)
          {
            int y0=dp[j][k][0], y1=dp[j][k][1];
            dp[j][k][0]=dp[j][k][1]=0;
            if(s0-j>c1||s1-k>d1)continue;
            if(p!=1&&k>=ts)
              cz(dp[j][k][0],dp[j][k-ts][0]);
            if(p!=2)
              cz(dp[j][k][0],y0);
            if(p!=3&&k>=ts)
              cz(dp[j][k][1],dp[j][k-ts][1]);
            if(p!=4)
              cz(dp[j][k][1],y1);
          }
        }
    }
      int ans=0;
      for(int j=Mn(c0,s0);j>=0;j--)
    for(int k=Mn(c1,s1);k>=0;k--)
      {
        if(s0-j>c1||s1-k>d1)continue;
        int tp=upt(dp[j][k][0]+dp[j][k][1]);
        if(tp)ans=(ans+(ll)tp*cal(j,k))%mod;
      }
      printf("%d\n",ans);
    }
  return 0;
}

 

posted on 2019-04-09 19:48  Narh  阅读(325)  评论(0编辑  收藏  举报

导航