您的位置:首页 > 产品设计 > UI/UE

[SPOJ GSS4] Can you answer these queries IV [树状数组+并查集][线段树+双向链表]

2014-07-27 10:56 435 查看
给一个序列,有两种操作,一种是把第x个数到第y个数都开根号,另一种是询问第x个数到第y个数的和。

实际上一个数最多开根号几次,就会变成1。所以开根号操作,直接对所有的非1节点暴力进行即可。

因为每个数最大可能到long long,所以sqrt最好不要直接用STL里边的double的,会丢掉最后几位造成精度损失..虽然这道题的数据测补出来..

但是自己写的sqrt又比STL里边的慢...好纠结...

可以用并查集记录所有的非1节点,然后用树状数组求和。

#include <cstdio>
#include <iostream>

using namespace std;

struct DisjoinSet {
int a[100002];
void clear(int n) {
for (int i=1;i<=n+1;i++) a[i]=i;
}
int get(int i) {
if (a[i]==i) return i;
return a[i]=get(a[i]);
}
void tosame(int x,int y) {
x=get(x);
y=get(y);
a[x]=y;
}
};
struct BIT {
long long a[100001];
int n;
void clear(int nn) {
n=nn;
for (int i=1;i<=n;i++) a[i]=0;
}
int lb(int i) {
return i&-i;
}
void set(int i,long long x) {
if (i==0) return;
for (;i<=n;i+=lb(i)) a[i]+=x;
}
long long get(int i) {
long long ans=0;
for (;i>0;i-=lb(i)) ans+=a[i];
return ans;
}
};
DisjoinSet c;
BIT b;
long long a[100001];
int n;

int sqrt(long long v) {
int l=1,r=1000000000;
while (l!=r) {
int t=(l+r)/2;
if ((long long)t*t>v) r=t;
else l=t+1;
}
return l-1;
}
inline long long in() {
char c=getchar();
while (c<'0'||c>'9') c=getchar();
long long ans=0;
while (c>='0'&&c<='9') {
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}

int main() {
int i,cas=1,m,x,y,z,j;
while (scanf("%d",&n)!=EOF) {
b.clear(n);
c.clear(n);
for (i=1;i<=n;i++) {
a[i]=in();
b.set(i,a[i]);
}
for (i=1;i<=n;i++) if (a[i]==1) c.tosame(i,i+1);
printf("Case #%d:\n",cas++);
scanf("%d",&m);
for (i=0;i<m;i++) {
x=in();y=in();z=in();
if (y>z) swap(y,z);
if (x==0) {
for (j=c.get(y);j<=z;j=c.get(j+1)) {
long long tmp=sqrt(a[j]);
b.set(j,tmp-a[j]);
a[j]=tmp;
if (a[j]==1) c.tosame(j,j+1);
}
} else {
printf("%lld\n",b.get(z)-b.get(y-1));
}
}
printf("\n");
}
return 0;
}


也可以用线段树记录第一个非1节点的位置以及区间和,用双向链表记录每一个非1节点的下一个非1节点。

#include <cstdio>
#include <iostream>

using namespace std;

struct SeqNode {
SeqNode *ls,*rs;
long long sum;
int first;
};
struct ListNode {
long long data;
int l,r;
};
ListNode a[100001];
SeqNode b[200000];
SeqNode *root,*bp;
int n;

void lisDel(int i) {
if (a[i].l!=-1) a[a[i].l].r=a[i].r;
if (a[i].r!=-1) a[a[i].r].l=a[i].l;
}
int sqrt(long long v) {
int l=1,r=1000000000;
while (l!=r) {
int t=(l+r)/2;
if ((long long)t*t>v) r=t;
else l=t+1;
}
return l-1;
}
void sqr(SeqNode *from,int l,int r,int i) {
if (l==r) {
a[l].data=sqrt(a[l].data);
from->sum=a[l].data;
if (a[l].data==1) {
from->first=-1;
lisDel(l);
} else {
from->first=l;
}
} else {
int t=(l+r)/2;
if (i<=t) sqr(from->ls,l,t,i);
else sqr(from->rs,t+1,r,i);
if (from->ls->first==-1) from->first=from->rs->first;
else from->first=from->ls->first;
from->sum=from->ls->sum+from->rs->sum;
}
}
SeqNode *maketree(int l,int r) {
SeqNode *ans=bp++;
if (l==r) {
ans->sum=a[l].data;
if (a[l].data==1) ans->first=-1;
else ans->first=l;
ans->ls=ans->rs=NULL;
} else {
int t=(l+r)/2;
ans->ls=maketree(l,t);
ans->rs=maketree(t+1,r);
if (ans->ls->first==-1) ans->first=ans->rs->first;
else ans->first=ans->ls->first;
ans->sum=ans->ls->sum+ans->rs->sum;
}
return ans;
}
int getfirst(SeqNode *from,int l,int r,int ll,int rr) {
if (l==ll&&r==rr) return from->first;
int t=(l+r)/2;
if (rr<=t) return getfirst(from->ls,l,t,ll,rr);
else if (ll>t) return getfirst(from->rs,t+1,r,ll,rr);
else {
int tmp=getfirst(from->ls,l,t,ll,t);
if (tmp!=-1) return tmp;
else return getfirst(from->rs,t+1,r,t+1,rr);
}
}
long long getsum(SeqNode *from,int l,int r,int ll,int rr) {
if (l==ll&&r==rr) return from->sum;
int t=(l+r)/2;
if (rr<=t) return getsum(from->ls,l,t,ll,rr);
else if (ll>t) return getsum(from->rs,t+1,r,ll,rr);
else return getsum(from->ls,l,t,ll,t)+getsum(from->rs,t+1,r,t+1,rr);
}
inline long long in() {
char c=getchar();
while (c<'0'||c>'9') c=getchar();
long long ans=0;
while (c>='0'&&c<='9') {
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}

int main() {
int i,cas=1,m,x,y,z,j;
while (scanf("%d",&n)!=EOF) {
for (i=1;i<=n;i++) {
a[i].l=i-1;
a[i].r=i+1;
a[i].data=in();
}
a[1].l=-1;a
.r=-1;
for (i=1;i<=n;i++) if (a[i].data==1) lisDel(i);
bp=b;
root=maketree(1,n);
printf("Case #%d:\n",cas++);
scanf("%d",&m);
for (i=0;i<m;i++) {
x=in();y=in();z=in();
if (y>z) swap(y,z);
if (x==0) {
for (j=getfirst(root,1,n,y,z);j!=-1&&j<=z;j=a[j].r) {
sqr(root,1,n,j);
}
} else {
printf("%lld\n",getsum(root,1,n,y,z));
}
}
printf("\n");
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: