您的位置:首页 > 大数据 > 人工智能

阿里云 机器学习pai的使用数据的使用以及模型的存储

2017-10-14 16:14 267 查看
1.数据的使用  读取pickle

import os
import sys
import argparse
import tensorflow as tf
import pickle
from tensorflow.python.lib.io import file_io
FLAGS = None
def main(_):
dir = os.path.join(FLAGS.buckets, 'Parsing.pickle')
object = file_io.read_file_to_string(dir,True)
result = pickle.loads(object)
training_records = result['training']
validation_records = result['validation']
print(len(training_records))
print("good")

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--buckets', type=str, default='',
help='input data path')
parser.add_argument('--checkpointDir', type=str, default='',
help='output model path')
FLAGS, _ = parser.parse_known_args()
tf.app.run(main=main)注意点1:buckets的定义,而且是缺省值不用定义具体的oss地址
注意点2:使用tensorflow进行读取,Python的open方法在pai上不能使用

注意点3:pickle存储dump时协议要用2,以为pai上的Python是2.7

2.模型的存储

import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("checkpointDir", "model/test.ckpt", "path to logs directory")
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print (sess.run(w4,feed_dict))
saver.save(sess,FLAGS.checkpointDir)注意点1:要定义checkpointDir
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