[学习笔记]ST表
给狂妄自负以适当的绝望,这就是真理
基本概念
-
RMQ问题:
给定一个长度为N的区间,M个询问,每次询问Li到Ri这段区间元素的最大值/最小值。
如果暴力找最大值,复杂度是\(o(n)\)。但如果查询多次,这个复杂度就很大了。
解决这个问题的方法是离线ST表和支持在线修改的线段树。 -
ST表:一种利用dp求解区间最值的倍增算法。
-
定义:\(f[i][j]\)表示\(i\)到\(i+2^{j-1}\)这段区间的最大值。
-
预处理:\(f[i][0]=a[i]\)。即\(i\)到\(i\)区间的最大值就是\(a[i]\)。
-
状态转移:将\(f[i][j]\)平均分成两段,一段为\(f[i][j-1]\),另一段为\(f[i+2^{j-1}][j-1]\)。
-
两段的长度均为\(2^{j-1}\)。\(f[i][j]\)的最大值即这两段的最大值中的最大值。
-
- 查询:需要查询的区间为\([i,j]\),则需要找到两个覆盖这个闭区间的最小幂区间。
这两个区间可以重复,因为两个区间是否相交对区间最值没有影响。(如下图)
模板题:Balanced Lineup
- 题目描述:
For the daily milking, Farmer John's N cows (1 ≤ N ≤ 50,000) always line up in the same order. One day Farmer John decides to organize a game of Ultimate Frisbee with some of the cows. To keep things simple, he will take a contiguous range of cows from the milking lineup to play the game. However, for all the cows to have fun they should not differ too much in height.
Farmer John has made a list of Q (1 ≤ Q ≤ 200,000) potential groups of cows and their heights (1 ≤ height ≤ 1,000,000). For each group, he wants your help to determine the difference in height between the shortest and the tallest cow in the group.
- 输入:
Line 1: Two space-separated integers, N and Q.
Lines 2.. N+1: Line i+1 contains a single integer that is the height of cow i
Lines N+2.. N+ Q+1: Two integers A and B (1 ≤ A ≤ B ≤ N), representing the range of cows from A to B inclusive.
- 输出:
Lines 1.. Q: Each line contains a single integer that is a response to a reply and indicates the difference in height between the tallest and shortest cow in the range.
- 样例
6 3
1
7
3
4
2
5
1 5
4 6
2 2
6
3
0
代码及模板
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 7;
int stmax[maxn][20], stmin[maxn][20];
int poww[25], logg[maxn];
int n, q;
void init() {
poww[0] = 1;//预处理次方
for (int i = 1; i <= 20; i++) poww[i] = poww[i - 1] << 1;
for (int i = 2; i <= n; i++) logg[i] = logg[i >> 1] + 1;
int temp = 1;//temp=2^(j-1)
for (int j = 1; j <= logg[n]; j++, temp <<= 1) {
for (int i = 1; i <= n - temp - temp + 1; i++) {
stmax[i][j] = max(stmax[i][j - 1], stmax[i + temp][j - 1]);
stmin[i][j] = min(stmin[i][j - 1], stmin[i + temp][j - 1]);
}
}
}
inline int query_min(int l, int r) {
int len = r - l + 1;
int k = logg[len];
return min(stmin[l][k], stmin[r - poww[k] + 1][k]);
}
inline int query_max(int l, int r) {
int len = r - l + 1;
int k = logg[len];
return max(stmax[l][k], stmax[r - poww[k] + 1][k]);
}
int main() {
int a;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", &a);
stmax[i][0] = stmin[i][0] = a;
}
init();
int l, r;
while (q--) {
scanf("%d%d", &l, &r);
printf("%d\n", query_max(l, r) - query_min(l, r));
}
return 0;
}
LCA+ST的模板
用dfs序表示2*n的数组,数组的值是深度。fir表示第一次遍历到这个点的dfs序,cur表示这个dfs序对应的节点。lca就是l到r区间深度最小的节点。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn=5e5+10;//点的个数
int cur[maxn<<1];// cur 当前dfs序的点
int fir[maxn];//fir 第一次遍历的 dfs序
int rmq[maxn<<1];//rmq 深度
struct Edge{
int to,next;
}e[maxn<<1];
int head[maxn],tol,cnt;
struct St{
int stm[maxn<<1][21];//stm dfs序列上深度最小的点
int logg[maxn<<1],poww[21];
void init(int n) {
poww[0]=1;logg[0]=-1;
for(int i=1; i<=20; i++)
{
poww[i]=poww[i-1]<<1;
}
for (int i = 1; i <= n; i++) {
logg[i]=logg[i>>1]+1;
stm[i][0]=i;
}
int temp=1;
for (int j = 1; j <=logg[n]; ++j,temp<<=1) {
for (int i = 1; i <=n-temp-temp+1; ++i) {
stm[i][j]=rmq[stm[i][j-1]]<rmq[stm[i+temp][j-1]]?stm[i][j-1]:stm[i+temp][j-1];
}
}
}
int Query(int a,int b){
if(a>b) swap(a,b);
int k=logg[b-a+1];
return rmq[stm[a][k]]<=rmq[stm[b-poww[k]+1][k]]?stm[a][k]:stm[b-poww[k]+1][k];
}
}st;
void add(int u,int v){
e[++tol].to=v;
e[tol].next=head[u];
head[u]=tol;
}
void dfs(int u,int pre,int dep){
cur[++cnt]=u;
rmq[cnt]=dep;
fir[u]=cnt;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==pre) continue;
dfs(v,u,dep+1);
cur[++cnt]=u;
rmq[cnt]=dep;
}
}
void cal(int roof,int n){
dfs(roof,roof,0);
st.init(n);
}
int query_lca(int a,int b){
return cur[st.Query(fir[a],fir[b])];
}
int main(){
int n,m,s,u,v;
scanf("%d%d%d",&n,&m,&s);
for (int i = 1; i < n; ++i) {
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
cal(s,(n*2)-1);
while (m--){
scanf("%d%d",&u,&v);
printf("%d\n",query_lca(u,v));
}
}