hdu - 4578(线段树)
题目:Yuanfang is puzzled with the question below:
There are n integers, a1, a2, …, an. The initial values of them are 0. There are four kinds of operations.
Operation 1: Add c to each number between ax and ay inclusive. In other words, do transformation ak<---ak+c, k = x,x+1,…,y.
Operation 2: Multiply c to each number between ax and ay inclusive. In other words, do transformation ak<---ak×c, k = x,x+1,…,y.
Operation 3: Change the numbers between ax and ay to c, inclusive. In other words, do transformation ak<---c, k = x,x+1,…,y.
Operation 4: Get the sum of p power among the numbers between ax and ay inclusive. In other words, get the result of axp+ax+1p+…+ay p.
Yuanfang has no idea of how to do it. So he wants to ask you to help him.
Input
There are no more than 10 test cases.
For each case, the first line contains two numbers n and m, meaning that there are n integers and m operations. 1 <= n, m <= 100,000.
Each the following m lines contains an operation. Operation 1 to 3 is in this format: "1 x y c" or "2 x y c" or "3 x y c". Operation 4 is in this format: "4 x y p". (1 <= x <= y <= n, 1 <= c <= 10,000, 1 <= p <= 3)
The input ends with 0 0.
Output
For each operation 4, output a single integer in one line representing the result. The answer may be quite large. You just need to calculate the remainder of the answer when divided by 10007.
事先声明: tag_1为加, tag_2 为乘 ,tag_3 表示变为tag_3
思路:立方和与平方和
公式: 对一个区间都加上一个常数c
已知: sum1 = a + b + c + ...... (n个数)
sum2 = a2 + b2 + c2 ......(n个数)
sum3 = a3 + b3 + c3 ...... (n个数)
一次方和: sum1' = sum1 + n * c
二次方和: sum2' = sum2 + 2 * sum1*c +n *c2 //平方和可得
三次方和: sum3' = sum3 +3*sum2 * c + 3 * sum1 * c * c + n * c * c * c //立方和可得
注意事项:当有tag_2 ! = 1时,此时若有tag_1,则tag_1 *= c;
当有tag_3时,此时将 tag_1赋值为0,tag_2 赋值为0 //显而易见的
代码:
#define _CRT_SECURE_NO_WARNINGS 1
#include<algorithm>
#include<fstream>
#include<iostream>
#include<cstdio>
#include<deque>
#include<string>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<vector>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<bitset>
#include<unordered_map>
using namespace std;
#define INF 0x3f3f3f3f
#define MAXN 310000
#define N 200010
#define M 10007
#define endl '\n'
#define exp 1e-8
#define lc p << 1
#define rc p << 1|1
#define lowbit(x) ((x)&-(x))
const double pi = acos(-1.0);
typedef long long LL;
typedef unsigned long long ULL;
inline ULL read() {
ULL x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch>'9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x * f;
}
void print(ULL x) {
if (x > 9)print(x / 10);
putchar(x % 10 ^ 48);
}
struct node
{
LL l, r, sum_1,sum_2,sum_3,tag_1,tag_2,tag_3;
}tr[4*N];
int n,m,cnt;
void pushup(int p)
{
tr[p].sum_1 = (tr[lc].sum_1 + tr[rc].sum_1)%M;
tr[p].sum_2 = (tr[lc].sum_2 + tr[rc].sum_2)%M;
tr[p].sum_3 = (tr[lc].sum_3 + tr[rc].sum_3)%M;
}
void pushdown(int p)
{
if (tr[p].tag_3)
{
tr[lc].sum_1 = ((tr[lc].r - tr[lc].l + 1) * tr[p].tag_3)%M;
tr[lc].sum_2 = ((tr[lc].r - tr[lc].l + 1) * tr[p].tag_3%M * tr[p].tag_3)%M;
tr[lc].sum_3 = ((tr[lc].r - tr[lc].l + 1) * tr[p].tag_3%M * tr[p].tag_3%M *tr[p].tag_3 )% M;
tr[rc].sum_1 = ((tr[rc].r - tr[rc].l + 1) * tr[p].tag_3 )% M;
tr[rc].sum_2 = ((tr[rc].r - tr[rc].l + 1) * tr[p].tag_3%M * tr[p].tag_3) % M;
tr[rc].sum_3 = ((tr[rc].r - tr[rc].l + 1) * tr[p].tag_3%M * tr[p].tag_3%M * tr[p].tag_3%M ) % M;
tr[lc].tag_3 = tr[p].tag_3;
tr[rc].tag_3 = tr[p].tag_3;
tr[lc].tag_2 = 1, tr[lc].tag_1 = 0;
tr[rc].tag_2 = 1, tr[rc].tag_1 = 0;
tr[p].tag_3 = 0;
}
if (tr[p].tag_2!=1)
{
tr[lc].sum_1 = (tr[lc].sum_1 * (tr[p].tag_2 % M)) % M;
tr[lc].sum_2 = (tr[lc].sum_2 * (tr[p].tag_2 % M) * (tr[p].tag_2 % M)) % M;
tr[lc].sum_3 = ((tr[lc].sum_3 * (tr[p].tag_2 % M)) * (tr[p].tag_2 % M) * (tr[p].tag_2 % M)) % M;
tr[rc].sum_1 = (tr[rc].sum_1 * (tr[p].tag_2 % M)) % M;
tr[rc].sum_2 = (tr[rc].sum_2 * (tr[p].tag_2 % M) * (tr[p].tag_2 % M)) % M;
tr[rc].sum_3 = ((tr[rc].sum_3 * (tr[p].tag_2 % M)) * (tr[p].tag_2 % M) * (tr[p].tag_2 % M)) % M;
tr[lc].tag_2 = tr[lc].tag_2 * tr[p].tag_2 % M;
tr[lc].tag_1 = tr[lc].tag_1 * tr[p].tag_2 % M;
tr[rc].tag_2 = tr[rc].tag_2 * tr[p].tag_2 % M;
tr[rc].tag_1 = tr[rc].tag_1 * tr[p].tag_2 % M;
tr[p].tag_2 = 1;
}
if (tr[p].tag_1)
{
LL a = tr[lc].sum_1, b = tr[lc].sum_2, c = tr[lc].sum_3;
tr[lc].sum_1 = (tr[lc].sum_1 + (tr[lc].r - tr[lc].l + 1) * tr[p].tag_1) % M;
tr[lc].sum_2 = (tr[lc].sum_2 + ((tr[lc].r - tr[lc].l + 1) * tr[p].tag_1 * tr[p].tag_1 + 2 * tr[p].tag_1 * a)) % M;
tr[lc].sum_3 = (tr[lc].sum_3 + 3 * tr[p].tag_1 * b + 3 * tr[p].tag_1 % M * tr[p].tag_1 % M * a % M + (tr[lc].r - tr[lc].l + 1) * tr[p].tag_1 % M * tr[p].tag_1 % M * tr[p].tag_1 % M) % M;
a = tr[rc].sum_1, b = tr[rc].sum_2, c = tr[rc].sum_3;
tr[rc].sum_1 = (tr[rc].sum_1 + (tr[rc].r - tr[rc].l + 1) * tr[p].tag_1) % M;
tr[rc].sum_2 = (tr[rc].sum_2 + ((tr[rc].r - tr[rc].l + 1) * tr[p].tag_1 % M * tr[p].tag_1 % M + 2 * tr[p].tag_1 * a)) % M;
tr[rc].sum_3 = (tr[rc].sum_3 + 3 * tr[p].tag_1 * b % M + 3 * tr[p].tag_1 % M * tr[p].tag_1 % M * a % M + (tr[rc].r - tr[rc].l + 1) * tr[p].tag_1 % M * tr[p].tag_1 % M * tr[p].tag_1 % M) % M;
tr[lc].tag_1 += tr[p].tag_1;
tr[rc].tag_1 += tr[p].tag_1;
tr[p].tag_1 = 0;
}
}
void build(int p, int l, int r)
{
tr[p] = { l,r,0,0,0,0,1,0};
if (l == r)return;
int m = l + r >> 1;
build(lc, l, m);
build(rc, m + 1, r);
pushup(p);
}
void update(int p, int x, int y, int c,LL k)
{
if (c == 1)
{
if (x <= tr[p].l && tr[p].r <= y)
{
LL a = tr[p].sum_1, b = tr[p].sum_2;
tr[p].sum_1 = (tr[p].sum_1 + (tr[p].r - tr[p].l + 1) * k) % M;
tr[p].sum_2 = (tr[p].sum_2 + ((tr[p].r - tr[p].l + 1) * k%M * k%M + 2 * k%M * a%M)) % M;
tr[p].sum_3 = (tr[p].sum_3 + 3 * k%M * b%M + 3 * k%M * k%M * a %M+ (tr[p].r - tr[p].l + 1) * k%M * k%M * k%M) % M;
tr[p].tag_1 = (tr[p].tag_1+k)%M;
return;
}
}
else if (c == 2)
{
if (x <= tr[p].l && tr[p].r <= y)
{
tr[p].sum_1 = (tr[p].sum_1 * k) % M;
tr[p].sum_2 = (tr[p].sum_2 * k % M * k % M);
tr[p].sum_3 = ((tr[p].sum_3 * k % M) * k % M * k % M);
tr[p].tag_2 = tr[p].tag_2 * k % M;
tr[p].tag_1 = tr[p].tag_1 * k % M;
return;
}
}
else
{
if (x <= tr[p].l && tr[p].r <= y)
{
tr[p].sum_1 = (tr[p].r - tr[p].l + 1) * k % M;
tr[p].sum_2 = ((tr[p].r - tr[p].l + 1) * k%M) % M * k % M;
tr[p].sum_3 = ((tr[p].r - tr[p].l + 1) * k%M) * ((k * k%M)) % M;
tr[p].tag_3 = k;
tr[p].tag_1 = 0;
tr[p].tag_2 = 1;
return;
}
}
pushdown(p);
int m = tr[p].l + tr[p].r >> 1;
if (x <= m)update(lc, x, y, c,k);
if (y > m)update(rc, x, y, c, k);
pushup(p);
}
LL query(int p, int x, int y, int c)
{
if (x <= tr[p].l && tr[p].r <= y)
{
if (c == 1)return tr[p].sum_1;
else if (c == 2)return tr[p].sum_2;
else return tr[p].sum_3;
}
pushdown(p);
LL sum = 0;
int m = tr[p].r + tr[p].l >> 1;
if (x <= m)sum += query(lc, x, y, c);
if (y > m)sum += query(rc, x, y, c);
return sum % M;
}
int main()
{
while (scanf("%d%d",&n,&m)&&n)
{
build(1, 1, n);
while (m--)
{
int op, x, y, c;
scanf("%d%d%d%d", &op, &x, &y, &c);
if (op == 1)
{
update(1, x, y, 1,c%M);
}
else if (op == 2)
{
update(1, x, y, 2,c%M);
}
else if (op == 3)
{
update(1, x, y, 3,c%M);
}
else
{
printf("%lld\n", query(1, x, y, c));
}
}
}
return 0;
}