二维树状数组导论
关于这篇文章
是一篇关于 树状数组导论 提到的二维树状数组的文章
如果不知道树状数组,出门左拐,树状数组导论
关于树状数组,你需要了解的
二维树状数组的定义
首先了解一下二维树状数组解决的是什么问题
二维树状数组和树状数组一样,都是用来维护一个数组内的前缀和
只不过树状数组维护的是一维数组,二维树状数组维护的是二维数组
也就是说二维树状数组支持单点修改和二位前缀和
二维前缀和
什么是二维前缀和
出门左转,OI-Wiki
简单的说,就是从 \(a_{1,1}\) 到 \(a_{x,y}\) 的矩阵所包含的所有元素的和
在这篇文章中,\(pre(x,y)\) 就是二维数组 \(a\) 中到 \(a_{x,y}\) 为止的二位前缀和,
形式化地说,\(pre(x,y)=\sum\limits_{i=1}^x\sum\limits_{j=1}^ya_{i,j}\)
二维树状数组的思想
这里就要用到 《树状数组导论》里的内容了
树状数组中我们用 c[i]
表示以 \(a_i\) 长度为 \(lowbit(a_i)\) 的区间和
二维树状数组中我们可以类似的用 c[i][j]
表示以 \(a_{i,j}\) 结尾的高为 \(lowbit(x)\) ,长为 \(lowbit(y)\) 的矩阵和
二维树状数组的实现
矩阵求和
十分简单
维护 ret
变量,代表当前统计了的部分的和
当 \(x\) 和 \(y\) 都不等于 \(0\) 时,
定义 i=x,j=y
然后进行循环操作
ret+=c[i][j]
,统计当前矩阵的和
j-=lowbit(x)
,减去当前的矩阵的高以及长,去查找下一个矩阵
如此循环,将所有 \(c[i][?]\) 的矩阵统计完,也就是 j
为 \(0\)
然后 x-=lowbit(x)
代码如下:
int get(int x,int y)
{
int ret=0,ry=y;
while(x)
{
y=ry;
while(y)
{
ret+=c[x][y];
y-=lowbit(y);
}
x-=lowbit(x);
}
return ret;
}
单点修改
与上面的矩阵查询同样的想法
只不过将 ret+=c[x][y]
替换为 c[x][y]+=k
,
再将 -=lowbit
替换为 +=lowbit
代码如下:
void add(int x,int y,int k)
{
int ry=y;
while(x<=n)
{
y=ry;
while(y<=m)
{
c[x][y]+=k;
y+=lowbit(y);
}
x+=lowbit(x);
}
}
二维树状数组的扩展
差分(矩阵修改+单点查询)
类似于树状数组差分的想法,
由于二维前缀和 \(pre(i,j)=pre(i-1,j)+pre(i,j-1)-pre(i-1,j-1)+a_{i,j}\)
由于差分是前缀和的逆运算,定义二维差分 \(dif(i,j)=a_{i,j}-a_{i-1,j}-a_{i,j-1}+a_{i-1,j-1}\)
然后用二维树状数组维护数组 \(a\) 的二维差分就好了
要查询单点,就直接将 \(get(x,y)\) 就好了
但如果要矩阵修改,就需要拆分成 \(4\) 个操作:
void add(int x1,int y1,int x2,int y2,int k)
{
stdadd(x1,y1,k);
stdadd(x1,y2+1,-k);
stdadd(x2+1,y1,-k);
stdadd(x2+1,y2+1,k);
}
完整代码如下:
int c[MAXN][MAXN];
inline int lowbit(int x){return x&-x;}
int get(int x,int y)
{
int ret=0,ry=y;
while(x)
{
y=ry;
while(y)
{
ret+=c[x][y];
y-=lowbit(y);
}
x-=lowbit(x);
}
return ret;
}
void stdadd(int x,int y,int k)
{
int ry=y;
while(x<=n)
{
y=ry;
while(y<=m)
{
c[x][y]+=k;
y+=lowbit(y);
}
x+=lowbit(x);
}
}
void add(int x1,int y1,int x2,int y2,int k)
{
stdadd(x1,y1,k);
stdadd(x1,y2+1,-k);
stdadd(x2+1,y1,-k);
stdadd(x2+1,y2+1,k);
}
矩阵修改+矩阵查询
接下来,请欣赏美妙的公式推导:
首先根据前缀和以及差分的定义得:
然后,仿照一维树状数组的推导:
将公式展开得:
于是需要开四个树状数组,分别维护 \(dif(i,j),dij(i,j) \times j,dif(i,j) \times i,dif(i,j) \times ij\)
int c[4][MAXN][MAXN];
inline int lowbit(int x){return x&-x;}
void stdadd(int x,int y,int k)
{
for(int i=x; i<=n; i+=lowbit(i))
for(int j=y; j<=m; j+=lowbit(j))
{
c[0][i][j]+=k;
c[1][i][j]+=k*x;
c[2][i][j]+=k*y;
c[3][i][j]+=k*x*y;
}
}
void add(int x1,int y1,int x2,int y2,int k)
{
stdadd(x1,y1,k);
stdadd(x1,y2+1,-k);
stdadd(x2+1,y1,-k);
stdadd(x2+1,y2+1,k);
}
int stdget(int x,int y)
{
int ret=0;
for(int i=x; i; i-=lowbit(i))
for(int j=y; j; j-=lowbit(j))
ret+=(x+1)*(y+1)*c[0][i][j]-(y+1)*c[1][i][j]-(x+1)*c[2][i][j]+c[3][i][j];
return ret;
}
int get(int x1,int y1,int x2,int y2)
{
return stdget(x2, y2)-stdget(x2,y1-1)-stdget(x1-1,y2)+stdget(x1-1,y1-1);
}