BZOJ4540 [Hnoi2016]序列 【莫队 + ST表 + 单调栈】
题目
给定长度为n的序列:a1,a2,…,an,记为a[1:n]。类似地,a[l:r](1≤l≤r≤N)是指序列:al,al+1,…,ar-
1,ar。若1≤l≤s≤t≤r≤n,则称a[s:t]是a[l:r]的子序列。现在有q个询问,每个询问给定两个数l和r,1≤l≤r
≤n,求a[l:r]的不同子序列的最小值之和。例如,给定序列5,2,4,1,3,询问给定的两个数为1和3,那么a[1:3]有
6个子序列a[1:1],a[2:2],a[3:3],a[1:2],a[2:3],a[1:3],这6个子序列的最小值之和为5+2+4+2+2+2=17。
输入格式
输入文件的第一行包含两个整数n和q,分别代表序列长度和询问数。接下来一行,包含n个整数,以空格隔开
,第i个整数为ai,即序列第i个元素的值。接下来q行,每行包含两个整数l和r,代表一次询问。
输出格式
对于每次询问,输出一行,代表询问的答案。
输入样例
5 5
5 2 4 1 3
1 5
1 3
2 4
3 5
2 5
输出样例
28
17
11
11
17
提示
1 ≤N,Q ≤ 100000,|Ai| ≤ 10^9
题解
考虑莫队
我们只需要考虑如何\(O(1)\)扩展区间即可
扩展左右端点是类似的
以右端点为例
新产生的子区间一定是以新的右端点\(r\)为右端点的那些子区间
那么考虑所有左端点产生的影响
从区间最小值的位置到区间左端点,左端点在这个区间内的贡献一定是区间最小值
至于左端点在右半区间时产生的贡献,我们再考虑:
从\(r\)一直到一个比\(A[r]\)小的位置\(pos\),答案一直是\(A[r]\)
然后从\(pos\)开始,向左一直到一个比\(A[pos]\)小的位置,答案一直是\(A[pos]\)
一次类推
如果将每个位置向往前第一个比它小的位置连边,这就是一个树结构,我们只用\(O(1)\)求出树链两点距离即可解决右半区间的答案
建树可以用单调栈实现
我们再用\(ST\)表维护最小值就可以做到\(O(1)\)扩展区间了
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 100005,maxm = 100005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
int h[maxn],ne = 1;
struct EDGE{int to,nxt;}ed[maxn];
inline void build(int u,int v){
ed[ne] = (EDGE){v,h[u]}; h[u] = ne++;
}
LL Ls[maxn],Rs[maxn];
LL mn[maxn][18],n,Q,A[maxn],fa[maxn],fa1[maxn],bin[30],Log[maxn],B;
LL st[maxn],top,vis[maxn];
void dfs1(int u){
vis[u] = true;
if (fa[u]) Ls[u] = Ls[fa[u]] + (u - fa[u]) * A[u];
Redge(u) {
fa[to = ed[k].to] = u;
dfs1(to);
}
}
void dfs2(int u){
vis[u] = true;
if (fa1[u]) Rs[u] = Rs[fa1[u]] + (fa1[u] - u) * A[u];
Redge(u) {
fa1[to = ed[k].to] = u;
dfs2(to);
}
}
void init(){
REP(j,17) REP(i,n){
if (i + bin[j] - 1 > n) break;
mn[i][j] = A[mn[i][j - 1]] <= A[mn[i + bin[j - 1]][j - 1]] ? mn[i][j - 1] : mn[i + bin[j - 1]][j - 1];
}
REP(i,n){
while (top && A[st[top]] > A[i]) top--;
if (top) build(st[top],i);
st[++top] = i;
}
REP(i,n) if (!vis[i]) fa[i] = 0,dfs1(i);
top = 0; memset(vis,0,sizeof(vis)); memset(h,0,sizeof(h)); ne = 1;
for (int i = n; i; i--){
while (top && A[st[top]] > A[i]) top--;
if (top) build(st[top],i);
st[++top] = i;
}
for (int i = n; i; i--) if (!vis[i]) fa1[i] = 0,dfs2(i);
}
int getmn(int l,int r){
int t = Log[r - l + 1];
return A[mn[l][t]] <= A[mn[r - bin[t] + 1][t]] ? mn[l][t] : mn[r - bin[t] + 1][t];
}
struct Que{int l,r,id,b;}q[maxn];
LL ans[maxn];
inline bool operator <(const Que& a,const Que& b){
return a.b == b.b ? a.r < b.r : a.l < b.l;
}
void solve(){
sort(q + 1,q + 1 + Q);
LL tot = 0,L = q[1].l,R = q[1].r,pos;
for (int i = L; i <= R; i++){
pos = getmn(L,i);
if (A[pos] == A[i]) tot += (i - L + 1) * A[i];
else tot += (pos - L + 1) * A[pos] + Ls[i] - Ls[pos];
}
ans[q[1].id] = tot;
for (int i = 2; i <= Q; i++){
while (L != q[i].l || R != q[i].r){
if (L > q[i].l){
L--;
pos = getmn(L,R);
if (A[pos] == A[L]) tot += (R - L + 1) * A[L];
else tot += (R - pos + 1) * A[pos] + Rs[L] - Rs[pos];
}
if (L < q[i].l){
pos = getmn(L,R);
if (A[pos] == A[L]) tot -= (R - L + 1) * A[L];
else tot -= (R - pos + 1) * A[pos] + Rs[L] - Rs[pos];
L++;
}
if (R < q[i].r){
R++;
pos = getmn(L,R);
if (A[pos] == A[R]) tot += (R - L + 1) * A[R];
else tot += (pos - L + 1) * A[pos] + Ls[R] - Ls[pos];
}
if (R > q[i].r){
pos = getmn(L,R);
if (A[pos] == A[R]) tot -= (R - L + 1) * A[R];
else tot -= (pos - L + 1) * A[pos] + Ls[R] - Ls[pos];
R--;
}
}
ans[q[i].id] = tot;
}
REP(i,Q) printf("%lld\n",ans[i]);
}
int main(){
bin[0] = 1; for (int i = 1; i <= 25; i++) bin[i] = bin[i - 1] << 1;
Log[0] = -1; for (int i = 1; i < maxn; i++) Log[i] = Log[i >> 1] + 1;
n = read(); Q = read();
REP(i,n) A[i] = read(),mn[i][0] = i;
init();
B = (int)sqrt(n) + 1;
REP(i,Q){
q[i].l = read(); q[i].r = read(); q[i].id = i; q[i].b = q[i].l / B;
}
solve();
return 0;
}