您的位置:首页 > 其它

线段树小结

2016-04-19 16:24 423 查看
因为要给大一讲课,所以又把线段树的模板题拿出来做了一下。

中所周知,线段树是用来处理区间的利器,这是因为它把区间按二分的思想分成一个个小段,使得我们用lgn的复杂度就能查询一个区间(例如最大值,最小值,和等等)

它的结构如下图所述:



我们可以看到,对于一个节点root,它的左子树下标为root*2,右子树下标为root*2+1。

实际上,由于线段树的父节点区间是平均分割到左右子树,因此线段树是完全二叉树,对于包含n个叶子节点的完全二叉树,它一定有n-1个非叶节点,总共2n-1个节点,因此存储线段是需要的空间复杂度是O(n)。接下来我们来看看如何用代码来实现它,以求区间和为例

1.创建线段树

void build(int root, int l, int r)
{
if(l == r) //叶子节点
{
scanf("%d", &segTree[root].sum);
return;
}
else
{
int mid = (l + r) >> 1;
build(root << 1, l, mid); //递归构造左子树
build((root << 1) + 1, mid + 1, r); //递归构造右子树
segTree[root].sum = segTree[root << 1].sum + segTree[(root << 1) + 1].sum; //根据左右子树跟节点的值来更新当前跟节点的值
}
}


2.查询

int query(int root, int L, int R, int l, int r)
{
if(L <= l && R >= r) //当前节点在区间内
{
return segTree[root].sum;
}
int mid = (l + r) >> 1;
int ans = 0;
if(L <= mid)
ans += query(root << 1, L, R, l, mid); //统计左子树
if(R > mid)
ans += query((root << 1) + 1, L, R, mid + 1, r); //统计右子树
return ans;
}


3.单点更新

void addNode(int root, int i, int add, int l, int r)
{
if(l == r)
{
segTree[root].sum += add; //找到叶子更新之
return;
}
int mid = (l + r) >> 1;
if(i <= mid)
addNode(root << 1, i, add, l, mid); //更新左子树
else
addNode((root << 1) + 1, i, add, mid + 1, r); //更新右子树
segTree[root].sum = segTree[root << 1].sum + segTree[(root << 1) + 1].sum; //根据左右子树回溯跟节点
}


hdu1166 敌兵布阵(典型的区间求和,单点更新)

#include <cstdio>
#include <cstring>
using namespace std;

struct SegTreeNode
{
int sum;
}segTree[50000 << 2];

char s[100];

void build(int root, int l, int r)
{
if(l == r)
{
scanf("%d", &segTree[root].sum);
return;
}
else
{
int mid = (l + r) >> 1;
build(root << 1, l, mid);
build((root << 1) + 1, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[(root << 1) + 1].sum;
}
}

void addNode(int root, int i, int add, int l, int r)
{
if(l == r)
{
segTree[root].sum += add;
return;
}
int mid = (l + r) >> 1;
if(i <= mid)
addNode(root << 1, i, add, l, mid);
else
addNode((root << 1) + 1, i, add, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[(root << 1) + 1].sum;
}

int query(int root, int L, int R, int l, int r)
{
if(L <= l && R >= r)
{
return segTree[root].sum;
}
int mid = (l + r) >> 1;
int ans = 0;
if(L <= mid)
ans += query(root << 1, L, R, l, mid);
if(R > mid)
ans += query((root << 1) + 1, L, R, mid + 1, r);
return ans;
}

int main()
{
int T;
scanf("%d", &T);
for(int nCase = 1; nCase <= T; nCase++)
{
printf("Case %d:\n", nCase);
int n;
scanf("%d", &n);
build(1, 1, n);
while(1)
{
scanf("%s", s);
if(s[0] == 'Q')
{
int a,b;
scanf("%d %d", &a, &b);
printf("%d\n", query(1, a, b, 1, n));
}
else if(s[0] == 'A')
{
int a,b;
scanf("%d %d", &a, &b);
addNode(1, a, b, 1, n);
}
else if(s[0] == 'S')
{
int a,b;
scanf("%d %d", &a, &b);
addNode(1, a, -b, 1, n);
}
else if(s[0] == 'E')
break;
}
}
return 0;
}


hdu1394(求区间最大值)

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

struct SegTreeNode
{
int val;
}segTree[200000 << 2];

void build(int root, int l, int r)
{
if(l == r)
{
scanf("%d", &segTree[root].val);
return;
}
int mid = (l + r) >> 1;
build(root << 1, l, mid);
build(root << 1|1, mid + 1, r);
segTree[root].val = max(segTree[root << 1].val, segTree[root << 1|1].val);
}

void update(int root, int i ,int v, int l, int r)
{
//printf("%d %d\n", l, r);
if(l == r)
{
segTree[root].val = v;
//printf("%d %d\n", root, segTree[root].val);
return;
}
int mid = (l + r) >> 1;
if(mid >= i)
update(root << 1, i, v, l, mid);
else
update(root << 1|1, i, v, mid + 1, r);
segTree[root].val = max(segTree[root << 1].val, segTree[root << 1|1].val);
}

int query(int root, int L, int R, int l, int r)
{
//printf("%d %d\n", l, r);
if(L <= l && R >= r)
{
return segTree[root].val;
}
int mid = (l + r) >> 1;
int Max = 0;
if(mid >= L)
Max = query(root << 1, L, R, l, mid);
//printf("%d ", Max);
if(mid < R)
Max = max(Max, query(root << 1|1, L, R, mid + 1, r));
//printf("%d\n", Max);
return Max;
//return max(query(root << 1, L, R, l, mid), query((root << 1) + 1, L, R, mid + 1, r));
}

int main()
{
int n,m;
while(~scanf("%d %d", &n, &m))
{
//memset(segTree, 0, sizeof(segTree));
build(1, 1, n);
while(m--)
{
getchar();
char c;
int a,b;
scanf("%c %d %d", &c, &a, &b);
if(c == 'Q')
printf("%d\n", query(1, a, b, 1, n));
else if(c == 'U')
update(1, a, b, 1, n);
}
}
return 0;
}


实际问题中还有一些问题需要我们成段更新区间的操作,如果对这个区间一个个进行单点更新,那么很显然时间效率会非常低,我们仔细思考一下,我们发现并没有必要更新所有点,只需要更新要用到的点,并标记,区别已更新和没更新的点,在下次更新时,通过该标记来判断当前点是否已更新

hdu 1698 (典型的成段更新)

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;

struct SegTreeNode
{
int sum;
int lazy;
}segTree[100000 << 2];

void build(int root, int l, int r)
{
segTree[root].lazy = 0;
if(l == r)
{
segTree[root].sum = 1;
return;
}
int mid = (l + r) >> 1;
build(root << 1, l, mid);
build(root << 1|1, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[root << 1|1].sum;
}

void pushdown(int root, int l, int r)
{
if(segTree[root].lazy)
{
segTree[root << 1].lazy = segTree[root].lazy; //标记左子树,说明更新到这一点
segTree[root << 1|1].lazy = segTree[root].lazy; //标记右子树,说明更新到这一点
int mid = (l + r) >> 1;
segTree[root << 1].sum = (mid - l + 1) * segTree[root].lazy; //因为求得是区间的和,所以要加上区间的数的个数乘上增加的值
segTree[root << 1|1].sum = (r - mid) * segTree[root].lazy;
segTree[root].lazy = 0; //标记置0,说明该点的左右孩子已更新
}
}

void update(int root, int x, int L, int R, int l, int r)
{
if(L <= l && R >= r)  //当前区间在所求区间内
{
segTree[root].sum = (r - l + 1) * x;
segTree[root].lazy = x;
//printf("(%d %d) %d\n",l, r, segTree[root].sum);
return;
}
pushdown(root, l, r); //向下更新,同时传递标记
int mid = (l + r) >> 1;
if(L <= mid)
update(root << 1, x, L, R, l, mid); //更新左子树
if(R > mid)
update(root << 1|1, x, L, R, mid + 1, r); //更新右子树
segTree[root].sum = segTree[root << 1].sum + segTree[root << 1|1].sum; //根据左右根节点回溯当前跟节点
}

int main()
{
int T;
scanf("%d", &T);
for(int i = 1; i <= T; i++)
{
int n;
scanf("%d", &n);
memset(segTree, 0, sizeof(segTree));
build(1,1,n);
//printf("%d\n", segTree[1].sum);
int Q;
scanf("%d", &Q);
while(Q--)
{
int x,y,z;
scanf("%d %d %d", &x, &y, &z);
update(1,z,x,y,1,n);
}
printf("Case %d: The total value of the hook is %d.\n", i, segTree[1].sum);
}
return 0;
}


poj3468 (区间求和,成段更新)

#include <cstdio>
#include <cstring>
#define LL long long
using namespace std;

struct SegTreeNode
{
LL sum;
LL add;
}segTree[100000 << 2];

void build(int root, int l, int r)
{
segTree[root].add = 0;
if(l == r)
{
scanf("%I64d", &segTree[root].sum);
return;
}
int mid = (l + r) >> 1;
build(root << 1, l, mid);
build(root << 1|1, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[root << 1|1].sum;
}

void pushdown(int root, int l, int r)
{
if(segTree[root].add)
{
int mid = (l + r) >> 1;
segTree[root << 1].sum += (mid - l + 1) * segTree[root].add;
segTree[root << 1|1].sum += (r - mid) * segTree[root].add;
segTree[root << 1].add += segTree[root].add;
segTree[root << 1|1].add += segTree[root].add;
segTree[root].add = 0;
}
}

void update(int root, int c, int L, int R, int l, int r)
{
if(L <= l && R >= r)
{
segTree[root].sum += (r - l + 1) * c;
segTree[root].add += c;
return;
}
pushdown(root, l, r);
int mid = (l + r) >> 1;
if(mid >= L)
update(root << 1, c, L, R, l, mid);
if(mid < R)
update(root << 1|1, c, L, R, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[root << 1|1].sum;
}

LL query(int root, int L, int R, int l, int r)
{
if(L <= l && R >= r)
{
return segTree[root].sum;
}
pushdown(root, l, r);
int mid = (l + r) >> 1;
LL ans = 0;
if(mid >= L)
ans += query(root << 1, L, R, l, mid);
if(mid < R)
ans += query(root << 1|1, L, R, mid + 1, r);
return(ans);
}

int main()
{
int n,q;
while(~scanf("%d %d", &n, &q))
{
build(1,1,n);
while(q--)
{
getchar();
char c;
scanf("%c", &c);
if(c == 'Q')
{
int a,b;
scanf("%d %d", &a, &b);
printf("%I64d\n", query(1,a,b,1,n));
}
else if(c == 'C')
{
int a,b,c;
scanf("%d %d %d", &a, &b, &c);
update(1,c,a,b,1,n);
}
}
}
return 0;
}


线段树还能用来求逆序对,怎么求呢?

很简单……
设数列为a,将数列离散化,在从前往后枚举,统计答案……
离散化:例如2 5 8 3 10 等价于 1 3 4 2 5,可以通过排序加小小处理解决。
枚举到第i个数,我们需要求出从1到i-1中有多少个比a[i]大的数,更新答案。
具体怎么做呢?
每次枚举完一个数之后,将这个数插入到线段树里,插入到线段树的神马地方呢?当然是这个数多大就插入到多大的地方。

hdu1394(线段树求逆序对,单点更新)

这题数列是环状结构,我们很容易得出,当前逆序对的个数是前一个逆序对的个数减去比这个数小的数的个数,再加上比这个数大的数的个数,这里比较坑的一点。。。数是从0~n-1的,之前没注意。。。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

struct SegTreeNode
{
int sum;
}segTree[5000 << 2];
int a[5000 + 10];

void update(int root, int x, int l, int r)
{
if(l == r)
{
segTree[root].sum++;
//printf("(%d %d) %d\n", l, r, segTree[root].sum);
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
update(root << 1, x, l, mid);
if(x > mid)
update(root << 1|1, x, mid + 1, r);
segTree[root].sum = segTree[root << 1].sum + segTree[root << 1|1].sum;
}

int query(int root, int L, int R, int l, int r)
{
if(L <= l && R >= r)
{
return segTree[root].sum;
}
int mid = (l + r) >> 1;
int ans = 0;
if(L <= mid)
ans += query(root << 1, L, R, l, mid);
if(R > mid)
ans += query(root << 1|1, L, R, mid + 1, r);
return ans;
}

int main()
{
int n;
while(~scanf("%d", &n))
{
memset(segTree, 0, sizeof(segTree));
int ans = 0;
for(int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
//printf("%d\n", query(1, a[i] + 1, n - 1, 0, n - 1));
ans += query(1, a[i] + 1, n - 1, 0, n - 1);
update(1, a[i], 0, n - 1);
}
//printf("%d\n", ans);
int Min = ans;
for(int i = 1; i <= n; i++)
{
ans = ans - a[i] + (n - a[i] - 1);
//printf("%d %d\n",a[i], ans);
Min = min(Min, ans);
}
printf("%d\n", Min);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: