二维树状数组
二维树状数组
学一学吧,没准考试就用到了呢。。。
学了也不亏,毕竟限制条件太多,有也肯定能看出来。
二维矩阵上应用,且维护的结果必须具有加和性。
引入
在一维树状数组中, 本质上 \(c[x]\) 表示的是右端点为 \(x\) ,区间长度是 \(lowbit(x)\) 的区间和。
我们把一维树状数组扩展一下, 用 \(c[x][y]\) 表示右下角是 \((x,y)\) ,长度为 \(lowbit(x)\) 且宽度为 \(lowbit(y)\) 这个矩阵的和,这就是二维树状数组。
下面分别讲一维树状数组的三个操作在二维树状数组中的实现。
单点修改+区间求和
比较简单,只需要在一维树状数组的基础上再枚举一维即可。
计算结果时,一维前缀和变成二维前缀和。
这里放图来解释一下二位前缀和的计算。
#include<iostream>
using namespace std;
#define int long long
const int N=5000;
int c[N][N];
int n,m,pd,aa,bb,cc,dd;
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int y,int z)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
c[i][j]+=z;
}
int query(int x,int y)
{
int res=0;
for(int i=x;i!=0;i-=lowbit(i))
for(int j=y;j!=0;j-=lowbit(j))
res+=c[i][j];
return res;
}
main()
{
scanf("%lld %lld",&n,&m);
while(scanf("%lld",&pd)!=EOF)
{
if(pd==1)
{
scanf("%lld %lld %lld",&aa,&bb,&cc);
add(aa,bb,cc);
}
if(pd==2)
{
scanf("%lld %lld %lld %lld",&aa,&bb,&cc,&dd);
printf("%lld\n",query(cc,dd)-query(aa-1,dd)-query(cc,bb-1)+query(aa-1,bb-1));
}
}
return 0;
}
区间修改+单点查询
类比于一维树状数组,我们将树状数组维护成差分数组。
这里放图来解释一下二维差分。
md 没找到图不惜放了。
如果我们需要对于左上角是 \((x1,y1)\) ,右上角是 \((x2,y2)\) 的矩阵进行修改,\(cf[x1][y2]+x,cf[x1][y2+1]-x,cf[x2+1][y1]-x,cf[x2+1][y2+1]+x\)
#include<iostream>
#define int long long
using namespace std;
const int N=5000;
int c[N][N];
int n,m;
inline int read()
{
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9')
{ if(ch=='-') w=-1; ch=getchar();}
while(ch>='0'&&ch<='9')
{ s=s*10+ch-'0'; ch=getchar();}
return s*w;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int y,int v)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
c[i][j]+=v;
}
int query(int x,int y)
{
int res=0;
for(int i=x;i;i-=lowbit(i))
for(int j=y;j;j-=lowbit(j))
res+=c[i][j];
return res;
}
signed main()
{
n=read(),m=read();
int op;
while(scanf("%lld",&op)!=EOF)
{
if(op==1)
{
int a=read(),b=read(),c=read(),d=read(),k=read();
add(a,b,k),add(a,d+1,-k),add(c+1,b,-k),add(c+1,d+1,k);
}
if(op==2)
{
int x=read(),y=read();
printf("%lld\n",query(x,y));
}
}
return 0;
}
区间修改+区间查询
又到了令人 \(excited\) 的时候了,类比与一维树状数组,我们来推公式。
首先,我们的结果 \(sum[n][m]=\sum^n_{i=1}\sum^m_{j=1}a[i][j]\)
又因为 \(a[i][j]=\sum^i_{x=1}\sum^j_{y=1}cf[x][y]\) ,所以 \(sum[n][m]=\sum^n_{i=1}\sum^m_{j=1}\sum^i_{x=1}\sum^j_{y=1}cf[x][y]\)。
。。。。。。
内心极度的崩溃。
没事,我们慢慢化简。
类比于一维树状数组,我们可以发现, \(cf[1][1]\) 在每次循环的时候都会出现,出现的次数为 \(n \times m\)。同理,我们可以得出 \(cf[1][2]\) 的出现次数为 \(n \times (m-1)\) 次, \(cf[2][1]\) 的出现次数为 \((n-1) \times m\) 次……
所以,我们的式子就可以化简为 \(sum[n][m]=\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (n+1-i) \times (m+1-j)\)
我们把式子展开,变为 \(sum[n][m]=\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (n+1) \times (m+1) \ -\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (n+1) \times j \ -\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (m+1) \times i \ +\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times i \times j\)
我们把四个式子分开,得到 \((n+1)\times (m+1) \times \sum^n_{i=1}\sum^m_{j=1}cf[i][j]\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (n+1) \times j\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times (m+1) \times i\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times i \times j\)。
\(n,m\) 已知,用四个树状数组分别维护 \(\sum^n_{i=1}\sum^m_{j=1}cf[i][j]\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times i\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times j\),\(\sum^n_{i=1}\sum^m_{j=1}cf[i][j] \times i \times j\)。
#include<iostream>
#define int long long
using namespace std;
const int N=3000;
int c1[N][N],c2[N][N],c3[N][N],c4[N][N];
int n,m;
inline int read()
{
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9')
{ if(ch=='-') w=-1; ch=getchar();}
while(ch>='0'&&ch<='9')
{ s=s*10+ch-'0'; ch=getchar();}
return s*w;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int y,int v)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
{
c1[i][j]+=v;
c2[i][j]+=v*x;
c3[i][j]+=v*y;
c4[i][j]+=v*x*y;
}
}
int query(int x,int y)
{
int res=0;
for(int i=x;i;i-=lowbit(i))
for(int j=y;j;j-=lowbit(j))
res+=(x+1)*(y+1)*c1[i][j]-(y+1)*c2[i][j]-(x+1)*c3[i][j]+c4[i][j];
return res;
}
signed main()
{
n=read(),m=read();
int op;
while(scanf("%lld",&op)!=EOF)
{
if(op==1)
{
int a=read(),b=read(),c=read(),d=read(),x=read();
add(a,b,x),add(a,d+1,-x),add(c+1,b,-x),add(c+1,d+1,x);
}
if(op==2)
{
int a=read(),b=read(),c=read(),d=read();
printf("%lld\n",query(c,d)-query(c,b-1)-query(a-1,d)+query(a-1,b-1));
}
}
return 0;
}