您的位置:首页 > 理论基础 > 数据结构算法

【可持久化数据结构】函数式线段树

2012-05-19 21:05 519 查看
clj的论文很不错

总体思想就是只赋值不修改,同时充分运用历史版本,也正因为这个特性,所以可以完成在线询问历史版本的功能

这东西不同于后缀自动机基本基于原有知识就可以有直观的理解,往往平常想题的时候想到某个思路认为无法实现就直接枪毙,但这些东西恰好用函数式编程就迎刃而解,如果用以前的思维方式很有可能直接滤过,囧的就是明明拿着钥匙却偏偏认为那道门打不开,所以对于此类题目关键就是要贴合函数式编程的优势(貌似我真心不习惯)

poj2104 区间k大值

我们维护一个控制值域的线段树,假设我们已经得到我们需要的区间的值域的线段树,用类似约瑟夫环的查找方式(主要是我第一次用这种查找做的是约瑟夫环,其实就是适时判断解在左区间还是右区间),就可以找到第k个值,关键是如何得到这棵线段树。

考虑挨个插入元素,第i个插入存的线段树是[1,i]的值域线段树,由于每个数的存在满足区间减法,对于询问[l,r],他的值域线段树每个节点的值就是历史版本r对应节点减去l-1的值,这样就很方便的在logn的时间回答询问

#include <cstdio>
#include <cstdlib>
#include <cstring>
int ans,n,m,a[100001],p[100001],u[100001],q[100001],ss,ne,mid,root[100001],lim;
int l[2000000],r[2000000],ls[2000000],rs[2000000],size[2000000];
inline void updata(int lson,int rson,int x)
{
  size[++ss]=size[x]+1;
  l[ss]=lson,r[ss]=rson,ls[ss]=ls[x],rs[ss]=rs[x];
}
inline void news(int x)
{
  l[x]=++ss,ls[ss]=ls[x],rs[ss]=(ls[x]+rs[x])/2;
  r[x]=++ss,ls[ss]=rs[ss-1]+1,rs[ss]=rs[x];
}
inline int change(int x,int w)
{
  if (ls[x]==rs[x]) {size[++ss]=size[x]+1,ls[ss]=rs[ss]=ls[x];return ss;}
  mid=(ls[x]+rs[x])/2;
  if (!l[x]) news(x);
  if (w<=mid) ne=change(l[x],w),updata(ne,r[x],x);
  else ne=change(r[x],w),updata(l[x],ne,x);
  return ss;
}
inline int ask(int i,int j,int k)
{
  for (int sum;ls[j]!=rs[j];) {
    if (!l[i]) news(i);
    if (!l[j]) news(j);
    sum=size[l[j]]-size[l[i]];
    if (sum<k) k-=sum,j=r[j],i=r[i];
    else i=l[i],j=l[j];
  }
  return q[ls[j]];
}
int cmp(const void *i,const void *j) {return a[*(int *)i]-a[*(int *)j];}
void init()
{
    int i,j,l,r,k;
  scanf("%d%d\n",&n,&m);
  for (i=1;i<=n;i++) scanf("%d",&a[i]),u[i]=i;
  qsort(u+1,n,sizeof(u[1]),cmp);
  for (p[u[1]]=1,q[1]=a[u[1]],i=2;i<=n;i++)
    if (a[u[i]]!=a[u[i-1]]) p[u[i]]=i,q[i]=a[u[i]];else p[u[i]]=p[u[i-1]];
  root[0]=0,ls[0]=1,rs[0]=n;
  for (i=1,lim=0;i<=m;i++) {
    scanf("%d%d%d",&l,&r,&k);
    if (lim<r) 
	for (j=lim+1,lim=r;j<=r;j++) root[j]=change(root[j-1],p[j]);
    ans=ask(root[l-1],root[r],k);
    printf("%d\n",ans);
  }
}
int main()
{
  freopen("poj2104.in","r",stdin);
  freopen("poj2104.out","w",stdout);
   init();
  return 0;
}


rank:带修改的区间k大值

