您的位置:首页 > 大数据 > 人工智能

混合高斯模型&AIC-BIC挑选中心个数

2015-12-07 10:29 429 查看
实验室项目中要把数据按正态分布分成几类,但是有不知道有几类,估计不超过三类。然后就用了BIC准则选择类个数,效果出奇的好  哈哈哈哈

GMM初值对结果会有很大影响,按数据min,max均分正态的均值,然后用整体数据的方差作为初始方差  完美的解决了这个问题。可能是我们数据本身的原因。

研一学的总算用上一点  好开心

const static int MAX_ITERATOR = 1000;
const static double END_THR = 0.0001;
const static double SIM_THR = 0.2;
const static double PI = 3.14159265;
const static double EE = 2.71828;

struct Gaussian{
double mean, dalta;
double weight;
Gaussian(double m=0, double v=0, double w=1.0): mean(m), dalta(v), weight(w){
}
double getProbability(double x) const {
return weight * std::pow(EE, -std::pow(x-mean, 2.0) / (2*dalta*dalta)) / ( std::pow(2*PI, 0.5) * dalta );
}
private:
friend std::ostream& operator<<(std::ostream& os, const Gaussian & x);
};

std::ostream& operator<<(std::ostream& os, const Gaussian & x) {
os << "mean: " << x.mean << " dalta: " << x.dalta << " weight: " << x.weight;
return os;
}

class GMM {
public:
void gmm(const std::vector<double> & data, int mxCenter, std::vector< Gaussian > &re) {
double BIC = DBL_MAX;
std::vector< Gaussian > tmpResult;
for(int i = 1; i <= mxCenter; ++i) {
std::vector< Gaussian > tmp;
double newBIC = fixCenterGmm(data, i, tmp);
if( newBIC < BIC) {
BIC = newBIC;
tmpResult = tmp;
}
}
for(int i=0; i<tmpResult.size(); ++i) {
bool ok = true;
for(int j=i+1; j<tmpResult.size(); ++j) {
if( fabs(tmpResult[i].mean - tmpResult[j].mean) < tmpResult[i].mean * SIM_THR) {
ok = false;
tmpResult[j].weight += tmpResult[i].weight;
break;
}
}
if(ok) {
re.push_back(tmpResult[i]);
}
}
return ;
}

private:
struct comp{
std::pair<double, double> operator()(const std::pair<double, double> &a, double x) {
return std::make_pair(a.first + x, a.second + x*x);
}
};

Gaussian getGaussian(const std::vector<double> & data) {
std::pair<double, double> re = accumulate(data.begin(), data.end(), std::make_pair(0.0, 0.0), comp());
return Gaussian(re.first / data.size(), std::pow( re.second / data.size() - std::pow(re.first / data.size(), 2.0), 0.5), 1.0);
}
double getDalta(const std::vector<double> & data) {
std::pair<double, double> re = accumulate(data.begin(), data.end(), std::make_pair(0.0, 0.0), comp());
return std::pow( re.second / data.size() - std::pow(re.first / data.size(), 2.0), 0.5);
}

double fixCenterGmm(const std::vector<double> & data, int centers, std::vector< Gaussian > &re ) {
if( centers <= 1 ) {
re.push_back( getGaussian(data) );
return caculateBIC(data, re);
}
double mx = *max_element(data.begin(), data.end());
double mn = *min_element(data.begin(), data.end());
double diff = mx - mn;
double dalta = getDalta(data);
for(int i = 0; i < centers; ++i) {
re.push_back( Gaussian(mn + i*diff/(centers-1), dalta, 1.0 / centers) );
}
std::vector< std::vector<double> > beta( data.size(), std::vector<double>(centers, 0.0) );
std::vector< Gaussian > tmp(centers, Gaussian() );
int itera = 0;
while( itera++ < MAX_ITERATOR && !ok(re, tmp) ) {
tmp = re;
for(int i=0; i<data.size(); ++i) {
for(int j=0; j<centers; ++j) {
beta[i][j] = re[j].getProbability(data[i]);
}
double sum = accumulate(beta[i].begin(), beta[i].end(), 0.0);
for(int j=0; j<centers; ++j) {
beta[i][j] /= sum;
}
}
for(int j=0; j<centers; ++j) {
double sumBeta = 0.0, sumweightBeta = 0.0, sumVar = 0.0;
for(int i=0; i<data.size(); ++i) {
sumBeta += beta[i][j];
sumweightBeta += data[i] * beta[i][j];
}
re[j].weight = sumBeta / data.size();
re[j].mean = sumweightBeta / sumBeta;
for(int i=0; i<data.size(); ++i) {
sumVar += beta[i][j] * (data[i] - re[j].mean) * (data[i] - re[j].mean);
}
re[j].dalta = std::pow( sumVar / sumBeta, 0.5);
}
}
return caculateBIC(data, re);
}

bool ok(const std::vector< Gaussian >& re, const std::vector< Gaussian >& tmp) {
double diff = 0.0;
double sum = 0.0;
for(int i=0; i<re.size(); ++i) {
diff += fabs( re[i].mean - tmp[i].mean );
sum += re[i].mean;
}
return diff / sum < END_THR;
}

double caculateBIC(const std::vector<double> &data, const std::vector< Gaussian >& gau) {
double BIC = (2 * gau.size() ) * log( data.size() );
for(int i=0; i<data.size(); ++i) {
double pro = 0.0;
for(int j=0; j<gau.size(); ++j) {
pro += gau[j].getProbability(data[i]);
}
BIC -= 2*log(pro);
}
return BIC;
}
};
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: