数据结构的一些模板
数据结构
表达式求值
** 求中序表达式值 **
例:(2+2)*(1+1)
#include <iostream>
#include <unordered_map>
#include <stack>
using namespace std;
string s;
stack<int> nums;
stack<char> op;
void eval() {
char c = op.top(); op.pop();
int b = nums.top(); nums.pop();
int a = nums.top(); nums.pop();
int x = 0;
if (c == '+') x = a + b;
if (c == '-') x = a - b;
if (c == '*') x = a * b;
if (c == '/') x = a / b;
nums.push(x);
}
int main(void) {
cin >> s;
unordered_map<char, int> pr = { {'+', 1}, {'-', 1}, {'*', 2}, {'/', 2} };
for (int i = 0; i < s.size(); ++i) {
char c = s[i];
if (isdigit(c)) {
int x = 0, j = i;
while (j < s.size() && isdigit(s[j])) {
x = x * 10 + s[j] - '0';
j++;
}
i = j - 1;
nums.push(x);
} else {
if (c == '(') op.push('(');
else if (c == ')') {
while (op.size() && op.top() != '(') eval();
op.pop();
} else {
while (op.size() && op.top() != ')' && pr[op.top()] >= pr[c]) eval();
op.push(c);
}
}
}
while (op.size()) eval();
cout << nums.top();
return 0;
}
单调栈
输入n个数,输出每个数左边第一个比它小的数,不存在则输出-1
#include <iostream>
#include <stack>
using namespace std;
stack<int> st;
int main(void) {
int n;
scanf("%d", &n);
while (n--) {
int x;
scanf("%d", &x);
while (st.size() && st.top() >= x) st.pop();
if (st.size()) printf("%d ", st.top());
else printf("-1 ");
st.push(x);
}
return 0;
}
单调队列
滑动窗口求最大&最小
#include <iostream>
using namespace std;
const int N = 1000010;
int a[N], q[N];
int n, k;
int main(void) {
scanf("%d%d", &n, &k);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
/* k中求最小 */
int hh = 0, tt = -1; // hh头,tt尾
for (int i = 0; i < n; ++i) {
if (hh <= tt && i - k + 1 > q[hh]) hh++;
while (hh <= tt && a[q[tt]] >= a[i]) tt--;
q[++tt] = i;
if (i - k + 1 >= 0) printf("%d ", a[q[hh]]);
}
puts("");
/* k中求最大 */
hh = 0, tt = -1;
for (int i = 0; i < n; ++i) {
if (hh <= tt && i - k + 1 > q[hh]) hh++;
while (hh <= tt && a[q[tt]] <= a[i]) tt--;
q[++tt] = i;
if (i - k + 1 >= 0) printf("%d ", a[q[hh]]);
}
return 0;
}
KMP
求P在S中所有出现位置的起始下标
#include <iostream>
using namespace std;
const int N = 100010, M = 1000010;
int n, m;
char p[N], s[M];
int ne[N];
int main(void) {
cin >> n >> p + 1;
cin >> m >> s + 1;
// next[]
for (int i = 2, j = 0; i <= n; ++i) {
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
//kmp
for (int i = 1, j = 0; i <= m; ++i) {
while (j && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j++;
if (j == n) {
cout << i - n + 1 - 1 << ' ';
j = ne[j];
}
}
return 0;
}
Trie
Trie字符串统计
统计字符串在集合中出现了多少次
输入的n个操作中,'I'表示插入字符串,'Q'表示查询字符串
#include <iostream>
using namespace std;
const int N = 20010;
int son[N][26], cnt[N], idx;
void insert(char str[]) {
int p = 0;
for (int i = 0; str[i]; ++i) {
int u = str[i] - 'a';
if (!son[p][u]) son[p][u] = ++idx;
p = son[p][u];
}
cnt[p]++;
}
int query(char str[]) {
int p = 0;
for (int i = 0; str[i]; ++i) {
int u = str[i] - 'a';
if (!son[p][u]) return 0;
p = son[p][u];
}
return cnt[p];
}
int main(void) {
int n;
scanf("%d", &n);
char str[N];
while (n--) {
char op[2];
scanf("%s%s", op, str);
if (*op == 'I')
insert(str);
else {
printf("%d\n", query(str));
}
}
return 0;
}
最大异或对
在n个数中,选两个数进行异或运算,求最大值
(相同为0,不同为1)
#include <iostream>
using namespace std;
const int N = 100010, M = 31 * N;
int son[M][2], idx;
void insert(int x) {
int p = 0;
for (int i = 31; i >= 0; --i) {
int u = x >> i & 1;
if (!son[p][u]) son[p][u] = ++idx;
p = son[p][u];
}
}
int query(int x) {
int p = 0;
int res = 0;
for (int i = 31; i >= 0; --i) {
int u = x >> i & 1;
if (son[p][!u]) {
p = son[p][!u];
res = res * 2 + !u;
} else {
p = son[p][u];
res = res * 2 + u;
}
}
return res;
}
int main(void) {
int n;
scanf("%d", &n);
int res = 0;
while (n--) {
int x;
scanf("%d", &x);
insert(x);
int t = query(x);
res = max(res, t ^ x);
}
printf("%d", res);
return 0;
}
并查集
路径压缩
int p[N], d[N];
int find(int x) {
if (p[x] != x) {
int u = find(p[x]);
d[x] += d[p[x]];
p[x] = u;
}
return p[x];
}
字符串哈希
输入字符串长度n和字符串s,给出q个询问[l1, r1],[l2,r2],判断两区间的字符串是否完全相同
#include <iostream>
using namespace std;
typedef unsigned long long ULL;
const int N = 100010, P = 131;
char s[N];
ULL sum[N], p[N];
ULL get(int l, int r){
return sum[r] - sum[l - 1] * p[r - l + 1];
}
int main(void) {
int n, q;
cin >> n >> q;
cin >> s + 1;
p[0] = 1;
for (int i = 1; i <= n; ++i) {
p[i] = p[i - 1] * P;
sum[i] = sum[i - 1] * P + s[i];
}
while (q--) {
int a, b, x, y;
cin >> a >> b >> x >> y;
if (get(a, b) == get(x, y)) puts("Yes");
else puts("No");
}
return 0;
}
线段树
最大值
给定一个正整数数列 a1,a2,…,an,每一个数都在 0∼p−1 之间。
可以对这列数进行两种操作:
添加操作:向序列后添加一个数,序列长度变成 n+1;
询问操作:询问这个序列中最后 L 个数中最大的数是多少。
程序运行的最开始,整数序列为空。
一共要对整数序列进行 m 次操作。
写一个程序,读入操作的序列,并输出询问操作的答案。
#include <iostream>
using namespace std;
const int N = 200010;
struct Node {
int l, r;
int v;
} tr[N * 4];
int m, p;
void pushup(int u) {
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) return ;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else return max(query(u << 1, l, r), query(u << 1 | 1, l, r));
}
void modify(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) tr[u].v = v;
else {
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int main(void) {
cin >> m >> p;
build(1, 1, m);
int n = 0;
int last = 0;
while (m--) {
char op[2];
int x;
scanf("%s%d", op, &x);
if (op[0] == 'Q') {
last = query(1, n - x + 1, n);
printf("%d\n", last);
}
else {
n++;
modify(1, n, (last + x) % p);
}
}
}
区间最大子段和
给定长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:
1 x y,查询区间 [x,y] 中的最大连续子段和
2 x y,把 A[x] 改成 y。
对于每个查询指令,输出一个整数表示答案。
#include <iostream>
using namespace std;
const int N = 500010;
struct Node {
int l, r;
int tmax, lmax, rmax, sum;
} tr[N * 4];
int n, q;
int w[N];
void pushup(Node &u, Node &l, Node &r) {
u.sum = l.sum + r.sum;
u.lmax = max(l.lmax, l.sum + r.lmax);
u.rmax = max(r.rmax, r.sum + l.rmax);
u.tmax = max(max(l.tmax, r.tmax), l.rmax + r.lmax);
}
void pushup(int u) {
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r) {
if (l == r) tr[u] = {r, r, w[r], w[r], w[r], w[r]};
else {
tr[u] = {l , r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u]. r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else {
Node res;
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
}
void modify(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) tr[u] = {x, x, v, v ,v ,v};
else {
int mid = tr[u].l + tr[u]. r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int main(void) {
// freopen("1.txt", "r", stdin);
cin >> n >> q;
for (int i= 1; i <= n; ++i) scanf("%d", &w[i]);
build(1, 1, n);
while (q--) {
int k, x, y;
cin >> k >>x >>y;
if (k == 1) {
if (x > y) swap(x, y);
printf("%d\n", query(1, x, y).tmax);
}
else modify(1, x, y);
}
return 0;
}
区间最大公约数
给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:
C l r d,表示把 A[l],A[l+1],…,A[r] 都加上 d。
Q l r,表示询问 A[l],A[l+1],…,A[r] 的最大公约数(GCD)。
对于每个询问,输出一个整数表示答案。
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 500010;
struct Node
{
int l, r;
LL sum, d;
}tr[N * 4];
LL w[N];
int n, q;
LL gcd(LL a, LL b)
{
return b ? gcd(b, a % b) : a;
}
void pushup(Node &u, Node &l, Node &r)
{
u.sum = l.sum + r.sum;
u.d = gcd(l.d, r.d);
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r) {
LL d = w[r] - w[r - 1];
tr[u] = {l, r, d, d};
}
else {
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, LL v) // add v in x
{
if (tr[u].l == x && tr[u].r == x) {
LL d = tr[u].sum + v;
tr[u] = {x, x, d, d};
}
else {
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else {
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main(void)
{
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; ++i) scanf("%lld", &w[i]);
build(1, 1, n);
while (q--) {
char op[2];
int l, r;
scanf("%s%d%d", op, &l, &r);
if (*op == 'C') {
LL d;
scanf("%lld", &d);
modify(1, l, d);
if (r + 1 <= n) modify(1, r + 1, -d);
}
else {
Node left = query(1, 1, l);
Node right = {0, 0, 0, 0};
if (l + 1 <= r) right = query(1, l + 1, r);
LL d = gcd(left.sum, right.d);
printf("%lld\n", abs(d));
}
}
return 0;
}