您的位置:首页 > 其它

HDU 4777 Rabbit Kingdom(树状数组+离线处理+尺取法)

2015-09-02 20:53 393 查看

题意:

给你n个数,有m个查询,问区间[L,R]之间有多少个数与这个区间内的其他数都互质。

解析:

很显然,[L,R][L,R]区间内的答案就是一个区间内的数的个数,减去与其他数不互质的数即可,即离当前数 aia_i 左边最近的不互质的数的位置(设为L[i])和右边最近的不互质的数的位置(设为R[i])有一个在区间[L,R][L,R]内。


那么问题就变成统计:

(1) 区间[L,R]中有多少个数的 L[i] 或 R[i] 在区间[L,R]内。

(2) 多少个数的 L[i]且R[i] 在区间[L,R]内。


对于每个询问,

令区间内的个数为lenlen,(1)的结果个数为cnt1,(2)的结果的个数为cnt2

那么区间内不合法的个数就是 (cnt1−cnt2)(cnt1 - cnt2)

那么每次询问的答案就是 len−(cnt1−cnt2)len - (cnt1 - cnt2)


(2)(2)的结果其实就是询问有多少个区间 [L[i],R[i]][L[i], R[i]] 完全在给定区间 [L,R][L,R] 内。


其实(1)(1)也可以转化为相同的问题,即区间[L[i],i][L[i],i]或[i,R[i]][i, R[i]],是否在给定区间内。

具体实现

对于如何求出一个a[i]a[i]的最大区间[L[i],R[i]][L[i], R[i]]?

可以分解成两次计算,先算L[i]L[i],可以枚举每个a[i]a[i],对a[i]a[i]分解质因子,令pos[num]表示每个质因子num出现的位置。

那么当前的L[i]=max(L[i],pos[num])L[i]=max(L[i], pos[num]),其中numnum就是a[i]a[i]的质因子。

并更新pos[num]=ipos[num] = i

求R[i]R[i]和L[i]L[i]类似


对于问有多少个区间是在给定的区间内?

可以直接离线,先离线处理出4个区间。

QueryQuery表示查询的区间

rad[0]=[L[i],i]rad[0] = [L[i], i]

rad[0]=[i,R[i]]rad[0] = [i, R[i]]

rad[0]=[L[i],R[i]]rad[0] = [L[i], R[i]]

先将所有的区间按照右边界进行排序,这样右边界就满足了单调性。

可以用尺取法,来枚举左端点,然后再利用树状数组,来维护左端点出现的次数,然后每次询问树状数组上,每个区间左端点总计出现了多少次,这就是完全包含区间的个数,并累加到一个离线查询的数组上。


这样cnt1和cnt2就求出来了。

最后的答案就是len−(cnt1−cnt2)len - (cnt1 - cnt2)


如果还有不理解的请看下面的代码。

mymy codecode

[code]#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define pb push_back
#define LEN(L, R) ((R) - (L) + 1)
using namespace std;
const int N = (int)2e5 + 10;
int n, m;
int a
;

struct BIT {
    int C
;
    void clear() { memset(C, 0, sizeof(C)); }
    inline int lowbit(int x) { return x&(-x); }  

    void add(int pos) {
        if(pos==0 || pos==n+1) return;
        for(int i = pos; i <= n; i += lowbit(i)) C[i]++;
    }

    int query(int st, int ed) {
        int ans=0;
        for(int i = ed; i >= 1; i -= lowbit(i)) ans += C[i];  
        for(int i = st-1; i >= 1; i -= lowbit(i)) ans -= C[i];  
        return ans;
    }
} bit;

struct Segment { int st, en, id; };

bool cmp(Segment a, Segment b) {
    return a.en < b.en;
}

vector<Segment> Query, rad[3];
int length
, cnt[3]
;

int L
, R
, pos
;
void getLeft() {
    for(int i = 0; i < N; i++)
        L[i] = pos[i] = 0;
    for(int i = 1; i <= n; i++) {
        int tmp = a[i];
        for(int j = 2; j*j <= tmp; j++) {
            if(tmp % j != 0) continue;
            L[i] = max(L[i], pos[j]);
            pos[j] = i;
            while(tmp % j == 0) tmp /= j;
        }
        if(tmp != 1) {
            L[i] = max(L[i], pos[tmp]);
            pos[tmp] = i;
        }
    }
}

void getRight() {
    for(int i = 0; i < N; i++)
        R[i] = pos[i] = n+1;
    for(int i = n; i >= 1; i--) {
        int tmp = a[i];
        for(int j = 2; j*j <= tmp; j++) {
            if(tmp % j != 0) continue;
            R[i] = min(R[i], pos[j]);
            pos[j] = i;
            while(tmp % j == 0) tmp /= j;
        }
        if(tmp != 1) {
            R[i] = min(R[i], pos[tmp]);
            pos[tmp] = i;
        }
    }
}

void init() {
    memset(cnt, 0, sizeof(cnt));
    Query.clear();
    for(int i = 0; i < 3; i++)
        rad[i].clear();
}

void getCnt(int x) {
    bit.clear();
    int st, en, id;
    int cur = 0;
    for(int i = 0; i < m; i++) {
        st = Query[i].st, en = Query[i].en;
        id = Query[i].id;
        while(cur < n && rad[x][cur].en <= en) {
            bit.add(rad[x][cur].st);
            cur++;
        }
        cnt[x][id] += bit.query(st, en);
    }
}

void prepare() {
    getLeft(); getRight();

    int st, en;
    for(int i = 0; i < m; i++) {
        scanf("%d%d", &st, &en);
        length[i] = LEN(st, en);
        Query.pb((Segment){st, en, i});
    }

    for(int i = 1; i <= n; i++) {
        rad[0].pb((Segment){L[i], i, 0});
        rad[1].pb((Segment){i, R[i], 0});
        rad[2].pb((Segment){L[i], R[i], 0});
    }

    sort(Query.begin(), Query.end(), cmp);
    for(int i = 0; i < 3; i++)
        sort(rad[i].begin(), rad[i].end(), cmp);
}

int main() {
    while(~scanf("%d%d", &n, &m) && (n || m)) {
        init();
        for(int i = 1; i <= n; i++)
            scanf("%d", &a[i]);

        prepare();
        for(int i = 0; i < 3; i++)
            getCnt(i);

        for(int i = 0; i < m; i++) {
            int ans = length[i] - (cnt[0][i] + cnt[1][i] - cnt[2][i]);
            printf("%d\n", ans);
        }
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: