One-vs-Rest算法介绍及Spark MLlib调用实例(Scala/Java/Python)
2016-12-02 16:50
931 查看
One-vs-Rest
算法介绍:
OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做“One-vs-All.”算法。OneVsRest是一个Estimator。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。
参数:
featuresCol:
类型:字符串型。
含义:特征列名。
labelCol:
类型:字符串型。
含义:标签列名。
predictionCol:
类型:字符串型。
含义:预测结果列名。
classifier:
类型:分类器型。
含义:基础二分类分类器。
示例:
Scala:
Java:
Python:
算法介绍:
OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做“One-vs-All.”算法。OneVsRest是一个Estimator。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。
参数:
featuresCol:
类型:字符串型。
含义:特征列名。
labelCol:
类型:字符串型。
含义:标签列名。
predictionCol:
类型:字符串型。
含义:预测结果列名。
classifier:
类型:分类器型。
含义:基础二分类分类器。
示例:
Scala:
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // load data file. val inputData = spark.read.format("libsvm") .load("data/mllib/sample_multiclass_classification_data.txt") // generate the train/test split. val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2)) // instantiate the base classifier val classifier = new LogisticRegression() .setMaxIter(10) .setTol(1E-6) .setFitIntercept(true) // instantiate the One Vs Rest Classifier. val ovr = new OneVsRest().setClassifier(classifier) // train the multiclass model. val ovrModel = ovr.fit(train) // score the model on test data. val predictions = ovrModel.transform(test) // obtain evaluator. val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") // compute the classification error on test data. val accuracy = evaluator.evaluate(predictions) println(s"Test Error : ${1 - accuracy}")
Java:
import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; // load data file. Dataset<Row> inputData = spark.read().format("libsvm") .load("data/mllib/sample_multiclass_classification_data.txt"); // generate the train/test split. Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2}); Dataset<Row> train = tmp[0]; Dataset<Row> test = tmp[1]; // configure the base classifier. LogisticRegression classifier = new LogisticRegression() .setMaxIter(10) .setTol(1E-6) .setFitIntercept(true); // instantiate the One Vs Rest Classifier. OneVsRest ovr = new OneVsRest().setClassifier(classifier); // train the multiclass model. OneVsRestModel ovrModel = ovr.fit(train); // score the model on test data. Dataset<Row> predictions = ovrModel.transform(test) .select("prediction", "label"); // obtain evaluator. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy"); // compute the classification error on test data. double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error : " + (1 - accuracy));
Python:
from pyspark.ml.classification import LogisticRegression, OneVsRest from pyspark.ml.evaluation import MulticlassClassificationEvaluator # load data file. inputData = spark.read.format("libsvm") \ .load("data/mllib/sample_multiclass_classification_data.txt") # generate the train/test split. (train, test) = inputData.randomSplit([0.8, 0.2]) # instantiate the base classifier. lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True) # instantiate the One Vs Rest Classifier. ovr = OneVsRest(classifier=lr) # train the multiclass model. ovrModel = ovr.fit(train) # score the model on test data. predictions = ovrModel.transform(test) # obtain evaluator. evaluator = MulticlassClassificationEvaluator(metricName="accuracy") # compute the classification error on test data. accuracy = evaluator.evaluate(predictions) print("Test Error : " + str(1 - accuracy))
相关文章推荐
- 随机森林回归(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 三种特征选择方法及Spark MLlib调用实例(Scala/Java/python)
- 随机森林(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 生存回归(加速失效时间模型)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 广义线性模型(GLMs)算法原理及Spark MLlib调用实例(Scala/Java/Python)
- scala--三种文本特征提取(TF-IDF/Word2Vec/CountVectorizer)及Spark MLlib调用实例(Scala/Java/python)
- 二十种特征变换方法及Spark MLlib调用实例(Scala/Java/python)(一)
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 三种特征选择方法及Spark MLlib调用实例(Scala/Java/python)
- 梯度迭代树回归(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
- K均值(K-means)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
- MLlib--多层感知机(MLP)算法原理及Spark MLlib调用实例(Scala/Java/Python)
- 多层感知机(MLP)算法原理及Spark MLlib调用实例(Scala/Java/Python)
- 混合高斯模型(GMM)Spark MLlib调用实例(Scala/Java/Python)
- 二十种特征变换方法及Spark MLlib调用实例(Scala/Java/python)(二)
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 决策树回归算法原理及Spark MLlib调用实例(Scala/Java/python)
- 朴素贝叶斯算法原理及Spark MLlib调用实例(Scala/Java/Python)
- 决策树算法原理及Spark MLlib调用实例(Scala/Java/python)