您的位置:首页 > 产品设计 > UI/UE

学习笔记:csv文件的读取和tf.contrib.learn Quickstart

2017-04-07 18:00 405 查看
tf.contrib.learn 是TensorFlow高层次机器学习API。

以下是TensorFlow官方文档的实例代码解析

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

import pandas as pd
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.INFO)#set logging verbosity to INFO

COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age","dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm","age", "dis", "tax", "ptratio"]
LABEL = "medv"
#数据集读取,训练集,测试集,预测集
training_set = pd.read_csv("boston_train.csv", skipinitialspace=True, skiprows=1, names=COLUMNS)
test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)
prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)

#skipinitialspace : boolean, default False忽略分隔符后的空白(默认为False,即不忽略).
#skiprows : list-like or integer, default None需要忽略的行数(从文件开始处算起),或需要跳过的行号列表(从0开始)。
#创建特征容量器FeatureColumns,把读入的特征分解为一个列表
feature_cols = [tf.contrib.layers.real_valued_column(k) for k in FEATURES]
#构建DNN网络,两个隐层,每层10个神经单元
regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir="/tmp/boston_model")
#定义输入函数,输入是数据集,返回的是FeatureColumns和labels(标签)
def input_fn(data_set):
feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES}#把feature_cols变为TensorFlow常量
labels = tf.constant(data_set[LABEL].values)
return feature_cols, labels

#-------------------------Training the Regressor-------------------------------
#迭代5000步 classifer.fit 训练模型
regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000)
#--------------------------Evaluating the Model----------------------
#计算精度
ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1)

loss_score = ev["loss"]
print("Loss: {0:f}".format(loss_score))

#-------------------------Making Predictions-----------------------
#输入prediction_set数据集对模型预测
y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
# .predict() returns an iterator; convert to a list and print predictions
predictions = list(itertools.islice(y, 6))#itertools用于高效循环的迭代函数集合,返回前6个值
print ("Predictions: {}".format(str(predictions)))


重要知识点:

tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

classifer.fit 训练模型

classifier.evaluate 评价模型

classifier.predict 预测新样本
函数my_input_fn():返回值frature_cols是一个字典,包含键值对,把列名和数据特征对应起来,返回值labels只是一个包含标签的张量。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: