动态规划树状数组优化(M元上升子序列)

我们先来看简单版:P1637 三元上升子序列

这道题显然考虑dp,转移式子也很好写

设f[i][j]表示以a[j]结尾长度为i的上升子序列个数。显然答案就是 \(\sum\limits_{k=1}^{n} f_{3,k}\)

\[f_{i,j}=\sum\limits_{k=i}^{j-1}[a_k<a_j]f_{i-1,k} \]

代码:

#include <bits/stdc++.h>
#define rei register int
#define LL long long
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#define cvar int n, m, T;
#define rep(i, s, n, c) for (register int i = s; i <= n; i+=c)
#define repd(i, s, n, c) for (register int i = s; i >= n; i-=c)
#define CHECK cout<<"WALKED"<<endl;
inline int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();return x*f;}
#define pb push_back
#define ls id<<1
#define rs id<<1|1
const int INF = INT_MAX;
long long binpow(long long a, long long b, LL mod){long long res = 1;  while (b > 0){if (b & 1) res = res * a % mod;a = a * a % mod;  b >>= 1;  }  return res;}
using namespace std;

const int maxn = 30010;
int n;
LL a[maxn];
LL f[4][maxn]; // 长度i,a[j]为结尾的个数

int main()
{
	n = read();
	rep(i, 1, n, 1) {
		a[i] = read();
		f[1][i] = 1;
	}
	rep (i, 1, 3, 1) {
		rep (j, 1, n, 1) {
			rep(k, 1, j - 1, 1) {
				if (a[k] < a[j])
					f[i][j] += f[i - 1][k];
			}
		}
	}
	LL ans = 0;
	rep (i, 1, n, 1) ans += f[3][i];
	cout << ans << endl;
	return 0;
}

别急,会T,考虑优化。

发现可以优化的地方是每次的求和。但是单单是一个前缀和恐怕不行,因为还要满足条件 \(a_k<a_j\),所以考虑将数组离散化一下,这样数组 a 中存储的就是排名,然后以 \(a_i\) 为树状数组下标存储,这样就完美的解决了大小问题。

#include <bits/stdc++.h>
#define rei register int
#define LL long long
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#define cvar int n, m, T;
#define rep(i, s, n, c) for (register int i = s; i <= n; i+=c)
#define repd(i, s, n, c) for (register int i = s; i >= n; i-=c)
#define CHECK cout<<"WALKED"<<endl;
inline int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();return x*f;}
#define pb push_back
#define ls id<<1
#define rs id<<1|1
const int INF = INT_MAX;
long long binpow(long long a, long long b, LL mod){long long res = 1;  while (b > 0){if (b & 1) res = res * a % mod;a = a * a % mod;  b >>= 1;  }  return res;}
using namespace std;

const int maxn = 30010;
int n, m;
//  树状数组
LL c[maxn]; // 2?
inline int lowbit(LL x) {
	return x & -x;
}
inline LL query(int x) {
	LL sum = 0;
	while (x)
	{
		sum += c[x];
		x -= lowbit(x);
	}
	return sum;
}
inline void add(int x, LL val) {
	while (x <= n) { // n ? m ?
		c[x] += val;
		x+= lowbit(x);
	}
}
//


LL a[maxn];
LL f[4][maxn]; // 长度i,a[j]为结尾的个数

// 离散化
LL s[maxn];
void disp()
{
	stable_sort(s + 1, s + n + 1);
	m = unique(s + 1, s + n + 1) - s - 1;
	rep (i, 1, n, 1)
		a[i] = lower_bound(s + 1, s + m + 1, a[i]) - s;
	
}
//

int main()
{
	n = read();
	rep(i, 1, n, 1) {
		a[i] = read();
		s[i] = a[i];
		f[1][i] = 1;
	}
	
	disp();
	
	rep (i, 2, 3, 1) {
		memset(c, 0, sizeof(c));
		rep (j, 1, n, 1) {
			f[i][j] = query(a[j] - 1);
			add(a[j], f[i - 1][j]);
		}
	}
	LL ans = 0;
	rep (i, 1, n, 1) ans += f[3][i];
	cout << ans << endl;
	return 0;
}

这道题就过了。

那么显然,这道题可以扩展到 M 元上升子序列

UVA12983 The Battle of Chibi

#include <bits/stdc++.h>
#define rei register int
#define LL long long
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#define cvar int n, m, T;
#define rep(i, s, n, c) for (register int i = s; i <= n; i+=c)
#define repd(i, s, n, c) for (register int i = s; i >= n; i-=c)
#define CHECK cout<<"WALKED"<<endl;
inline int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();return x*f;}
#define pb push_back
#define ls id<<1
#define rs id<<1|1
const int INF = INT_MAX;
long long binpow(long long a, long long b, LL mod){long long res = 1;  while (b > 0){if (b & 1) res = res * a % mod;a = a * a % mod;  b >>= 1;  }  return res;}
using namespace std;

const int maxn = 1010;
const LL mod = 1e9 + 7;
int n, m;
//  树状数组
LL c[maxn]; // 2?
inline int lowbit(LL x) {
	return x & -x;
}
inline LL query(int x) {
	LL sum = 0;
	while (x)
	{
		sum = (sum + c[x]) % mod;
		x -= lowbit(x);
	}
	return sum;
}
inline void add(int x, LL val) {
	while (x <= m) { // n ? m ?
		c[x] = (c[x] + val) % mod;
		x+= lowbit(x);
	}
}
//


LL a[maxn];
LL f[maxn][maxn]; // 长度i,a[j]为结尾的个数

// 离散化
LL s[maxn];
void disp()
{
	stable_sort(s + 1, s + n + 1);
	m = unique(s + 1, s + n + 1) - s - 1;
	rep (i, 1, n, 1)
		a[i] = lower_bound(s + 1, s + m + 1, a[i]) - s;
	
}
//

int main()
{
	int T = read(), cnt = 1;
	int kkk;
	while (T--)
	{
		n = read(); kkk = read();
		rep(i, 1, n, 1) {
			a[i] = read();
			s[i] = a[i];
			f[1][i] = 1;
		}
		
		disp();
		
		rep (i, 2, kkk, 1) {
			memset(c, 0, sizeof(c));
			rep (j, 1, n, 1) {
				f[i][j] = query(a[j] - 1);
				add(a[j], f[i - 1][j]);
			}
		}
		LL ans = 0;
		rep (i, 1, n, 1) ans = (ans + f[kkk][i]) % mod;
		printf("Case #%d: %lld\n", cnt, ans);
		cnt++;
	}
	return 0;
}
posted @ 2022-10-29 20:13  Vegdie  阅读(41)  评论(0编辑  收藏  举报