您的位置:首页 > 其它

AC自动机

2020-04-05 17:12 453 查看

#AC自动机(Aho-Corasick自动机)

部分资料来自

https://www.geek-share.com/detail/2711281586.html

http://www.cppblog.com/menjitianya/archive/2014/07/10/207604.html

https://www.luogu.com.cn/problemnew/solution/P5357

##问题模式

给定多个模式串和一个目标串

问有关模式串匹配问题

如果没有AC自动机,你可能需要对n个模板串分别求一趟KMP,但是复杂度过高,而AC自动机可以一次匹配,效率更优秀

俗话称 AC自动机=KMP+Trie

##模型的应用

很多网站,游戏都有敏感词过滤功能,其底层实现也无非就是ac自动机

##解决步骤

##1.将各模式串构建起Trie树

此部分可看Trie相关芝士~

##2.构建失配指针(核心部分)

他的遍历方式是利用BFS

至于为什么使用BFS,下文有提及,他与构建失配指针以及更新访问域的先后有关

注意,失配指针很多文章没有说清楚真正的含义,它实际上有两部分

###Part1.强制失配指针(伪失配处理,failptr)

先看上图,理解一下它的含义,它与KMP里面的next数组作用相似:

对于节点x而言,它的失配指针指向的节点标号为u,则有

自上而下形成的字符串中,x对应字符串的最长后缀等于u对应字符串的最长前缀

(如图红线,最长后缀是she中的he,也是her的最长前缀he)

它的作用在于在匹配一段相对较长的模式串时,可能其后缀蕴含了一段(或多段)其他的模式串,当前者匹配成功,后者同样也能被匹配出来

如offset,set,你现在正在遍历offset,它后缀中含有set,你正在匹配offset的时候不可能说我停下来去check一下它后缀情况如何。取而代之的是,我们使用failptr去假装它失配了,人为强制地让这个指针去跳转检索一下,看看当前后缀当中是否能匹配到其他模式串

构建failptr对应代码:

trie[trie[curpos].vis[i]].failptr=trie[trie[curpos].failptr].vis[i];

当前踩在curpos这个父节点,父节点curpos失配指针指向节点q,我们现在访问curpos的子节点vis,如果它是存在的,那么vis节点的失配指针指向是q节点对应字符的vis节点

这一段代码实现了三个功能,对应q节点的三种不同情况:

1.若非空,那么我们就成功使得原来最长后缀++

2.若此节点为空,且指向0,就直接指向了root

3.若此节点为空,且访问域被更新过,那么就是把访问域的地址给了过去(后面讲访问域)

这里啰嗦一句,这里体现了为什么我们节点之间是通过标号来构建关系(每新构建一个新的节点,标号加1)

其优势在于,我们所有指向空的表示为0,Trie树的虚根也记为0,那么在构建失配指针的时候会让代码显得很简洁,不会像指针型这一个null,那一个null

###Part2.访问域的更新(后继状态的更新,遍历跳转指针的更新,vis[26])

我们再来理解一下访问域的更新

它其实表示的是节点后继状态的更新。Trie树的每个节点均有26个访问域,对应26个字母。

它既可以储存指向代表含真正字符的子节点地址,又可能是空

倘若它是空的,那就白白浪费了一块空间。类似线索二叉树,AC自动机将这块空的区域利用了起来,加快跳转效率

如上图,我们构建了含有模式串abc,bc,s的trie树,菱形s是c对应的vis访问域,我们现在匹配abcs

当我们匹配abcs时,abc匹完了,现在多个s,接下来该跳转到哪里嘞?

显然c访问域为空,就会匹配指针无所适从,不知道下一个s何去何从

解决问题的代码,当vis指向空:

trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];

我们建树的时候将这个空的区域指向真实存在的元素,并且将这个对应的地址传递下去。

这时BFS的优势就体现了出来,我给bc串更新访问域vis['s']时,让他指向了s,我们再给abc更新的时候,我们赋给它bc串的vis['s'],也就是s的地址,换句话说,我们就成功的把s的地址传递下移了

这样一来,当我们的abc完成了匹配,要跳转匹配s的时候,一步到位

真正模拟上述过程很麻烦,他们互相融合交错的,分不了孰先孰后,多画图模拟几次

上述对应的代码

值得强调的是,虚根的直接子节点一定要额外优先处理(建立failptr,入队),不然会有瑕疵

void built_fail()
{
queue<int>curline;
for(int i=0;i<26;i++)
{
if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
}
while(!curline.empty())
{
int curpos=curline.front();
curline.pop();
for(int i=0;i<26;i++)
{
if(trie[curpos].vis[i])
{
trie[trie[curpos].vis[i]].failptr=trie[trie	[curpos].failptr].vis[i];
curline.push(trie[curpos].vis[i]);
}
else
trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];}
}

}

###3.遍历目标串

完成上一步失配指针的配置后,匹配就变得简单了

每次沿着Trie树匹配,匹配到当前位置失配时,直接跳转到访问域所指向的位置继续进行匹配,而每次匹配过程中还要进行一次伪失配的处理,在这个过程中进行统计。

核心代码

int checkAC(string s)
{
int ans=0,curpos=0;
for(int i=0;i<s.size();i++)
{
curpos=trie[curpos].vis[s[i]-'a'];
for(int j=curpos;j&&trie[j].sum!=-1;j=trie[j].failptr)
{
ans+=trie[j].sum;
trie[j].sum=-1;
}
}
return ans;
}

##优化

我们上面AC自动机在进行多模匹配时是暴力跳转failptr,但这样做复杂度还是有问题

在类似于aaaaa……aaaaa这样的串中,复杂度会退化成O(模式串长度·目标串长度)为什么?因为对于每一次跳转failptr我们都只使深度减1,那样深度(深度最深是模式串长度)是多少,每一次跳的时间复杂度就是多少。那么还要乘上文本串长度,就几乎是O(模式串长度·文本串长度)的了

再举个栗子

我们匹配ABC的时候(1234),强制失配跳转failptr的时候要先经过BC(57),可是57上并没有结束点,我们要的是9上的c,所以跳转效率大打折扣

###优化思路一,我自己想的(拉垮的很):

这里运用路径压缩的思想。我们强制跳转failptr的过程采用递归的方式,通过记录含有结束单词位置,不断回溯更新failptr的地址

代码如下:

void dfs(int curpos)
{
if(!curpos){existpos=0;return;}
dfs(trie[curpos].failptr);
trie[curpos].failptr=existpos;
if(trie[curpos].stringpos)total[trie[curpos].stringpos]++,existpos=curpos;
}

在后面模板题洛谷P5356中,经过我反复各类卡常优化,我唯一T的点还是T了,1.15s->1.09s

(假算法++)

###优化思路二,大佬题解:

它的核心是,观察到了failptr与节点之间构成了DAG

它优先进行目标串的匹配,放弃每步暴力跳转failptr,优先给各个节点统计上遍历过的次数sum

接下来再进行在DAG上进行拓扑排序,累加统计即可

数组in记录入度

in[trie[trie[curpos].vis[i]].failptr]++;//当vis存在,并且被更新了failptr

AC自动机匹配时,先遍历目标串,统计各点的sum:

int ans=0,curpos=0;
for(register int i=0;i<s.size();i++)
{
curpos=trie[curpos].vis[s[i]-'a'];
trie[curpos].sum++;
}

DAG上拓扑累加:

queue<int>curline;
for(int i=1;i<=cnt;i++)if(!in[i])curline.push(i);//初始化拓扑队列
while(!curline.empty())
{
int x=curline.front();curline.pop();
total[trie[x].stringpos]=trie[x].sum;//记录对应模式串出现的次数
int y=trie[x].failptr;in[y]--;
trie[y].sum+=trie[x].sum;//DAG上状态的转移
if(in[y]==0)curline.push(y);//入度为0,入队
}

这样一来复杂度就成了O(max(模式串长度,文本串长度))

详细代码见下面P5357

优化效果果然大幅度提升

##三套模版(第三套效率比较好)

1.洛谷模版(弱化版)P3808

查询有几个模式串出现在目标串中

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
int failptr;
int vis[26];
int sum;
trienode(){memset(vis,0,sizeof(vis));failptr=0;sum=0;}
};
trienode trie[maxn];
int cnt=0;
void built(string s)//build trie
{
int curpos=0;
for(int i=0;i<s.size();i++)
{
if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
curpos=trie[curpos].vis[s[i]-'a'];
}
trie[curpos].sum++;
}
void built_fail()//build fail_ptr
{
queue<int>curline;
for(int i=0;i<26;i++)
{
if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
}
while(!curline.empty())
{
int curpos=curline.front();
curline.pop();
for(int i=0;i<26;i++)
{
if(trie[curpos].vis[i])
{
trie[trie[curpos].vis[i]].failptr=trie[trie	[curpos].failptr].vis[i];
curline.push(trie[curpos].vis[i]);
}
else
{
trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];}
}
}

}
int checkAC(string s)
{
int ans=0,curpos=0;
for(int i=0;i<s.size();i++)
{
curpos=trie[curpos].vis[s[i]-'a'];
for(int j=curpos;j&&trie[j].sum!=-1;j=trie[j].failptr)
{
ans+=trie[j].sum;
trie[j].sum=-1;
}
}
return ans;
}int main()
{
int _t;
string s;
cin>>_t;
while(_t--)
{
cin>>s;
built(s);
}
trie[0].failptr=0;
built_fail();
cin>>s;
cout<<checkAC(s)<<endl;
return 0;
}

