本文已参与「新人创作礼」活动, 一起开启掘金创作之路。
功能
维护一个数据结构,对一列n个数,实现下面两种操作:
- 将某一个数加上 x
- 求出下标从x到y所有数的和
思路
定义 表示将 i 转化为二进制后最低位的 1 所对应的值。
比如 12 的二进制表示为 ,则 。
开一个长度为n的数组 s[i],表示从下标i开始前 项和。
上图中 a 数组表示输入的原数组,s 为树状数组。则:
- s[1]=a[1]
- s[2]=a[1]+a[2]
- s[3]=a[3]
- s[4]=a[1]+a[2]+a[3]+a[4]
- s[5]=a[5]
- s[6]=a[5]+a[6]
- s[7]=a[7]
- s[8]=a[1]+a[2]+a[3]+a[4]+a[5]+a[6]+a[7]+a[8]
- s[9]=a[9]
- ……
那么树状数组应该怎样建立和维护呢?
建树
我们只需自下而上的遍历上图即可完成建树。
void build()
{
for (int i=1;i<=n;++i) s[i]=a[i];
for (int i=1;i<=n;i*=2)
for (int j=i*2;j<=n;j+=i*2) s[j]+=s[j-i];
}
修改
考虑第x个数的值改变会对树状数组中那些值产生影响。
由 s[i] 表示从下标 i 开始前 项和,即第 i-lowbit(i)+1 项到第 i 项之和。
观察发现修改 x 的值只会上图中一条链的值,且该条链中 i 的下一个元素为 i+lowbit(i)。
修改第 x 个元素的值的操作如下:
- 将修改树状数组中第x个元素的值。
- 若 x+lowbit(x) 不超过n,则 x+=lowbit(x),重复step1。
举个栗子: 比如我们想要让 a[5] 加 1,则需要依次进行以下操作:
代码实现如下:
void change(int x,int y)
{
for (int i=x;i<=n;i+=lowbit(i))
s[i]+=y;
}
查询
易知想要查询下标从x到y的所有元素的和,只需要查询前y项和与前x-1项和作差即可。
统计前 x 项的和的方法如下:
- 将第 x-lowbit(x)+1~x 个元素的和加入答案,即 ans+=s[x]。
- 若 x-lowbit(x) 不为 0,则 x-=lowbit(x),重复 step1。
- ans 即为所求。
举个栗子: 我们想要统计前 82 项和,现将 82 转化为二进制得 。由上述树状数组定义得:
则前82项和为 s[82]+s[80]+s[64]。 代码实现如下。
int find(int x)
{
int sum=0;
for (int i=x;i;i-=lowbit(i))
sum+=s[i];
return sum;
}
lowbit
定义 owbit(i) 表示将 i 转化为二进制后最低位的 1 所对应的值。上文中基于 lowbit 的定义我们实现了单次操作时间复杂度为 O(log(n)) 维护上述数据结构。下面给出 lowbit(x) 的实现。
int lowbit(int x)
{
return x&(-x);
}
不理解&运算符请搜索C语言位运算。
为什么x&(-x)就是将x转化为二进制后最低位的1所对应的值呢?
这就涉及到了数在计算机中的存储方式。
一个数在计算机中的二进制表示形式叫做这个数的机器数。机器数是带符号的,在计算机用一个数的最高位存放符号,正数符号位为0,负数符号位为1。一个数在计算机中多以二进制补码形式存储。
在介绍补码之前,我们需要先定义原码和反码。
-
原码 现将数的绝对值表示为二进制形式,将最高位按符号置为0/1,就能得到数的原码表示形式。
-
反码 非负数的反码是它本身。 负数的反码是原码除符号位逐位取反。
-
补码 非负数的补码是它本身。 负数的补码是反码+1。
举个栗子:
由于 x 为数组下标,下文将 x 视为非负的整类型变量展开讨论。
当 x 为 0 时,0&0=0。
当 x 不为 0 时,x 为正数,则其补码就是自身的二进制形式。
-x 的补码是将反码 +1 的结果,
当 后,
将会对 包括最低位 0 开始右边所有字节逐位取反,
最低位 0 左边保持不变。
易知 与 的每一位都不同,
所以 包括最低位 1 开始向右所有字节均与 相同,
最低位 1 左边所有字节均与 不同,
则 x&(-x) 会取出将 x 转化为二进制后最低位的 1 所对应的值。
代码
完整代码如下:
#include <stdio.h>
#define lowbit(x) x&(-x)
int n,m,s[500001],t,x,y;
void build()
{
for (int i=1;i<=n;i*=2)
for (int j=i*2;j<=n;j+=i*2) s[j]+=s[j-i];
}
int find(int x)
{
int sum=0;
for (int i=x;i;i-=lowbit(i))
sum+=s[i];
return sum;
}
void change(int x,int y)
{
for (int i=x;i<=n;i+=lowbit(i))
s[i]+=y;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;++i) scanf("%d",&s[i]);
build();
while (m--)
{
scanf("%d%d%d",&t,&x,&y);
if (t==1) change(x,y);
else printf("%d\n",find(y)-find(x-1));
}
return 0;
}