HDU 4578 线段树各种区间操作
HDU 4578
【题意】:初始一个长度为n的数组全为0,有m个操作,输入op, l, r, x。
- op = 1时,把 [l, r] 中的所有数加上x
- op = 2时, 把 [l, r] 中的所有树乘上x
- op = 3时, 把[l, r]中的所有数全置为x
- op = 4时, 输出 [l, r] 中所有数的 x 方的和
【思路】:令\(x = a * x + b\),即对一个x,令他乘以a, 再加上b,观察各个和的情况:
\[sum1=∑x⇒∑(a×x+b)=a*sum1+b*length
\]
\[sum2=∑x^2⇒∑(a×x+b)^2=a^2 *sum2+2*a*b*sum1+b^2*ength
\]
\[sum3=∑x^3⇒∑(a×x+b)^3=a^3*sum3+3*a^2*b*sum2+3*a*b^2*sum1+b^3×length
\]
只执行加法时:a = 1, b = x; 只执行乘法时, a = x, b = 0; 一起执行时,\(a = t[cur].mul, b = t[cur].add\).
在置数时直接将add和mul懒标记恢复到初始值即可。
#include <bits/stdc++.h>
#define debug(x) cout << #x << " = " << x << endl;
#define ls cur<<1
#define rs cur<<1|1
using namespace std;
typedef long long LL;
const int maxn = 2e5 + 10;
const int inf = 0x3f3f3f3f;
const int mod = 10007;
const double eps = 1e-6, pi = acos(-1.0);
struct Tree{
int l, r, len;
LL sum1, sum2, sum3;
LL add, mul, st;
}t[maxn<<1];
int n, m;
inline void pushup(int cur){
t[cur].sum1 = (t[ls].sum1 + t[rs].sum1) % mod;
t[cur].sum2 = (t[ls].sum2 + t[rs].sum2) % mod;
t[cur].sum3 = (t[ls].sum3 + t[rs].sum3) % mod;
}
inline void pushdown(int cur){
if(t[cur].st){ //只有置数
t[ls].mul = t[rs].mul = 1;
t[ls].add = t[rs].add = 0;
t[ls].st = t[rs].st = t[cur].st;
t[ls].sum3 = t[cur].st * t[cur].st % mod * t[cur].st % mod * t[ls].len % mod;
t[rs].sum3 = t[cur].st * t[cur].st % mod * t[cur].st % mod * t[rs].len % mod;
t[ls].sum2 = t[cur].st * t[cur].st % mod * t[ls].len % mod;
t[rs].sum2 = t[cur].st * t[cur].st % mod * t[rs].len % mod;
t[ls].sum1 = t[cur].st * t[ls].len % mod;
t[rs].sum1 = t[cur].st * t[rs].len % mod;
t[cur].st = 0;
}
if(t[cur].add != 0 || t[cur].mul != 1){ //加和乘都有
LL a = t[cur].mul, b = t[cur].add;
t[ls].mul = t[cur].mul * t[ls].mul % mod;
t[rs].mul = t[cur].mul * t[rs].mul % mod;
t[ls].add = (t[ls].add * t[cur].mul % mod + t[cur].add) % mod;
t[rs].add = (t[rs].add * t[cur].mul % mod + t[cur].add) % mod;
t[ls].sum3 = (a * a % mod * a % mod * t[ls].sum3 % mod + 3 * a * a % mod * b % mod * t[ls].sum2 % mod + 3 * a * b % mod * b % mod * t[ls].sum1 % mod + b * b % mod * b % mod * t[ls].len % mod) % mod;
t[rs].sum3 = (a * a % mod * a % mod * t[rs].sum3 % mod + 3 * a * a % mod * b % mod * t[rs].sum2 % mod + 3 * a * b % mod * b % mod * t[rs].sum1 % mod + b * b % mod * b % mod * t[rs].len % mod) % mod;
t[ls].sum2 = (a * a % mod * t[ls].sum2 % mod + 2 * a * b % mod * t[ls].sum1 + b * b % mod * t[ls].len % mod) % mod;
t[rs].sum2 = (a * a % mod * t[rs].sum2 % mod + 2 * a * b % mod * t[rs].sum1 + b * b % mod * t[rs].len % mod) % mod;
t[ls].sum1 = (a * t[ls].sum1 % mod + b * t[ls].len % mod) % mod;
t[rs].sum1 = (a * t[rs].sum1 % mod + b * t[rs].len % mod) % mod;
t[cur].add = 0; t[cur].mul = 1;
}
}
void build(int l, int r, int cur){
t[cur].l = l; t[cur].r = r; t[cur].len = r - l + 1;
t[cur].sum1 = t[cur].sum2 = t[cur].sum3 = 0;
t[cur].add = t[cur].st = 0; t[cur].mul = 1;
if(l == r) return;
int mid = l + r >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(cur);
}
void change(int l, int r, int op, int x, int cur){
if(l <= t[cur].l && t[cur].r <= r){
LL tmp = x * x % mod * x % mod;
if(op == 1){ //只加
t[cur].add = (t[cur].add + x) % mod;
t[cur].sum3 = (t[cur].sum3 + 3 * x * t[cur].sum2 % mod + 3 * x * x * t[cur].sum1 + tmp * t[cur].len % mod) % mod;
t[cur].sum2 = (t[cur].sum2 + 2 * x * t[cur].sum1 % mod + x * x % mod * t[cur].len);
t[cur].sum1 = (t[cur].sum1 + x * t[cur].len) % mod;
}
else if(op == 2){ //只乘
t[cur].add = (t[cur].add * x) % mod;
t[cur].mul = (t[cur].mul * x) % mod;
t[cur].sum3 = tmp * t[cur].sum3 % mod;
t[cur].sum2 = x * x % mod * t[cur].sum2 % mod;
t[cur].sum1 = x * t[cur].sum1 % mod;
}
else if(op == 3){ //只置数
t[cur].mul = 1;
t[cur].add = 0;
t[cur].st = x;
t[cur].sum3 = tmp * t[cur].len % mod;
t[cur].sum2 = x * x % mod * t[cur].len % mod;
t[cur].sum1 = x * t[cur].len % mod;
}
return;
}
int mid = t[cur].l + t[cur].r >> 1;
pushdown(cur);
if(l <= mid) change(l, r, op, x, ls);
if(mid < r) change(l, r, op, x, rs);
pushup(cur);
}
LL query(int l, int r, int x, int cur){
if(l <= t[cur].l && t[cur].r <= r){
if(x == 1) return t[cur].sum1;
if(x == 2) return t[cur].sum2;
if(x == 3) return t[cur].sum3;
}
int mid = t[cur].l + t[cur].r >> 1;
pushdown(cur);
LL ans = 0;
if(l <= mid) ans = (ans + query(l, r, x, ls)) % mod;
if(mid < r) ans = (ans + query(l, r, x, rs)) % mod;
return ans;
}
int main()
{
while(~scanf("%d %d", &n, &m), n + m){
build(1, n, 1);
while(m--){
int l, r, op, x;
scanf("%d %d %d %d", &op, &l, &r, &x);
if(op != 4) change(l, r, op, x % mod, 1);
else printf("%lld\n", query(l, r, x, 1));
}
}
getchar(); getchar();
}