学习笔记:Creating Estimators in tf.contrib.learn
2017-04-08 10:52
423 查看
tf.contrib.learn是一个高级的训练API,利用它可以简单的创建和训练机器学习模型。
提供了四种模型:
1、LinearClassifier:线性分类模型
2、LinearRegressor:线性回归模型
3、DNNClassifier:神经网络分类模型
4、DNNRegressor:神经网络回归模型
注:回归和分类的区别:回归问题通常是用来预测一个值,如预测房价、未来的天气情况等等,分类问题是用于将事物打上一个标签,通常结果为离散值。
但是:如果已经给出的tf.contrib.learn无法满足你的要求呢,就得自己创建一个模型。
以下教程就是教你如何利用tf.contrib.learn提供的模块创建自己的Estimator,
提供了四种模型:
1、LinearClassifier:线性分类模型
2、LinearRegressor:线性回归模型
3、DNNClassifier:神经网络分类模型
4、DNNRegressor:神经网络回归模型
注:回归和分类的区别:回归问题通常是用来预测一个值,如预测房价、未来的天气情况等等,分类问题是用于将事物打上一个标签,通常结果为离散值。
但是:如果已经给出的tf.contrib.learn无法满足你的要求呢,就得自己创建一个模型。
以下教程就是教你如何利用tf.contrib.learn提供的模块创建自己的Estimator,
from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys import tempfile#临时文件 from six.moves import urllib import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib FLAGS = None tf.logging.set_verbosity(tf.logging.INFO) # Learning rate for the model LEARNING_RATE = 0.001 #数据集的下载 def maybe_download(train_data, test_data, predict_data): """Maybe downloads training data and returns train and test file names.""" if train_data: train_file_name = train_data else: train_file = tempfile.NamedTemporaryFile(delete=False) urllib.request.urlretrieve( "http://download.tensorflow.org/data/abalone_train.csv", train_file.name) #urlretrieve() 方法直接将远程数据下载到本地。 train_file_name = train_file.name train_file.close() print("Training data is downloaded to %s" % train_file_name) if test_data: test_file_name = test_data else: test_file = tempfile.NamedTemporaryFile(delete=False) urllib.request.urlretrieve( "http://download.tensorflow.org/data/abalone_test.csv", test_file.name) test_file_name = test_file.name test_file.close() print("Test data is downloaded to %s" % test_file_name) if predict_data: predict_file_name = predict_data else: predict_file = tempfile.NamedTemporaryFile(delete=False) urllib.request.urlretrieve( "http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) predict_file_name = predict_file.name predict_file.close() print("Prediction data is downloaded to %s" % predict_file_name) return train_file_name, test_file_name, predict_file_name def model_fn(features, targets, mode, params): """Model function for Estimator. # Logic to do the following: # 1. Configure the model via TensorFlow operations # 2. Define the loss function for training/evaluation # 3. Define the training operation/optimizer # 4. Generate predictions """ # Connect the first hidden layer to input layer # (features) with relu activation first_hidden_layer = tf.contrib.layers.relu(features, 10) # Connect the second hidden layer to first hidden layer with relu second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10) # Connect the output layer to second hidden layer (no activation fn) output_layer = tf.contrib.layers.linear(second_hidden_layer, 1) # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) predictions_dict = {"ages": predictions} # Calculate loss using mean squared error loss = tf.contrib.losses.mean_squared_error(targets, predictions) # Calculate root mean squared error as additional eval metric计算均方差 eval_metric_ops = { "rmse": tf.contrib.metrics.streaming_root_mean_squared_error( tf.cast(targets, tf.float64), predictions) } train_op = tf.contrib.layers.optimize_loss( loss=loss, global_step=tf.contrib.framework.get_global_step(), learning_rate=params["learning_rate"], optimizer="SGD") #返回的是什么意思 return model_fn_lib.ModelFnOps( mode=mode, predictions=predictions_dict, loss=loss, b18e train_op=train_op, eval_metric_ops=eval_metric_ops) def main(unused_argv): # Load datasets abalone_train, abalone_test, abalone_predict = maybe_download( FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data) # Training examples 读取数据 没有header training_set = tf.contrib.learn.datasets.base.load_csv_without_header( filename=abalone_train, target_dtype=np.int, features_dtype=np.float64) # Test examples test_set = tf.contrib.learn.datasets.base.load_csv_without_header( filename=abalone_test, target_dtype=np.int, features_dtype=np.float64) # Set of 7 examples for which to predict abalone ages prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header( filename=abalone_predict, target_dtype=np.int, features_dtype=np.float64) # Set model params model_params = {"learning_rate": LEARNING_RATE} # Instantiate Estimator 建立模型 nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) # Fit 训练 nn.fit(x=training_set.data, y=training_set.target, steps=5000) # Score accuracy计算精度 ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1) print("Loss: %s" % ev["loss"]) print("Root Mean Squared Error: %s" % ev["rmse"]) # Print out predictions predictions = nn.predict(x=prediction_set.data, as_iterable=True) for i, p in enumerate(predictions): print("Prediction %s: %s" % (i + 1, p["ages"])) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( "--train_data", type=str, default="", help="Path to the training data.") parser.add_argument( "--test_data", type=str, default="", help="Path to the test data.") parser.add_argument( "--predict_data", type=str, default="", help="Path to the prediction data.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
相关文章推荐
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- TensorFlow学习笔记6----tf.contrib.learn Quickstart
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- 学习笔记:csv文件的读取和tf.contrib.learn Quickstart
- 深度学习笔记——深度学习框架TensorFlow(八)[Logging and Monitoring Basics with tf.contrib.learn]
- TensorFlow学习笔记10----Logging and Monitoring Basics with tf.contrib.learn
- 深度学习笔记——深度学习框架TensorFlow(九)[Building Input Functions with tf.contrib.learn]
- Learn Emacs in 21 Days: day 5 学习笔记
- Learn Emacs in 21 Days: day 3 学习笔记
- 学习笔记TF044:TF.Contrib组件、统计分布、Layer、性能分析器tfprof
- 学习笔记TF054:TFLearn、Keras
- 学习笔记TF054:TFLearn、Keras
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- 学习使用tf.contrib.learn框架开发机器学习程序
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- Learn Emacs in 21 Days: day 1 学习笔记
- [Python][MachineLeaning]Python Scikit-learn学习笔记1-Datasets&Estimators
- 学习笔记TF043:TF.Learn 机器学习Estimator、DataFrame、监督器Monitors
- Learn Emacs in 21 Days: day 2 学习笔记