题解 [AGC019F] Yes or No

传送门

AGC 的优雅.jpg

首先有个 \(O(n^2)\) 的简单 DP
然后手玩一下转移路径发现答案是这个东西

\[\sum\limits_{i=0}^n\sum\limits_{j=0}^m\frac{\max(i, j)}{i+j}\frac{n^{\underline{n-i}}m^{\underline{m-j}}}{(n+m)^{\underline{n+m-i-j}}}\binom{n-i+m-j}{n-i} \]

尝试多项式处理
拆开发现是

\[\frac{n!m!}{(n+m)!}\sum\limits_{i=0}^n\sum\limits_{j=0}^m\frac{\max(i, j)}{i+j}\frac{(i+j)!}{i!j!}\frac{(n-i+m-j)!}{(n-i)!(m-j)!} \]

听说这个东西可以分治 NTT 做
口胡一下是额外维护一个不考虑那个带 \(\max\) 项的 NTT
这样在做一个分治 NTT 时在考虑跨过 mid 贡献时就可以取右边的值了

然后正解:
还是考虑放到格路上处理
这里从粉兔那里剽的图:

钦定 \(n\geqslant m\),那么相当于从点 \((n, m)\) 走到 \((0, 0)\)
每条路径的贡献是其经过的红色边数
那么求出所有路径的贡献之和再除以总路径数即可
所有路径的贡献之和怎么求呢?
考虑将每条路径在斜线 \(y=x\) 上面的部分都翻下来
可以发现除了恰好从斜线上往左走的部分贡献都是不变的
同时发现若忽略刚才没算的贡献那每条路径的贡献都恰好是 \(n\)
于是单独计算恰好从斜线上往左走的贡献
根据期望的线性性拆开,经过斜线上每个点的路径中有一半是往左走的
所以这部分对答案的贡献是

\[\frac{1}{2}\sum\limits_{i=1}^{\min(n, m)}\frac{\binom{n+m-2i}{n-i}\binom{2i}{i}}{\binom{n+m}{n}} \]

那么就做完了,复杂度 \(O(n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
ll fac[N], inv[N], ans;
const ll mod=998244353;
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

signed main()
{
	n=read(); m=read();
	fac[0]=fac[1]=1; inv[0]=inv[1]=1;
	for (int i=2; i<=n+m; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<=n+m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<=n+m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
	// for (int i=0; i<=n; ++i)
	// 	for (int j=0; j<=m; ++j)
	// 		ans=(ans+max(i, j)*qpow(i+j, mod-2)%mod*fac[i+j]*inv[i]%mod*inv[j]%mod*fac[n+m-i-j]%mod*inv[n-i]%mod*inv[m-j])%mod;
	// ans=fac[n]*fac[m]%mod*inv[n+m]%mod*ans%mod;
	for (int i=1; i<=min(n, m); ++i) ans=(ans+C(n+m-2*i, n-i)*C(2*i ,i)%mod*qpow(C(n+m, n), mod-2))%mod;
	printf("%lld\n", (max(n, m)+qpow(2, mod-2)*ans)%mod);

	return 0;
}
posted @ 2022-05-29 07:46  Administrator-09  阅读(2)  评论(0编辑  收藏  举报