您的位置:首页 > 编程语言 > C语言/C++

B00014 C++实现的AC自动机

2016-06-19 22:08 393 查看
代码来自:A C++ implementation of the aho corasick pattern search algorithm

源程序如下:

/*
* Copyright (C) 2015 Christopher Gilbert.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

#ifndef AHO_CORASICK_HPP
#define AHO_CORASICK_HPP

#include <algorithm>
#include <cctype>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <queue>
#include <vector>

namespace aho_corasick {

// class interval
class interval {
size_t d_start;
size_t d_end;

public:
interval(size_t start, size_t end)
: d_start(start)
, d_end(end) {}

size_t get_start() const { return d_start; }
size_t get_end() const { return d_end; }
size_t size() const { return d_end - d_start + 1; }

bool overlaps_with(const interval& other) const {
return d_start <= other.d_end && d_end >= other.d_start;
}

bool overlaps_with(size_t point) const {
return d_start <= point && point <= d_end;
}

bool operator <(const interval& other) const {
return get_start() < other.get_start();
}

bool operator !=(const interval& other) const {
return get_start() != other.get_start() || get_end() != other.get_end();
}

bool operator ==(const interval& other) const {
return get_start() == other.get_start() && get_end() == other.get_end();
}
};

// class interval_tree
template<typename T>
class interval_tree {
public:
using interval_collection = std::vector<T>;

private:
// class node
class node {
enum direction {
LEFT, RIGHT
};
using node_ptr = std::unique_ptr<node>;

size_t              d_point;
node_ptr            d_left;
node_ptr            d_right;
interval_collection d_intervals;

public:
node(const interval_collection& intervals)
: d_point(0)
, d_left(nullptr)
, d_right(nullptr)
, d_intervals()
{
d_point = determine_median(intervals);
interval_collection to_left, to_right;
for (const auto& i : intervals) {
if (i.get_end() < d_point) {
to_left.push_back(i);
} else if (i.get_start() > d_point) {
to_right.push_back(i);
} else {
d_intervals.push_back(i);
}
}
if (to_left.size() > 0) {
d_left.reset(new node(to_left));
}
if (to_right.size() > 0) {
d_right.reset(new node(to_right));
}
}

size_t determine_median(const interval_collection& intervals) const {
size_t start = -1;
size_t end = -1;
for (const auto& i : intervals) {
size_t cur_start = i.get_start();
size_t cur_end = i.get_end();
if (start == -1 || cur_start < start) {
start = cur_start;
}
if (end == -1 || cur_end > end) {
end = cur_end;
}
}
return (start + end) / 2;
}

interval_collection find_overlaps(const T& i) {
interval_collection overlaps;
if (d_point < i.get_start()) {
add_to_overlaps(i, overlaps, find_overlapping_ranges(d_right, i));
add_to_overlaps(i, overlaps, check_right_overlaps(i));
} else if (d_point > i.get_end()) {
add_to_overlaps(i, overlaps, find_overlapping_ranges(d_left, i));
add_to_overlaps(i, overlaps, check_left_overlaps(i));
} else {
add_to_overlaps(i, overlaps, d_intervals);
add_to_overlaps(i, overlaps, find_overlapping_ranges(d_left, i));
add_to_overlaps(i, overlaps, find_overlapping_ranges(d_right, i));
}
return interval_collection(overlaps);
}

protected:
void add_to_overlaps(const T& i, interval_collection& overlaps, interval_collection new_overlaps) const {
for (const auto& cur : new_overlaps) {
if (cur != i) {
overlaps.push_back(cur);
}
}
}

interval_collection check_left_overlaps(const T& i) const {
return interval_collection(check_overlaps(i, LEFT));
}

interval_collection check_right_overlaps(const T& i) const {
return interval_collection(check_overlaps(i, RIGHT));
}

interval_collection check_overlaps(const T& i, direction d) const {
interval_collection overlaps;
for (const auto& cur : d_intervals) {
switch (d) {
case LEFT:
if (cur.get_start() <= i.get_end()) {
overlaps.push_back(cur);
}
break;
case RIGHT:
if (cur.get_end() >= i.get_start()) {
overlaps.push_back(cur);
}
break;
}
}
return interval_collection(overlaps);
}

interval_collection find_overlapping_ranges(node_ptr& node, const T& i) const {
if (node) {
return interval_collection(node->find_overlaps(i));
}
return interval_collection();
}
};
node d_root;

public:
interval_tree(const interval_collection& intervals)
: d_root(intervals) {}

interval_collection remove_overlaps(const interval_collection& intervals) {
interval_collection result(intervals.begin(), intervals.end());
std::sort(result.begin(), result.end(), [](const T& a, const T& b) -> bool {
if (b.size() - a.size() == 0) {
return a.get_start() > b.get_start();
}
return a.size() > b.size();
});
std::set<T> remove_tmp;
for (const auto& i : result) {
if (remove_tmp.find(i) != remove_tmp.end()) {
continue;
}
auto overlaps = find_overlaps(i);
for (const auto& overlap : overlaps) {
remove_tmp.insert(overlap);
}
}
for (const auto& i : remove_tmp) {
result.erase(
std::find(result.begin(), result.end(), i)
);
}
std::sort(result.begin(), result.end(), [](const T& a, const T& b) -> bool {
return a.get_start() < b.get_start();
});
return interval_collection(result);
}

interval_collection find_overlaps(const T& i) {
return interval_collection(d_root.find_overlaps(i));
}
};

// class emit
template<typename CharType>
class emit: public interval {
public:
typedef std::basic_string<CharType>  string_type;
typedef std::basic_string<CharType>& string_ref_type;

private:
string_type d_keyword;

public:
emit()
: interval(-1, -1)
, d_keyword() {}

emit(size_t start, size_t end, string_type keyword)
: interval(start, end)
, d_keyword(keyword) {}

string_type get_keyword() const { return string_type(d_keyword); }
bool is_empty() const { return (get_start() == -1 && get_end() == -1); }
};

// class token
template<typename CharType>
class token {
public:
enum token_type{
TYPE_FRAGMENT,
TYPE_MATCH,
};

using string_type     = std::basic_string<CharType>;
using string_ref_type = std::basic_string<CharType>&;
using emit_type       = emit<CharType>;

private:
token_type  d_type;
string_type d_fragment;
emit_type   d_emit;

public:
token(string_ref_type fragment)
: d_type(TYPE_FRAGMENT)
, d_fragment(fragment)
, d_emit() {}

token(string_ref_type fragment, const emit_type& e)
: d_type(TYPE_MATCH)
, d_fragment(fragment)
, d_emit(e) {}

bool is_match() const { return (d_type == TYPE_MATCH); }
string_type get_fragment() const { return string_type(d_fragment); }
emit_type get_emit() const { return d_emit; }
};

// class state
template<typename CharType>
class state {
public:
typedef state<CharType>*                 ptr;
typedef std::unique_ptr<state<CharType>> unique_ptr;
typedef std::basic_string<CharType>      string_type;
typedef std::basic_string<CharType>&     string_ref_type;
typedef std::set<string_type>            string_collection;
typedef std::vector<ptr>                 state_collection;
typedef std::vector<CharType>            transition_collection;

private:
size_t                         d_depth;
ptr                            d_root;
std::map<CharType, unique_ptr> d_success;
ptr                            d_failure;
string_collection              d_emits;

public:
state(): state(0) {}

state(size_t depth)
: d_depth(depth)
, d_root(depth == 0 ? this : nullptr)
, d_success()
, d_failure(nullptr)
, d_emits() {}

ptr next_state(CharType character) const {
return next_state(character, false);
}

ptr next_state_ignore_root_state(CharType character) const {
return next_state(character, true);
}

ptr add_state(CharType character) {
auto next = next_state_ignore_root_state(character);
if (next == nullptr) {
next = new state<CharType>(d_depth + 1);
d_success[character].reset(next);
}
return next;
}

size_t get_depth() const { return d_depth; }

void add_emit(string_ref_type keyword) {
d_emits.insert(keyword);
}

void add_emit(const string_collection& emits) {
for (const auto& e : emits) {
string_type str(e);
add_emit(str);
}
}

string_collection get_emits() const { return d_emits; }

ptr failure() const { return d_failure; }

void set_failure(ptr fail_state) { d_failure = fail_state; }

state_collection get_states() const {
state_collection result;
for (auto it = d_success.cbegin(); it != d_success.cend(); ++it) {
result.push_back(it->second.get());
}
return state_collection(result);
}

transition_collection get_transitions() const {
transition_collection result;
for (auto it = d_success.cbegin(); it != d_success.cend(); ++it) {
result.push_back(it->first);
}
return transition_collection(result);
}

private:
ptr next_state(CharType character, bool ignore_root_state) const {
ptr result = nullptr;
auto found = d_success.find(character);
if (found != d_success.end()) {
result = found->second.get();
} else if (!ignore_root_state && d_root != nullptr) {
result = d_root;
}
return result;
}
};

template<typename CharType>
class basic_trie {
public:
using string_type = std::basic_string < CharType > ;
using string_ref_type = std::basic_string<CharType>&;

typedef state<CharType>         state_type;
typedef state<CharType>*        state_ptr_type;
typedef token<CharType>         token_type;
typedef emit<CharType>          emit_type;
typedef std::vector<token_type> token_collection;
typedef std::vector<emit_type>  emit_collection;

class config {
bool d_allow_overlaps;
bool d_only_whole_words;
bool d_case_insensitive;

public:
config()
: d_allow_overlaps(true)
, d_only_whole_words(false)
, d_case_insensitive(false) {}

bool is_allow_overlaps() const { return d_allow_overlaps; }
void set_allow_overlaps(bool val) { d_allow_overlaps = val; }

bool is_only_whole_words() const { return d_only_whole_words; }
void set_only_whole_words(bool val) { d_only_whole_words = val; }

bool is_case_insensitive() const { return d_case_insensitive; }
void set_case_insensitive(bool val) { d_case_insensitive = val; }
};

private:
std::unique_ptr<state_type> d_root;
config                      d_config;
bool                        d_constructed_failure_states;

public:
basic_trie(): basic_trie(config()) {}

basic_trie(const config& c)
: d_root(new state_type())
, d_config(c)
, d_constructed_failure_states(false) {}

basic_trie& case_insensitive() {
d_config.set_case_insensitive(true);
return (*this);
}

basic_trie& remove_overlaps() {
d_config.set_allow_overlaps(false);
return (*this);
}

basic_trie& only_whole_words() {
d_config.set_only_whole_words(true);
return (*this);
}

void insert(string_type keyword) {
if (keyword.empty())
return;
state_ptr_type cur_state = d_root.get();
for (const auto& ch : keyword) {
cur_state = cur_state->add_state(ch);
}
cur_state->add_emit(keyword);
}

template<class InputIterator>
void insert(InputIterator first, InputIterator last) {
for (InputIterator it = first; first != last; ++it) {
insert(*it);
}
}

token_collection tokenise(string_type text) {
token_collection tokens;
auto collected_emits = parse_text(text);
size_t last_pos = -1;
for (const auto& e : collected_emits) {
if (e.get_start() - last_pos > 1) {
tokens.push_back(create_fragment(e, text, last_pos));
}
tokens.push_back(create_match(e, text));
last_pos = e.get_end();
}
if (text.size() - last_pos > 1) {
tokens.push_back(create_fragment(typename token_type::emit_type(), text, last_pos));
}
return token_collection(tokens);
}

emit_collection parse_text(string_type text) {
check_construct_failure_states();
size_t pos = 0;
state_ptr_type cur_state = d_root.get();
emit_collection collected_emits;
for (auto c : text) {
if (d_config.is_case_insensitive()) {
c = std::tolower(c);
}
cur_state = get_state(cur_state, c);
store_emits(pos, cur_state, collected_emits);
pos++;
}
if (d_config.is_only_whole_words()) {
remove_partial_matches(text, collected_emits);
}
if (!d_config.is_allow_overlaps()) {
interval_tree<emit_type> tree(typename interval_tree<emit_type>::interval_collection(collected_emits.begin(), collected_emits.end()));
auto tmp = tree.remove_overlaps(collected_emits);
collected_emits.swap(tmp);
}
return emit_collection(collected_emits);
}

private:
token_type create_fragment(const typename token_type::emit_type& e, string_ref_type text, size_t last_pos) const {
auto start = last_pos + 1;
auto end = (e.is_empty()) ? text.size() : e.get_start();
auto len = end - start;
typename token_type::string_type str(text.substr(start, len));
return token_type(str);
}

token_type create_match(const typename token_type::emit_type& e, string_ref_type text) const {
auto start = e.get_start();
auto end = e.get_end() + 1;
auto len = end - start;
typename token_type::string_type str(text.substr(start, len));
return token_type(str, e);
}

void remove_partial_matches(string_ref_type search_text, emit_collection& collected_emits) const {
size_t size = search_text.size();
emit_collection remove_emits;
for (const auto& e : collected_emits) {
if ((e.get_start() == 0 || !std::isalpha(search_text.at(e.get_start() - 1))) &&
(e.get_end() + 1 == size || !std::isalpha(search_text.at(e.get_end() + 1)))
) {
continue;
}
remove_emits.push_back(e);
}
for (auto& e : remove_emits) {
collected_emits.erase(
std::find(collected_emits.begin(), collected_emits.end(), e)
);
}
}

state_ptr_type get_state(state_ptr_type cur_state, CharType c) const {
state_ptr_type result = cur_state->next_state(c);
while (result == nullptr) {
cur_state = cur_state->failure();
result = cur_state->next_state(c);
}
return result;
}

void check_construct_failure_states() {
if (!d_constructed_failure_states) {
construct_failure_states();
}
}

void construct_failure_states() {
std::queue<state_ptr_type> q;
for (auto& depth_one_state : d_root->get_states()) {
depth_one_state->set_failure(d_root.get());
q.push(depth_one_state);
}
d_constructed_failure_states = true;

while (!q.empty()) {
auto cur_state = q.front();
for (const auto& transition : cur_state->get_transitions()) {
state_ptr_type target_state = cur_state->next_state(transition);
q.push(target_state);

state_ptr_type trace_failure_state = cur_state->failure();
while (trace_failure_state->next_state(transition) == nullptr) {
trace_failure_state = trace_failure_state->failure();
}
state_ptr_type new_failure_state = trace_failure_state->next_state(transition);
target_state->set_failure(new_failure_state);
target_state->add_emit(new_failure_state->get_emits());
}
q.pop();
}
}

void store_emits(size_t pos, state_ptr_type cur_state, emit_collection& collected_emits) const {
auto emits = cur_state->get_emits();
if (!emits.empty()) {
for (const auto& str : emits) {
auto emit_str = typename emit_type::string_type(str);
collected_emits.push_back(emit_type(pos - emit_str.size() + 1, pos, emit_str));
}
}
}
};

typedef basic_trie<char>     trie;
typedef basic_trie<wchar_t>  wtrie;

} // namespace aho_corasick

#endif // AHO_CORASICK_HPP
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: