您的位置:首页 > 其它

hdoj 4747 线段树

2015-08-07 13:06 260 查看
hdoj 4747

题意:给一个序列,问所有区间[i,j](i<=j)中,区间中最小的非负整数加和是多少。

思路:

看样例:

1 0 2 0 1

做出以第一个元素1为起点的分别到其他各点的所有区间最小值:

0 2 3 3 3

发现是一个单调递增序列,实际上对于任何一个样例都是这样的结果,因为随着数字的增加,最小值只会增大。

然后我们删掉第一个数字,做第二个元素0的区间最小值:

x 1 1 1 3

因为我们删掉了0之前的1后,在下一个1出现之前,这段区间中就可能拥有一个更小的值1。

而相比于1的区间最小值,0的区间最小值只修改了两个1中间的这段序列,并且是将所有大于1的数字修改为1。

而又因为这是个单调序列,所以修改的这段范围最多有1个(如果没有比删掉的这个值小的数就不用修改)

所以题目就转化成了线段树区间修改区间查询。

先预处理相同数字的位置邻接表,再找出序列中最左端的大于等于删除掉的数Vi的位置L,然后找到下个Vi出现的位置R,如果没有就是序列的最后位置n,最后将这个区间的值修改为Vi。

#include <cstdio>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;

const int M = 200020;
struct Tree{
int l, r, val;
int flag;
long long sum;
}tree[M * 4];
int seq[M], mex[M], Next[M];
bool vis[M];
map<int, int>mp;
int cnt, n;
void buildtree(int rt, int l, int r) {
tree[rt].l = l, tree[rt].r = r, tree[rt].flag = false;
if(l == r) {
tree[rt].val = mex[cnt++];
tree[rt].sum = tree[rt].val;
//    printf("%d %d %d %I64d\n", rt, tree[rt].l, tree[rt].r, tree[rt].sum);
return ;
}
int mid = (l + r) / 2;
buildtree(rt * 2, l, mid);
buildtree(rt * 2 + 1, mid + 1, r);
tree[rt].val = max(tree[rt * 2].val, tree[rt * 2 + 1].val);
tree[rt].sum = tree[rt * 2].sum + tree[rt * 2 + 1].sum;
//   printf("%d %d %d %I64d\n", rt, tree[rt].l, tree[rt].r, tree[rt].sum);
}
void pushdown(int rt){
if(tree[rt].flag) {
tree[rt * 2].val = tree[rt * 2 + 1].val = tree[rt].val;
tree[rt * 2].sum = (long long) (tree[rt * 2].r - tree[rt * 2].l + 1) * tree[rt * 2].val;
tree[rt * 2 + 1].sum = (long long) (tree[rt * 2 + 1].r - tree[rt * 2 + 1].l + 1) * tree[rt * 2 + 1].val;
tree[rt * 2].flag = tree[rt * 2 + 1].flag = true;
tree[rt].flag = false;
}
}
void pullup(int rt) {
tree[rt].sum = tree[rt * 2].sum + tree[rt * 2 + 1].sum;
tree[rt].val = max(tree[rt * 2].val, tree[rt * 2 + 1].val);
}
long long query(int rt, int l, int r) {
if(tree[rt].l == l && tree[rt].r == r) {
return tree[rt].sum;
}
pushdown(rt);
int mid = (tree[rt].l + tree[rt].r) / 2;
if(l > mid) return query(rt * 2 + 1, l, r);
else if(r <= mid) return query(rt * 2, l, r);
else return query(rt * 2, l, mid) + query(rt * 2 + 1, mid + 1, r);
}
void update(int rt, int l, int r, int a) {
if(tree[rt].l == l && r == tree[rt].r) {
tree[rt].val = a;
tree[rt].sum = (long long)(tree[rt].r - tree[rt].l + 1) * a;
tree[rt].flag = true;
return ;
}
pushdown(rt);
int mid = (tree[rt].l + tree[rt].r) / 2;
if(l > mid) update(rt * 2 + 1, l, r, a);
else if(r <= mid) update(rt * 2, l, r, a);
else {
update(rt * 2, l, mid, a);
update(rt * 2 + 1, mid + 1, r, a);
}
pullup(rt);
}
int getLeft(int rt, int val) {
if(tree[rt].l == tree[rt].r) {
if(tree[rt].val >= val) return tree[rt].l;
else return n + 1;
}
pushdown(rt);
if(tree[rt * 2].val >= val) return getLeft(rt * 2, val);
else return getLeft(rt * 2 + 1, val);
}
int main() {
//  freopen("in.txt", "r", stdin);
while(~scanf("%d", &n) && n) {
mp.clear();
for(int i = 1; i <= n; i++) scanf("%d", &seq[i]);
for(int i = 1; i <= n; i++) Next[i] = n;
memset(vis, 0, sizeof vis);
int val = 0;
for(int i = 1; i <= n; i++) {
if(seq[i] < M) vis[seq[i]] = true;
while(vis[val]) val++;
mex[i] = val;
}
for(int i = 1; i <= n; i++) {
if(mp.find(seq[i]) != mp.end()) Next[mp[seq[i]]] = i - 1;
mp[seq[i]] = i;
}
long long ans = 0;
cnt = 1;
buildtree(1, 1, n);
for(int i = 1; i <= n; i++) {
ans += query(1, i, n);
int left = getLeft(1, seq[i]);
//    printf("[%d, %d] = %d\n", left, Next[i], seq[i]);
if(left <= Next[i]) update(1, left, Next[i], seq[i]);
//      printf("%I64d\n", ans);
}
printf("%I64d\n", ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: