您的位置:首页 > 其它

【bzoj3238】[Ahoi2013]差异 后缀数组+单调栈

2015-12-06 19:01 471 查看
首先求出height数组,原式很明显可以化成一堆长度的和-两两LCP的和,所以我们考虑每个height能充当多少个区间的最小值即可,那么这个问题可以用单调栈解决,从左和从右各维护一个单调递增的单调栈,求出点i向左和向右分别最多能延伸多长。

注意:

1.height数组的[i,i]是要计入区间数的,因为我们查询lcp(i,j)的时候查询的是height数组中rank[i]+1~rank[j]的最小值,所以[i,i]这个区间其实在原串中是对应的两个问题。

2.注意处理相等的情况,如果两边都是维护单调递增的单调栈的话,那么会少算一部分情况,所以我们一边用单调递增的,另一边用单调不降的就可以了。证明的话自己画个图应该能看出来。

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<iostream>
#define maxn 500010

using namespace std;

int sa[maxn],height[maxn],rank[maxn];
int wa[maxn],wb[maxn],wc[maxn];
char s[maxn];
int st[maxn],l[maxn],r[maxn];
int n,m,top;

void get_sa()
{
	m=200;
	int *x=wa,*y=wc,*t;
	for (int i=0;i<=m;i++) wb[i]=0;
	for (int i=1;i<=n;i++) wb[x[i]=s[i]]++;
	for (int i=1;i<=m;i++) wb[i]+=wb[i-1];
	for (int i=n;i>=1;i--) sa[wb[x[i]]--]=i;
	for (int j=1,p=0;p<n;j*=2,m=p)
	{
		p=0;
		for (int i=n-j+1;i<=n;i++) y[++p]=i;
		for (int i=1;i<=n;i++) if (sa[i]>j) y[++p]=sa[i]-j;
		for (int i=0;i<=m;i++) wb[i]=0;
		for (int i=1;i<=n;i++) wb[x[y[i]]]++;
		for (int i=1;i<=m;i++) wb[i]+=wb[i-1];
		for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--]=y[i];
		t=x;x=y;y=t;
		p=1;x[sa[1]]=1;
		for (int i=2;i<=n;i++)
		  if (y[sa[i]]==y[sa[i-1]] && y[sa[i]+j]==y[sa[i-1]+j]) x[sa[i]]=p;
		  else x[sa[i]]=++p;
	}
}

void get_height()
{
	for (int i=1;i<=n;i++) rank[sa[i]]=i;
	int k=0;
	for (int i=1;i<=n;i++)
	{
		if (k) k--;
		if (rank[i]==1) continue;
		int j=sa[rank[i]-1];
		while (s[i+k]==s[j+k]) k++;
		height[rank[i]]=k;
	}
}

int main()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	get_sa();
	get_height();
	height[0]=-1;height[n+1]=-1;
	top=1;st[1]=0;
	for (int i=1;i<=n;i++)
	{
		while (top && height[st[top]]>=height[i]) top--;
		l[i]=st[top]+1;
		st[++top]=i;
	}
	top=1;st[1]=n+1;
	for (int i=n;i>=1;i--)
	{
		while (top && height[st[top]]>height[i]) top--;
		r[i]=st[top]-1;
		st[++top]=i;
	}
	long long ans=0;
	for (int i=1;i<=n;i++) ans-=(long long)height[i]*(long long)(i-l[i]+1)*(long long)(r[i]-i+1)*2;
	for (int i=1;i<=n;i++) ans+=(long long)(n-i+1)*(n-1);
	printf("%lld\n",ans);
	return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: