Loading

P7444 「EZEC-7」猜排列 (插入型 dp)

P7444 「EZEC-7」猜排列

dp

考虑 dp。从小到大插入数字,从小到大满足限制。假如现在想知道是否满足 \(f(l,r)=c_i\),发现我们只关心 包含 \(0\sim i-1\) 的最小区间的左右端点位置,于是可以设 \(f_{i,l,r}\) 表示填完了 \(i-1\) 个数,最小区间的左右端点为 \(l\)\(r\),考虑第 \(i\) 个数的位置的方案数。

转移可以分成:\(c_i=0\),那么第 \(i\) 个数的位置在 \([l,r]\) 中,转移易得;\(c_i\ne 0\),那么第 \(i\) 个数有可能在左侧和右侧,以左侧为例,假设位置为 \(L\),那么满足限制的区间数量为 \(c_i=(l-L)\times(n-r+1)\),所以 \(r\) 位置确定,\(L\) 位置也确定。

空间复杂度为 \(O(n^3)\),第一维可以滚动,又发现 \(l\) 位置确定,因为 \(\sum_{j\ge i}c_i\) 确定,所以 \(r\) 位置确定,复杂度降到 \(O(n)\)

关于分析时间复杂度,可以用柯西不等式证明大约为 \(O(n\sqrt n)\)。实现方面就是转移后放入队列等待下一次更新。

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 5e5 + 10, mod = 998244353;
i64 n, ans;
i64 dp[2][N], c[N], pos[2][N], vis[N];
std::vector<int> st, pre, nxt;
i64 calc(int x) {return 1LL * x * (x + 1) / 2;}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	std::cin >> n;
	for(int i = 0; i < n; i++) std::cin >> c[i];

	for(int i = 1; i <= n; i++) {
		if(calc(i - 1) + calc(n - i) == c[0]) st.pb(i);   
	}
	if(!st.size()) {
		std::cout << "0\n";
		return 0;
	}

	dp[0][st[0]] = 1;
	pos[0][st[0]] = st[0];
	pre.pb(st[0]);
	for(int i = 1; i < n - 1; i++) {
		nxt.clear();
		int lst = (i - 1) & 1;
		for(auto j : pre) {
			int l = j, r = pos[lst][l];
			if(vis[j]) continue;
			vis[j] = true;
			if(!c[i]) {
				dp[i & 1][l] = (dp[i & 1][l] + dp[lst][l] * (r - l + 1 - i) % mod) % mod;
				pos[i & 1][l] = r;
				nxt.pb(l);
			}
			else {
				int lef = l, rig = n - r + 1;
				if(c[i] % rig == 0) {
					int L = l - c[i] / rig;
					dp[i & 1][L] = (dp[i & 1][L] + dp[lst][l]) % mod;
					pos[i & 1][L] = r;
					nxt.pb(L);
				}	
				if(c[i] % lef == 0) {
					int R = r + c[i] / lef;
					dp[i & 1][l] = (dp[i & 1][l] + dp[lst][l]) % mod;
					pos[i & 1][l] = R;
					nxt.pb(l); 
				}
			}
		}
		for(auto j : pre) vis[j] = false, dp[lst][j] = 0;
		nxt.swap(pre);
	}
	for(int i = 1; i <= n; i++) ans = (ans + dp[(n - 2) & 1][i]) % mod;

	std::cout << 1LL * ans * st.size() % mod << "\n";
	return 0;
}
posted @ 2024-06-28 17:11  Fire_Raku  阅读(9)  评论(0编辑  收藏  举报