洛谷 P5608 [Ynoi2013] 文化课

Chimuzu 手上有一个数字序列 \(\{a_{1},a_{2},\ldots,a_{n}\}\) 和一个运算符序列 \(\{p_{1},p_{2},\ldots,p_{n-1}\}\)。其中 \(p_{i}\) 只能为 \(+\)\(\times\)

我们定义一个区间 \([l,r]\) 的权值 \(w(l,r)\) 为将字符串

\[a_{l}~p_{l}~a_{l+1}~p_{l+1} \cdots a_{r-1}~p_{r-1}~a_{r} \]

写下来之后,按照运算符的优先级计算出的结果。

下面给出一个运算的例子:

\(a=\{1,3,5,7,9,11\}\)\(p=\{+,\times,\times,+,\times\}\),那么有:

\[w(1,6)=1+3\times 5\times 7+9\times 11=205 \]

\[w(3,6)=5\times 7+9\times 11=134 \]

Chimuzu 需要你对这两个序列进行修改,同时查询某个给定区间的权值。

你需要维护这 \(3\) 个操作:

操作一:给定 \(l,r,x\),将\(a_{l},a_{l+1},\ldots,a_{r}\) 全部修改成 \(x\)

操作二:给定 \(l,r,y\):将 \(p_{l},p_{l+1},\ldots,p_{r}\) 全部修改成 \(y\)\(0\) 表示 \(+\)\(1\) 表示 \(\times\)

操作三:给定 \(l,r\):查询 \(w(l,r) \bmod 1000000007\) 的值。

\(n,m\leq 100000\)\(1\leq a_{i},x\lt 2^{32}\)\(a_{i},x\) 均为奇数,\(p_{i},y\in\{0,1\}\)\(1\leq l\leq r\leq n\),所有操作 \(2\) 还满足 \(r\lt n\)。时限为 3s


首先我们发现这个东西非常可以线段树维护,考虑合并两个区间 \([l,mid],[mid+1,r]\) ,如果是加法就直接加起来;如果是乘法,收到乘法影响的只有左区间的一段后缀和右区间的一段前缀乘积,把这两段乘起来再加上剩下的贡献就是这个区间的答案。

然后考虑区间修改符号,发现这个区间变成了区间和或者区间乘积,那么多维护一下区间和、乘积就可以了。

最后是区间修改数,改完数之后这段区间的数都一样了,所以新的区间变成了若干个连乘段的和。而我们注意到 \(1+2+3+\dots\sqrt{n}=n\) ,所以本质不同的连乘段只有 \(\sqrt{len}\) 个,于是我们记录每种长度为 \(d_i\) 段的个数 \(ds_i\) ,就可以直接计算 \(\sum_{i=1}^{cnt}ds_ix^{d_i}\) 来更新答案,这样子复杂度是 \(O(n\sqrt{n})\) 的。

然后你会发现,要计算这个式子需要快速幂,而如果直接快速幂的话一定会使复杂度多一个 \(\log\) ,而这题光速幂的空间开不下,所以我们考虑对 \(d\) 排好序之后 \(x^i\)\(i-1\) 快速幂过来,这样子复杂度会去掉这个 \(\log\) ,我们来证明一下(这里分析上界)。

现在我们有这个方程组:

\[\begin{cases}\sum_{i=1}^{cnt}x_i=n\\\sum_{i=1}^{cnt}\log(x_i-x_{i-1})\le O(\sqrt{n})\end{cases} \]

\(C_i=x_i-x_{i-1}\) ,那么变成:

\[\begin{cases}\sum_{i=1}^{cnt}C_i\times i=n\\\sum_{i=1}^{cnt}\log(C_i)\le O(\sqrt{n})\end{cases} \]

然后有:

\[\begin{cases}\sum_{i=1}^{cnt}C_i\times i=n\\log(\prod_{i=1}^{cnt}C_i)\le O(\sqrt{n})\end{cases} \]

\(\sum_{i=1}^{cnt}C_i\times i\)取均值不等式有:

\[cnt!\prod_{i=1}^{cnt}C_i\le (\frac{\sum_{i=1}^{cnt}C_i\times i}{cnt})^{cnt}=(\frac{n}{cnt})^{cnt} \]

两边取 \(\log\) 有:

\[\log(\prod_{i=1}^{cnt}C_i)\le cnt\log n -cnt\log{cnt}-\log(cnt!) \]

因为斯特林近似有 \(n!\approx\sqrt{2\pi n}(\frac{n}{e})^n\) ,于是 \(\log(cnt!)=\frac{1}{2}\log {2\pi}+\frac{1}{2}\log{cnt}+cnt\log{cnt}-cnt\) ,带回原式得:

\[\log(\prod_{i=1}^{cnt}C_i)\le cnt\log n-\frac{1}{2}\log {2\pi}-\frac{1}{2}\log{cnt}-2cnt\log{cnt}+cnt \]

最大值也就是右式的最大值,舍掉部分小量和常数,那么就有一个 \(y\) 关于 \(cnt\) 的方程:

\[y=cnt\log n-2cnt\log{cnt}+cnt \]

求导之后可以得到最大值是 \(\le \sqrt{n}\) 的,于是上面的结论得证。

而至于有序,合并两个区间的时候归并就可以了。

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
const int N = 1e5;
const int M = 317;
const int p = 1e9 + 7;
using namespace std;
int a[N + 5],opt[N + 5],n,m,t[N + 5];
vector <int>::iterator it;
struct node
{
    int sm,ans,mul,ls,rs,ll,rr,opt,lx,rx;
    vector <int> d,ds;
    void init(int x,int op)
    {
        sm = ans = mul = ls = rs = lx = rx = x;
        opt = op;
        ll = rr = 1;
    }
};
int mypow(int a,int x){int s = 1;for (;x;x & 1 ? s = 1ll * s * a % p : 0,a = 1ll * a * a % p,x >>= 1);return s;}
struct Seg
{
    node s[N * 4 + 5];
    int tag[N * 4 + 5],st[N * 4 + 5];
    #define zrt k << 1
    #define yrt k << 1 | 1
    void upd(node &s,node x,node y,int l,int mid,int r)
    {
        s.ll = x.ll;
        s.rr = y.rr;
        s.ls = x.ls;
        s.rs = y.rs;
        s.rx = y.rx;
        s.lx = x.lx;
        s.opt = y.opt;
        s.d.clear();
        s.ds.clear();
        if (x.opt)
        {
            s.ans = ((1ll * x.rs * y.ls % p + (x.ans - x.rs) % p) % p + (y.ans - y.ls) % p) % p;
            if (x.ll == mid - l + 1)
                s.ls = 1ll * x.ls * y.ls % p,s.ll += y.ll;
            if (y.rr == r - mid)
                s.rs = 1ll * y.rs * x.rs % p,s.rr += x.rr;
            s.sm = (x.sm + y.sm) % p;
            s.mul = 1ll * x.mul * y.mul % p;
        }
        else
        {
            s.ans = (x.ans + y.ans) % p;
            s.sm = (x.sm + y.sm) % p;
            s.mul = 1ll * x.mul * y.mul % p;
        }
        int j = 0,vis = (!x.opt),now = x.rr + y.ll;
        for (int i = 0;i < x.d.size();i++)
        {
            if (x.rr == x.d[i] && x.opt)
            {
                x.ds[i]--;
                if (!x.ds[i])
                    continue;
            }
            while (j < y.d.size())
            {
                if (y.d[j] > x.d[i])
                    break;
                if (y.ll == y.d[j] && x.opt)
                {
                    y.ds[j]--;
                    if (!y.ds[j])
                    {
                        j++;
                        continue;
                    }
                }
                if (!vis && now <= y.d[j])
                {
                    if (!s.d.size())
                        s.d.push_back(now),s.ds.push_back(1);
                    else
                        if (s.d[s.d.size() - 1] == now)
                            s.ds[s.d.size() - 1]++;
                        else
                            s.d.push_back(now),s.ds.push_back(1);
                    vis = 1;
                }
                if (!s.d.size())
                    s.d.push_back(y.d[j]),s.ds.push_back(y.ds[j]);
                else
                    if (s.d[s.d.size() - 1] == y.d[j])
                        s.ds[s.d.size() - 1] += y.ds[j];
                    else
                        s.d.push_back(y.d[j]),s.ds.push_back(y.ds[j]);
                j++;
            }
            if (!vis && now <= x.d[i])
            {
                if (!s.d.size())
                    s.d.push_back(now),s.ds.push_back(1);
                else
                    if (s.d[s.d.size() - 1] == now)
                        s.ds[s.d.size() - 1]++;
                    else
                        s.d.push_back(now),s.ds.push_back(1);
                vis = 1;
            }
            if (!s.d.size())
            {
                s.d.push_back(x.d[i]);
                s.ds.push_back(x.ds[i]);
            }
            else
                if (s.d[s.d.size() - 1] == x.d[i])
                    s.ds[s.d.size() - 1] += x.ds[i];
                else
                    s.d.push_back(x.d[i]),s.ds.push_back(x.ds[i]);
        }
        while (j < y.d.size())
        {
            if (y.ll == y.d[j] && x.opt)
            {
                y.ds[j]--;
                if (!y.ds[j])
                {
                    j++;
                    continue;
                }
            }
            if (!vis && now <= y.d[j])
            {
                if (!s.d.size())
                    s.d.push_back(now),s.ds.push_back(1);
                else
                    if (s.d[s.d.size() - 1] == now)
                        s.ds[s.d.size() - 1]++;
                    else
                        s.d.push_back(now),s.ds.push_back(1);
                vis = 1;
            }
            if (!s.d.size())
                s.d.push_back(y.d[j]),s.ds.push_back(y.ds[j]);
            else
                if (s.d[s.d.size() - 1] == y.d[j])
                    s.ds[s.d.size() - 1] += y.ds[j];
                else
                    s.d.push_back(y.d[j]),s.ds.push_back(y.ds[j]);
            j++;
        }
        if (!vis)
            s.d.push_back(now),s.ds.push_back(1);
    }
    void upd2(node &s,node x,node y,int l,int mid,int r)
    {
        s.ll = x.ll;
        s.rr = y.rr;
        s.ls = x.ls;
        s.rs = y.rs;
        s.rx = y.rx;
        s.lx = x.lx;
        s.opt = y.opt;
        if (x.opt)
        {
            s.ans = ((1ll * x.rs * y.ls % p + (x.ans - x.rs) % p) % p + (y.ans - y.ls) % p) % p;
            if (x.ll == mid - l + 1)
                s.ls = 1ll * x.ls * y.ls % p,s.ll += y.ll;
            if (y.rr == r - mid)
                s.rs = 1ll * y.rs * x.rs % p,s.rr += x.rr;
            s.sm = (x.sm + y.sm) % p;
            s.mul = 1ll * x.mul * y.mul % p;
        }
        else
        {
            s.ans = (x.ans + y.ans) % p;
            s.sm = (x.sm + y.sm) % p;
            s.mul = 1ll * x.mul * y.mul % p;
        }
    }
    node upd1(node x,node y,int l,int mid,int r)
    {
        if (x.opt)
        {
            x.ans = ((1ll * x.rs * y.ls % p + (x.ans - x.rs) % p) % p + (y.ans - y.ls) % p) % p;
            if (x.ll == mid - l + 1)
                x.ls = 1ll * x.ls * y.ls % p,x.ll += y.ll;
            if (y.rr == r - mid)
                x.rs = 1ll * y.rs * x.rs % p,x.rr += y.rr;
            else
                x.rs = y.rs,x.rr = y.rr;
            x.sm = (x.sm + y.sm) % p;
            x.mul = 1ll * x.mul * y.mul % p;
        }
        else
        {
            x.ans = (x.ans + y.ans) % p;
            x.sm = (x.sm + y.sm) % p;
            x.mul = 1ll * x.mul * y.mul % p;
            x.rs = y.rs;
            x.rr = y.rr;
        }
        x.rx = y.rx;
        x.opt = y.opt;
        return x;
    }
    void build(int k,int l,int r)
    {
        tag[k] = st[k] = -1;
        if (l == r)
        {
            s[k].init(a[l],opt[l]);
            s[k].d.push_back(1);
            s[k].ds.push_back(1);
            return;
        }
        int mid = l + r >> 1;
        build(zrt,l,mid);
        build(yrt,mid + 1,r);
        upd(s[k],s[zrt],s[yrt],l,mid,r);
    }
    void cha(int k,int l,int r,int z)
    {
        if (z == 0)
        {
            s[k].ans = s[k].sm;
            s[k].ls = s[k].lx;
            s[k].rs = s[k].rx;
            s[k].ll = s[k].rr = 1;
            s[k].d.clear();
            s[k].ds.clear();
            s[k].d.push_back(1);
            s[k].ds.push_back(r - l + 1);
        }
        else
        {
            s[k].ans = s[k].mul;
            s[k].ls = s[k].rs = s[k].mul;
            s[k].ll = s[k].rr = r - l + 1;
            s[k].d.clear();
            s[k].ds.clear();
            s[k].d.push_back(r - l + 1);
            s[k].ds.push_back(1);
        }
        tag[k] = z;
        s[k].opt = z;
    }
    void chan(int k,int l,int r,int z)
    {
        s[k].lx = s[k].rx = z;
        st[k] = z;
        int res = mypow(z,s[k].d[0]);
        s[k].ans = 1ll * s[k].ds[0] * res % p;
        for (int i = 1;i < s[k].d.size();i++)
        {
            res = 1ll * res * mypow(z,s[k].d[i] - s[k].d[i - 1]) % p;
            s[k].ans = (s[k].ans + 1ll * s[k].ds[i] * res % p) % p;
        }
        s[k].ls = mypow(z,s[k].ll);
        s[k].rs = mypow(z,s[k].rr);
        s[k].mul = mypow(z,r - l + 1);
        s[k].sm = 1ll * z * (r - l + 1) % p;
    }
    void pushdown(int k,int l,int r,int mid)
    {
        if (st[k] != -1)
        {
            chan(zrt,l,mid,st[k]);
            chan(yrt,mid + 1,r,st[k]);
            st[k] = -1;
        }
        if (tag[k] != -1)
        {
            cha(zrt,l,mid,tag[k]);
            cha(yrt,mid + 1,r,tag[k]);
            tag[k] = -1;
        }
    }    
    node query(int k,int l,int r,int x,int y)
    {
        if (l >= x && r <= y)
            return s[k];
        int mid = l + r >> 1;
        pushdown(k,l,r,mid);
        if (x > mid)
            return query(yrt,mid + 1,r,x,y);
        if (y <= mid)
            return query(zrt,l,mid,x,y);
        return upd1(query(zrt,l,mid,x,y),query(yrt,mid + 1,r,x,y),l,mid,r);
    }
    void modify1(int k,int l,int r,int x,int y,int z)
    {
        if (l >= x && r <= y)
        {
            cha(k,l,r,z);
            return;
        }
        int mid = l + r >> 1;
        pushdown(k,l,r,mid);
        if (x <= mid)
            modify1(zrt,l,mid,x,y,z);
        if (y > mid)
            modify1(yrt,mid + 1,r,x,y,z);
        upd(s[k],s[zrt],s[yrt],l,mid,r);
    }
    void modify2(int k,int l,int r,int x,int y,int z)
    {
        if (l >= x && r <= y)
        {
            chan(k,l,r,z);
            return;
        }
        int mid = l + r >> 1;
        pushdown(k,l,r,mid);
        if (x <= mid)
            modify2(zrt,l,mid,x,y,z);
        if (y > mid)
            modify2(yrt,mid + 1,r,x,y,z);
        upd2(s[k],s[zrt],s[yrt],l,mid,r);
    }
}tree;
inline long long read()
{
	long long X(0);int w(0);char ch(0);
	while (!isdigit(ch)) w |= ch == '-',ch = getchar();
	while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
	return w ? -X : X;
}
int main()
{
    n = read();m = read();
    long long x;
    for (int i = 1;i <= n;i++)
    {
        x = read();
        a[i] = x % p;
    }
    for (int i = 1;i < n;i++)
        opt[i] = read();
    tree.build(1,1,n);
    int op,l,r;
    node ans;
    while (m--)
    {
        op = read();l = read();r = read();
        if (op == 3)
        {
            ans = tree.query(1,1,n,l,r);
            printf("%d\n",(ans.ans + p) % p);
        }
        else
            if (op == 2)
            {
                x = read();
                tree.modify1(1,1,n,l,r,x);
            }
            else
            {
                x = read();
                x %= p;
                tree.modify2(1,1,n,l,r,x);
            }
    }
    return 0;
}
posted @ 2021-01-13 18:57  eee_hoho  阅读(102)  评论(0编辑  收藏  举报