BZOJ3747 [POI2015]Kinoman 【线段树】
题目链接
题解
这种找区间最优的问题,一定是枚举一个端点,然后用数据结构维护另一个端点
我们枚举左端点,用线段树维护每个点作为右端点时的答案
当左端点为\(1\)时,我们能\(O(n)\)预处理出每个位置的答案初始化线段树
当左端点右移一位时,该位上的电影就从区间删除了,记\(nxt[i]\)为下一个同类电影的位置,那么从左端点到\(nxt[i] - 1\)的位置的权值都会减少掉\(w[f[i]]\),而\(nxt[i]\)到\(nxt[nxt[i]] - 1\)的位置都会增加\(w[f[i]]\)
就是一个简单的线段树维护区间最大值了
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
#define ls (u << 1)
#define rs (u << 1 | 1)
using namespace std;
const int maxn = 1000005,maxm = 100005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
LL mx[maxn << 2],tag[maxn << 2],C[maxn];
void upd(int u){mx[u] = max(mx[ls],mx[rs]);}
void pd(int u){
if (tag[u]){
mx[ls] += tag[u]; tag[ls] += tag[u];
mx[rs] += tag[u]; tag[rs] += tag[u];
tag[u] = 0;
}
}
void build(int u,int l,int r){
if (l == r){mx[u] = C[l]; return;}
int mid = l + r >> 1;
build(ls,l,mid);
build(rs,mid + 1,r);
upd(u);
}
void add(int u,int l,int r,int L,int R,LL v){
if (l >= L && r <= R){mx[u] += v; tag[u] += v; return;}
pd(u);
int mid = l + r >> 1;
if (mid >= L) add(ls,l,mid,L,R,v);
if (mid < R) add(rs,mid + 1,r,L,R,v);
upd(u);
}
LL query(int u,int l,int r,int L,int R){
if (l >= L && r <= R) return mx[u];
pd(u);
int mid = l + r >> 1;
if (mid >= R) return query(ls,l,mid,L,R);
if (mid < L) return query(rs,mid + 1,r,L,R);
return max(query(ls,l,mid,L,R),query(rs,mid + 1,r,L,R));
}
int nxt[maxn],last[maxn],bac[maxn];
int n,m,f[maxn],w[maxn];
LL sum,ans;
int main(){
n = read(); m = read();
REP(i,n) f[i] = read();
REP(i,m) w[i] = read();
for (int i = n; i; i--) nxt[i] = last[f[i]],last[f[i]] = i;
for (int i = 1; i <= n; i++){
if (!bac[f[i]]) sum += w[f[i]];
else if (bac[f[i]] == 1) sum -= w[f[i]];
bac[f[i]]++;
C[i] = sum;
}
build(1,1,n);
for (int i = 1; i <= n; i++){
ans = max(ans,query(1,1,n,i,n));
if (nxt[i]){
if (nxt[i] > i + 1) add(1,1,n,i + 1,nxt[i] - 1,-w[f[i]]);
add(1,1,n,nxt[i],nxt[nxt[i]] ? nxt[nxt[i]] - 1 : n,w[f[i]]);
}
else add(1,1,n,i + 1,n,-w[f[i]]);
}
printf("%lld\n",ans);
return 0;
}