您的位置:首页 > 其它

poj 3744 概率dp 矩阵快速幂优化

2015-07-31 17:08 393 查看
一位童子兵要穿过一条路,路上有些地方放着地雷。这位童子兵非常好玩,走路一蹦一跳的。每次他在 i 位置有 p 的概率走一步到 i+1 ,或者 (1-p) 的概率跳一步到 i+2。童子兵初始在1位置,求他安全通过这条道路的概率。

以所在位置为状态,dp[i] 表示在位置 i 的安全的概率。

dp[i] = p * dp[i-1] + (1 - p) * dp[i-2]; // i 位置没有地雷

但是题目数据的范围是 10^8 这样dp的话会 TLE。

想想可以用矩阵快速幂优化。简单推出矩阵是

|p 1-p| * |dp[i] | = |dp[i+1]|

|1 0 | |dp[i-1]| |dp[i] |

而这时地雷位置是不满足这个矩阵的,因此我们得对地雷位置进行特判。而两个地雷中间的位置可以用快速幂优化。

假设 k 位置放有地雷,,我们可以得到 dp[k+1] = dp[k-1] * (1 - p);

对于***位置为 a[i] 和 a[i+1] 之间的数,知道 dp[a[i]+1] 后可以推出

|dp[a[i+1]-1]| = |p 1-p|^(a[i+1]-1-a[i]-1) * |dp[a[i]]+1|

|dp[a[i+1]-2]| |1 0 | |dp[a[i]] |

(视0位置有颗地雷,有地雷的位置的dp值为0)

于是我们可以对两个前后两个地雷之间用快速幂优化,并最终得到答案dp[max(a[i])+1];

hint:两个地雷相邻(a[i] + 1 = a[i+1])要特判,直接快速幂的话会TLE。

/***********************************************
 ** problem ID	: poj_3744.cpp
 ** create time	: Fri Jul 31 11:12:39 2015
 ** auther name	: xuelanghu
 ** auther blog	: blog.csdn.net/xuelanghu407
 **********************************************/

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

#define rep(i) for (int i=0; i<2; i++)

struct Array {
    double m[2][2];
    void _set(int c) {
        rep(i) rep(j) {
            m[i][j] = (i == j) ? c : 0;
        }
    }
    void _show() {
        cout << "==============" << endl;
        rep(i) {rep(j) cout << m[i][j] << " "; cout << endl;}
    }
};

Array _mul (Array a, Array b) {
    Array _tmp;
    _tmp._set(0.0);
    rep(i) rep(j) rep(k) {
        _tmp.m[i][j] += a.m[i][k] * b.m[k][j];
    }
    return _tmp;
}

Array power (Array s, int p) {
    Array res;
    res._set(1);

    for (; p; p >>= 1) {
        if (p & 1) res  = _mul(res, s);
        s = _mul(s, s);
    }
    return res;
}

int _unique(int a[], int n) {
    int j=0;
    for (int i=1; i<=n; i++) {
        if (a[i] == a[j]) continue;
        a[++j] = a[i];
    }
    return j;
}

int main () {
    int N;
    int a[12];
    double p;
    for (; scanf("%d%lf", &N, &p) == 2; ) {
        a[0] = 0;
        for (int i=1; i<=N; i++) {
            scanf ("%d", &a[i]);
        }
        sort(a, a+N+1);
        int K = _unique(a, N);
        Array A;
        A.m[0][0] = p;   A.m[0][1] = 1 - p;
        A.m[1][0] = 1.0; A.m[1][1] = 0.0;

        bool flag = false;
        for (int i=1; i<=K; i++) {
            if (a[i] == a[i-1] + 1) flag = true;
        }
        if (flag) { printf ("%.7f\n", 0.0); continue; }

        double res = 1;
        Array B;
        for (int i=1 ;i<=K; i++) {
            B = power(A, a[i] - a[i-1] - 2);

            res = res * B.m[0][0] * (1 - p);
        }

        printf ("%.7f\n", res);
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: