CF739E Gosha is hunting(费用流/凸优化dp)

纪念合格考爆炸。

其实这个题之前就写过博客了,qwq但是不小心弄丢了,所以今天来补一下。

首先,一看到球的个数的限制,不难相当用网络流的流量来限制每个球使用的数量。

由于涉及到最大化期望,所以要使用最大费用最大流。

我们新建两个点\(ss,tt\),分别表示两种球。

那么我们现在考虑应该怎么计算期望呢。

首先,如果假设如果对于一个怪物用一个球,那么连边也就比较容易了
对于一个怪物\(x\)
我们\(ss -> x\),费用为\(p[i]\),流量为1。表示一个球在一个怪物上只能用一次。
\(tt\)也是同理。

然后对于每一个\(x->t\),费用是\(0\),流量是\(1\),表示一个怪物只能用一个球。

但是,要是每次不要求只能用一个球应该怎么做呢。

我们考虑,这条边的费用应该是多少。

两个球都用的期望应该是\(1-(1-p_i)(1-q_i)\)
经过展开,我们发现应该是\(p_i+q_i-p_i\times q_i\)

那么由于我们发现,由于用了两个球,所以已经获得了二者之和的收益,那么在这一侧,只需要在上述建图的基础上\(x->t\),费用是\(-p_i\times q_i\)即可。

最后跑一发最大费用最大流就能通过这个题qwq时间复杂度玄学。

#include<bits/stdc++.h>
#define pb push_back
#define mk make_pair
#define ll long long
#define db double

using namespace std;

inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
}

const int maxn = 4010;
const int maxm = 3e6+1e2;
const double eps = 1e-10;

int point[maxn],nxt[maxm],to[maxm],pre[maxm],from[maxn];
double dis[maxn];
int vis[maxn];
double cost[maxm];
int flow[maxm];
double ans;
int n,m,cnt=1;
int s,t;

void addedge(int x,int y,db w,int f)
{
    nxt[++cnt]=point[x];
    pre[cnt]=x;
    to[cnt]=y;
    cost[cnt]=w;
    flow[cnt]=f;
    point[x]=cnt;
}

void insert(int x,int y,db w,int f)
{
    addedge(x,y,w,f);
    addedge(y,x,-w,0);
}

queue<int> q;

bool spfa(int s)
{
    for (int i=1;i<=maxn-3;i++) dis[i]=-1e9;
    memset(vis,0,sizeof(vis));
    q.push(s);
    dis[s]=0;
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        vis[x]=0;
        for (int i=point[x];i;i=nxt[i])
        {
            int p = to[i];
            if (dis[p]-(dis[x]+cost[i])<-eps && flow[i]>0)
            {
                from[p]=i;
                dis[p]=dis[x]+cost[i];
                if (!vis[p])
                {
                    q.push(p);
                    vis[p]=1;
                }
            }
        }
    }
    if (dis[t]==-1e9) return false;
    return true;
}

void mcf()
{
    double x = 1e9;
    for (int i=from[t];i;i=from[pre[i]]) x=min(x,1.0*flow[i]);
    for (int i=from[t];i;i=from[pre[i]])
    {
        flow[i]-=x;
        flow[i^1]+=x;
        ans+=x*cost[i];
    }
}

void solve()
{
    while (spfa(s)) mcf();
}

db a[maxn],b[maxn];
int ss,sss;
int aa,bb;

int main()
{
   n=read(),aa=read(),bb=read();
   s=maxn-10;
   ss=s+1;
   t=s+3;
   sss=ss+1;
   insert(s,ss,0,aa);
   insert(s,sss,0,bb);
   for (int i=1;i<=n;i++) scanf("%lf",&a[i]);
   for (int i=1;i<=n;i++) scanf("%lf",&b[i]);
   for (int i=1;i<=n;i++)
   {
   	  insert(ss,i,a[i],1);
   	  insert(sss,i,b[i],1);
   	  insert(i,t,0,1);
   	  insert(i,t,-a[i]*b[i],1);
   }
   solve();
   printf("%.4lf\n",ans);
   return 0;
}