同样是维护值域线段树,但是一个元素的修改会影响o(n)棵线段树,那么我们用树状数组的思想,每课线段树存i-lowbit(i)+1~i的区间值域线段树的值,但是这样的话,就不能像无修改一样由i-1递推,而是每棵线段树重开,虽然只有o(nlogn*logn)空间消耗,但是如果强用函数式加上我没事先离散(毕竟函数式是解决在线问题的利器),足足用了170+M空间,如果换成朴素线段树还是要100+M空间,极度怀疑是我理解错误

函数式

#include <cstdio>
#include <cstdlib>
#include <cstring>
const int maxc=1000000000,max=15000000;
int size[max],l[max],r[max],root[100000],ss,mid,ne,ans,n,m,st[2][100000],t[2],a[100000];
void origin()
{
  for (int i=1;i<=n+1;i++) root[i]=i;
  ss=n+1;
}
inline void updata(int lson,int rson,int x,int w) {size[++ss]=size[x]+w;l[ss]=lson,r[ss]=rson;}
inline int change(int x,int ls,int rs,int y,int w)
{
  if (ls==rs) {updata(0,0,x,w);return ss;}
  mid=(ls+rs)>>1;
  if (y<=mid) ne=change(l[x],ls,mid,y,w),updata(ne,r[x],x,w);
  else ne=change(r[x],mid+1,rs,y,w),updata(l[x],ne,x,w);
  return ss;
}
inline int need(int e)
{
  int sum=0;
  for (int i=1;i<=t[e];i++) sum+=size[l[st[e][i]]];
  return sum;
}
inline void left(int e) {for (int i=1;i<=t[e];i++) st[e][i]=l[st[e][i]];}
inline void right(int e){for (int i=1;i<=t[e];i++) st[e][i]=r[st[e][i]];}
inline int ask(int ll,int rr,int k)
{
  t[0]=t[1]=0;
  int ls=1,rs=maxc;
  for (;ll;ll-=ll & -ll) st[0][++t[0]]=root[ll];
  for (;rr;rr-=rr & -rr) st[1][++t[1]]=root[rr];
  for (int sum1,sum2;ls!=rs;) {
    sum1=need(0),sum2=need(1);
    if (sum2-sum1<k) k-=sum2-sum1,right(0),right(1),ls=((ls+rs)>>1)+1;
    else left(0),left(1),rs=(ls+rs)>>1;
  }
  return ls;
}
void init()
{
  int i,j,x,l,r,k;
  char ch;
  scanf("%d%d\n",&n,&m);
  origin();
  for (i=1;i<=n;i++) {
    scanf("%d",&x);
    for (a[j=i+1]=x;j<=n+1;j+=j & -j) root[j]=change(root[j],1,maxc,x,1);
  }
  scanf("\n");
  for (i=1;i<=m;i++) {
    scanf("%c",&ch);
    if ('Q'==ch) {
      scanf("%d%d%d\n",&l,&r,&k);l++,r++;
      ans=ask(l-1,r,k);
      printf("%d\n",ans);
    }
    else {
      scanf("%d%d\n",&l,&x);l++;
      for (r=l;l<=n+1;l+=l & -l) root[l]=change(root[l],1,maxc,a[r],-1);a[r]=x;
      for (l=r;l<=n+1;l+=l & -l) root[l]=change(root[l],1,maxc,x,1);
    }
  }
}
int main()
{
  freopen("rank.in","r",stdin);
  freopen("rank.out","w",stdout);
   init();
  return 0;
}


朴素

