您的位置:首页 > 产品设计 > UI/UE

[leetcode] 307. Range Sum Query - Mutable 解题报告

2016-05-03 14:23 477 查看
题目链接: https://leetcode.com/problems/range-sum-query-mutable/
Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i,
val) function modifies nums by
updating the element at index i to val.

Example:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8


Note:

The array is only modifiable by the update function.
You may assume the number of calls to update and sumRange function is distributed evenly.

思路: 想了好久只想出一个线性时间的方法, 后来才发现原来是另外一个数据结构树状数组的应用, 大学时候看过, 忘掉了.

这种数据结构真是太适合区间求和了, 初始化的时候时间复杂度是n log n, 以后每次查询和更新log n, 相比于每次线性时间的更新和常数时间的求和要快了很多.

这种数据结构的原理是利用数字的二进制性质. 将索引数字转换成二进制, 按照其二进制数从右往左第一个1所在位置代表的值将其分层. 我们知道从右往左每增加一位, 数字增长一倍, 故可以按照数字二进制1的位置分层如下:

1 2 4 8 ....

举个栗子, 5 = (101), 故其lowbit(5) = 1. 8 = (1000), lowbit(8) = 8, 所以可以看到数字按照lowbit分布如下:

lowbit=113579...
lowbit=226101418...
lowbit=4412202836...
lowbit=8824405672...
求一个数的lowbit可以利用这个公式lowbit(val) = val & -val, 其原理是数字在计算机中是以补码形式存储的, 一个正数的补码是其本身, 一个负数的补码是按照其正数的按位取反加1.

ok, 有了上面的基础我们要看看如何利用其性质来求和.

8val
4val
2valval
1valvalvalvalval
0123456789
每一行代表不同的lowbit分层, 每一列代表数组索引, 上表代表每一个索引处在哪一层. 我们可以看到其实这是像树一样的结构的, 具有比较高的lowbit的索引位置将存储在其左边的比其低的数值的和, 即可以说一个结点存储其本身和其左子树的值的和.

举个栗子, 索引4的lowbit是4, 其左子树是索引1, 2, 3, 因此索引4的位置保存了数组1, 2, 3, 4位置元素的和. 而索引5的位置只保存了其自身的值, 因为他没有左子树.

然后我们再来看如何初始化一个二叉索引树:

因为一个数包含其左子树的值, 所以一个数的值更新的时候就要更新其父结点的值, 而寻找父结点的方式就是利用lowbit, 也就是一个位置的父结点就是其本身位置+其lowbit, 这样就可以在log n的时间更新其父结点的值.代码如下:

void add(int index, int d)
{
while(index <= tem.size())
{
bit[index] += d;
index += (index&-index);
}
}


这样初始化完成一个二叉索引树之后我们再来看如下查询前n个数的值:

我们已经知道一个结点包含其左子树和其本身的值, 也就是说结点是不包含其右子树的值的, 因此我们看成要查询的位置是在一个右子树上, 我们只要一直往上找其父结点, 再将其父结点看成是一个右子节点, 找其父结点的父结点, 直到找到最高的父结点, 并将沿途的值加起来就是我们要找的前n个数的值. 可以看到这个过程是一直往左找的过程, 因为是将其看做是在右子树. 而更新是往右寻找的过程. 查询代码如下:

int sum(int index)
{
int ans = 0;
while(index > 0)
{
ans += bit[index];
index -= (index&-index);
}
return ans;
}


因此当我们要查询一个区间的时候就很容易用前n个元素的值- 前i个元素的值就是[i, n]区间的值.

代码如下:

class NumArray {
public:
NumArray(vector<int> &nums):tem(nums) {
bit.resize(nums.size()+1, 0);
for(int i =0; i< nums.size(); i++)
add(i+1, nums[i]);
}

void add(int index, int d) { while(index <= tem.size()) { bit[index] += d; index += (index&-index); } }

int sum(int index)
{
int ans = 0;
while(index > 0)
{
ans += bit[index];
index -= (index&-index);
}
return ans;
}

void update(int i, int val) {
add(i+1, val-tem[i]);
tem[i] = val;
}

int sumRange(int i, int j) {
return sum(j+1) - sum(i);
}
private:
vector<int> bit;
vector<int>& tem;
};

// Your NumArray object will be instantiated and called as such:
// NumArray numArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: