您的位置:首页 > 其它

Baby Step Giant Step 及扩展 进一步解释补充 和 poj 3243

2014-05-10 10:47 405 查看
先 mark一下 后面进行补充 想办法改进和补充说明(标为红色)。如有错误,请指正

转载自作者AekdyCoin !

【普通Baby Step Giant Step】

【问题模型】

求解

A^x = B (mod C) 中 0 <= x < C 的解,C 为素数

【思路】

我们可以做一个等价

x = i * m + j ( 0 <= i < m, 0 <=j < m) m = Ceil ( sqrt( C) )

而这么分解的目的无非是为了转化为:

(A^i)^m * A^j = B ( mod C)

之后做少许暴力的工作就可以解决问题:

(1) for i = 0 -> m, 插入Hash (i, A^i mod C)

(2) 枚举 i ,对于每一个枚举到的i,令 AA = (A^m)^i mod C

我们有

AA * A^j = B (mod C)

显然AA,B,C均已知,而由于C为素数,那么(AA,C)无条件为1

于是对于这个模方程解的个数唯一(可以利用扩展欧几里得或 欧拉定理来求解)

那么对于得到的唯一解X,在Hash表中寻找,如果找到,则返回 i * m +
j

注意:由于i从小到大的枚举,而Hash表中存在的j必然是对于某个剩余系内的元素X
是最小的(就是指标)

所以显然此时就可以得到最小解


如果需要得到 x > 0的解,那么只需要在上面的步骤中判断 当 i * m + j > 0 的时候才返回

到目前为止,以上的算法都不存在争议,大家实现的代码均相差不大。可见当C为素数的时候,此类离散对数的问题可以变得十分容易实现。

【扩展Baby Step Giant Step】

【问题模型】

求解

A^x = B (mod C) 中 0 <= x < C 的解,C 无限制(当然大小有限制……)

【写在前面】

这个问题比较麻烦,目前网络上流传许多版本的做法,不过大部分已近被证明是完全错误的!

这里就不再累述这些做法,下面是我的做法(有问题欢迎提出)

下面先给出算法框架,稍后给出详细证明:

(0) for i = 0 -> 50 if(A^i mod C == B) return i O(50)

(1) d<- 0 D<- 1 mod C(d
= 1,D = 1 % C)

while((tmp=gcd(A,C))!=1)

{

if(B%tmp)return -1; // 无解!

++d;

C/=tmp;

B/=tmp;

D=D*A/tmp%C;

}

(2) m = Ceil ( sqrt(C) ) //Ceil是必要的 O(1)

(3) for i = 0 -> m 插入Hash表(i, A^i mod C) O( m)

(4) K=pow_mod(A,m,C)

for i = 0 -> m

解 D * X = B (mod C) 的唯一解 (如果存在解,必然唯一!)

之后Hash表中查询,若查到(假设是 j),则 return i * m + j + d

否则

D=D*K%C,继续循环

(5) 无条件返回 -1 ;//无解!

下面是证明:

推论1:

A^x = B(mod C)

等价为

A^a * A^b = B ( mod C) (a+b) == x a,b >= 0

证明:

A^x = K * C + B (模的定义)

A^a * A^b = K*C + B( a,b >=0, a + b == x)

所以有

A^a * A^b = B(mod C)

推论 2:

令 AA * A^b = B(mod C)

那么解存在的必要条件为: 可以得到至少一个可行解 A^b = X (mod C)

使上式成立

推论3

AA * A^b = B(mod C)

中解的个数为 (AA,C)(这里指的是解x的个数,而不是A^b中b的解的个数,因为对于每个x并不是都有一个b的解)

由推论3 不难想到对原始Baby Step Giant Step的改进

For I = 0 -> m

For any solution that AA * X = B (mod C)

If find X (AA = A^i ,如果有解,找出所有的解X,对于每个X,找出A^j = X (mod C) 的解j,如果存在,return)

Return I * m + j

而根据推论3,以上算法的复杂度实际在 (AA,C)很大的时候会退化到几乎O(C)

归结原因,是因为(AA,C)过大,而就是(A,C)过大

于是我们需要找到一种做法,可以将(A,C)减少,并不影响解

下面介绍一种“消因子”的做法

一开始D = 1 mod C

进行若干论的消因子,对于每次消因子

令 G = (A,C[i]) // C[i]表示经过i轮消因子以后的C的值

如果不存在 G | B[i] //B[i]表示经过i轮消因子以后的B的值

直接返回无解

否则

B[i+1] = B[i] / G

C[i+1] = C[i] / G

D = D * A / G

具体实现只需要用若干变量,细节参考代码

假设我们消了 a' 轮(假设最后得到的B,C分别为B',C')

那么有

D * A^b = B' (mod C')

于是可以得到算法

for i = 0 -> m

解 ( D* (A^m) ^i ) * X = B'(mod C')

由于 ( D* (A^m) ^i , C') = 1 (想想为什么?) ( (D,C') = 1,(A,C') = 1)

于是我们可以得到唯一解

之后的做法就是对于这个唯一解在Hash中查找

这样我们可以得到b的值,那么最小解就是a' + b !! (a’是上面的值,b = i*m + j)

现在问题大约已近解决了,可是细心看来,其实还是有BUG的,那就是

对于

A^x = B(mod C)

如果x的最小解< a',那么会出错

而考虑到每次消因子最小消 2

故a'最大值为log(C)

于是我们可以暴力枚举0->log(C)的解,若得到了一个解,直接返回

否则必然有 解x > log(C) (这就是为什么程序第一步预先判50次的原因,事实上a'比这个小很多)

PS.以上算法基于Hash 表,如果使用map等平衡树维护,那么复杂度会更大

poj 3243的代码:

#include <cstdio>
#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cmath>
#include <map>
using namespace std;
class node
{
    public:
        node() {
            num = 0;
            next = -1;
            key = "";
        }
        ~node() { }
        string key;
        long long num;
        int next;
};
class HashMap
{
       static const int HASHSIZE = 1500007, HASHSEED = 137, HASHCAP = 1233333;
public:
        node et[HASHCAP];//hash的节点
        int tot;  //节点数
        int eh[HASHSIZE]; //哈希表头
        HashMap(){ }
        ~HashMap() { }

        void init(){
            tot = 0;
            memset(eh,-1,sizeof(eh));
        }
        char str[25];
        void insert(long long s, long long num){
            int len = 0;
            while(s){
                str[len++] = s % 10 + '0';
                s /= 10;
            }
            str[len] = '\0';
            string tmp = str;
            int rt = (int)gethash(tmp, len);
            bool flag = false;
            for (int i = eh[rt]; i != -1; i = et[i].next) {
                if (tmp.compare(et[i].key) == 0) {
                    flag = true;
                    break;
                }
            }
            if (!flag) {
                et[tot].key = tmp;
                et[tot].num = num;
                et[tot].next = eh[rt];
                eh[rt] = tot++;
            }
        }
        long long find(long long x) {
            int len = 0;
            while(x){
                str[len++] = x % 10 + '0';
                x /= 10;
            }str[len] = '\0';
            string key = str;
            int rt = (int)gethash(key, len);
            string tmp = key;
            for (int i = eh[rt]; i != -1; i = et[i].next) {
                if (tmp.compare(et[i].key) == 0) {
                    return et[i].num;
                }
            }
            return -1;
        }
        unsigned long long gethash(string str, int len) {
            unsigned long long ret = 0;
            for (int i = len - 1; i >= 0; --i) {
                ret = (ret * HASHSEED + (unsigned long long)(str[i] - 'a'));
            }
            return ret % HASHSIZE;
        }
}hmap;

long long gcd(long long a,long long b)
{
    if(!b) return a;
    return gcd(b,a % b);
}
void exgcd(long long a,long long b,long long &d,long long &x,long long &y)
{
    if(!b) {
        d = a;x = a,y = 0;return;
    }
    exgcd(b,a%b,d,y,x);y -= x * (a / b);
}
long long powmod(long long a,long long b,long long n)
{
    long long res = 1;
    while(b)
    {
        if(b & 1)
        {
            res = res * a;
            if(res >= n)res %= n;
        }
        a *= a;
        if(a >= n) a %= n;
        b >>= 1;
    }
    return res;
}
int solveequation(long long a,long long b,long long c)
{
    if(a % c == b % c) return 0;
    long long x,y,d;
    exgcd(a,-c,d,x,y);
    if(b % d != 0) return -1;
    x *= (long long)fabs((long double)(b / d));
    if(x < 0)x = (x % c) + c;
    else x %= c;
    return hmap.find(x);
}
long long solve(long long x,long long z,long long k)
{
    if(x == 0){
        if(k % z == 0) return 1;
        return -1;
    }
    long long res = 1;k %= z;
    for(int i = 0;i <= 50; ++i) {if(res % z == k) return i;else res = res * x % z;
    }
    long long m = ceil(sqrt((long double)z));
    long long d = 0,D = 1 % z,g;
    while((g = gcd(z,x)) != 1)
    {
        d++;
        if(k % g != 0) return -1;
        k /= g;
        z /= g;
        D = D * x / g % z;
    }
    res = 1;
    for(int i = 0;i <= m; ++i) {hmap.insert(res,i); res = res * x % z;}
    long long t = powmod(x,m,z);
    for(int i = 0,j;i <= m; ++i)
        if(~(j = solveequation(D,k,z))) {
            return i * m + j + d;
        }else D = D * t % z;
    return -1;
}
//long long test(long long x,long long k,long long z)
//{
//    long long res = 1;k %= z;
//    long long d = gcd(x,z);
//    for(int i = 0;i < 50; ++i) if(res % z == k) return i;else res = res * x % z;
//    if(k % d != 0) return -1;
//    for(int i = 50;i <= z; ++i) if(res % z == k) return i;else res = res * x % z;
//    return -1;
//}
int main()
{
    long long x,z,k;
    while(~scanf("%lld%lld%lld",&x,&z,&k)){
        if(!(x | z | k)) continue;
        hmap.init();
        long long ans = solve(x,z,k);
        if(ans == -1) puts("No Solution");
        else printf("%lld\n",ans);
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: