您的位置:首页 > 其它

POJ2778----AC自动机的变形+矩阵快速幂(AC自动机和矩阵快速幂必做题)

2013-05-24 20:24 309 查看
题目地址:http://poj.org/problem?id=2778

题目意思:

给你M个DNA的小序列

然后要你求出长度为N但是不含给出的M个小DNA的情况有多少种

这是一道很好的题目

对算法的要求很高,具体的思路我是在:http://blog.csdn.net/morgan_xww/article/details/7834801学来的,所以可以移步去看原创

我主要说说几个要注意的地方

首先就是AC自动机的创建

由于这个题目我们求的是一个跳转矩阵,所以和之前的匹配不是一个意思

那么在创建AC自动机这个数据结构的时候就要注意

主要的区别在于getfail()这个函数里面,个中真意还是自己体会一下比较好

然后就是快速矩阵幂了

我开始用一个比较粗糙的算法写的,但是直接RE了

后来看DISCUSS里面,反思了一下

因为我是开的long long 的101*101的矩阵

而且是用递归写的,用之前的粗糙的算法在递归的过程中我每一层都申请了矩阵的

那么出于保护现场的原因,我就相当于申请了很多的矩阵,最后,肯定就是内存HOLD不住了

所以,我修改了一下,我不用重新申请,直接从上一层引用,这样就可以大大的节约内存

特别是在层数很多的时候

下面上我的代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;

const int maxnode = 10*10+5;
const int size=4;
const int mod=100000;

struct AC
{
    int ch[maxnode][size];
    int f[maxnode];
    bool val[maxnode];
    int sz;

    void init()
    {
        memset(ch[0],-1,sizeof(ch[0]));
        sz=1;
        val[0]=false;
    }

    int idx(char c)
    {
        if(c=='A')
            return 0;
        else if(c=='C')
            return 1;
        else if(c=='T')
            return 2;
        else
            return 3;
    }

    void insert(char *s)
    {
        int len = strlen(s);
        int u=0;
        for(int i=0;i<len;i++)
        {
            int c=idx(s[i]);
            if(-1 == ch[u][c])
            {
                memset(ch[sz],-1,sizeof(ch[sz]));
                val[sz]=false;
                ch[u][c]=sz++;
            }
            u=ch[u][c];
        }
        val[u]=true;
    }

    void getfail()
    {
        queue<int> q;
        for(int i=0;i<size;i++)
        {
            if(-1 != ch[0][i])
            {
                f[ch[0][i]] = 0;
                q.push(ch[0][i]);
            }
            else
                ch[0][i] = 0;
        }
        while(!q.empty())
        {
            int r=q.front();
            q.pop();

            //这里我们是要把trie变为一棵跳转的树,不是一棵匹配树
            if(val[f[r]])
                val[r]=true;
            for(int i=0;i<size;i++)
            {
                int &v=ch[r][i];
                if(-1 != v)
                {
                    q.push(v);
                    f[v]=ch[f[r]][i];
                }
                else
                    v=ch[f[r]][i];
            }
        }
    }
};

AC ac;

struct maxtrix
{
    long long m[maxnode][maxnode];
};

void maxtrixmul(maxtrix a,maxtrix b,maxtrix &ans)
{
    int n=ac.sz;
    memset(ans.m,0,sizeof(ans.m));
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            for(int k=0;k<n;k++)
            {
                ans.m[i][j]+=a.m[i][k]*b.m[k][j];
            }
            ans.m[i][j] = ans.m[i][j]%mod;
        }
    }
}

void maxtrixpower(int n,maxtrix a,maxtrix &ans)
{
    if(n==1)
    {
        ans=a;
        return;
    }
    if(n&1)
    {
        int p=n-1;
        maxtrixpower(p/2,a,ans);
        maxtrixmul(ans,ans,ans);
        maxtrixmul(ans,a,ans);
    }
    else
    {
        maxtrixpower(n/2,a,ans);
        maxtrixmul(ans,ans,ans);
    }
}

void build(maxtrix &ans)
{
    memset(ans.m,0,sizeof(ans.m));
    int n = ac.sz;
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<size;j++)
        {
            if(!ac.val[i] && !ac.val[ac.ch[i][j]])
            {
                ans.m[i][ac.ch[i][j]]++;
            }

        }
    }
}

int main()
{
    int mm,nm;
    while(~scanf("%d%d",&mm,&nm))
    {
        char op[20];
        ac.init();
        for(int i=0;i<mm;i++)
        {
            scanf("%s",op);
            ac.insert(op);
        }
        ac.getfail();
        maxtrix mat;
        build(mat);
        maxtrixpower(nm,mat,mat);
        long long cnt=0;
        for(int i=0;i<ac.sz;i++)
            cnt+=mat.m[0][i];
        printf("%I64d\n",cnt%mod);
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: