高维前缀和总结(sosdp)

前言

今天中午不知怎么的对这个东西产生了兴趣,感觉很神奇,结果花了一个中午多的时间来看QAQ
下面说下自己的理解。

高维前缀和一般解决这类问题:

对于所有的\(i,0\leq i\leq 2^n-1\),求解\(\sum_{j\subset i}a_j\)

显然,这类问题可以直接枚举子集求解,但复杂度为\(O(3^n)\)。如果我们施展高维前缀和的话,复杂度可以到\(O(n\cdot 2^n)\)

说起来很高级,其实代码就三行:

for(int j = 0; j < n; j++) 
    for(int i = 0; i < 1 << n; i++)
        if(i >> j & 1) f[i] += f[i ^ (1 << j)];

相信大家一开始学的时候就感觉很神奇,这是什么东西,这跟前缀和有什么关系?
好吧,其实看到后面就知道了。

正文

二维前缀和

一维前缀和就不说了,一般我们求二维前缀和时是直接容斥来求的:

\[sum_{i,j}=sum_{i-1,j}+sum_{i,j-1}-sum_{i-1,j-1}+a_{i,j} \]

但还有一种求法,就是一维一维来求,也可以得到二维前缀和:

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        a[i][j] += a[i - 1][j];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        a[i][j] += a[i][j - 1];

模拟一下就很清晰了。

三维前缀和

同二位前缀和,我们也可以对每一维分别来求:

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++) 
            a[i][j][k] += a[i - 1][j][k];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++)
            a[i][j][k] += a[i][j - 1][k];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++)
            a[i][j][k] += a[i][j][k - 1];

高维前缀和

接下来就步入正题啦。
求解高维前缀和的核心思想也就是一维一维来处理,可以类比二维前缀和的求法稍微模拟一下。
具体来说代码中的\(f[i] = f[i] + f[i\ xor\ (1 << j)]\),因为我们是正序枚举,所以\(i\ xor\ (1 << j)\)在当前层,而\(i\)还在上层,所以我们将两个合并一下就能求出当前层的前缀和了QAQ。
然后...就完了,好像没什么好说的。

应用

  • 子集

那这跟子集有啥关系?在二进制表示中,发现当\(i\subset j\)时,其实这存在一个偏序关系,对于每一位都是这样。而我们求出的前缀和就是满足这个偏序关系的。
回到开始那个问题,初始化\(f[i]=a_i\),直接求高维前缀和,那么最终得到的\(f\)就是答案数组了。

  • 超集

理解了子集过后,我们将二进制中的每一个\(1\)当作\(0\)对待,\(0\)当作\(1\)对待求出来的就是超集了~相当于从另一个角出发来求前缀和。
求超集代码如下:

for(int j = 0; j < n; j++) 
    for(int i = 0; i < 1 << n; i++)
        if(!(i >> j & 1)) f[i] += f[i ^ (1 << j)];

似乎\(FMT\)(快速莫比乌斯变换)就是借助高维前缀和这个东西来实现的。
虽然只有三行代码,但很神奇QAQ

upd:
这个东西其实和\(sos\ dp\)是一个东西,但感觉用\(dp\)的思想去理解要稍微好一些,就再说一下\(dp\)的想法。

  • 子集

我们还是来求子集,定义\(dp_{i,mask}\)为处理了状态为\(mask\),二进制最后\(i\)位的子集信息时的和。
那么我们枚举\(i+1\)位时,若当前\(mask\)这一位为\(1\),那么就从\(dp_{i,mask},dp_{i,mask-(1<<i)}\)转移过来,分别代表有当前这一位时的子集或者没这一位时的子集,合并一下即可;若当前这位不为\(1\),就从\(dp_{i,mask}\)转移过来。
最后在代码中我们一般习惯滚动掉一维。

  • 超集

类似地,定义\(dp_{i,mask}\)为当前状态为\(mask\),处理了后\(i\)位的超集信息时的和。
然后枚举第\(i+1\)位,若当前这一位为\(0\),就从\(dp_{i,mask},dp_{i,mask+(1<<(i+1))}\)转移;若当前这位为\(1\),就直接从\(dp_{i,mask}\)转移过来。

例题

arc 100E

题意:
给出\(2^n\)个数:\(a_0,a_1,\cdots,a_{2^n-1}\)
之后对于\(1\leq k\leq 2^n-1\),求出:\(a_i+a_j\)的最大值,同时\(i\ or\ j\leq k\)

思路:
挺奇妙的一个题,需要将问题转换。

  • 发现我们可以对每个\(k\),求出最大的\(a_i+a_j\)并且满足\(i\ or\ j=k\),最后答案就为一个前缀最大值。
  • 但这种形式也不好处理,我们可以将问题进一步转化为\(i\ or\ j\subset k\)。那么我们就将问题转化为了子集问题。
  • 所以接下来就对于每个\(k\),求出其所有子集的最大值和次大值就行了。
  • 直接枚举子集复杂度显然不能忍受,其实直接上高位前缀和搞一下就行~

注意一下细节,集合中一开始有一个数。
代码如下:

#include <bits/stdc++.h>
#define MP make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define INF 0x3f3f3f3f3f
//#define Local
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int N = 20;

int n;
pii a[1 << N];

pii merge(pii A, pii B) {
    if(A.fi < B.fi) swap(A, B);
    pii ans = A;
    if(B.fi > ans.se) ans.se = B.fi;
    return ans;
}

void run() {
    for(int i = 0; i < 1 << n; i++) {
        int x; cin >> x;
        a[i] = MP(x, -INF);
    }
    for(int j = 0; j < n; j++) {
        for(int i = 0; i < 1 << n; i++) {
            if(i >> j & 1) a[i] = merge(a[i], a[i ^ (1 << j)]);
        }
    }
    int ans = 0;
    for(int i = 1; i < 1 << n; i++) {
        ans = max(ans, a[i].fi + a[i].se);
        cout << ans << '\n';
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
#ifdef Local
    freopen("../input.in", "r", stdin);
    freopen("../output.out", "w", stdout);
#endif
    while(cin >> n) run();
    return 0;
}

cf1208F
题意:
给出序列\(a_{1,2\cdots,n},n\leq 10^6\)
现在要找最大的\(a_i|(a_j\& a_k)\),其中\((i,j,k)\)满足\(i<j<k\)

思路:

  • 显然我们可以枚举\(a_i\),那么问题就转换为如何快速找\(a_j\& a_k\)
  • 因为最后要使得结果最大,我们二进制从高到底枚举时肯定是贪心来考虑的:即如果有两个数他们的与在这一位为\(1\),那么最后的答案中一定有这一位。
  • 那么我们逐位考虑,并且考虑是否有两个在右边的数他们"与"的结果为当前答案的超集即可,有的话答案直接加上这一位。
  • 那么可以用\(sos\ dp\)处理超集的信息,维护在最右端的两个位置,之后贪心来处理即可。

代码如下:

/*
 * Author:  heyuhhh
 * Created Time:  2020/2/27 10:51:39
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#define MP make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << '\n'; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
  #define dbg(...)
#endif
void pt() {std::cout << '\n'; }
template<typename T, typename...Args>
void pt(T a, Args...args) {std::cout << a << ' '; pt(args...); }
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 2e6 + 5;
 
int n;
int a[N];
pii dp[N];
 
void add(int x, int id) {
    if(dp[x].fi == -1) {
        dp[x].fi = id;   
    } else if(dp[x].se == -1) {
        if(dp[x].fi == id) return;
        dp[x].se = id;   
        if(dp[x].fi < dp[x].se) swap(dp[x].fi, dp[x].se);
    } else if(dp[x].fi < id) {
        dp[x].se = dp[x].fi;
        dp[x].fi = id;   
    } else if(dp[x].se < id) {
        if(dp[x].fi == id) return;
        dp[x].se = id;
    }
}
 
void merge(int x1, int x2) {
    add(x1, dp[x2].fi);
    add(x1, dp[x2].se);
}
 
void run() {
    memset(dp, -1, sizeof(dp));
    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
        add(a[i], i);
    }
    for(int i = 0; i < 21; i++) {
        for(int j = 0; j < N; j++) if(j >> i & 1) {
            merge(j ^ (1 << i), j);
        }
    }
    int ans = 0;
    for(int i = 1; i <= n - 2; i++) {
        int lim = (1 << 21) - 1;
        int cur = a[i] ^ lim, res = 0;
        for(int j = 20; j >= 0; j--) if(cur >> j & 1) {
            if(dp[res ^ (1 << j)].se > i) {
                res ^= (1 << j);   
            }
        }
        ans = max(ans, res | a[i]);
    }
    cout << ans << '\n';
}
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    run();
    return 0;
}
posted @ 2019-09-25 16:06  heyuhhh  阅读(12214)  评论(5编辑  收藏  举报