您的位置:首页 > 其它

莫队算法

2016-01-26 10:37 183 查看
莫队算法常用于处理不带修改的连续区间的询问问题上,一http://般来说都有以下这种形式:

给你n(1e5左右)个数字A[i],再给你q(1e5左右)个询问,每个询问包含一个l和r,表示用某种操作,对原数组的A[l...r]进行操作,并求出值。

一般来说,用莫队算法的情况下,我们已知A[l...r]的值,要转移到A[l'...r']的值需要花费为O((|l - l'| + |r - r'|) * T)的话,那么总的时间复杂度就为O((n + q) * sqrt(n) * T).这里的证明不再说明,具体可见最小曼哈顿生成树。

莫队算法一定是离线算法,其中先保存下来所有的询问。对于每个询问排序,设S为sqrt(n),询问要按照如下排序,

bool cmp(query a,query b){
return (a.l / S != b.l / S) ? a.l / S < b.l / S : a.r < b.r;
}


之后对于每一个询问,直接暴力转移即可。

例题1:

CF 86D

题意:

有一个数组A有n个数字,t个询问,对于每个询问有个l,r,计算A[l...r]中所有数字乘以它出现次数的平方的和。

解法就强行搬上面的结论就可以了

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
typedef long long ll;
using namespace std;
const int maxn = 200005;

ll mark[maxn * 5];
int n,t;
ll now;
int blo;
struct ppp{
int l,r,id;
void read(int _id){
scanf("%d%d",&l,&r);
id = _id;
}
bool operator < (const ppp & nex)const{
return (l / blo != nex.l / blo) ? l / blo < nex.l / blo : r > nex.r;
}
}que[maxn];

void init(){
blo = sqrt(n);
now = 0;
}
int ori[maxn];
void add(int x){
now += (2 * mark[x] + 1) * x;
mark[x]++;
}
void del(int x){
now += (-2 * mark[x] + 1) * x;
mark[x]--;
}
ll ans[maxn];

int main(){
while(cin>>n>>t){
init();
for(int i = 1;i <= n;i++)scanf("%d",ori + i);
for(int i = 0;i < t;i++)que[i].read(i);
sort(que,que + t);
int l = 1,r = 0;
for(int i = 0;i < t;i++){
while(que[i].l < l)add(ori[--l]);
while(que[i].l > l)del(ori[l++]);
while(que[i].r < r)del(ori[r--]);
while(que[i].r > r)add(ori[++r]);
ans[que[i].id] = now;
}
for(int i = 0;i < t;i++)printf("%I64d\n",ans[i]);
}
}


例题2:

CF 617E

题意:

给你一个数组A有n个数字,给你m个询问,每个询问l,r,要你求出A[l...r]中有多少不同组i,j(l <= i <= j <= r)满足A[i] xor A[i + 1] xor ...A[j]的值为k,输出组数。

题解:

那么我们只要保存A[1]到A[i](for all i)的异或和表示为func[i],那么A[i] xor ...A[j]就是func[i - 1] xor func[j],那么就可以写了。要注意的是当前数组在往左边移动时,要计算到A[i - 1]为止,往右边移动时到A[j]就好,这样才满足func[i - 1] xor func[j]。

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
typedef long long ll;
using namespace std;
const int maxn = 100005;

int n,m,k;
struct ppp{
int l,r,id;
void read(int _id){
scanf("%d%d",&l,&r);
id = _id;
}
}que[maxn];
int ori[maxn];
int xor_sum[maxn];
int sqrt_n;
ll ans[maxn];
ll cnt[1 << 20];
bool cmp(ppp a,ppp b){
return (a.l / sqrt_n != b.l / sqrt_n) ? a.l / sqrt_n < b.l / sqrt_n : a.r > b.r;
}

void show(){
for(int i= 1;i <= n;i++)cout<<xor_sum[i ^ ori[i]]<<" ";
cout<<endl;
for(int i = 1;i <= n;i++)cout<<xor_sum[i]<<" ";
cout<<endl;
}

ll anss;
void add(int x){
anss += cnt[x ^ k];
cnt[x]++;
}
void del(int x){
cnt[x]--;
anss -= cnt[x ^ k];
}

int main(){
while(cin>>n>>m>>k){
for(int i = 1;i <= n;i++)scanf("%d",&ori[i]);
for(int i = 0;i < m;i++)que[i].read(i);
mem(cnt,0);
xor_sum[0] = 0;
anss = 0;
for(int i = 1;i <= n;i++)xor_sum[i] = xor_sum[i - 1] ^ ori[i];
sqrt_n = sqrt(n);
sort(que,que + m,cmp);
int prel = 1,prer = 0;
for(int i = 0;i < m;i++){
while(que[i].l - 1 > prel)del(xor_sum[prel++]);//左边都是维护l - 1,右边维护r
while(que[i].l - 1 < prel)add(xor_sum[--prel]);//那么异或的结果就是l ~ r之间的
while(que[i].r > prer)add(xor_sum[++prer]);
while(que[i].r < prer)del(xor_sum[prer--]);
int &id = que[i].id;
ans[id] = anss;
}
for(int i = 0;i < m;i++)printf("%I64d\n",ans[i]);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: