POJ--3468 A Simple Problem with Integers(线段树/树状数组/分块)

记录
11:03 2024-2-25

http://poj.org/problem?id=3468

1. 线段树

区间增加 + 区间查询

点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;

const int MAXN = 100005;
const int DAT_SIZE = (1 << 18) - 1;
int N, Q;
ll a[MAXN];
// 线段树
// a 给这个节点对应的区间内的所有元素共同加上的值
// b 在这个节点对应的区间中除去a之外其他的值的和
ll datA[DAT_SIZE], datB[DAT_SIZE];

// 对区间[a, b)同时加x
// k是节点的编号, 对应的区间是[l, r)
void add(int a, int b, int x, int k, int l, int r) {
    if (a <= l && r <= b) {
        // a.给这个节点对应的区间内的所有元素共同加上的值
        datA[k] += x;
    } else if (l < b && a < r) {
        // b.在这个节点对应的区间中除去a之外其他的值的和
        datB[k] += (min(b, r) - max(a, l)) * x;
        add(a, b, x, k * 2 + 1, l, (l + r) / 2);
        add(a, b, x, k * 2 + 2, (l + r) / 2, r );
    }
}

// 计算[a, b)的和
// k是节点的编号 , 对应的区间是[l, r)
ll sum(int a, int b, int k, int l, int r) {
    if (b <= l||r <= a) {
        return 0;
    } else if (a <= l && r <= b) {
        //a l r b 全部包含
        return datA[k] * (r - l) + datB[k];
    } else {
        // 交集区间上同时加上的值 + 子集区间上加上的值
        // 如果对于父亲节点同时加了一个值, 那么这个值就不会在儿子节点被重复考虑。 
        // 在递归计算和时再把这一部分的值加到结果里面就可以了。
        ll res = (min(b, r) - max(a, l)) * datA[k];
        res += sum(a, b, k * 2 + 1, l, (l + r) / 2);
        res += sum(a, b, k * 2 + 2, (l + r) / 2, r);
        return res;
    }
}

int main() {
    cin >> N >> Q;
    for(int i = 0; i < N; i++) {
        scanf("%lld", &a[i]);
        add(i, i + 1, a[i], 0, 0, N);
    }
    char c[2];
    int l, r, v;
    for(int i = 0; i < Q; i++) {
        // %s 它会读入一个不含空格、TAB和回车符的字符串,存入字符数组
        // %c 会读入\n
        scanf("%s%d%d", c, &l, &r);
        if(c[0] == 'C') {
            scanf("%d", &v);
            add(l - 1, r, v, 0, 0, N);
        } else {
            printf("%lld\n" , sum(l - 1, r, 0, 0, N));
        }
    }
}

2. 树状数组

区间增加 + 区间查询

把区间增加转变为单点增加,利用两个树状数组\(c_0 和 c_1\)
将”C l r d" 转化为

  1. 在树状数组\(c_0\)中,把位置l上的数加d
  2. 在树状数组\(c_0\)中,把位置r + 1上的数减d
  3. 在树状数组\(c_1\)中,把位置l上的数加l * d
  4. 在树状数组\(c_1\)中,把位置r + 1上的数减(r + 1) * d

建立sum存储a的原始前缀和
将“Q l r” 转化为 1~r 和 1~l-1两部分进行相减
$ (sum[r] + (r + 1) * ask(c_0, r) - ask(c_1, r)) - (sum[l - 1] + l * ask(c_0, l - 1) - ask(c_1, l - 1)) $

点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;

const int MAXN = 100005;
int N, Q;
ll a[MAXN], sum[MAXN];
ll c[2][MAXN];

// k表示是哪个树状数组,i表示位置, v表示加入的值
void add(int k, int i, int v) {
    while (i <= N){
        c[k][i] += v;
        i += i & -i;
    }
}

// k表示是哪个树状数组,i表示位置
ll ask(int k, int i) {
    ll s = 0;
    while (i > 0) {
        s += c[k][i];
        i -= i & -i;
    }
    return s;
}

int main() {
    cin >> N >> Q;
    for(int i = 1; i <= N; i++) {
        scanf("%lld", &a[i]);
        sum[i] = sum[i - 1] + a[i];
    }
    char c[2];
    int l, r, v;
    for(int i = 0; i < Q; i++) {
        // %s 它会读入一个不含空格、TAB和回车符的字符串,存入字符数组
        // %c 会读入\n
        scanf("%s%d%d", c, &l, &r);
        if(c[0] == 'C') {
            scanf("%d", &v);
            add(0, l, v);
            add(0, r + 1, -v);
            add(1, l, l * v);
            add(1, r + 1, -(r + 1) * v);
        } else {
            ll result = (sum[r] + (r + 1) * ask(0, r) - ask(1, r))
                        - (sum[l - 1] + l * ask(0, l - 1) - ask(1, l - 1));
            printf("%lld\n", result);
        }
    }
}

延迟标记

点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int SIZE=100010;
struct SegmentTree{
	int l,r;
	long long sum,add;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define sum(x) tree[x].sum
	#define add(x) tree[x].add
}tree[SIZE*4];
int a[SIZE],n,m;

void build(int p,int l,int r)//No.p, [l,r]
{
	l(p)=l,r(p)=r;
	if(l==r) { sum(p)=a[l]; return; }
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	sum(p)=sum(p*2)+sum(p*2+1);
}

void spread(int p)
{
	if(add(p))
	{
		sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
		sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
		add(p*2)+=add(p);
		add(p*2+1)+=add(p);
		add(p)=0;
	}
}

void change(int p,int l,int r,int z)
{
	if(l<=l(p)&&r>=r(p))
	{
		sum(p)+=(long long)z*(r(p)-l(p)+1);
		add(p)+=z;
		return;
	}
	spread(p);
	int mid=(l(p)+r(p))/2;
	if(l<=mid) change(p*2,l,r,z);
    if(r>mid) change(p*2+1,l,r,z);
	sum(p)=sum(p*2)+sum(p*2+1);
}

long long ask(int p,int l,int r)
{
	if(l<=l(p)&&r>=r(p)) return sum(p);
	spread(p);
    int mid=(l(p)+r(p))/2;
	long long ans=0;
    if(l<=mid) ans+=ask(p*2,l,r);
    if(r>mid) ans+=ask(p*2+1,l,r);
    return ans;
}

int main()
{
	cin>>n>>m;
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	build(1,1,n);
	while(m--)
	{
		char op[2]; int x,y,z;
		scanf("%s%d%d",op,&x,&y);
		if(op[0]=='C')
		{
			scanf("%d",&z);
			change(1,x,y,z);
		}
		else printf("%lld\n",ask(1,x,y));
	}
}

3. 分块

这个就是较为朴素的做法了,分块的思想:大段维护、局部朴素,直接复制书上的了

点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
long long a[100010], sum[100010], add[100010];
int L[100010], R[100010]; // 每段左右端点
int pos[100010]; // 每个位置属于哪一段
int n, m, t;

void change(int l, int r, long long d) {
	int p = pos[l], q = pos[r];
	if (p == q) {
		for (int i = l; i <= r; i++) a[i] += d;
		sum[p] += d*(r - l + 1);
	}
	else {
		for (int i = p + 1; i <= q - 1; i++) add[i] += d;
		for (int i = l; i <= R[p]; i++) a[i] += d;
		sum[p] += d*(R[p] - l + 1);
		for (int i = L[q]; i <= r; i++) a[i] += d;
		sum[q] += d*(r - L[q] + 1);
	}
}

long long ask(int l, int r) {
	int p = pos[l], q = pos[r];
	long long ans = 0;
	if (p == q) {
		for (int i = l; i <= r; i++) ans += a[i];
		ans += add[p] * (r - l + 1);
	}
	else {
		for (int i = p + 1; i <= q - 1; i++)
			ans += sum[i] + add[i] * (R[i] - L[i] + 1);
		for (int i = l; i <= R[p]; i++) ans += a[i];
		ans += add[p] * (R[p] - l + 1);
		for (int i = L[q]; i <= r; i++) ans += a[i];
		ans += add[q] * (r - L[q] + 1);
	}
	return ans;
}

int main() {
	cin >> n >> m;
	for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
	// 分块
	t = sqrt(n*1.0);
	for (int i = 1; i <= t; i++) {
		L[i] = (i - 1)*sqrt(n*1.0) + 1;
		R[i] = i*sqrt(n*1.0);
	}
	if (R[t] < n) t++, L[t] = R[t - 1] + 1, R[t] = n;
	// 预处理
	for (int i = 1; i <= t; i++)
		for (int j = L[i]; j <= R[i]; j++) {
			pos[j] = i;
			sum[i] += a[j];
		}
	// 指令
	while (m--) {
		char op[3];
		int l, r, d;
		scanf("%s%d%d", op, &l, &r);
		if (op[0] == 'C') {
			scanf("%d", &d);
			change(l, r, d);
		}
		else printf("%lld\n", ask(l, r));
	}
}
posted @ 2024-02-25 11:37  57one  阅读(11)  评论(0编辑  收藏  举报