您的位置:首页 > 编程语言 > Java开发

FPGrowth的java实现

2015-06-23 13:40 393 查看
1、公共类

package com.apriori.common;

import java.util.ArrayList;

import java.util.Collections;

import java.util.HashMap;

import java.util.List;

import java.util.Map;

import java.util.Set;

import java.util.Map.Entry;

/**

* <p>本类描述: 公共类</p>

* <p>其他说明: </p>

* @author Wang Haiyang

* @date 2015-6-23 下午01:42:01

*/

public class Aprioris {

/**

* 方法描述:得到频繁1项集

* @param D:事务数据库

* @param min_sup:最小支持度阀值

* @return

*/

public static List<ArrayList<Integer>> getFrequent1Itemsets(List<ArrayList<Integer>> D, Integer min_sup, Map<ArrayList<Integer>, Integer> L) {

List<ArrayList<Integer>> results = new ArrayList<ArrayList<Integer>>();

Map<Integer, Integer> map = new HashMap<Integer, Integer>();

for (ArrayList<Integer> d : D) {

for (Integer g : d) {

if (map.containsKey(g)) {

map.put(g, map.get(g) + 1);

} else {

map.put(g, 1);

}

}

}

Set<Entry<Integer, Integer>> entrySet = map.entrySet();

for (Entry<Integer, Integer> entry : entrySet) {

if (entry.getValue() >= min_sup) {

ArrayList<Integer> l = new ArrayList<Integer>();

l.add(entry.getKey());

results.add(l);

L.put(l, entry.getValue());

}

}

return results;

}

public static void displayAssociationRules(Map<String, Double> rules) {

for (Entry<String, Double> entry : rules.entrySet()) {

System.out.println(entry.getKey() + ":" + entry.getValue());

}

}

/**

* 方法描述:遍历频繁项集

* @param L

*/

public static void displayFrequentItemsets(Map<ArrayList<Integer>, Integer> L) {

for (Entry<ArrayList<Integer>, Integer> entry : L.entrySet()) {

System.out.print("(");

for (Integer integer : entry.getKey()) {

System.out.print(integer);

System.out.print(",");

}

System.out.print(")");

System.out.println();

}

}

/**

* 方法描述:产生关联规则

* @param L

* @param min_con

* @return

*/

public static Map<String, Double> produceAssociationRules(Map<ArrayList<Integer>, Integer> L, Double min_con) {

Map<String, Double> result = new HashMap<String, Double>();

for (Entry<ArrayList<Integer>, Integer> entry : L.entrySet()) {

ArrayList<Integer> v = entry.getKey();

if (v.size() > 1) {

List<ArrayList<Integer>> lists = subList(v); // 得到给定list的所有非空真子集

for (ArrayList<Integer> list : lists) {

List<Integer> exp = exceptList(v, list); // 得到除了list之外的子集

Integer integer1 = entry.getValue();

Integer integer2 = L.get(list);

if (integer1 != null && integer2 != null) {

Double per = Double.parseDouble(integer1 + "") / integer2;

if (per >= min_con) {

result.put(list.toString() + "=>" + exp.toString(), per);

}

}

}

}

}

return result;

}

/**

* 方法描述:得到除了list之外的子集

* @param key

* @param list

* @return

*/

private static List<Integer> exceptList(ArrayList<Integer> key, ArrayList<Integer> list) {

List<Integer> results = new ArrayList<Integer>();

for (Integer l : key) {

if (!list.contains(l)) {

results.add(l);

}

}

return results;

}

/**

* 方法描述:得到给定list的所有非空真子集

* @param key

* @return

*/

private static List<ArrayList<Integer>> subList(ArrayList<Integer> key) {

List<ArrayList<Integer>> results = new ArrayList<ArrayList<Integer>>();

for (int i = 0; i < key.size(); i++) {

ArrayList<Integer> l = new ArrayList<Integer>();

l.add(key.get(i));

results.add(l);

}

for (int i = 0; i < key.size(); i++) {

int keyi = key.get(i);

for (int j = i + 1; j < key.size(); j++) {

int keyj = key.get(j);

ArrayList<Integer> l = new ArrayList<Integer>();

l.add(keyi);

l.add(keyj);

Collections.sort(l);

if (!l.containsAll(key)) {

if (!results.containsAll(l)) {

results.add(l);

}

}

}

}

return results;

}

}

2、ConditionalPatternBase.java

package com.apriori.fpgrowth;

import java.util.ArrayList;

import java.util.List;

/**

* <p>本类描述: t条件模式基</p>

* <p>其他说明: </p>

* @author Wang Haiyang

* @date 2015-6-19 下午05:01:40

*/

public class ConditionalPatternBase {

/**每个条件模式基*/

private List<Integer> base = new ArrayList<Integer>();

/**每个条件模式基的值*/

private Integer value;

public List<Integer> getBase() {

return base;

}

public void setBase(List<Integer> base) {

this.base = base;

}

public Integer getValue() {

return value;

}

public void setValue(Integer value) {

this.value = value;

}

}

3、TreeNode

package com.apriori.fpgrowth;

import java.util.ArrayList;

import java.util.List;

public class TreeNode implements Comparable<TreeNode>{

/**节点名字*/

private Integer name;

/**节点的出现次数*/

private Integer value = 0;

/**节点的孩子*/

private List<TreeNode> child = new ArrayList<TreeNode>();

/**节点的父亲*/

private TreeNode parent;

@Override

public int compareTo(TreeNode o) {

return o.getValue() - this.value;

}

public List<TreeNode> getChild() {

return child;

}

public void setChild(List<TreeNode> child) {

this.child = child;

}

public TreeNode getParent() {

return parent;

}

public void setParent(TreeNode parent) {

this.parent = parent;

}

public Integer getName() {

return name;

}

public void setName(Integer name) {

this.name = name;

}

public Integer getValue() {

return value;

}

public void setValue(Integer value) {

this.value = value;

}

}

4、FPGrowth

package com.apriori.fpgrowth;

import java.util.ArrayList;

import java.util.Collections;

import java.util.HashMap;

import java.util.List;

import java.util.Map;

import java.util.Map.Entry;

import com.apriori.common.Aprioris;

/**

* <p>

* 本类描述:

* 本类主要完成找出频繁项集

* </p>

* <p>

* 主要步骤:

* 1. 找出频繁1项集,并按照支持度递减排列

* 2. 遍历项集D,按照频繁1项集的顺序构造频繁模式增长树

* 3. 找出条件模式基

* 4. 根据条件模式基找出频繁模式

* </p>

* @author Wang Haiyang

* @date 2015-6-19 上午10:26:53

*/

public class FPGrowth {

/**

* 方法描述: 得到频繁模式集

* @param D

* @param min_sup

* @return

*/

public static Map<ArrayList<Integer>, Integer> getFrequentItemsets(List<ArrayList<Integer>> D, Integer min_sup) {

Map<ArrayList<Integer>, Integer> L = new HashMap<ArrayList<Integer>, Integer>();

Aprioris.getFrequent1Itemsets(D, min_sup, L); // 得到频繁1项集

ArrayList<TreeNode> L1 = sortDes(L); // 降序排序频繁1项集

TreeNode root = createFPTree(D, L1); // 得到频繁模式增长树

getFrequentItemsetsByFPTree(D, min_sup, root, L, L1); // 得到频繁模式

return L;

}

/**

* 方法描述:得到频繁模式

* @param D

* @param min_sup

* @param root

* @param L

*/

private static void getFrequentItemsetsByFPTree(List<ArrayList<Integer>> D, Integer min_sup, TreeNode root,

Map<ArrayList<Integer>, Integer> L, ArrayList<TreeNode> L1) {

List<TreeNode> nodes = getLeafs(root); // 得到树的叶子节点

Map<Integer, ArrayList<ConditionalPatternBase>> map = getAllConditionalPatternBases(nodes); // 得到所有的条件模式基

// 得到所有的频繁模式

for (Entry<Integer, ArrayList<ConditionalPatternBase>> entry : map.entrySet()) { // 得到组合条件模式基

ArrayList<ConditionalPatternBase> value = entry.getValue();

TreeNode t = createFPTree(value, L1); // 得到组合条件模式基树

List<TreeNode> n = new ArrayList<TreeNode>();

getNodes(t, min_sup, n); // 得到满足min_sup的节点

for (int i = 0; i < n.size(); i++) {

ArrayList<Integer> l = new ArrayList<Integer>();

l.add(n.get(i).getName());

l.add(entry.getKey());

L.put(l, n.get(i).getValue());

}

for (int i = 0; i < n.size(); i++) {

int keyi = n.get(i).getName();

for (int j = i + 1; j < n.size(); j++) {

int keyj = n.get(j).getName();

ArrayList<Integer> l = new ArrayList<Integer>();

l.add(keyi);

l.add(keyj);

l.add(entry.getKey());

L.put(l, n.get(j).getValue());

}

}

}

}

/**

* 方法描述:得到满足min_sup的节点

* @param node

* @param min_sup

* @param results

*/

private static void getNodes(TreeNode node, Integer min_sup, List<TreeNode> results) {

List<TreeNode> childs = node.getChild();

if (childs == null || childs.size() == 0) {

return;

} else {

for (TreeNode child : childs) {

if (child.getValue() >= min_sup) {

results.add(child);

}

getNodes(child, min_sup, results);

}

}

return;

}

/**

* 方法描述:得到所有的条件模式基

* @param nodes

* @return

*/

private static Map<Integer, ArrayList<ConditionalPatternBase>> getAllConditionalPatternBases(List<TreeNode> nodes) {

Map<Integer, ArrayList<ConditionalPatternBase>> results = new HashMap<Integer, ArrayList<ConditionalPatternBase>>();

for (TreeNode leaf : nodes) {

ConditionalPatternBase base = new ConditionalPatternBase();

TreeNode parent = leaf.getParent();

base.setValue(leaf.getValue());

List<Integer> ins = new ArrayList<Integer>();

while (parent != null && parent.getName() != null) {

ins.add(parent.getName());

parent = parent.getParent();

}

Collections.reverse(ins);

base.setBase(ins);

if (results.containsKey(leaf.getName())) {

results.get(leaf.getName()).add(base);

} else {

ArrayList<ConditionalPatternBase> lists = new ArrayList<ConditionalPatternBase>();

lists.add(base);

results.put(leaf.getName(), lists);

}

}

return results;

}

/**

* 方法描述:得到指定树的所有叶子节点

* @param root

* @return

*/

private static List<TreeNode> getLeafs(TreeNode root) {

List<TreeNode> results = new ArrayList<TreeNode>();

traverseTree(root, results);

return results;

}

/**

* 方法描述:递归遍整个树

* @param node

*/

private static void traverseTree(TreeNode node, List<TreeNode> results) {

List<TreeNode> childs = node.getChild();

if (childs == null || childs.size() == 0) {

results.add(node);

} else {

for (TreeNode child : childs) {

traverseTree(child, results);

}

}

}

/**

* 方法描述: 得到频繁模式增长树

* @param D

* @param L1

* @return

*/

private static TreeNode createFPTree(List<ArrayList<Integer>> D, ArrayList<TreeNode> L1) {

TreeNode root = new TreeNode();

for (ArrayList<Integer> lists : D) {

int flag = 0;

for (TreeNode node : L1) { // 针对lists,按照L1的顺序排序

if(lists.contains(node.getName())) {

int index = lists.indexOf(node.getName());

swap(lists, index, flag);

flag++;

}

}

TreeNode node = root;

for (Integer element : lists) { // 将lists放到result(即tree中)

if(containsValue(node.getChild(), element)) {

int index = getIndexOf(node.getChild(), element);

node.getChild().get(index).setValue(node.getChild().get(index).getValue() + 1);

node.getChild().get(index).setParent(node);

node = node.getChild().get(index);

} else {

TreeNode n = new TreeNode();

n.setName(element);

n.setValue(1);

node.getChild().add(n);

n.setParent(node);

node = n;

}

}

}

return root;

}

private static TreeNode createFPTree(ArrayList<ConditionalPatternBase> value, ArrayList<TreeNode> L1) {

TreeNode root = new TreeNode();

for (ConditionalPatternBase c : value) {

ArrayList<Integer> lists = (ArrayList<Integer>)c.getBase();

int v = c.getValue();

int flag = 0;

for (TreeNode node : L1) { // 针对lists,按照L1的顺序排序

if(lists.contains(node.getName())) {

int index = lists.indexOf(node.getName());

swap(lists, index, flag);

flag++;

}

}

TreeNode node = root;

for (Integer element : lists) { // 将lists放到result(即tree中)

if(containsValue(node.getChild(), element)) {

int index = getIndexOf(node.getChild(), element);

node.getChild().get(index).setValue(node.getChild().get(index).getValue() + v);

node.getChild().get(index).setParent(node);

node = node.getChild().get(index);

} else {

TreeNode n = new TreeNode();

n.setName(element);

n.setValue(v);

node.getChild().add(n);

n.setParent(node);

node = n;

}

}

}

return root;

}

/**

* 方法描述: 交换

* @param lists

* @param index

* @param flag

*/

private static void swap(ArrayList<Integer> lists, int index, int flag) {

int temp = lists.get(index);

lists.set(index, lists.get(flag));

lists.set(flag, temp);

}

/**

* 方法描述:按照出现次数降序排序频繁1项集

* @param L

* @return

*/

private static ArrayList<TreeNode> sortDes(Map<ArrayList<Integer>, Integer> L) {

ArrayList<TreeNode> results = new ArrayList<TreeNode>();

for (Entry<ArrayList<Integer>, Integer> enttry : L.entrySet()) {

TreeNode node = new TreeNode();

node.setName(enttry.getKey().get(0));

node.setValue(enttry.getValue());

results.add(node);

}

Collections.sort(results);

return results;

}

private static int getIndexOf(List<TreeNode> child, Integer element) {

for (int i = 0; i < child.size(); i++) {

if(child.get(i).getName() == element) {

return i;

}

}

return 0;

}

private static boolean containsValue(List<TreeNode> child, Integer element) {

for (TreeNode node : child) {

if(node.getName() == element) {

return true;

}

}

return false;

}

public static void main(String[] args) {

List<ArrayList<Integer>> D = new ArrayList<ArrayList<Integer>>();

ArrayList<Integer> list1 = new ArrayList<Integer>();

list1.add(1);

list1.add(2);

list1.add(5);

ArrayList<Integer> list2 = new ArrayList<Integer>();

list2.add(2);

list2.add(4);

ArrayList<Integer> list3 = new ArrayList<Integer>();

list3.add(2);

list3.add(3);

ArrayList<Integer> list4 = new ArrayList<Integer>();

list4.add(1);

list4.add(2);

list4.add(4);

ArrayList<Integer> list5 = new ArrayList<Integer>();

list5.add(1);

list5.add(3);

ArrayList<Integer> list6 = new ArrayList<Integer>();

list6.add(2);

list6.add(3);

ArrayList<Integer> list7 = new ArrayList<Integer>();

list7.add(1);

list7.add(3);

ArrayList<Integer> list8 = new ArrayList<Integer>();

list8.add(1);

list8.add(2);

list8.add(3);

list8.add(5);

ArrayList<Integer> list9 = new ArrayList<Integer>();

list9.add(1);

list9.add(2);

list9.add(3);

D.add(list1);

D.add(list2);

D.add(list3);

D.add(list4);

D.add(list5);

D.add(list6);

D.add(list7);

D.add(list8);

D.add(list9);

Integer min_sup = 2;

Double min_con = 0.7;

Map<ArrayList<Integer>, Integer> L = getFrequentItemsets(D, min_sup);

Aprioris.displayFrequentItemsets(L); // 打印频繁项集

Map<String, Double> rules = Aprioris.produceAssociationRules(L, min_con); // 产生关联规则

Aprioris.displayAssociationRules(rules); // 打印关联规则

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: