您的位置:首页 > 其它

TensorFlow学习笔记-tf.estimator

2018-02-20 18:53 453 查看
tfestimatorEstimator
属性

方法

tf.estimator.Estimator

Estimator class训练和测试TF模型。
Estimator
对象封装好通过
model_fn
指定的模型,给定输入和其它超参数,返回ops执行training, evaluation or prediction. 所有的输出(包含checkpoints, event files, etc.)被写入
model_dir


属性

config

传入
model_fn
,如果
model_fn
有参数named “config”

model_dir

model_fn

The model_fn with following signature:
def model_fn(features, labels, mode, config)


params

方法

__init__


__init__(
model_fn,
model_dir=None,
config=None,
params=None # 将要传入model_fn的超参数字典
)


evaluate


对训练模型评价

evaluate(
input_fn, # 输入函数,返回元组features和labels
steps=None,
hooks=None, # List of SessionRunHook subclass instances
checkpoint_path=None, # if none, 用model_dir中latest checkpoint
name=None
)


export_savemodel


导出inference graph作为一个SavedModel

export_savedmodel(
export_dir_base, # 目录
serving_input_receiver_fn, # 返回ServingInputReceiver的函数
assets_extra=None,
as_text=False,
checkpoint_path=None
)


get_variable_names

get_variable_names()

返回模型中所有变量名字的列表

get_variable_value(name)

根据变量name返回value

latest_checkpoint()

model_dir
中找到最近保存的checkpoint

predict

根据给定的features产生预测

predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)


train

给定训练数据后训练model

train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息