二维树状数组导论

关于这篇文章

是一篇关于 树状数组导论 提到的二维树状数组的文章

如果不知道树状数组,出门左拐,树状数组导论

关于树状数组,你需要了解的

二维树状数组的定义

首先了解一下二维树状数组解决的是什么问题

二维树状数组和树状数组一样,都是用来维护一个数组内的前缀和

只不过树状数组维护的是一维数组,二维树状数组维护的是二维数组

也就是说二维树状数组支持单点修改和二位前缀和

二维前缀和

什么是二维前缀和

出门左转,OI-Wiki

简单的说,就是从 \(a_{1,1}\)\(a_{x,y}\) 的矩阵所包含的所有元素的和

在这篇文章中,\(pre(x,y)\) 就是二维数组 \(a\) 中到 \(a_{x,y}\) 为止的二位前缀和,

形式化地说,\(pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^ya_{i,j}\)

二维树状数组的思想

这里就要用到 《树状数组导论》里的内容了

树状数组中我们用 c[i] 表示以 \(a_i\) 长度为 \(lowbit(a_i)\) 的区间和

二维树状数组中我们可以类似的用 c[i][j] 表示以 \(a_{i,j}\) 结尾的高为 \(lowbit(x)\) ,长为 \(lowbit(y)\) 的矩阵和

二维树状数组的实现

矩阵求和

十分简单

维护 ret 变量,代表当前统计了的部分的和

\(x\)\(y\) 都不等于 \(0\) 时,

定义 i=x,j=y

然后进行循环操作

ret+=c[i][j] ,统计当前矩阵的和

j-=lowbit(x) ,减去当前的矩阵的高以及长,去查找下一个矩阵

如此循环,将所有 \(c[i][?]\) 的矩阵统计完,也就是 j\(0\)

然后 x-=lowbit(x)

代码如下:

int get(int x,int y)
{
    int ret=0,ry=y;
    while(x)
    {
        y=ry;
        while(y)
        {
            ret+=c[x][y];
            y-=lowbit(y);
        }
        x-=lowbit(x);
    }
    return ret;
}

单点修改

与上面的矩阵查询同样的想法

只不过将 ret+=c[x][y] 替换为 c[x][y]+=k

再将 -=lowbit 替换为 +=lowbit

代码如下:

void add(int x,int y,int k)
{
    int ry=y;
    while(x<=n)
    {
        y=ry;
        while(y<=m)
        {
            c[x][y]+=k;
            y+=lowbit(y);
        }
        x+=lowbit(x);
    }
}

二维树状数组的扩展

差分(矩阵修改+单点查询)

类似于树状数组差分的想法,

由于二维前缀和 \(pre(i,j)=pre(i-1,j)+pre(i,j-1)-pre(i-1,j-1)+a_{i,j}\)

由于差分是前缀和的逆运算,定义二维差分 \(dif(i,j)=a_{i,j}-a_{i-1,j}-a_{i,j-1}+a_{i-1,j-1}\)

然后用二维树状数组维护数组 \(a\) 的二维差分就好了

要查询单点,就直接将 \(get(x,y)\) 就好了

但如果要矩阵修改,就需要拆分成 \(4\) 个操作:

void add(int x1,int y1,int x2,int y2,int k)
{
    stdadd(x1,y1,k);
    stdadd(x1,y2+1,-k);
    stdadd(x2+1,y1,-k);
    stdadd(x2+1,y2+1,k);
}

完整代码如下:

int c[MAXN][MAXN];
inline int lowbit(int x){return x&-x;}

int get(int x,int y)
{
    int ret=0,ry=y;
    while(x)
    {
        y=ry;
        while(y)
        {
            ret+=c[x][y];
            y-=lowbit(y);
        }
        x-=lowbit(x);
    }
    return ret;
}

void stdadd(int x,int y,int k)
{
    int ry=y;
    while(x<=n)
    {
        y=ry;
        while(y<=m)
        {
            c[x][y]+=k;
            y+=lowbit(y);
        }
        x+=lowbit(x);
    }
}

void add(int x1,int y1,int x2,int y2,int k)
{
    stdadd(x1,y1,k);
    stdadd(x1,y2+1,-k);
    stdadd(x2+1,y1,-k);
    stdadd(x2+1,y2+1,k);
}

矩阵修改+矩阵查询

接下来,请欣赏美妙的公式推导:

首先根据前缀和以及差分的定义得:

\[\because pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^ya_{i,j}\\ \because a_{i,j}=\sum\limits_{i=1}^x\sum\limits_{j=1}^ydif(i,j)\\ \therefore pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^y\sum\limits_{k=1}^i\sum\limits_{r=1}^jdif(k,r) \]

然后,仿照一维树状数组的推导:

\[pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^y\sum\limits_{k=1}^i\sum\limits_{r=1}^jdif(k,r)\\ =dif(1,1) \times (x \times y) + dif(1,2) \times (x \times (y-1)) + \dots + dif(x,y) \times ((x-x+1) \times (y-y+1))\\ =\sum\limits_{i=1}^x\sum\limits_{j=1}^ydif(i,j) \times (x-i+1) \times (y-j+1) \]

将公式展开得:

\[pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^ydif(i,j) \times (x-i+1) \times (y-j+1)\\ =\sum\limits_{i=1}^x\sum\limits_{j=1}^ydif(i,j) \times (xy+x+y+1-xj-yi-i-j+ij)\\ =\sum\limits_{i=1}^x\sum\limits_{j=1}^ydif(i,j) \times (xy+x+y+1)- dif(i,j) \times j \times (x+1) - dif(i,j) \times i \times (y+1) +dif(i,j) \times ij \]

于是需要开四个树状数组,分别维护 \(dif(i,j),dij(i,j) \times j,dif(i,j) \times i,dif(i,j) \times ij\)

int c[4][MAXN][MAXN];
inline int lowbit(int x){return x&-x;}

void stdadd(int x,int y,int k)
{
    for(int i=x; i<=n; i+=lowbit(i))
        for(int j=y; j<=m; j+=lowbit(j))
        {
            c[0][i][j]+=k;
            c[1][i][j]+=k*x;
            c[2][i][j]+=k*y;
            c[3][i][j]+=k*x*y;
        }
}

void add(int x1,int y1,int x2,int y2,int k)
{
    stdadd(x1,y1,k);
    stdadd(x1,y2+1,-k);
    stdadd(x2+1,y1,-k);
    stdadd(x2+1,y2+1,k);
}

int stdget(int x,int y)
{
    int ret=0;
    for(int i=x; i; i-=lowbit(i))
        for(int j=y; j; j-=lowbit(j))
            ret+=(x+1)*(y+1)*c[0][i][j]-(y+1)*c[1][i][j]-(x+1)*c[2][i][j]+c[3][i][j];
    return ret;
}

int get(int x1,int y1,int x2,int y2)
{
    return stdget(x2, y2)-stdget(x2,y1-1)-stdget(x1-1,y2)+stdget(x1-1,y1-1);
}
posted @ 2022-01-15 12:28  yhang323  阅读(17)  评论(0编辑  收藏  举报