您的位置:首页 > 编程语言 > Java开发

Adaboost的java实现

2014-02-22 11:30 288 查看
实体类

/**
* 样例
* @author Administrator
*
*/
public class Instance {

public double[] dim;	//各个维度值
public int label;		//类别标号

public Instance(double[] dim, int label) {
this.dim = dim;
this.label = label;
}
}


父类分类器

public abstract class Classifier {

public double errorRate;
public int errorNumber;

public abstract int classify(Instance instance) ;

}


自己实现的简单分类器,功能很弱,只是演示,可换用其他分类器

/**
* 简单分类器
* @author wangyongkang
*
*/
public class SimpleClassifier extends Classifier{

double threshold ;	//分类的阈值
int dimNum;			//对哪个维度分类
int fuhao = 1;		//对阈值两边的处理

public int classify(Instance instance) {

if(instance.dim[dimNum] >= threshold) {
return fuhao;
}else {
return -fuhao;
}
}

/**
* 训练出threshold和fuhao
* @param instances
* @param W 样例的权重
* @param dim 对样例的哪个维度进行训练
*/
public void train(Instance[] instances, double[] W, int dimNum) {

errorRate = Double.MAX_VALUE;
this.dimNum = dimNum;
double adaThreshold = 0;
int adaFuhao = 0;
for(Instance instance : instances) {
threshold = instance.dim[dimNum];
for(int fuhaoIt = 0; fuhaoIt < 2; fuhaoIt ++) {
fuhao = -fuhao;
double error = 0;
int errorNum = 0;
for(int i = 0; i< instances.length; i++) {
if(classify(instances[i]) != instances[i].label) {
error += W[i];
errorNum++;
}
}
if(errorRate > error){
errorRate = error;
errorNumber = errorNum;
adaThreshold = threshold;
adaFuhao = fuhao;
}
}
}
threshold = adaThreshold;
fuhao = adaFuhao;
}
}


adaboost类

/**
* 实现adaboost功能
* @author Administrator
*
*/
public class Adaboost {

Instance[] instances;
List<Classifier> classifierList = null;	//各个弱分类器
List<Double> alphaList = null;			//每个弱分类器的权重

public Adaboost(Instance[] instances) {

this.instances = instances;
}

public void adaboost(int T) {

int len = this.instances.length;
double[] W = new double[len];	//权重
for(int i = 0; i < len; i ++) {
W[i] = 1.0 / len;
}
classifierList = new ArrayList<Classifier>();
alphaList = new ArrayList<Double>();
for(int t = 0; t < T; t++) {
Classifier cf = getMinErrorRateClassifier(W);
classifierList.add(cf);
double errorRate = cf.errorRate;
//计算弱分类器的权重
double alpha = 0.5 * Math.log((1 - errorRate) / errorRate);
alphaList.add(alpha);
//更新样例的权重
double z = 0;
for(int i = 0; i < W.length; i++) {
W[i] = W[i] * Math.exp(-alpha * instances[i].label * cf.classify(instances[i]));
z += W[i];
}
for(int i = 0; i < W.length; i++) {
W[i] /= z;
}
System.out.println(getErrorCount());
}
}

private int getErrorCount() {

int count = 0;
for(Instance instance : instances) {
if(predict(instance) != instance.label)
count ++;
}
return count;
}

/**
* 预测
* @param instance
* @return
*/
public int predict(Instance instance) {

double p = 0;
for(int i = 0; i < classifierList.size(); i++) {
p += classifierList.get(i).classify(instance) * alphaList.get(i);
}
if(p > 0) return 1;
return -1;
}

/**
* 得到错误率最低的分类器
* @param W
* @return
*/
private Classifier getMinErrorRateClassifier(double[] W) {

double errorRate = Double.MAX_VALUE;
SimpleClassifier minErrorRateClassifier = null;
int dimLength = instances[0].dim.length;
for(int i = 0; i < dimLength; i++) {
SimpleClassifier sc = new SimpleClassifier();
sc.train(instances, W, i);
if(errorRate > sc.errorRate){
errorRate  = sc.errorRate;
minErrorRateClassifier = sc;
}
}
return minErrorRateClassifier;
}

}


测试类

public class AdaboostTest {

public static void main(String[] args) {

double[] ins1 = {0,3};
double[] ins2 = {1,3};
double[] ins3 = {2,3};
double[] ins4 = {3,1};
double[] ins5 = {4,1};
double[] ins6 = {5,1};
double[] ins7 = {6,3};
double[] ins8 = {7,3};
double[] ins9 = {8,0};
double[] ins10 = {9,1};

Instance instance1 = new Instance(ins1, 1);
Instance instance2 = new Instance(ins2, 1);
Instance instance3 = new Instance(ins3, 1);
Instance instance4 = new Instance(ins4, -1);
Instance instance5 = new Instance(ins5, -1);
Instance instance6 = new Instance(ins6, -1);
Instance instance7 = new Instance(ins7, 1);
Instance instance8 = new Instance(ins8, 1);
Instance instance9 = new Instance(ins9, 1);
Instance instance10 = new Instance(ins10, -1);

Instance[] instances = {instance1, instance2, instance3, instance4, instance5, instance6, instance7, instance8, instance9, instance10 };

Adaboost ab = new Adaboost(instances);
ab.adaboost(10);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: