题解 matrix

传送门

无比毒瘤的dp题,而且伪装地好像很可做的样子
考场上我给它氪了差不多一个小时最后还是只能扔了个20pts状压走人

以下思路基本均来源于题解:
对于此题,题面中三个限制条件:
(1)第 i 行第 1~li 列恰好有 1 个 1。 (li+1到ri-1不能放1)
(2)第 i 行第 ri~m 列恰好有 1 个 1。
(3)每列至多有 1 个 1。
根本注意不到注意到条件(3)相对比较好转移,所以此题考虑dp列而不是dp行

那么考虑从中选择一列该如何转移
显然每列只有放1或不放1两种情况,但条件(1)(2)在干扰我们转移
所以尝试消除干扰: 仅考虑单一区间(以右区间为例)

显然只有左端点到达这一列的区间需要被考虑
那么令\(dp[i][j]\)为从左到右第i列,所有跨越第i列的右区间中有j个已放过1
对于不放1的情况,\(dp[i][j] += dp[i-1][j]\)
对于放1,先定义数组\(cntl[i]\)为第i列及其左侧左区间的右端点的个数,\(cntr[i]\)同理
则第i列可以放1的左端点有\(cntr[i]-(j-1)\)个,所以\(dp[i][j] += dp[i-1][j-1]\times (cntr[i]-(j-1))\)

上面仅考虑了跨越第i列的右区间的方案数,那么下面处理左区间
首先定义\(f[i][j]\)时j指的是「所有跨越第i列的右区间中有j个已放过1」
那就不必考虑第i列的1应该给左区间还是右区间了
对于这种类似左右两边对抗的方案数dp
(就是说类似只能 在某条线以左/以右/以此线为分界左右同时 进行某种操作)
(为什么我一想到这就想起alpha-beta对抗搜索啊,好像将答案区间划分的思路差不多?)
考虑拆分区间,拆分出对左侧产生影响的右区间
则其对左区间影响已知,可计算出左区间

对于此题,i列左侧总共能放\(i-(j-1)\)个1,需要放\(cntl[i]-cntl[i-1]\)
这个\(cntl[i]-cntl[i-1]\)其实是有多少个左区间在i位置结束,即「新增加的必须放1的区间个数」
是在满足i列以前必须放1的区间都满足的条件下转移
\(cntl[i]-cntl[i-1]\)个1放到\(i-(j-1)\)个行中,考虑不同方案,应该是排列数
所以\(dp[i][j] *= A^{cntl[i]-cntl[i-1]}_{i-(j-1)}\)
就可以转移了

一大坑点: 这类n, m混杂的题一定要分清n, m!我预处理阶乘逆元的时候习惯性打了i<=n直接调了一晚上最后还是战神帮忙指出来的谢谢战神小可爱啦大雾逃

Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 3010
#define ll long long 
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long 

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int l[N], r[N], lcnt[N], rcnt[N];
ll dp[N][N], fac[N], inv[N];
const ll mod=998244353;

inline ll A(ll n, ll k) {return fac[n]*inv[n-k]%mod;}
inline ll md(ll a) {return a>=mod?a-mod:a;}

signed main()
{
	#ifdef DEBUG
	freopen("1.in", "r", stdin);
	#endif
	
	n=read(); m=read();
	if ((n<<1)>m) {puts("0"); return 0;}
	for (int i=1; i<=n; ++i) l[i]=read(), r[i]=read();
	sort(l+1, l+n+1); sort(r+1, r+n+1);
	fac[0]=fac[1]=1; inv[0]=inv[1]=1;
	for (int i=2; i<=m; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<=m; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<=m; ++i) inv[i]=inv[i-1]*inv[i]%mod;
	for (int i=1,p=1; i<=m; ++i) {rcnt[i]=rcnt[i-1]; while (p<=n&&r[p]==i) ++rcnt[i],++p;}
	for (int i=1,p=1; i<=m; ++i) {lcnt[i]=lcnt[i-1]; while (p<=n&&l[p]==i) ++lcnt[i],++p;}
	dp[0][0]=1;
	for (int i=1; i<=m; ++i) {
		dp[i][0] = dp[i-1][0]*A(i-lcnt[i-1], lcnt[i]-lcnt[i-1])%mod;
		for (int j=1; j<=min(i, n); ++j)
			dp[i][j] = md(dp[i-1][j]+dp[i-1][j-1]*max(rcnt[i]-(j-1), 0)%mod)*A(i-j-lcnt[i-1], lcnt[i]-lcnt[i-1])%mod; //, cout<<"A: "<<dp[i][j]<<endl;
	}
	printf("%lld\n", dp[m][n]);

	return 0;
}
posted @ 2021-06-08 06:41  Administrator-09  阅读(31)  评论(0编辑  收藏  举报