但是其实这个题的正解是凸优化\(dp\)

首先,先做一个最\(naive\)的想法。

我们令\(dp[i][j][k]\)表示前\(i\)个怪物,用了\(j\)一号球,用了\(k\)个二号球

那么转移也是显然的。
每次只需要讨论一下对于当前的怪物是用几个球,用哪个即可。

但是这样的复杂度是\(O(n^3)\)的。
显然没有办法通过。

考虑怎么优化。

由于题目中涉及到的正好用几个球,并且通过打表发现函数是凸的,那么我们就可以直接用凸优化来优化掉一维。

(其实是可以直接优化两个的,但是我太懒,所以没写。)

我们对于当前二分的值,表示每选一个二号球,就可以多得到\(mid\)的期望。不限制选的个数。

那么不难得到下面的这个转移式子。

dp[i][j]=dp[i-1][j];
dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
if (j)
{
   	 dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
   	 dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
}

然后通过调整\(mid\),通过正好选到\(k\)个二号球。

最后求一个\(dp\)数组,然后记得把贡献减去就行。

时间复杂度\(n^2log\),非常优秀。

(其实是如果精度太小会\(WA\),精度太大会\(TLE\)

但是完全可以做到\(nlog^2\)的。

给代码。

#include<bits/stdc++.h>
#define pb push_back
#define mk make_pair
#define ll long long
#define db double 

using namespace std;

inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
}

const int maxn = 2010;
const db eps = 1e-6;

struct ymh{
    db val;
    int num;
    ymh operator + (const ymh &b) const
    {
        return (ymh){val+b.val,num+b.num};
    }
};

ymh dp[maxn][maxn];
db a[maxn],b[maxn];
int n;
db l=-4,r=4;

inline int dcmp(double x,double y) 
{
  return x-y<-eps ? -1 : (x-y>eps ? 1 : 0);
}

inline ymh max(ymh a,ymh b)
{
    int now = dcmp(a.val,b.val);
    if (now==0)
    {
        if (a.num<b.num) return a;
        else return b;
    }
    else
    {
        if(now==-1) return b;
        else return a;
    }
}

int numa,numb;
db aa[maxn];
db bb[maxn];
db both[maxn];

bool check(db lim)
{
   for (int i=1;i<=n;i++) aa[i]=a[i];
   for (int i=1;i<=n;i++) bb[i]=b[i]+lim;
   for (int i=1;i<=n;i++) both[i]=1.0-(1.0-a[i])*(1.0-b[i])+lim;
   for (register int i=1;i<=n;++i)
   {
   	  for (register int j=0;j<=numa;++j)
   	  {
   	  	dp[i][j]=dp[i-1][j];
   	  	dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
   	  	if (j)
   	  	{
   	  		dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
   	  		dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
        }
    //	if (dp[i][j].num>numb) return false;
      }
   }
   return dp[n][numa].num<=numb;
}

int main()
{
   n=read(),numa=read(),numb=read();
   for (int i=1;i<=n;i++) scanf("%lf",&a[i]);
   for (int i=1;i<=n;i++) scanf("%lf",&b[i]);
   double ans=0;
   while (r-l>=eps)
   {
   	  db mid = (l+r)/2;
   	 // memset(dp,0,sizeof(dp));
   	  
   	  if (check(mid)) l=mid,ans=mid;
   	  else r=mid;
   	  //printf("%.4lf %d\n",mid,dp[n][numa].num);
   }
   //cout<<1<<endl; 
   //printf("%.4lf\n",ans);
   //memset(dp,0,sizeof(dp));
   check(ans);
   //printf("")
   //printf("%.4lf %d\n",dp[n][numa].val,dp[n][numa].num);
   printf("%.4lf",dp[n][numa].val-1.0*numb*ans); 
   return 0;
}

posted @ 2019-01-14 20:56  y_immortal  阅读(235)  评论(0编辑  收藏  举报