hdu 4578 Transformation
1 a b c 为 [a,b]内的元素+c ;2 a b c 为 [a,b]内的元素乘以c
3 a b c 为 [a,b]内的元素置c ;4 a b c 为 求[a,b]内的元素的c次方之和。
可以参考这篇博客
维护3个lazy:st(set), add, mul。set操作的时候:把add和mul标记取消(相当于add=0,mul=1)
mul操作的时候,因为元素e+add标记, 要乘以c,所以可以把add标记乘以c,然后元素e乘个c即可。
那么在push_down的时候的操作顺序是:有set->mul->add。
对于更新操作:
首先有2个公式:(a+c)^2=a^2+c^2+2ac;① (a+c)^3=a^3+c^3+3c(a^2+ac)。②
sum1是∑(ai),sum2是∑(ai^2),sum3是∑(ai^3)。
那么我们容易有:
如果是set c操作:直接把标记清零,上面3个sum都是各自长度 len*c
如果是mul c操作:更新标记,然后3个sum[rt]都是 sum[rt]=sum[rt]*c
如果是add c操作:我们可以看上面的①和②两个公式了:
如果是sum1,加c,直接 sum1[rt] += c*len 即可。
如果是sum2,看①公式,原本的sum2[rt]相当于a^2,元素要加c,那么元素添加之后相当于(a+c)^2。 a^2+c^2+2ac,其中a是什么?a是原本的∑(ai),其实就是sum1[rt]。 令a^2=sum2[rt],a=sum1[rt],c=当前的add[rt],带入①,就是新的sum2[rt]。
同理sum3,a^3=sum3[rt],c=add[rt],a^2=sum2[rt],a=sum1[rt],带入②就是新的sum3[rt]了。
因为涉及多次求模,所以我写函数来实现,f1是加法,f2是乘法,f3是a^b,f4是更新sum3或sum2的时候应用公式①和②。
遇到的坑的地方: 由于实在繁琐,建议尽量让程序清晰,即使牺牲一点效率。
二..query的时候,if(L<=m) ans+=query(左孩子), if(R>m) ans+=query(右孩子),之后的ans也要取模别忘了。
#include <stdio.h> #define ll long long #define FOR(i,a,b) for(int i=(a);i<=(b);++i) const int maxN=1e5+5,MOD=1e4+7; int N, M, K, T; #define lson l,m,rt*2 #define rson m+1,r,rt*2+1 ll add[maxN<<2], st[maxN<<2], mul[maxN<<2]; ll sum1[maxN<<2], sum2[maxN<<2], sum3[maxN<<2]; ll f1(ll a, ll b) {a %= MOD, b %= MOD; return (a + b) % MOD;} ll f2(ll a, ll b) {a %= MOD, b %= MOD; return (a * b) % MOD;} ll f3(ll a, ll b) { a %= MOD; ll ans = 1; FOR(i, 1, b) ans = (a * ans) % MOD; return ans; } // return (a+c)^3 ll f4(ll rt, ll c, int len, int op) { ll ans; if (op == 3) { ll a = sum3[rt]; a %= MOD, c %= MOD; ans = f1(a, len * f3(c, 3)) + f2(3 * c, sum2[rt] + sum1[rt] * c); } else { ll a = sum2[rt]; a %= MOD, c %= MOD; ans = a + f2(len, f3(c, 2)) + f2(2 * c, sum1[rt]); } return ans % MOD; } void push_up(int rt) { sum1[rt] = f1(sum1[rt * 2], sum1[rt * 2 + 1]); sum2[rt] = f1(sum2[rt * 2], sum2[rt * 2 + 1]); sum3[rt] = f1(sum3[rt * 2], sum3[rt * 2 + 1]); } void build(int l, int r, int rt) { add[rt] = st[rt] = 0; mul[rt] = 1; if (l == r) { sum1[rt] = sum2[rt] = sum3[rt] = 0; return; } int m = (l + r) / 2; build(lson); build(rson); push_up(rt); } void push_down(int rt, int L) { int lch = rt * 2, rch = lch + 1; if (st[rt]) { st[lch] = st[rch] = st[rt]; add[lch] = add[rch] = 0; mul[lch] = mul[rch] = 1; sum1[lch] = f2((L - L / 2), st[rt]); sum1[rch] = f2(L / 2, st[rt]); sum2[lch] = f2(L - L / 2, f3(st[rt], 2)); sum2[rch] = f2(L / 2, f3(st[rt], 2)); sum3[lch] = f2(L - L / 2, f3(st[rt],3)); sum3[rch] = f2(L / 2, f3(st[rt],3)); st[rt] = 0; } if (mul[rt] != 1) { mul[lch] = f2(mul[lch], mul[rt]); mul[rch] = f2(mul[rch], mul[rt]); if (add[lch]) add[lch] = f2(add[lch], mul[rt]); if (add[rch]) add[rch] = f2(add[rch], mul[rt]); sum1[lch] = f2(sum1[lch], mul[rt]); sum1[rch] = f2(sum1[rch], mul[rt]); sum2[lch] = f2(sum2[lch], f3(mul[rt], 2)); sum2[rch] = f2(sum2[rch], f3(mul[rt], 2)); sum3[lch] = f2(sum3[lch], f3(mul[rt],3)); sum3[rch] = f2(sum3[rch], f3(mul[rt],3)); mul[rt] = 1; } if (add[rt]) { add[lch] += add[rt]; add[rch] += add[rt]; sum3[lch] = f4(lch, add[rt], L - L / 2, 3); sum3[rch] = f4(rch, add[rt], L / 2, 3); sum2[lch] = f4(lch, add[rt], L - L/ 2, 2); sum2[rch] = f4(rch, add[rt], L / 2, 2); sum1[lch] = f1(sum1[lch], f2(L - L / 2, add[rt])); sum1[rch] = f1(sum1[rch], f2(L / 2, add[rt])); add[rt] = 0; } } void update(int L, int R, int c, int ch, int l, int r, int rt) { if (L <= l && r <= R) { if (ch == 3) { // 区间置c st[rt] = c; add[rt] = 0; mul[rt] = 1; sum1[rt] = f2(r - l + 1, c); sum2[rt] = f2(r - l + 1, f3(c, 2)); sum3[rt] = f2(r - l + 1, f3(c, 3)); } else if (ch == 2) { // 区间乘c mul[rt] = f2(mul[rt], c); if (add[rt]) add[rt] = f2(add[rt], c); sum1[rt] = f2(sum1[rt], c); sum2[rt] = f2(sum2[rt], f3(c, 2)); sum3[rt] = f2(sum3[rt], f3(c, 3)); } else if (ch == 1) { // 区间加c add[rt] += c; sum3[rt] = f4(rt, c, r - l + 1, 3); sum2[rt] = f4(rt, c, r - l + 1, 2); sum1[rt] = f1(sum1[rt], (r - l + 1) * c); } return; } push_down(rt, r - l + 1); int m = (l + r) / 2; if (L <= m) update(L, R, c, ch, lson); if (R > m) update(L, R, c, ch, rson); push_up(rt); } ll query(int L, int R, int p, int l, int r, int rt) { if (L <= l && r <= R) { if (p == 1) return sum1[rt] % MOD; if (p == 2) return sum2[rt] % MOD; else return sum3[rt] % MOD; } push_down(rt, r - l + 1); int m = (l + r) / 2; ll ans = 0; if (L <= m) ans += query(L, R, p, lson); if (R > m) ans += query(L, R, p, rson); return ans % MOD; } int main () { #ifndef ONLINE_JUDGE freopen("data.in", "r", stdin); #endif while (~scanf("%d%d", &N, &M) && N + M) { build(1, N, 1); int ch, a, b, c; while (M--) { scanf("%d%d%d%d", &ch, &a, &b, &c); if (ch != 4) update(a, b, c, ch, 1, N, 1); else printf("%lld\n", query(a, b, c, 1, N, 1)); } } return 0; }