【atc abc159F】F - Knapsack for All Segments(dp优化)

传送门

题意:
给定序列\(a_1,a_2,...,a_n\)\(s\),定义\(f(L,R):\)

  • \((x_1,x_2,...,x_k)\)的对数且满足\(L\leq x_1<x_2<...<x_k\leq R,a_{x_1}+a_{x_2}+\cdots+a_{x_k}=s\)

现在要求\(\sum f(L,R),L\leq R\)
\(n,s,a_i\leq 3000\)

思路:

  • 我们很容易想到\(O(n^3)\)\(dp:dp[l,r,x]\)表示序列左右端点为\(l,r\),序列和为\(x\)的方案数。因为转移我们可以直接\(O(1)\)进行转移,所以时间复杂度为\(O(n^3)\)。但是显然时间空间都不能承受。
  • 观察性质:我们最直接的想法肯定是固定左右端点,假设为\(l,r\),那么方案数为\(l\cdot(n-r+1)\),现在假设固定右端点,左端点在进行改变,那么最终\(ans_r=\sum l_i\cdot(n-r+1)\)
  • 现在考虑给\(dp\)将维,我们考虑去掉一个左端点(貌似两个端点一起太浪费),那么我们考虑不直接记录方案数,而是记录左端点的和,之后可以直接通过这个和来计算方案数。
  • 所以优化过后的\(dp\)为:\(dp[r,x]\)表示固定右端点为\(r\),序列和为\(x\),左端点的和,那么这个转移为\(\displaystyle dp[r,x]=\sum_{k<r}dp[k,x-a_i]\)。这个\(dp\)看似也为\(O(n^3)\),但其实第一维是一个标准的前缀和形式,所以我们可以优化掉那一层枚举。
  • 总的时间复杂度为\(O(n^2)\)

这是一个挺有意思的\(dp\),有些时候\(dp\)不一定直接记录答案,我们可以记录一些可以直接计算出答案的量,这样可能在某些时候能够优化时间/空间或者方便转移。我记得之前cf有一道题也是这样。
代码如下:

/*
 * Author:  heyuhhh
 * Created Time:  2020/5/16 16:47:05
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#include <assert.h>
#define MP make_pair
#define fi first
#define se second
#define pb push_back
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << std::endl; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
  template <template<typename...> class T, typename t, typename... A> 
  void err(const T <t> &arg, const A&... args) {
  for (auto &v : arg) std::cout << v << ' '; err(args...); }
#else
  #define dbg(...)
#endif
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 3000 + 5, MOD = 998244353;

int n, s;
int a[N];
int dp[N][N], sum[N];

void run() {
    cin >> n >> s;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i <= n; i++) {
        dp[i][a[i]] = i;
        for (int j = a[i]; j <= s; j++) {
            dp[i][j] += sum[j - a[i]];
            if (dp[i][j] >= MOD) dp[i][j] -= MOD;
        }
        for (int j = 0; j <= s; j++) {
            sum[j] += dp[i][j];
            if (sum[j] >= MOD) sum[j] -= MOD;
        }
    }
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        ans += 1ll * dp[i][s] * (n - i + 1) % MOD;
        ans %= MOD;
    }
    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    run();
    return 0;
}

posted @ 2020-05-18 17:35  heyuhhh  阅读(259)  评论(0编辑  收藏  举报