博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Estimator guide
阅读量:6565 次
发布时间:2019-06-24

本文共 2844 字,大约阅读时间需要 9 分钟。

1,introduction

Estimator 会封装下列操作:

  • 训练
  • 评估
  • 预测
  • 导出以供使用

预创建的 Estimator,也可以编写自定义 Estimator。所有 Estimator(无论是预创建的还是自定义)都是基于 类的类

2,Estimator 的优势

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
  • Estimator 简化了在模型开发者之间共享实现的过程。
  • 您可以使用高级直观代码开发先进的模型。简言之,采用 Estimator 创建模型通常比采用低阶 TensorFlow API 更简单。
  • Estimator 本身在 之上构建而成,可以简化自定义过程。
  • Estimator 会为您构建图。
  • Estimator 提供安全的分布式训练循环,可以控制如何以及何时:
    • 构建图
    • 初始化变量
    • 开始排队
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 的摘要

3,预创建的 Estimator

预创建的 Estimator 会为您创建和管理 和 对象

预创建的 Estimator 程序的结构

编写一个或多个数据集导入函数

def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label

定义特征列

以下代码段创建了三个存储整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个 lambda,该程序将调用此 lambda 来调节原始数据 # Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column('median_education', normalizer_fn=lambda x: x - global_education_mean)

实例化相关的预创建的 Estimator

# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education], )

调用训练、评估或推理方法

# my_training_set is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)

4,自定义 Estimator

每个 Estimator(无论是预创建还是自定义)的核心都是其模型函数,这是一种为训练、评估和预测构建图的方法。如果您使用预创建的 Estimator,则有人已经实现了模型函数。如果您使用自定义 Estimator,则必须自行编写模型函数

 

5,从 Keras 模型创建 Estimator

您可以将现有的 Keras 模型转换为 Estimator。这样做之后,Keras 模型就可以利用 Estimator 的优势,例如分布式训练。调用

 

# Instantiate a Keras inception v3 model. keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None) # Compile model with the optimizer, loss, and metrics you'd like to train with. keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metric='accuracy') # Create an Estimator from the compiled Keras model. Note the initial model # state of the keras model is preserved in the created Estimator. est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3) # Treat the derived Estimator as you would with any other Estimator. # First, recover the input name(s) of Keras model, so we can use them as the # feature column name(s) of the Estimator input function: keras_inception_v3.input_names # print out: ['input_1'] # Once we have the input name(s), we can create the input function, for example, # for input(s) in the format of numpy ndarray: train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={ "input_1": train_data}, y=train_labels, num_epochs=1, shuffle=False) # To train, we call Estimator's train function: est_inception_v3.train(input_fn=train_input_fn, steps=2000)

 

转载于:https://www.cnblogs.com/augustone/p/10520168.html

你可能感兴趣的文章
BZOJ3799 : 字符串重组
查看>>
用纯JS做俄罗斯方块 - 简要思路介绍(1)
查看>>
blog摘录--测试感触
查看>>
数据持久化的复习
查看>>
【DeepLearning】Exercise:Sparse Autoencoder
查看>>
Util应用程序框架公共操作类(八):Lambda表达式公共操作类(二)
查看>>
android 设置布局横屏竖屏
查看>>
Java从零开始学六(运算符)
查看>>
thinkphp学习笔记10—看不懂的路由规则
查看>>
Eclipse中SVN的安装步骤(两种)和使用方法[转载]
查看>>
JavaScript学习
查看>>
Codeforces Round #295 (Div. 2)B - Two Buttons BFS
查看>>
使用SQLServer 2008的CDC功能实现数据变更捕获
查看>>
iPad 3g版完美实现打电话功能(phoneitipad破解)
查看>>
VBoxGuestAdditions.iso下载地址
查看>>
EXPORT_SYMBOL的作用是什么
查看>>
BZOJ 1022 [SHOI2008]小约翰的游戏John AntiNim游戏
查看>>
PPTPD服务端搭建
查看>>
SyncTrayzor -- Windows tray utility / filesystem watcher / launcher for syncthing
查看>>
SANS top 20
查看>>