您的位置:首页 > 其它

hdu 4578 Transformation(线段树区间更新)

2015-08-29 09:46 302 查看

题意:

给你一个数组,初始值为零,有四种操作

(1)”1 x y c”,代表 把区间 [x,y] 上的值全部加c

(2)”2 x y c”,代表 把区间 [x,y] 上的值全部乘以c

(3)”3 x y c” 代表 把区间 [x,y]上的值全部赋值为c

(4)”4 x y p” 代表 求区间 [x,y] 上值的p次方和1<=p<=3

解析:

我们可以把区间内的数字看做 ax+bax + b 的形式。

xx 是原来的值,aa 当前区间乘上的值,bb 是当前区间要加上的值。

因为P只有1到3,所以我们可以开3个数组来保存每个次方的和,分别是sum[1]sum[1],sum[2]sum[2],sum[3]sum[3]。

对于加cc操作,则变成 ax+b+cax+b+c(b−>b+cb->b+c)。

对于乘cc操作,则变成 acx+bcacx+bc(a−>ac,b−>bca->ac,b->bc)

对于赋值c操作,则变成c,即(a−>1,x−>c,b−>0a->1,x->c,b->0)

这里要解决的一个问题是加上或乘上一个数对这个区间的p次方和分别产生什么改变,很简单,一化简就能得到。

乘上一个数aa之后

sum[1]=a∗sum[1]sum[1] = a * sum[1]

sum[2]=a2∗sum[2]sum[2] = a^2 * sum[2]

sum[3]=a3∗sum[3]sum[3] = a^3 * sum[3]

当加上一个数bb之后

sum[1]=sum[1]+len∗bsum[1] = sum[1] + len*b

sum[2]=sum[2]+2∗sum[1]∗b+len∗bsum[2] = sum[2] + 2*sum[1]*b + len*b

sum[3]=sum[3]+3∗b2∗sum[1]+3∗b∗sum[2]+len∗b3sum[3] = sum[3] + 3 * b^2 * sum[1] + 3 * b * sum[2] + len * b^3

其中lenlen是当前区间的长度。

mymy codecode

[code]#include <cstdio>
#include <cstring>
#include <algorithm>
#define ls (o<<1)
#define rs (o<<1|1)
#define lson ls, L, M
#define rson rs, M+1, R
using namespace std;
const int MOD = (int)1e4 + 7;
const int MAXN = (int)1e5 + 10;
int n, m;

struct Node {
    int L, R;
    int sum[4];
    int mult, addv;

    inline int length() { return R - L + 1; }

    void multiply(int val) {
        mult = (mult * val) % MOD;
        addv = (addv * val) % MOD;
        for(int i = 1; i <= 3; i++) {
            for(int p = 1; p <= i; p++) {
                sum[i] = (sum[i] * val) % MOD;
            }
        }
    }

    void add(int val) {
        int len = length();
        addv = (addv + val) % MOD;

        sum[3] = (sum[3] + 3 * val % MOD * val % MOD * sum[1] % MOD) % MOD;
        sum[3] = (sum[3] + 3 * val % MOD * sum[2] % MOD) % MOD;
        sum[3] = (sum[3] + len * val % MOD * val % MOD * val % MOD) % MOD;

        sum[2] = (sum[2] + 2 * val % MOD * sum[1] % MOD) % MOD;
        sum[2] = (sum[2] + len * val % MOD * val % MOD) % MOD;

        sum[1] = (sum[1] + len * val % MOD) % MOD;
    }

    void cal(int MUL, int ADD) {
        multiply(MUL);
        add(ADD);
    }

} node[MAXN << 2];

void pushDown(int o) {
    if(node[o].mult != 1 || node[o].addv != 0) {
        node[ls].cal(node[o].mult, node[o].addv);
        node[rs].cal(node[o].mult, node[o].addv);
        node[o].mult = 1, node[o].addv = 0;
    }
}

void pushUp(int o) {
    for(int i = 1; i <= 3; i++) {
        node[o].sum[i] = (node[ls].sum[i] + node[rs].sum[i]) % MOD;
    }
}

void build(int o, int L, int R) {
    node[o].L = L, node[o].R = R;
    node[o].addv = 0, node[o].mult = 1;
    memset(node[o].sum, 0, 4*sizeof(int));
    if(L == R) return ;
    int M = (L + R)/2;
    build(lson);
    build(rson);
}

int query(int o, int L, int R, int ql, int qr, int p) {
    if(ql <= L && R <= qr) return node[o].sum[p];
    int M = (L + R)/2, ret = 0;
    pushDown(o);
    if(ql <= M) ret = (ret + query(lson, ql, qr, p)) % MOD;
    if(qr > M) ret = (ret + query(rson, ql, qr, p)) % MOD;
    return ret;
}

void modify(int o, int L, int R, int ql, int qr, int val, int op) {
    if(ql <= L && R <= qr) {
        if(op == 1) node[o].cal(1, val);
        else if(op == 2) node[o].cal(val, 0);
        else node[o].cal(0, val);
        return ;
    }
    int M = (L + R)/2;
    pushDown(o);
    if(ql <= M) modify(lson, ql, qr, val, op);
    if(qr > M) modify(rson, ql, qr, val, op);
    pushUp(o);
}

int main() {
    int op, ql, qr, val;
    while(~scanf("%d%d", &n, &m) && (n || m)) {
        build(1, 1, n);
        while(m--) {
            scanf("%d%d%d%d", &op, &ql, &qr, &val);
            if(op == 4) {
                printf("%d\n", query(1, 1, n, ql, qr, val) % MOD);
            }else {
                modify(1, 1, n, ql, qr, val, op);
            }
        }
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: