线段树优化 DP

线段树优化 DP

UPDATE:

upd on 2024.8.22:因为又更新了一题,所以文章标题就改力!

upd on 2024.8.27:终于把一直想补的 NOIP 2023 T4 补完力!


CF833B The Bakery

题目大意:

将一个长度为 n 的序列分为 m 段,使得总价值最大。

一段区间的价值表示为区间内不同数字的个数。

n35000,m50

(虽然原题是 k,但是我代码中写的是 m,所以就改成 m 了)


首先看到划分区间,算总价值最大的题,可以先考虑朴素的区间 DP。

dp[i][j] 表示前 j 个数划分为 i 段的最大总价值。val(l,r) 表示区间 [l,r] 的价值,即其中有多少个不同的数。

可得转移方程:

dp[i][j]=maxi1kj1{dp[i1][k]+val(k+1,j)}

i1kj1 的原因是最少 i1 个数字,最多 j1 个数字可划分为 i1 段。

我们发现对于每次 dp[i][j] 转移,只需用到 dp[i1][j],所以 dp 数组可以滚动掉第一维,但是没用。

时间复杂度为 O(n3k),显然超时。

考虑优化:

优化 1:发现每次计算 val(l,r) 需要 O(n) 的时间复杂度,其中每个数都会经过很多次重复计算。所以我们反过来考虑,对于每个数,它对于 val(l,r) 的贡献在哪里。

记这个数 a[i] 前面的第一个与它相等的数的位置为 pre[a[i]]
那么这个 a[i] 对于区间 [pre[a[i]]+1,i] 均有 1 的贡献。

优化 2:发现转移方程里有 max(),并且没有其它量(其实有已经确定的量也行),所以考虑用线段树来优化。

因为每次由第 i1 层转移到第 i 层,所以我们顺序 DP,先用上一次的 i1 的 DP 值建树。

然后每扫过一个数 a[j],根据上文,它会影响到 k[pre[a[j]]+1,j]的 DP 值,所以就在线段树上的区间 [pre[a[j]],j1] 全部加上 1
更新时的左右端点各减了 1 的原因是转移方程的 valk 加了 1。

每次更新,就在线段树上 [i1,j1] 的区间找 max 即可。

总时间复杂度为 O(nklogn)


朴素 DP 代码:

Code
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 105;

int a[N];
int dp[N][N]; // dp[i][j] 表示前 j 个数划分为 i 个连续区间的最大总价值
bool tong[N];

int calc(int l, int r){
    memset(tong, 0, sizeof(tong));
    int res = 0;
    for(int i=l; i<=r; i++){
        if(tong[a[i]]) continue;
        tong[a[i]] = 1;
        res++;
    }
    return res;
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, t; cin>>n>>t;
    for(int i=1; i<=n; i++)
        cin>>a[i];
    for(int i=1; i<=t; i++){
        for(int j=1; j<=n; j++){
            for(int k=i-1; k<=j-1; k++){
                dp[i][j] = max(dp[i][j], dp[i-1][k]+calc(k+1, j));
            }
        }
    }
    cout<<dp[t][n];
    return 0;
}

正解代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 35005;

int a[N], pre[N], p[N];
int dp[55][N]; // dp[i][j] 表示前 j 个数划分为 i 个连续区间的最大总价值
struct node{
    int l, r;
    int maxn, tag;
    #define ls (x<<1)
    #define rs (x<<1|1)
}tr[N<<2];

void pushup(int x){
    tr[x].maxn = max(tr[ls].maxn, tr[rs].maxn);
}

void pushdown(int x){
    if(!tr[x].tag) return;
    tr[ls].maxn += tr[x].tag;
    tr[rs].maxn += tr[x].tag;
    tr[ls].tag += tr[x].tag;
    tr[rs].tag += tr[x].tag;
    tr[x].tag = 0;
}

void build(int x, int l, int r, int k){
    tr[x].l = l, tr[x].r = r, tr[x].maxn = 0; tr[x].tag = 0;
    if(l == r){
        tr[x].maxn = dp[k][l];
        return;
    }
    int mid = (l+r)>>1;
    build(ls, l, mid, k);
    build(rs, mid+1, r, k);
    pushup(x);
}

void update(int x, int l, int r, int v){
    if(tr[x].l>=l && tr[x].r<=r){
        tr[x].maxn += v;
        tr[x].tag += v;
        return;
    }
    int mid = (tr[x].l+tr[x].r)>>1;
    pushdown(x);
    if(l<=mid) update(ls, l, r, v);
    if(r>mid) update(rs, l, r, v);
    pushup(x);
}

int query(int x, int l, int r){
    if(tr[x].l>=l && tr[x].r<=r){
        return tr[x].maxn;
    }
    int mid = (tr[x].l+tr[x].r)>>1, ret = 0;
    pushdown(x);
    if(l<=mid) ret = max(ret, query(ls, l, r));
    if(r>mid) ret = max(ret, query(rs, l, r));
    return ret;
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, m; cin>>n>>m;
    for(int i=1; i<=n; i++){
        cin>>a[i];
        pre[i] = p[a[i]];
        p[a[i]] = i;
    }
    for(int i=1; i<=m; i++){
        build(1, 0, n-1, i-1); // 求的是由前 [i-1, j-1] -> [0, n-1] 个数字已经划分为 i-1 个连续区间
        for(int j=1; j<=n; j++){
            update(1, pre[j], j-1, 1);
            dp[i][j] = query(1, i-1, j-1);
        }
    }
    cout<<dp[m][n];
    return 0;
}

CF474E Pillars

题目大意:

给出一个长度为 n 的序列 a 和一个参数 d,要求求出 a 中一个最长的子序列 b,满足 b 中任意相邻元素的差大于等于 d


题目形似最长上升子序列,我们可以先考虑一个朴素 DP:

dp[i] 表示以 a 中第 i 个数(下标为 i)结尾的最长子序列长度。

可得转移方程:

dp[i]=maxj<i|aiaj|d{dp[j]}+1

时间复杂度为 O(n2)

考虑如何用线段树优化:

发现能转移过来的值域其实只有两个区间(比 a[i] 小至少 d,比 a[i] 多至少 d)。所以只要先找到下标最大的 l,满足 a[l]a[i]d,再找到下标最小的 r,满足 a[r]a[i]+d,求 max(max(1,l),max(r,n)) 即可。可以用权值线段树维护,每次转移完 i 就把 dp[i] 放到线段树的 a[i] 位置即可。最后 DP 更新的时候记录从哪个元素转移过来的即可。

a[i] 的值域很大,需要离散化或者动态开点。

有个小细节:需要处理掉不存在 lr 的情况。


朴素 DP 代码:

Code
#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
const int N = 1e5+5;

ll a[N];

ll pre[N], dp[N]; // dp[i] 表示以第 i 个数结尾的最长子序列长度
ll ans, pos;

void print(int x){
    if(pre[x]) print(pre[x]);
    cout<<x<<" ";
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, d; cin>>n>>d;
    for(int i=1; i<=n; i++) 
        cin>>a[i];
    dp[1] = 1;
    for(int i=2; i<=n; i++){
        for(int j=1; j<i; j++){
            if(abs(a[i]-a[j])>=d){
                if(dp[j]+1 > dp[i]){
                    dp[i] = dp[j]+1;
                    pre[i] = j;
                }
            }
        }
        if(dp[i] > ans){
            ans = dp[i];
            pos = i;
        }
    }
    cout<<ans<<"\n";
    print(pos);
    return 0;
}

正解代码:

#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
#define pii pair<int, int>
#define val first
#define id second
#define ls x<<1
#define rs x<<1|1

const int N = 1e5+5;

ll a[N], b[N];
ll pre[N], dp[N]; // dp[i] 表示以第 i 个数结尾的最长子序列长度
pii mx[N<<2];
ll ans, pos;

void update(int x, int l, int r, int k, pii v){
    if(l == r){
        mx[x] = max(mx[x], v);
        return;
    }
    int mid = (l+r)>>1;
    if(k<=mid) update(ls, l, mid, k, v);
    else update(rs, mid+1, r, k, v);
    mx[x] = max(mx[ls], mx[rs]);
}

pii query(int x, int l, int r, int ql, int qr){
    if(ql > qr) return {-1, 0}; 
    if(l>=ql && r<=qr) return mx[x];
    int mid = (l+r)>>1;
    pii ret = {-1, 0};
    if(ql<=mid) ret = max(ret, query(ls, l, mid, ql, qr));
    if(qr>mid) ret = max(ret, query(rs, mid+1, r, ql, qr));
    return ret;
}

void print(int x){
    if(pre[x]) print(pre[x]);
    cout<<x<<" ";
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, d; cin>>n>>d;
    for(int i=1; i<=n; i++){
        cin>>a[i];
        b[i] = a[i];
    }
    sort(b+1, b+1+n);
    int M = unique(b+1, b+1+n) - (b+1);
    for(int i=1; i<=n; i++){
        int x = lower_bound(b+1, b+1+M, a[i]) - b;
        int l = upper_bound(b+1, b+1+M, a[i]-d) - (b+1); // 1<=j<=l
        int r = lower_bound(b+1, b+1+M, a[i]+d) - b; // r<=j<=n
        pii res = max({{0, 0}, query(1, 1, M, 1, l), query(1, 1, M, r, M)});
        dp[i] = res.val+1, pre[i] = res.id;
        update(1, 1, M, x, {dp[i], i});
        if(dp[i] > ans){
            ans = dp[i];
            pos = i;
        }
    }
    cout<<ans<<"\n";
    print(pos);
    return 0;
}

CF597C Subsequences

题目大意:

给定一个 1n 的排列 a,求 a 中长度为 m+1 的上升子序列个数。


先让 mm+1,以下内容中的 m 均为更新后的值。

依旧先考虑朴素 DP:

dp[i][j] 表示以 i 结尾,子序列长度恰好为 j 的满足条件的子序列个数。

base case:dp[i][1]=1

j=2 开始枚举,可得转移方程:

dp[i][j]=k=1i1dp[k][j1](a[k]<a[i])

时间复杂度为 O(n2m)

考虑线段树(本题也可以使用树状数组)优化:

观察到瓶颈在于每个 i 都要枚举一次 k<i,可以考虑对于每个 j 构建一颗值域线段树,从前往后枚举 i 时,在 a[i] 的位置上加上上一次的 dp 值,最后加上线段树中 [1,a[i]1] 的和即可。

最后统计答案:

ans=i=mndp[i][m]

时间复杂度为 O(nmlogn)


朴素 DP 代码:

Code
#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long

ll dp[100005][12];
int a[100005];

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, m; cin>>n>>m;
    m++;
    for(int i=1; i<=n; i++)
        cin>>a[i];
    for(int i=1; i<=n; i++)
        dp[i][1] = 1;
    for(int j=2; j<=m; j++)
        for(int i=1; i<=n; i++)
            for(int k=1; k<i; k++)
                if(a[k] < a[i]) dp[i][j] += dp[k][j-1];
    ll ans = 0;
    for(int i=m; i<=n; i++)
        ans += dp[i][m];
    cout<<ans;
    return 0;
}

正解代码:

#include<bits/stdc++.h>
using namespace std;
#define DEBUG(a) cout<<"Dline[ "<<__LINE__<<" ]: "<<(a)<<"\n";
#define ll long long
#define ls (x<<1)
#define rs (x<<1|1)
const int N = 100005;

ll dp[N][12];
int a[N];
ll trsum[N<<2];

void build(int x, int l, int r, int k){
    trsum[x] = 0;
    if(l == r) return;
    int mid = (l+r)>>1;
    build(ls, l, mid, k);
    build(rs, mid+1, r, k);
}

void update(int x, int l, int r, int q, ll v){
    if(l == r){
        trsum[x] += v;
        return; 
    }
    int mid = (l+r)>>1;
    if(q<=mid) update(ls, l, mid, q, v);
    else update(rs, mid+1, r, q, v);
    trsum[x] = trsum[ls] + trsum[rs];
}

ll query(int x, int l, int r, int ql, int qr){
    if(ql > qr) return 0;
    if(l>=ql && r<=qr){
        return trsum[x];
    }
    int mid = (l+r)>>1; ll res = 0;
    if(ql<=mid) res += query(ls, l, mid, ql, qr);
    if(qr>mid) res += query(rs, mid+1, r, ql, qr);
    return res;
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, m; cin>>n>>m;
    m++;
    for(int i=1; i<=n; i++)
        cin>>a[i];
    for(int i=1; i<=n; i++)
        dp[i][1] = 1;
    for(int j=2; j<=m; j++){
        build(1, 1, n, j-1);
        for(int i=1; i<=n; i++){
            dp[i][j] = query(1, 1, n, 1, a[i]-1);
            update(1, 1, n, a[i], dp[i][j-1]);
        }
    }
    ll ans = 0;
    for(int i=m; i<=n; i++)
        ans += dp[i][m];
    cout<<ans;
    return 0;
}

P9871 [NOIP2023] 天天爱打卡

题目太长自己看。

很早之前 (刚考完) 就想补这道题了,虽然考过了等于 24 年不考,但是在考场上连“板子”都没写出来的我就是菜!(话说好像没有逻辑关系)


先考虑不同于正解的朴素 DP:(正解可以从下一个分割线开始看)

可以设 dp[i][j] 表示到第 i 天连续打卡了 j 天的最大值,可得转移方程:

dp[i][0]=max(dp[i][0],dp[i1][j])(0jmin(i,k))

dp[i][j]=dp[i1][j1]+s[i][j]d(1jmin(i,k))

其中 s[i][j] 表示到第 i 天连续打卡了 j 天可获得的任务奖励之和,可用前缀和优化。

时间复杂度为 O(nk),期望得分 36pts。

显然状态很多,最后内存肯定会超,于是考虑换一种状态定义。


基于正解的朴素 DP:

dp[i] 表示第 i打卡的能获得的最大能量,显然从上一天不打卡的位置 j 转移而来,可得转移方程:

dp[i]=maxij1k{dp[j]+w(j+1,i1)(ij1)d}

其中 w(l,r) 表示所有被 [l,r] 完全包含的任务的奖励之和。

时间复杂度为 O(n2k),期望得分 16pts。好像也可以前缀和优化,但是我没优化出来。

(另)性质 B 很水,其实是 1i<m,ri+1<li+1


考虑优化:

首先因为 n109,就算通过数据结构优化了转移,状态存储的空间肯定会爆。观察转移得到的 dp[i] 发现其中会有一堆连续段,这启示我们 dp 的有效决策点其实并不多。状态只会从不打卡的地方转移,那么对于每个任务 li,ri(即 xiyi+1,xi),决策点(不打卡的点)只有 li1,ri+1。于是我们可以直接离散化掉所有的 li1,ri+1,这样状态数可以优化到 2m

离散化之后的转移肯定是自左端点到右端点转移的,并且随右端点从小到大。考虑将所有线段发配到右端点 i 上,在之前 dp 的转移点的基础上更新以 i 为右端点的每个任务的价值。换句话说,我们可以扫描线地更新每条线段的贡献。每条线段对于区间的贡献即为 [1,li1],表示以 [1,li1] 中的一点为左端点,以 ri+1 为右端点的转移中可获得任务区间为 [li,ri] 的奖励。

观察转移方程中的几个限制:

  • 对于 k 的限制,只要在转移前让 limlim+1 直至 lsh[i]lsh[lim]1k 即可。

  • 对于 d 的限制,问题在于最后在线段树上返回的区间查询最大值不会告诉我们 j 的位置在哪(或许可以记录下来,但这样内存会变大)。可以考虑参变分离。把含 d 的与 i 有关的式子移到左边,可以得到:dp[i]+d×i=maxdp[j]+d×j+w(j+1,i1)d于是我们在线段树上只要维护 dp[i]+d×i 即可。最后得到 dp[i] 的答案要用区间查询返回的最大值减去 d×(i+1)

时间复杂度为 O(mlogm)我最后写完没有被卡常嘿嘿!


不同于正解的朴素 DP 代码:

Code
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e3+5, inf=1e18;

int c, T, n, m, k, d, f[N][N], s[N][N];
int x, y, v, ans;

signed main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin>>c>>T;
    while(T--){
        int x, y, v;
		cin>>n>>m>>k>>d;
		memset(s, 0, sizeof(s));
		for(int i=1; i<=m; i++)
			cin>>x>>y>>v, s[x][y]+=v;
		for(int i=1; i<=n; i++)
			for(int j=1; j<=n; j++)
				s[i][j] += s[i][j-1];
		for(int i=0; i<=n; i++)
			for(int j=1; j<=n; j++)
				f[i][j] = -inf;
		for(int i=1; i<=n; i++){
			f[i][0] = f[i-1][0];
			for(int j=1; j<=min(i,k); j++){
				f[i][j] = f[i-1][j-1]+s[i][j]-d;
				f[i][0] = max(f[i][0], f[i-1][j]);
            }
		}
        ans = 0;
		for(int i=0; i<=min(n,k); i++)
			ans = max(ans,f[n][i]);
		cout<<ans<<"\n";
	}
    return 0;
}

基于正解的朴素 DP + 性质 B 代码:

Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5+5, INF = 1e18;

struct node{
    int l, r, v;
}ft[100005];

int n, m, k, d;
int dp[N];

int calc(int l, int r){
    int res = 0;
    if(l > r) return 0; 
    for(int i=1; i<=m; i++){
        if(l<=ft[i].l && ft[i].r<=r){
            res += ft[i].v;
        }
    }
    return res;
}

signed main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int c, T; cin>>c>>T;
    while(T--){
        cin>>n>>m>>k>>d;
        for(int i=1; i<=m; i++){
            int x, y, v; cin>>x>>y>>v;
            ft[i] = {x-y+1, x, v};
        }
        if(n<=1e4){
            for(int i=1; i<=n+1; i++)   
                dp[i] = -INF;
            for(int i=1; i<=n+1; i++){
                for(int j=max(0ll, i-k-1); j<i; j++){
                    dp[i] = max(dp[i], dp[j]+calc(j+1, i-1)-(i-j-1)*d);
                }
            }
            cout<<dp[n+1]<<"\n";
        }
        else{
            int ans = 0;
            for(int i=1; i<=m; i++){
                if(ft[i].r-ft[i].l+1>k) continue;
                else if((ft[i].r-ft[i].l+1)*d>ft[i].v) continue;
                else{
                    ans += ft[i].v-(ft[i].r-ft[i].l+1)*d;
                }
            }
            cout<<ans<<"\n";
        }
    }
    return 0;
}

正解代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 2e5+5;
#define ls (x<<1)
#define rs (x<<1|1)

int n, m, k, tot; ll d;
int a[N], b[N], w[N], lsh[N];
ll dp[N], trmax[N<<2], trtag[N<<2];

void pushdown(int x){
    if(!trtag[x]) return;
    trmax[ls] += trtag[x]; trtag[ls] += trtag[x];
    trmax[rs] += trtag[x]; trtag[rs] += trtag[x];
    trtag[x] = 0;
}

void pushup(int x){
    trmax[x] = max(trmax[ls], trmax[rs]);
}

void update(int x, int l, int r, int ql, int qr, ll v){
    if(ql<=l && r<=qr){
        trmax[x] += v; trtag[x] += v;
        return;
    }
    pushdown(x);
    int mid = (l+r)>>1;
    if(ql<=mid) update(ls, l, mid, ql, qr, v);
    if(qr>mid) update(rs, mid+1, r, ql, qr, v);
    pushup(x);
}

ll query(int x, int l, int r, int ql, int qr){
    if(ql<=l && r<=qr)
        return trmax[x];
    pushdown(x);
    ll ret = -1e18;
    int mid = (l+r)>>1;
    if(ql<=mid) ret = max(ret, query(ls, l, mid, ql, qr));
    if(qr>mid) ret = max(ret, query(rs, mid+1, r, ql, qr));
    return ret;
}

void clear(){
    tot = 0;
    for(int i=0; i<=m*8; i++)
        trmax[i] = trtag[i] = 0;
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int c, T; cin>>c>>T;
    while(T--){
        cin>>n>>m>>k>>d;
        clear();
        for(int i=1; i<=m; i++){
            int x, y; cin>>x>>y>>w[i];
            a[i] = x-y+1, b[i] = x;
            lsh[++tot] = a[i]-1;
            lsh[++tot] = b[i]+1;
        }
        sort(lsh+1, lsh+1+tot);
        tot = unique(lsh+1, lsh+1+tot) - (lsh+1);
        for(int i=1; i<=m; i++){
            a[i] = lower_bound(lsh+1, lsh+1+tot, a[i]-1) - lsh;
            b[i] = lower_bound(lsh+1, lsh+1+tot, b[i]+1) - lsh;
        }
        vector<vector<int>> qr(tot+1, vector<int>());
        for(int i=1; i<=m; i++)
            qr[b[i]].push_back(i);
        ll ans = 0, lim = 0;
        lsh[0] = 0;
        update(1, 1, tot, 1, 1, lsh[1]*d);
        for(int i=2; i<=tot; i++){
            for(int j : qr[i])
                update(1, 1, tot, 1, a[j], w[j]);
            while(lsh[i]-lsh[lim]-1>k) lim++;
            ans = max(ans, query(1, 1, tot, lim, i-1)+d-d*lsh[i]);
            update(1, 1, tot, i, i, ans+d*lsh[i]);
        }
        cout<<ans<<"\n";
    }
    return 0;
}
posted @   FlyPancake  阅读(103)  评论(2编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
// music
点击右上角即可分享
微信分享提示