2.洛谷板子(强化版)P3796

查询最多出现的模式串,并输出(可能有多个)

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
#include<map>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
int failptr;
int vis[26];
int sum;
int stringpos;
void strienode(){memset(vis,0,sizeof(vis));failptr=0;sum=0;stringpos=0;}
};
trienode trie[maxn];
string stringsave[maxn];
map<string,int>Map;
int maxans;
int total[maxn];
int cnt=0;

void built(string s)
{
int curpos=0;
for(int i=0;i<s.size();i++)
{
if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
curpos=trie[curpos].vis[s[i]-'a'];
}
trie[curpos].sum++;
trie[curpos].stringpos=Map[s];
}
void built_fail()
{
queue<int>curline;
for(int i=0;i<26;i++)
{
if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
}
while(!curline.empty())
{
int curpos=curline.front();
curline.pop();
for(int i=0;i<26;i++)
{
if(trie[curpos].vis[i])
{
trie[trie[curpos].vis[i]].failptr=trie[trie[curpos].failptr].vis[i];curline.push(trie[curpos].vis[i]);
}
else
{
trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];}
}
}

}
void checkAC(string s)
{
int ans=0,curpos=0;
for(int i=0;i<s.size();i++)
{
curpos=trie[curpos].vis[s[i]-'a'];
for(int j=curpos;j;j=trie[j].failptr)
{
total[trie[j].stringpos]+=trie[j].sum;
maxans=max(maxans,total[trie[j].stringpos]);
}
}
}
int main()
{
int _t;
string s;
while(1)
{
cin>>_t;
if(!_t)break;
int index=0;
maxans=0;
memset(total,0,sizeof(total));
Map.clear();
for(int i=0;i<100000;i++)
trie[i].strienode();
while(_t--)
{
cin>>s;
if(s.size()>=maxn)continue;
if(!Map[s])Map[s]=index,stringsave[index]=s,index++;
built(s);
}
trie[0].failptr=0;
built_fail();
cin>>s;
checkAC(s);
cout<<maxans<<endl;
for(int i=0;i<index;i++)
{
if(total[i]==maxans)cout<<stringsave[i]<<endl;
}
}

return 0;
}

3.洛谷模版3(二次强化版)P5357

这个有一点值得注意的是

aa与aa出现在aaa中分别认为是2次,2次,而不是4次,4次

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
int failptr;
int vis[26];
int sum;
int stringpos;
trienode(){memset(vis,0,sizeof(vis));failptr=0;stringpos=0;sum=0;}
};
trienode trie[maxn];
queue<int>curline;
string s;
int total[maxn],Map[maxn],in[maxn];
int cnt=0;
int existpos=0;
int indexnum=0;
void built(int p)
{
int curpos=0;
int len=s.size();
for(register int i=0;i<len;i++)
{
if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
curpos=trie[curpos].vis[s[i]-'a'];
}
if(!trie[curpos].stringpos)trie[curpos].stringpos=++indexnum;
Map[p]=trie[curpos].stringpos;
}
void built_fail()
{
for(register int i=0;i<26;i++)
{
if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
}
while(!curline.empty())
{
int curpos=curline.front();
curline.pop();
for(int i=0;i<26;i++)
{
if(trie[curpos].vis[i])
{
int x=trie[curpos].vis[i];
trie[x].failptr=trie[trie[curpos].failptr].vis[i];
in[trie[x].failptr]++;
curline.push(trie[curpos].vis[i]);
}
else
{
trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];}
}
}
}
void checkAC()
{
int ans=0,curpos=0;
//targetstring 的遍历
for(register int i=0;i<s.size();i++)
{
curpos=trie[curpos].vis[s[i]-'a'];
trie[curpos].sum++;
}
//topu on DAG
queue<int>curline;
for(int i=1;i<=cnt;i++)if(!in[i])curline.push(i);
while(!curline.empty())
{
int x=curline.front();curline.pop();
total[trie[x].stringpos]=trie[x].sum;
int y=trie[x].failptr;in[y]--;
trie[y].sum+=trie[x].sum;
if(in[y]==0)curline.push(y);
}
}
int main()
{
int _t;
cin>>_t;
memset(total,0,sizeof(total));
memset(in,0,sizeof(in));
for(register int i=1;i<=_t;i++)
{
cin>>s;
built(i);
}
trie[0].failptr=0;
built_fail();
cin>>s;
checkAC();
for(register int i=1;i<=_t;i++)cout<<total[Map[i]]<<'\n';
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: