您的位置:首页 > 其它

CQUOJ 24914 线段树写个爽(线段树TAT)

2016-07-27 15:03 337 查看
http://acm.cqu.edu.cn:8888/oj/problem_show.php?pid=24914

一个有四种操作的线段树,区间加和,区间加等差数列,区间赋值,区间查询。

真是写死人并且那个坑调了好久才调出来。

区间加等差数列可以维护个首项和公差就行了,一样pushdown,然后区间赋值的时候,需要把前两个的标记全部清除,在pushdown的时候也是,左右子树的前两个操作的标记要全部清除,而且要先执行第三个操作,因为如果前两个操作在这个区间,再来第三个,就会把他们清除,这无所谓,但是如果是第三个操作先在这个区间,然后后两个操作也在这个区间,pushdown的时候如果先执行了前两个,再执行第三个的话,左右子树的标记就会被第三个操作给清除,所以第三个操作应该优先执行,反正他的标记只能被他自己清除。这里错了好久。

代码:

#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")

using namespace std;
#define   MAX           100005
#define   MAXN          6005
#define   maxnode       15
#define   sigma_size    30
#define   lson          l,m,rt<<1
#define   rson          m+1,r,rt<<1|1
#define   lrt           rt<<1
#define   rrt           rt<<1|1
#define   middle        int m=(r+l)>>1
#define   LL            long long
#define   ull           unsigned long long
#define   mem(x,v)      memset(x,v,sizeof(x))
#define   lowbit(x)     (x&-x)
#define   pii           pair<int,int>
#define   bits(a)       __builtin_popcount(a)
#define   mk            make_pair
#define   limit         10000

//const int    prime = 999983;
const int    INF   = 0x3f3f3f3f;
const LL     INFF  = 0x3f3f;
const double pi    = acos(-1.0);
//const double inf   = 1e18;
const double eps   = 1e-8;
const LL     mod    = 1e9+7;
const ull    mx    = 133333331;

/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/

LL sum[MAX<<2];
LL add[MAX<<2];
LL a1[MAX<<2],a2[MAX<<2];
LL col[MAX<<2];
LL aa[MAX<<2],bb[MAX<<2],cc[MAX<<2];
void pushup(int rt){
sum[rt]=sum[lrt]+sum[rrt];
}

void build(int l,int r,int rt){
add[rt]=a1[rt]=a2[rt]=col[rt]=0;
aa[rt]=bb[rt]=cc[rt]=0;
if(l==r){
scanf("%lld",&sum[rt]);
return ;
}
middle;
build(lson);
build(rson);
pushup(rt);
}

void pushdown(int rt,int m){
if(cc[rt]){
cc[lrt]=1;
aa[lrt]=0;add[lrt]=0;
bb[lrt]=0;a1[lrt]=0,a2[lrt]=0;
cc[rrt]=1;
aa[rrt]=0;add[rrt]=0;
bb[rrt]=0;a1[rrt]=0,a2[rrt]=0;
col[lrt]=col[rt];
col[rrt]=col[rt];
sum[lrt]=col[rt]*(m-(m>>1));
sum[rrt]=col[rt]*(m>>1);
col[rt]=0;
cc[rt]=0;
}
if(aa[rt]){
aa[lrt]=1;
aa[rrt]=1;
aa[rt]=0;
add[lrt]+=add[rt];
add[rrt]+=add[rt];
sum[lrt]+=add[rt]*(m-(m>>1));
sum[rrt]+=add[rt]*(m>>1);
add[rt]=0;
}
if(bb[rt]){
bb[lrt]=1;
bb[rrt]=1;
bb[rt]=0;
a1[lrt]+=a1[rt];
a1[rrt]+=a1[rt]+a2[rt]*(m-(m>>1));
a2[lrt]+=a2[rt];
a2[rrt]+=a2[rt];
sum[lrt]+=(2*a1[rt]+a2[rt]*(m-(m>>1)-1))*(m-(m>>1))/2;
sum[rrt]+=(2*a1[rt]+a2[rt]*(2*m-(m>>1)-1))*(m>>1)/2;
a1[rt]=0;a2[rt]=0;
}
}

void update1(int l,int r,int rt,int L,int R,int d){
if(L<=l&&r<=R){
sum[rt]+=(r-l+1)*d;
add[rt]+=d;
aa[rt]=1;
return;
}
middle;
pushdown(rt,r-l+1);
if(L<=m) update1(lson,L,R,d);
if(R>m) update1(rson,L,R,d);
pushup(rt);
}

void update2(int l,int r,int rt,int L,int R,int d){
if(L<=l&&r<=R){
sum[rt]+=(LL)(r+l-2*L+2)*(r-l+1)*d/2;
a1[rt]+=(LL)d*(l-L+1);
a2[rt]+=d;
bb[rt]=1;
return;
}
middle;
pushdown(rt,r-l+1);
if(L<=m) update2(lson,L,R,d);
if(R>m) update2(rson,L,R,d);
pushup(rt);
}

void update3(int l,int r,int rt,int L,int R,int d){
if(L<=l&&r<=R){
sum[rt]=(r-l+1)*d;
col[rt]=d;
cc[rt]=1;
aa[rt]=0;add[rt]=0;
bb[rt]=0;a1[rt]=0,a2[rt]=0;
return;
}
middle;
pushdown(rt,r-l+1);
if(L<=m) update3(lson,L,R,d);
if(R>m) update3(rson,L,R,d);
pushup(rt);
}

LL query(int l,int r,int rt,int L,int R){
if(L<=l&&r<=R) return sum[rt];
middle;
pushdown(rt,r-l+1);
LL ans=0;
if(L<=m) ans+=query(lson,L,R);
if(R>m) ans+=query(rson,L,R);
pushup(rt);
return ans;
}

int main(){
int t;
cin>>t;
while(t--){
int n,m;
cin>>n>>m;
build(1,n,1);
while(m--){
int op;
scanf("%d",&op);
if(op==1){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
update1(1,n,1,a,b,c);
}
else if(op==2){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
update2(1,n,1,a,b,c);
}
else if(op==3){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
update3(1,n,1,a,b,c);
}
else{
int a,b;
scanf("%d%d",&a,&b);
printf("%lld\n",query(1,n,1,a,b));
}
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: