Sweety

Practice makes perfect

导航

琐碎的区间(线段树区间更新 + 技巧!)

Posted on 2017-05-03 10:30  蓝空  阅读(122)  评论(0编辑  收藏  举报

琐碎的区间

时间限制: 4 Sec  内存限制: 256 MB
提交: 131  解决: 26
[提交][状态][讨论版]

题目描述

给出一个长度为 n 的整数序列 A[1..n],有三种操作: 
1 l r x :  把[l, r]区间的每个数都加上 x 
2 l r :  把[l, r]  区间每个 A[i]变为sqrt(a[i])的整数部分
3 l r :  求[l, r]  区间所有数的和 
其中 l 和 r 和 x 都代表一个整数 

输入

第一行一个 T,表示数据组数。 
对于每组数据 
Line1:两个数 n m,表示整数序列长度和操作数 
Line2:n 个数,表示 A[1..n] 
Line3…Line3+m-1:每行一个询问,对于第三种询问,请输出答案。 
对于每一种询问,先给出操作的编号,再给出相应的操作,编号与题目描述对应。 
数据约定: 
1<=T<=5 
n,m <= 100000 
1<= A[i], x<=100000 

输出

对于第三种询问,输出答案。每个答案占一行。 

样例输入

1
5 5
1 2 3 4 5
1 3 5 2
2 1 4
3 2 4
2 3 5
3 1 5

样例输出

5
6

提示

来源

[提交][状态]

线段树,然后在记录当前区间的最大最小值,如果一样了的话以后更新就直接对区间更新就好了
#include <bits/stdc++.h>
   
#define ll long long
const int maxn=1e5+10;
const int N=2e5+10;
using namespace std;
#define ls rt<<1
#define rs rt<<1|1
int n,m,k,t,a[maxn];
ll tag[maxn<<2],ma[maxn<<2],mi[maxn<<2],sum[maxn<<2];
const int BufferSize=1<<16;
char buffer[BufferSize],*head,*tail;
inline char Getchar() {
    if(head==tail) {
        int l=fread(buffer,1,BufferSize,stdin);
        tail=(head=buffer)+l;
    }
    return *head++;
}
inline int read() {
    int x=0,f=1;char c=Getchar();
    for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=Getchar()) x=x*10+c-'0';
    return x*f;
}
void pup(int l,int r,int rt)
{
    int mid=l+r>>1;
    sum[rt]=sum[ls]+sum[rs];
    ma[rt]=max(ma[ls],ma[rs]);
    mi[rt]=min(mi[ls],mi[rs]);
    tag[rt]=0;
}
void pdw(int l,int r,int rt)
{
    int mid=l+r>>1;
    sum[ls]+=tag[rt]*(mid-l+1);
    ma[ls]+=tag[rt];
    mi[ls]+=tag[rt];
    tag[ls]+=tag[rt];
    sum[rs]+=tag[rt]*(r-mid);
    ma[rs]+=tag[rt];
    mi[rs]+=tag[rt];
    tag[rs]+=tag[rt];
    tag[rt]=0;
}
void build(int l,int r,int rt)
{
    if(l==r)
    {
        tag[rt]=0;
        ma[rt]=mi[rt]=sum[rt]=a[l];
        return;
    }
    int mid=l+r>>1;
    build(l,mid,ls);
    build(mid+1,r,rs);
    pup(l,r,rt);
}
void upd(int L,int R,ll v,int l,int r,int rt)
{
    if(L<=l&&r<=R)
    {
        sum[rt]+=v*(r-l+1);
        ma[rt]+=v;
        mi[rt]+=v;
        tag[rt]+=v;
        return;
    }
    int mid=l+r>>1;
    if(tag[rt])pdw(l,r,rt);
    if(L<=mid)upd(L,R,v,l,mid,ls);
    if(R>mid)upd(L,R,v,mid+1,r,rs);
    pup(l,r,rt);
}
void qsqrt(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)
    {
        if(ma[rt]==mi[rt])
        {
            tag[rt]-=ma[rt];
            ma[rt]=sqrt(ma[rt]);
            tag[rt]+=ma[rt];
            mi[rt]=ma[rt];
            sum[rt]=(r-l+1)*ma[rt];
            return;
        }
        else if(ma[rt]==mi[rt]+1)
        {
            if((ll)sqrt(ma[rt])==(ll)sqrt(mi[rt])+1)
            {
                tag[rt]-=ma[rt];
                sum[rt]-=(r-l+1)*(ma[rt]-(ll)sqrt(ma[rt]));
                ma[rt]=sqrt(ma[rt]);
                tag[rt]+=ma[rt];
                mi[rt]=ma[rt]-1;
                return;
            }
        }
    }
    int mid=l+r>>1;
    if(tag[rt])pdw(l,r,rt);
    if(L<=mid)qsqrt(L,R,l,mid,ls);
    if(R>mid)qsqrt(L,R,mid+1,r,rs);
    pup(l,r,rt);
}
ll gao(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)return sum[rt];
    int mid=l+r>>1;
    if(tag[rt])pdw(l,r,rt);
    ll ret=0;
    if(L<=mid)ret+=gao(L,R,l,mid,ls);
    if(R>mid)ret+=gao(L,R,mid+1,r,rs);
    return ret;
}
int main()
{
    t=read();
    while(t--)
    {
        n=read();m=read();
        for(int i=1;i<=n;i++)
            a[i]=read();
        build(1,n,1);
        while(m--)
        {
            int b,c,d,e;
            b=read(),c=read(),d=read();
            if(b==1)
            {
                e=read();
                upd(c,d,e,1,n,1);
            }
            else if(b==2)
            {
                qsqrt(c,d,1,n,1);
            }
            else printf("%lld\n",gao(c,d,1,n,1));
        }
    }
    return 0;
}