#include <cstdio>
#include <cstdlib>
#include <cstring>
const int maxc=1000000000,max=8780000;
int size[max],l[max],r[max],root[100000],ss,mid,ne,ans,n,m,st[2][100000],t[2],a[100000];
void origin()
{
  for (int i=1;i<=n+1;i++) root[i]=i;
  ss=n+1;
}
inline void change(int x,int ls,int rs,int y,int w)
{
    size[x]+=w;
    if (ls==rs) return ;
    mid=(ls+rs)>>1;
    if (y<=mid) {
	if (!l[x]) l[x]=++ss;
	change(l[x],ls,mid,y,w);
    }
    else {
	if (!r[x]) r[x]=++ss;
	change(r[x],mid+1,rs,y,w);
    }
}
inline int need(int e)
{
  int sum=0;
  for (int i=1;i<=t[e];i++) sum+=size[l[st[e][i]]];
  return sum;
}
inline void left(int e) {for (int i=1;i<=t[e];i++) st[e][i]=l[st[e][i]];}
inline void right(int e){for (int i=1;i<=t[e];i++) st[e][i]=r[st[e][i]];}
inline int ask(int ll,int rr,int k)
{
  t[0]=t[1]=0;
  int ls=1,rs=maxc;
  for (;ll;ll-=ll & -ll) st[0][++t[0]]=root[ll];
  for (;rr;rr-=rr & -rr) st[1][++t[1]]=root[rr];
  for (int sum1,sum2;ls!=rs;) {
    sum1=need(0),sum2=need(1);
    if (sum2-sum1<k) k-=sum2-sum1,right(0),right(1),ls=((ls+rs)>>1)+1;
    else left(0),left(1),rs=(ls+rs)>>1;
  }
  return ls;
}
void init()
{
  int i,j,x,l,r,k;
  char ch;
  scanf("%d%d\n",&n,&m);
  origin();
  for (i=1;i<=n;i++) {
    scanf("%d",&x);
    for (a[j=i+1]=x;j<=n+1;j+=j & -j) change(root[j],1,maxc,x,1);
  }
  scanf("\n");
  for (i=1;i<=m;i++) {
    scanf("%c",&ch);
    if ('Q'==ch) {
      scanf("%d%d%d\n",&l,&r,&k);l++,r++;
      ans=ask(l-1,r,k);
      printf("%d\n",ans);
    }
    else {
      scanf("%d%d\n",&l,&x);l++;
      for (r=l;l<=n+1;l+=l & -l) change(root[l],1,maxc,a[r],-1);a[r]=x;
      for (l=r;l<=n+1;l+=l & -l) change(root[l],1,maxc,x,1);
    }
  }
}
int main()
{
  freopen("rank.in","r",stdin);
  freopen("rank.out","w",stdout);
   init();
  return 0;
}


spoj cot1

询问树上k大值

用类似区间k大值思想,只不过第i个元素插入是在root[i]线段树上修改,询问的时候减去2*lca线段树上的值,这类问题大约都是这样子的吧

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
const int oo=1073741819;
int b,e,ne,mid,n,m,ans;
int root[100001],tail[100001],next[500000],sora[500000],f[100001][18],a[100001];
int l[2000000],r[2000000],size[2000000],ss,s1,d[100001],q[100001],u[100001],p[100001],st[100001];
void origin() {for (int i=1;i<=n;i++) tail[i]=i;s1=n;}
inline void updata(int lson,int rson,int x)
{
  size[++ss]=size[x]+1,l[ss]=lson,r[ss]=rson;
}
inline int change(int x,int ls,int rs,int y)
{
  if (ls==rs) {updata(0,0,x);return ss;}
  mid=(ls+rs)>>1;
  if (y<=mid) ne=change(l[x],ls,mid,y),updata(ne,r[x],x);
  else ne=change(r[x],mid+1,rs,y),updata(l[x],ne,x);
  return ss;
}
void bfs(int s)
{
  int h,r,i,ne,na,j,k;
  memset(d,127,sizeof(d));
  h=r=0;
  st[r=1]=s,d[s]=0;
  root[s]=change(0,1,n,p[s]);
  for (;h<r;) {
    ne=st[++h];
    for (i=ne;next[i]!=0;) {
      i=next[i],na=sora[i];
      if (d[na]>oo) d[na]=d[ne]+1,st[++r]=na,root[na]=change(root[ne],1,n,p[na]),f[na][0]=ne;
    }
  }
  k=(int)log2(d[st[r]]);
  for (j=1;j<=k;j++)
    for (i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
}
inline int ask(int i,int j,int v,int k)
{
  int ls=1,rs=n,sum;
  for (;ls!=rs;) {
    sum=size[l[i]]-(size[l[j]]<<1)+size[l[v]];
    if (p[ne]<=(ls+rs)>>1 && p[ne]>=ls) sum++;
    if (sum<k) k-=sum,i=r[i],j=r[j],v=r[v],ls=((ls+rs)>>1)+1;
    else i=l[i],j=l[j],v=l[v],rs=(ls+rs)>>1;
  }
  return q[ls];
}
inline int lca(int x,int y)
{
  if (d[x]<d[y]) e=x,x=y,y=e;
  for (e=d[x]-d[y],b=0;e;e>>=1,b++) if (e&1) x=f[x][b];
  if (x==y) return x;
  for (e=1;;) {
    if (f[x][e]==f[y][e]) 
      if (e) e--;else break;
    else x=f[x][e],y=f[y][e],e++;
  }
  return f[x][e];
}
inline void link(int x,int y) 
{
  s1++,next[tail[x]]=s1,tail[x]=s1,sora[s1]=y;
  s1++,next[tail[y]]=s1,tail[y]=s1,sora[s1]=x;
}
inline int cmp(const void *i,const void *j) {return a[*(int *)i]-a[*(int *)j];}
void init()
{
  int i,x,y,k,l,r;
  scanf("%d%d\n",&n,&m);
  origin();
  for (i=1;i<=n;i++) scanf("%d",&a[i]),u[i]=i;
  qsort(u+1,n,sizeof(u[1]),cmp);
  for (p[u[1]]=1,q[1]=a[u[1]],i=2;i<=n;i++) 
    if (a[u[i]]!=a[u[i-1]]) p[u[i]]=i,q[i]=a[u[i]];else p[u[i]]=p[u[i-1]];
  for (i=1;i<=n-1;i++) {
    scanf("%d%d\n",&x,&y);
    link(x,y);
  }
  bfs(1);
  for (i=1;i<=m;i++) {
    scanf("%d%d%d\n",&l,&r,&k);
    if (r==l) {printf("%d\n",a[l]);continue;}
    ne=mid=lca(l,r);
    ans=ask(root[l],root[mid],root[r],k);
    printf("%d\n",ans);
  }
}
int main()
{
  freopen("spojcot1.in","r",stdin);
  freopen("spojcot1.out","w",stdout);
   init();
  return 0;
}


clj middle

一个长度为n的序列a,设其排过序之后为b,其中位数定义为b[n/2],其中a,b从0开始标号,除法取下整。

给你一个长度为n的序列s。

回答Q个这样的询问:s的左端点在[a,b]之间,右端点在[c,d]之间的子序列中,最大的中位数。

其中a<b<c<d。

位置也从0开始标号。

我会使用一些方式强制你在线。

这道题反映出对函数式编程解题还完全不熟悉,想了半天还是在传统解法兜圈,最后看了solution的提示才想出解法

首先是解决最大中位数的常用方法,二分答案用区间(小为-1大为1)最大和判断,问题是不能每次都裸扫来算,因为没有修改,所以可能成为答案的值有o(n)个,排过序后依次插入,用线段树维护区间和及最大和(初始全为1,之后依次改为-1),每次就可以在o(logn)的时间完成判断

#include <cstdio>
#include <cstdlib>
#include <cstring>
int n,m,a[50000],u[50000],p[50000],q[50000],as[5],st[500000],tail[500000],next[500000],sora[500000],root[500000];
int sum[3000000],lsum[3000000],rsum[3000000],s1,ss,l[3000000],r[3000000],mid,ne,t,ans,e;
void origin() {for (int i=1;i<=n;i++) tail[i]=i;s1=n;}
inline int max(int x,int y) {return (x>y) ? x : y;}
inline void link(int x,int y) {s1++,next[tail[x]]=s1,tail[x]=s1,sora[s1]=y;}
inline void updata(int lson,int rson,int ls,int rs)
{
  int mid=(ls+rs)>>1;
  ss++,l[ss]=lson,r[ss]=rson;
  if (!l[ss] && !r[ss]) lsum[ss]=rsum[ss]=sum[ss]=-1;
  else if (!l[ss]) {
    sum[ss]=mid-ls+1+sum[r[ss]];
    lsum[ss]=mid-ls+1+max(0,lsum[r[ss]]);
    rsum[ss]=max(rsum[r[ss]],sum[ss]);
  }
  else if (!r[ss]) {
    sum[ss]=sum[l[ss]]+rs-mid;
    lsum[ss]=max(lsum[l[ss]],sum[ss]);
    rsum[ss]=rs-mid+max(0,rsum[l[ss]]);
  }
  else {
    sum[ss]=sum[l[ss]]+sum[r[ss]];
    lsum[ss]=max(lsum[l[ss]],sum[l[ss]]+lsum[r[ss]]);
    rsum[ss]=max(rsum[r[ss]],sum[r[ss]]+rsum[l[ss]]);
  }
}
inline int change(int x,int ls,int rs,int y)
{
  mid=(ls+rs)>>1;
  if (ls==rs) {updata(0,0,ls,rs);return ss;}
  if (y<=mid) ne=change(l[x],ls,mid,y),updata(ne,r[x],ls,rs);
  else ne=change(r[x],mid+1,rs,y),updata(l[x],ne,ls,rs);
  return ss;
}
inline void newsl(int x,int ls,int rs)
{
  int mid=(ls+rs)>>1;
  l[x]=++ss,sum[ss]=lsum[ss]=rsum[ss]=mid-ls+1;
}
inline void newsr(int x,int ls,int rs)
{
  int mid=(ls+rs)>>1;
   r[x]=++ss,sum[ss]=lsum[ss]=rsum[ss]=rs-mid;
}
inline void find(int x,int ls,int rs,int ll,int rr)
{
  if (ls==ll && rr==rs) {st[++t]=x;return ;}
  int mid=(ls+rs)>>1;
  if (rr<=mid) {
    if (!l[x]) newsl(x,ls,rs);
    find(l[x],ls,mid,ll,rr);
  }
  else if (ll>=mid+1) {
    if (!r[x]) newsr(x,ls,rs);
    find(r[x],mid+1,rs,ll,rr);
  }
  else {
    if (!l[x]) newsl(x,ls,rs);find(l[x],ls,mid,ll,mid);
    if (!r[x]) newsr(x,ls,rs);find(r[x],mid+1,rs,mid+1,rr);
  }
}
inline int check(int x)
{
  int ans=0,lans=0,rans=0,i;
  t=0,find(x,1,n,as[1],as[2]);
  for (i=t,ans=-1;i>=1;i--) {
    ans=max(ans,rsum[st[i]]+rans);
    rans+=sum[st[i]];
  }
  t=0;
  if (as[2]+1<=as[3]-1) find(x,1,n,as[2]+1,as[3]-1);
  for (i=1;i<=t;i++) ans+=sum[st[i]];
  t=0,find(x,1,n,as[3],as[4]);
  for (i=1,rans=-1;i<=t;i++) {
    rans=max(rans,lsum[st[i]]+lans);
    lans+=sum[st[i]];
  }
  ans+=rans;
  return ans;
}
inline int cmp(const void *i,const void *j) {return a[*(int *)i]-a[*(int *)j];}
void init()
{
  int i,j,k,l,r,mid,ne,maxc,h;
  scanf("%d\n",&n);
  origin();
  for (i=1;i<=n;i++) scanf("%d\n",&a[i]),u[i]=i;
  qsort(u+1,n,sizeof(u[1]),cmp);
  p[u[1]]=1;q[1]=a[u[1]];
  link(1,u[1]);
  for (i=2,k=1;i<=n;i++)
    if (a[u[i]]!=a[u[i-1]]) p[u[i]]=++k,q[k]=a[u[i]],link(k,u[i]);
    else p[u[i]]=k,link(k,u[i]);
  maxc=k,sum[0]=lsum[0]=rsum[0]=n;
  for (i=1,h=0;i<=maxc;i++) {
    for (j=i;next[j]!=0;) {
      j=next[j],ne=sora[j];
      root[i]=change(h,1,n,ne);
      h=root[i];
    }
  }
  scanf("%d\n",&m);
  for (ans=0,i=1;i<=m;i++) {
    scanf("%d%d%d%d\n",&as[1],&as[2],&as[3],&as[4]);
    // for (j=1;j<=4;j++) as[j]=(as[j]+ans)%n;
    // for (j=1;j<=3;j++)
    //   for (k=j+1;k<=4;k++)
    // 	if (as[j]>as[k]) e=as[j],as[j]=as[k],as[k]=e;
    for (j=1;j<=4;j++) as[j]++;
    for (l=1,r=maxc;l<=r;) {
      mid=(l+r)>>1;
      if (check(root[mid])>=0) l=mid+1;else r=mid-1;
    }
    ans=q[l];
    printf("%d\n",ans);
  }
}
int main()
{
  freopen("middle.in","r",stdin);
  freopen("middle.out","w",stdout);
   init();
  return 0;
}


2010 天津 J

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
const long long lim=(1LL<<31);
using namespace std;
int ss,n;
int l[6000000],r[6000000];
long long ls[6000000],rs[6000000];
int size[6000000];
int root[2000000];
char ch[2000];
int ori()
{
    ++ss;
    l[ss]=r[ss]=ls[ss]=rs[ss]=0;
    size[ss]=0;
    return ss;
}
void origin()
{
    ss=0;
    root[0]=ori(),ls[ss]=1,rs[ss]=lim;
}
int updata(int lson,int rson,int x)
{
    ori();
    size[ss]=size[x]+1;
    l[ss]=lson,r[ss]=rson,ls[ss]=ls[x],rs[ss]=rs[x];
    return ss;
}
void news(int x)
{
    l[x]=ori(),ls[ss]=ls[x],rs[ss]=(ls[x]+rs[x])>>1;
    r[x]=ori(),ls[ss]=rs[ss-1]+1,rs[ss]=rs[x];
}
int change(int x,int w)
{
    if (ls[x]==rs[x]) {
        ori();
        size[ss]=size[x]+1,ls[ss]=rs[ss]=ls[x];
        return ss;
    }
    long long mid=((ls[x]+rs[x])>>1);
    if (!l[x]) news(x);
    if (w<=mid) {
        int ne=change(l[x],w);
        return updata(ne,r[x],x);
    }
    else {
        int ne=change(r[x],w);
        return updata(l[x],ne,x);
    }
}
long long ask(int L,int R,int k)
{
    int sum=0;
    L--;
    L=root[L],R=root[R];
    for (;ls[R]!=rs[R];) {
        if (!l[L]) news(L);
        if (!l[R]) news(R);
        sum=size[l[R]]-size[l[L]];
//        cout<<sum<<' '<<ls[L]<<' '<<rs[L]<<endl;
        if (sum<k) k-=sum,L=r[L],R=r[R];
        else L=l[L],R=l[R];
    }
    return ls[R];
}
int ask2(int x,int w)
{
    int sum=0;
    x=root[x];
    for (;ls[x]!=rs[x];) {
        if (!l[x]) news(x);
        long long mid=(ls[x]+rs[x])>>1;
        if (w<=mid) x=l[x];
        else {
            sum+=size[l[x]];
            x=r[x];
        }
    }
    sum+=size[x];
    return sum;
}
int main()
{
    for (int test=1;scanf("%d",&n)==1;test++) {
        printf("Case %d:\n",test);
        origin();
        int tot=0;
        long long sum1=0,sum2=0,sum3=0;
        for (int i=1;i<=n;i++) {
            scanf("%s",ch+1);
            if (ch[1]=='I') {
                int x;
                scanf("%d",&x);
                ++tot;
                root[tot]=change(root[tot-1],x);
            }
            else if (ch[1]=='Q' && ch[7]=='1') {
                int L,R,k;
                scanf("%d%d%d",&L,&R,&k);
                int ans=ask(L,R,k);
//                printf("%d\n",ans);
                sum1+=ans;
            }
            else if (ch[1]=='Q' && ch[7]=='2') {
                int x;
                scanf("%d",&x);
                int ans=ask2(tot,x);
//                printf("%d\n",ans);
                sum2+=ans;
            }
            else if (ch[1]=='Q' && ch[7]=='3') {
                int k;
                scanf("%d",&k);
                int ans=ask(1,tot,k);
//                printf("%d\n",ans);
                sum3+=ans;
            }
        }
        cout<<sum1<<'\n'<<sum2<<'\n'<<sum3<<endl;
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: