P4927 [1007]梦美与线段树

毒瘤题!!!

这道题刷新了我对线段树懒标记维护的新认识。

显然可以得到我们要的\(p\)是线段树中所有点权值的平方和,\(q\)是原数字的和。

现在我们引入一些要用到的值:

  • val数组。表示一个点的权值。

  • ssum数组。表示该子树的所有节点权值的平方和。显然\(p=ssum[1]\)

  • len数组。表示一个节点的区间长度。

  • len_2数组。表示该子树的所有节点的区间长度平方和。

  • llen数组。表示该子树的所有节点的区间长度和。

  • lx数组。表示该子树的所有节点的区间长度 乘以 权值的和。

建树的时候这些标记都能比较简单地被维护,所以我们考虑如何维护区间加。

然后考虑如何维护我们想要的子树区间平方和。

我们考虑一个特殊情况:线段树只有三个节点。

设更新值为\(x\)\(val\)\(len\)分别为一个节点的权值和区间长度。

那么对于这么一个小树,他的子树权值平方和为\((val + len \times x)^2 + (val_l + len_l \times x)^2 + (val_r + len_r \times x)^2\)

化简为\(val^2+val_l^2+val_r^2+2x(len \times val + len_l \times val_l + len_r \times val_r) + (len^2 + len_l^2 + len_r^2)x^2\)

其实可以类比到一般情况。我们可以维护\(len \times val + len_l \times val_l + len_r \times val_r\),即子树区间长度乘以权值之和,而后面那一项就是子树区间长度的平方和。都可以建树的时候弄出来。

如何维护子树区间长度乘以权值之和?

对于单个节点,设增加\(x\),区间长度为\(len\)。那么新增的答案是\(len(val + len \times x) - len \times val = len^2x\)

对于子树也是差不多的道理,将里面单节点的len换成维护子树的即可。

然后注意:维护一个点的权值跟维护一个子树的权值完全不同!!!

而对于维护ssum数组,上面的那个公式中的所有东西都应该是子树角度的。

时间不够了。明天再填坑。

代码:(代码中的sum没有用到)

#include<cstdio>

#define ll long long
#define int128 __int128
const int maxn = 100005;
const int mod = 998244353;

#define lson (root << 1)
#define rson (root << 1 | 1)
int128 val[maxn << 2], sum[maxn << 2], ssum[maxn << 2], lazy[maxn << 2], len[maxn << 2], len_2[maxn << 2], ls[maxn << 2];
int128 llen[maxn << 2];
ll a[maxn], n, m;
ll read()
{
    ll ans = 0, s = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0'){ if(ch == '-') s = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9') ans = (ans << 3) + (ans << 1) + ch - '0', ch = getchar();
    return s * ans;
}
void pushup(int root)
{
    val[root] = val[lson] + val[rson];// 节点权值 
    sum[root] = 2 * (sum[lson] + sum[rson]);// 子树sum之和 
    ssum[root] = val[root] * val[root] + ssum[lson] + ssum[rson];// 子树sum平方和 
    ls[root] = len[root] * val[root] + ls[lson] + ls[rson];// 子树len * sum之和 
    llen[root] = len[root] + llen[lson] + llen[rson];// 子树size和(不变) 
    len_2[root] = len[root] * len[root] + len_2[lson] + len_2[rson];// 子树size平方和(不变) 
}
void pushdown(int root, int l, int r)
{
    if(lazy[root] != 0)
    {
        int128 &x = lazy[root];
        ssum[lson] += 2 * ls[lson] * x + len_2[lson] * x * x;
        ls[lson] += len_2[lson] * x;
        sum[lson] += llen[lson] * x;
        val[lson] += len[lson] * x;
        lazy[lson] += x;
        
        ssum[rson] += 2 * ls[rson] * x + len_2[rson] * x * x;
        ls[rson] += len_2[rson] * x;
        sum[rson] += llen[rson] * x;
        val[rson] += len[rson] * x;
        lazy[rson] += x;
        x = 0;
    }
}
void build(int root, int l, int r)
{
    if(l == r)
    {
        val[root] = sum[root] = a[l];
        ssum[root] = a[l] * a[l];
        len[root] = len_2[root] = 1;
        ls[root] = len[root] * sum[root];
        llen[root] = 1;
    }
    else
    {
        int mid = (l + r) >> 1;
        len[root] = r - l + 1;
        build(lson, l, mid);
        build(rson, mid + 1, r);
        pushup(root);
    }
}
void update(int root, int l, int r, int x, int y, ll k)
{
    if(r < x || y < l) return;
    if(x <= l && r <= y)
    {
        ssum[root] += 2 * ls[root] * k + len_2[root] * k * k;
        ls[root] += len_2[root] * k;
        sum[root] += llen[root] * k;
        val[root] += len[root] * k;
        lazy[root] += k;
        return;
    }
    pushdown(root, l, r);
    int mid = (l + r) >> 1;
    update(lson, l, mid, x, y, k);
    update(rson, mid + 1, r, x, y, k);
    pushup(root);
}
int128 gcd(int128 x, int128 y)
{
    return y == 0 ? x : gcd(y, x % y);
}
int128 pow_mod(int128 x, int128 y, int128 z)
{
    int128 ans = 1, base = x;
    while(y)
    {
        if(y & 1) ans = ans * base % z;
        base = base * base % z;
        y >>= 1;
    }
    return ans % z;
}
void write(int128 x)
{
    if(x >= 10) write(x / 10);
    putchar(x % 10 + '0');
}
int main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    build(1, 1, n);
    while(m--)
    {
        int opt = read();
        if(opt == 1)
        {
            ll l = read(), r = read(), v = read();
            update(1, 1, n, l, r, v);
        }
        else if(opt == 2)
        {
            int128 p = ssum[1];
            int128 q = val[1];
            //write(p); putchar('\n');
            //write(q); putchar('\n');
            while(q % mod == 0) q /= mod, p /= mod;
            p %= mod, q %= mod;
            //write(p); putchar('\n');
            //write(q); putchar('\n');
            int128 ans = p * pow_mod(q, mod - 2, mod) % mod;
            printf("%lld\n", (ll)ans);
        }
    }
    return 0;
}
posted @ 2018-10-11 22:03  Garen-Wang  阅读(167)  评论(0编辑  收藏  举报