PS超时的代码:
#include <bits/stdc++.h>
using namespace std;

const int BufferSize=1<<16;
char buffer[BufferSize],*head,*tail;
inline char Getchar() {
    if(head==tail) {
        int l=fread(buffer,1,BufferSize,stdin);
        tail=(head=buffer)+l;
    }
    return *head++;
}

inline int read() {
    int x=0,f=1;char c=Getchar();
    for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=Getchar()) x=x*10+c-'0';
    return x*f;
}

#define LL long long

#define L(root) ((root) << 1)
#define R(root) (((root) << 1) | 1)

const int MAXN = 1e5 + 5;
int numbers[MAXN];

//LL delay[MAXN * 4], sum[MAXN * 4], mx[MAXN * 4], mn[MAXN * 4];

struct Node {
    int left, right;
    LL delay;
    LL sum;
    LL mx, mn;
    int mid()
    {
        return left + ((right - left) >> 1);
    }
} tree[MAXN * 4];

void pushUp(int root)
{
    tree[root].sum = tree[L(root)].sum + tree[R(root)].sum;
    tree[root].mx = max(tree[L(root)].mx, tree[R(root)].mx);
    tree[root].mn = min(tree[L(root)].mn, tree[R(root)].mn);
    tree[root].delay = 0;
}

void pushDown(int root, int l, int r)
{
        LL mid = (r + l) >> 1;
        tree[L(root)].delay += tree[root].delay;
        tree[R(root)].delay += tree[root].delay;
        tree[L(root)].sum += tree[root].delay * (mid - l + 1);
        tree[R(root)].sum += tree[root].delay * (r - mid);
        tree[L(root)].mx += tree[root].delay;
        tree[R(root)].mx += tree[root].delay;
        tree[L(root)].mn += tree[root].delay;
        tree[R(root)].mn += tree[root].delay;
        tree[root].delay = 0;

}

void build(int root, int left, int right)
{
    tree[root].left = left;
    tree[root].right = right;
    if (left == right) {
        tree[root].delay = 0;
        tree[root].sum = numbers[left];
        tree[root].mx = numbers[left];
        tree[root].mn = numbers[left];
        return;
    }
    int mid = tree[root].mid();
    build(L(root), left, mid);
    build(R(root), mid + 1, right);
    pushUp(root);
}

LL query(int root, int left, int right)
{
    if (tree[root].left == left && tree[root].right == right) {
        return tree[root].sum;
    }
    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);
    int mid = tree[root].mid();
    if (right <= mid) {
        return query(L(root), left, right);
    } else if (left > mid) {
        return query(R(root), left, right);
    } else {
        return query(L(root), left, mid) + query(R(root), mid + 1, right);
    }
}

void update(int root, int left, int right, LL add)
{
    if (tree[root].left == left && tree[root].right == right) {
        tree[root].delay += add;
        tree[root].sum += add * (right - left + 1);
        tree[root].mx += add;
        tree[root].mn += add;
        return;
    }
    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);
    int mid = tree[root].mid();
    if (right <= mid) {
        update(L(root), left, right, add);
    } else if (left > mid) {
        update(R(root), left, right, add);
    } else {
        update(L(root), left, mid, add);
        update(R(root), mid + 1, right, add);
    }
    pushUp(root);
}

void sq(int root, int left, int right)
{
    if (tree[root].left == left && tree[root].right == right) {
        LL mx = sqrt(tree[root].mx);
        LL mn = sqrt(tree[root].mn);
        if (tree[root].mx == tree[root].mn) {
            tree[root].delay -= (tree[root].mx - mx);
            tree[root].sum = mx * (right - left + 1);
            tree[root].mx = mx;
            tree[root].mn = mn;
            return;
//        }
        } else if ((tree[root].mx == tree[root].mn + 1) && (mx == mn + 1)) {
            tree[root].delay -= (tree[root].mx - mx);
            tree[root].sum -= (tree[root].mx - mx) * (right - left + 1);
            tree[root].mx = mx;
            tree[root].mn = mn;
        }
    }
    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);
    int mid = tree[root].mid();
    if (right <= mid) {
        sq(L(root), left, right);
    } else if (left > mid) {
        sq(R(root), left, right);
    } else {
        sq(L(root), left, mid);
        sq(R(root), mid + 1, right);
    }
    pushUp(root);
}

int main()
{
    int t;
    int n, m;
    int i;
    int op, l, r, x;

//    scanf("%d", &t);
    t = read();
    while (t--) {
//        scanf("%d%d", &n, &m);
        n = read(), m = read();
        for (i = 1; i <= n; ++i) {
//            scanf("%d", &numbers[i]);
            numbers[i] = read();
        }
        build(1, 1, n);
        for (i = 0; i < m; ++i) {
//            scanf("%d", &op);
            op = read();
            if (op == 1) {
//                scanf("%d%d%d", &l, &r, &x);
                l = read(), r = read(), x = read();
                update(1, l, r, x);
                //printf("debug op = 1\n");
            } else if (op == 2) {
//                scanf("%d%d", &l, &r);
                l = read(), r = read();
                sq(1, l, r);
                //printf("debug op = 2\n");
            } else {
//                scanf("%d%d", &l, &r);
                l = read(), r = read();
                printf("%lld\n", query(1, l, r));
                //printf("debug op = 3\n");
            }
        }
    }
    return 0;
}


int main2()
{
    int t;
    int n, m;
    int i;
    int op, l, r, x;

    scanf("%d", &t);
//    t = read();
    while (t--) {
        scanf("%d%d", &n, &m);
//        n = read(), m = read();
        for (i = 1; i <= n; ++i) {
            scanf("%d", &numbers[i]);
//            numbers[i] = read();
        }
        build(1, 1, n);
        for (i = 0; i < m; ++i) {
            scanf("%d", &op);
//            op = read();
            if (op == 1) {
                scanf("%d%d%d", &l, &r, &x);
//                l = read(), r = read(), x = read();
                update(1, l, r, x);
                //printf("debug op = 1\n");
            } else if (op == 2) {
                scanf("%d%d", &l, &r);
//                l = read(), r = read();
                sq(1, l, r);
                //printf("debug op = 2\n");
            } else {
                scanf("%d%d", &l, &r);
//                l = read(), r = read();
                printf("%lld\n", query(1, l, r));
                //printf("debug op = 3\n");
            }
        }
    }
    return 0;
}