题目链接

https://agc019.contest.atcoder.jp/tasks/agc019_f

题意简述

n+mn+m个问题,答案都是"Yes"或"No",其中nn个是"Yes",mm个是"No"。你回答一个问题后,不管是否正确,都可以得到这个问题的答案。假设你都不知道答案只能猜,按最有策略行动,求你期望答对问题的个数。

题解

容易得到一个O(n2)O(n^2)的dp
fi,j=ii+jfi1,j+ji+jfi,j1+max(i,j)i+j f_{i,j}=\frac{i}{i+j}f_{i-1,j}+\frac{j}{i+j}f_{i,j-1}+\frac{\max(i,j)}{i+j}
gi,j=fi,jmax(i,j)g_{i,j}=f_{i,j}-\max(i,j),假设i>ji>j
gi,j+i=ii+jfi1,j+ji+jfi,j1+ii+jgi,j+i=ii+j(gi1,j+i1)+ji+j(gi,j1+i)+ii+jgi,j=ii+jgi1,j+ji+jgi,j1 \begin{aligned} g_{i,j}+i & =\frac{i}{i+j}f_{i-1,j}+\frac{j}{i+j}f_{i,j-1}+\frac{i}{i+j}\\ g_{i,j}+i & =\frac{i}{i+j}(g_{i-1,j}+i-1)+\frac{j}{i+j}(g_{i,j-1}+i)+\frac{i}{i+j}\\ g_{i,j} & =\frac{i}{i+j}g_{i-1,j}+\frac{j}{i+j}g_{i,j-1} \end{aligned}
同理,i<ji<j时上式也成立。

假设i=ji=j
gi,i+i=12fi1,i+12fi,i1+12gi,i+i=12(gi1,i+i)+12(gi,i1+i)+12gi,j=12gi1,i+12gi,i1+12 \begin{aligned} g_{i,i}+i & =\frac{1}{2}f_{i-1,i}+\frac{1}{2}f_{i,i-1}+\frac{1}{2}\\ g_{i,i}+i & =\frac{1}{2}(g_{i-1,i}+i)+\frac{1}{2}(g_{i,i-1}+i)+\frac{1}{2}\\ g_{i,j} & =\frac{1}{2}g_{i-1,i}+\frac{1}{2}g_{i,i-1}+\frac{1}{2} \end{aligned}
综上
gi,j=ii+jgi1,j+ji+jgi,j1+12[i=j] g_{i,j}=\frac{i}{i+j}g_{i-1,j}+\frac{j}{i+j}g_{i,j-1}+\frac{1}{2}[i=j]
上式可以看作从(n,m)(n,m)出发,在(i,j)(i,j)每次有ii+j\frac{i}{i+j}向左走,有ji+j\frac{j}{i+j}向下走,最终到达(0,0)(0,0),求经过横坐标等于纵坐标的点的个数的期望的12\frac{1}{2}。这个可以对所有这样的点分开求
i=0min(n,m)(n+m2ini)(2ii)(n+mn) \sum_{i=0}^{\min(n,m)}\frac{\binom{n+m-2i}{n-i}\binom{2i}{i}}{\binom{n+m}{n}}
那么答案为
max(i,j)+i=0min(n,m)(n+m2ini)(2ii)(n+mn) \max(i,j)+\sum_{i=0}^{\min(n,m)}\frac{\binom{n+m-2i}{n-i}\binom{2i}{i}}{\binom{n+m}{n}}

代码

#include <cstdio>
#include <algorithm>

int read()
{
  int x=0,f=1;
  char ch=getchar();
  while((ch<'0')||(ch>'9'))
    {
      if(ch=='-')
        {
          f=-f;
        }
      ch=getchar();
    }
  while((ch>='0')&&(ch<='9'))
    {
      x=x*10+ch-'0';
      ch=getchar();
    }
  return x*f;
}

const int maxn=500000;
const int mod=998244353;
const int inv=499122177;

int n,m,ans,fac[maxn*2+10],ifac[maxn*2+10];

int main()
{
  n=read();
  m=read();
  fac[0]=ifac[0]=ifac[1]=1;
  for(int i=1; i<=std::max(n,m)*2; ++i)
    {
      fac[i]=1ll*fac[i-1]*i%mod;
    }
  for(int i=2; i<=std::max(n,m)*2; ++i)
    {
      ifac[i]=1ll*(mod-mod/i)*ifac[mod%i]%mod;
    }
  for(int i=1; i<=std::max(n,m)*2; ++i)
    {
      ifac[i]=1ll*ifac[i-1]*ifac[i]%mod;
    }
  ans=std::max(n,m);
  for(int i=1; i<=std::min(n,m); ++i)
    {
      ans=(ans+1ll*fac[n+m-2*i]*fac[2*i]%mod*fac[n]%mod*fac[m]%mod*ifac[n+m]%mod*ifac[n-i]%mod*ifac[m-i]%mod*ifac[i]%mod*ifac[i]%mod*inv)%mod;
    }
  printf("%d\n",ans);
  return 0;
}