您的位置:首页 > 其它

HDU 2243 考研路茫茫——单词情结(AC自动机+矩阵幂)

2014-04-12 12:11 344 查看
HDU 2243 考研路茫茫——单词情结(AC自动机+矩阵幂)
http://acm.hdu.edu.cn/showproblem.php?pid=2243
题意:

给你多个模板串,现在要问你长度不超过L的文本串中,至少包含一个模板串的的文本串有多少个?

分析:

其实本题就是POJ 2778的变形:

/article/1517634.html

我们只需要求出所有串中不包含模板串的长度不超过L的文本串有x个,然后用总数26^1+26^2+…26^L-x即可.本题要用unsigned long long. 注意:26^1+26^2+…26^L是个等比数列,但是因为要求余,所以用公式肯定不行.这里我们依然用矩阵幂来算:





上面的sum=1+26+26^2+26^3…+26^n,验证一下上面的结果看看是不是这样.然后我们只需要求那个2*2矩阵的n次幂即可.不过注意最后求出的sum还是需要-1的.因为sum包括了1.

POJ 2778中我们求的是长度为L的不包含模式串的文本个数,现在这里我们要求长度为长度不超过L的不包含模式串的文本串个数。令f[i]
==x表示
当前在i点,已经走了n步且没有走过单词节点的总方法数。我们依然求出(m为自动机节点数目):

f[i]
=a0*f[0][n-1]+a1*f[1][n-1]+a2*f[2][n-1]+…+am-1*f[m-1][n-1]

的递推矩阵。

不过这里我们假想在AC自动机中加一个节点,该节点表示符号’\0’即结束符.然后任何其他有效字符包括’\0’自己都可以到达’\0’,但是’\0’的后继只能是自己.也就是说无论走到了第几步,当前节点都可以选择去走’\0’,只要它选择了走’\0’,那么以后只能继续走’\0’了。

假设’\0’的节点序号为sz,那么其转移方程为:

f[sz]
=f[0][n-1]+f[1][n-1]+….+f[sz][n-1]

初值f[i][0]=1 其中i为非单词节点(包括表示'\0'的sz节点),f[j][0]=0 其中j为单词节点。

最终长度<=L且不包含任何模式串的文本个数为:f[0][L]+f[1][L]+f[2][L]+...f[sz][L]。其中f[sz][L]表示所有那些真实长度<L,所以只能以'\0'结尾的合法字符串。

即把以前的矩阵从sz-1*sz-1规模变成了sz*sz的规模,添加了最后一行和最后一列,其中最后一列除了末尾为1其他都为0,最后一行都是1(仔细想想是不是)。

注意我们最后算出来的结果还有一个是非法的,因为最后mat^L取第一列,最后节点包含了空串的情况,所以我们还要-1才是不包含模式串且长度不超过L的串数。

AC代码:31ms

#include<queue>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
using namespace std;
const int maxnode=35;
const int sigma_size=26;
struct AC_Automata
{
    int ch[maxnode][sigma_size];
    int match[maxnode];
    int f[maxnode];
    int sz;
    void init()
    {
        sz=1;
        memset(ch[0],0,sizeof(ch[0]));
        f[0]=match[0]=0;
    }
    void insert(char *s)
    {
        int n=strlen(s),u=0;
        for(int i=0;i<n;i++)
        {
            int id=s[i]-'a';
            if(ch[u][id]==0)
            {
                ch[u][id]=sz;
                memset(ch[sz],0,sizeof(ch[sz]));
                match[sz++]=0;
            }
            u=ch[u][id];
        }
        match[u]=1;
    }
    void getFail()
    {
        f[0]=0;
        queue<int> q;
        for(int i=0;i<sigma_size;i++)
        {
            int u=ch[0][i];
            if(u)
            {
                f[u]=0;
                q.push(u);
            }
        }
        while(!q.empty())
        {
            int r=q.front();q.pop();
            for(int i=0;i<sigma_size;i++)
            {
                int u=ch[r][i];
                if(!u){ch[r][i]=ch[f[r]][i]; continue;}
                q.push(u);//之前漏了这句话,找BUG找了半天.我擦
                int v=f[r];
                while(v && ch[v][i]==0) v=f[v];
                f[u]=ch[v][i];
                match[u] |= match[f[u]];
            }
        }
    }
};
AC_Automata ac;
unsigned long long z[maxnode][maxnode];
unsigned long long mat[maxnode][maxnode],mat2[maxnode][maxnode];
unsigned long long ans[maxnode][maxnode],ans2[maxnode][maxnode];
void mutiply(unsigned long long x[maxnode][maxnode], unsigned long long y[maxnode][maxnode],int sz)
{
    for(int i=0;i<=sz;i++)
    {
        for(int j=0;j<=sz;j++)
        {
            z[i][j]=0;
            for(int k=0;k<=sz;k++)
            {
                z[i][j] += x[i][k]*y[k][j];
            }
        }
    }
    for(int i=0;i<=sz;i++)
        for(int j=0;j<=sz;j++)
            y[i][j]=z[i][j];
}
int main()
{
    int n,L;
    while(scanf("%d%d",&n,&L)==2)
    {
        ac.init();
        memset(mat,0,sizeof(mat));
        memset(ans,0,sizeof(ans));
        memset(mat2,0,sizeof(mat2));
        memset(ans2,0,sizeof(ans2));
        for(int i=0;i<n;i++)
        {
            char str[20];
            scanf("%s",str);
            ac.insert(str);
        }
        ac.getFail();
        for(int i=0;i<ac.sz;i++)
            if(ac.match[i]==0)
                for(int j=0;j<sigma_size;j++)
                    if(ac.match[ac.ch[i][j]]==0)
                        mat[ac.ch[i][j]][i]++;

        for(int i=0;i<=ac.sz;i++)
        {
            mat[ac.sz][i]=1;
            ans[i][i]=1;
        }
        int m=L;
        while(m)
        {
            if(m&1) mutiply(mat,ans,ac.sz);
            mutiply(mat,mat,ac.sz);
            m>>=1;
        }
        unsigned long long no_pattern_sum  =0;
        for(int i=0;i<=ac.sz;i++) no_pattern_sum += ans[i][0];
        no_pattern_sum = no_pattern_sum -1 ;//减掉空串

        unsigned long long res=0;//用来计算26+26^2+26^3+..26^L的
        mat2[0][0]=1,mat2[0][1]=26;
        mat2[1][0]=0,mat2[1][1]=26;
        ans2[0][0]=ans2[1][1]=1;
        m=L;
        while(m)
        {
            if(m&1) mutiply(mat2,ans2,1);
            mutiply(mat2,mat2,1);
            m>>=1;
        }
        res = ans2[0][0]+ans2[0][1]-1;
        printf("%I64u\n",res-no_pattern_sum);
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: