tensorflow estimator 与 model_fn 是这样沟通的

在自定义估计器过程中,搞清Estimator 与model_fn 及其他参数之间的关系十分中重要!总结一下,就是
estimator 拿着获取到的参数往model_fn里面灌,model_fn 是作为用数据的关键用户。

Depending on the value of mode, different arguments are required. Namely

* For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
* For `mode == ModeKeys.EVAL`: required field is `loss`.
* For `mode == ModeKeys.PREDICT`: required fields are `predictions`.

class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.

The Estimator object wraps a model which is specified by a model_fn,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.

All outputs (checkpoints, event files, etc.) are written to model_dir, or a
subdirectory thereof. If model_dir is not set, a temporary directory is

The config argument can be passed tf.estimator.RunConfig object containing
information about the execution environment. It is passed on to the
model_fn, if the model_fn has a parameter named "config" (and input
functions in the same manner). If the config parameter is not passed, it is
instantiated by the Estimator. Not passing config means that defaults useful
for local execution are used. Estimator makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
regarding checkpointing.

The params argument contains hyperparameters. It is passed to the
model_fn, if the model_fn has a parameter named "params", and to the input
functions in the same manner. Estimator only passes params along, it does
not inspect it. The structure of params is therefore entirely up to the

None of Estimator's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use model_fn to configure
the base class, and may add methods implementing specialized functionality.

Calling methods of Estimator will work while eager execution is enabled.
However, the model_fn and input_fn is not executed eagerly, Estimator
will switch to graph model before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
that input_fn code using tf.data generally works in both graph and eager

def init(self, model_fn, model_dir=None, config=None, params=None,
"""Constructs an Estimator instance.

See [estimators](https://tensorflow.org/guide/estimators) for more

To warm-start an `Estimator`:

estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],

For more details on warm-start configuration, see

  model_fn: Model function. Follows the signature:

    * Args:

      * `features`: This is the first item returned from the `input_fn`
             passed to `train`, `evaluate`, and `predict`. This should be a
             single `tf.Tensor` or `dict` of same.
      * `labels`: This is the second item returned from the `input_fn`
             passed to `train`, `evaluate`, and `predict`. This should be a
             single `tf.Tensor` or `dict` of same (for multi-head models).
             If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
             be passed. If the `model_fn`'s signature does not accept
             `mode`, the `model_fn` must still be able to handle
      * `mode`: Optional. Specifies if this training, evaluation or
             prediction. See `tf.estimator.ModeKeys`.
      * `params`: Optional `dict` of hyperparameters.  Will receive what
             is passed to Estimator in `params` parameter. This allows
             to configure Estimators from hyper parameter tuning.
      * `config`: Optional `estimator.RunConfig` object. Will receive what
             is passed to Estimator as its `config` parameter, or a default
             value. Allows setting up things in your `model_fn` based on
             configuration such as `num_ps_replicas`, or `model_dir`.

    * Returns:

  model_dir: Directory to save model parameters, graph and etc. This can
    also be used to load checkpoints from the directory into an estimator to
    continue training a previously saved model. If `PathLike` object, the
    path will be resolved. If `None`, the model_dir in `config` will be used
    if set. If both are set, they must be same. If both are `None`, a
    temporary directory will be used.
  config: `estimator.RunConfig` configuration object.
  params: `dict` of hyper parameters that will be passed into `model_fn`.
          Keys are names of parameters, values are basic python types.
  warm_start_from: Optional string filepath to a checkpoint or SavedModel to
                   warm-start from, or a `tf.estimator.WarmStartSettings`
                   object to fully configure warm-starting.  If the string
                   filepath is provided instead of a
                   `tf.estimator.WarmStartSettings`, then all variables are
                   warm-started, and it is assumed that vocabularies
                   and `tf.Tensor` names are unchanged.

  ValueError: parameters of `model_fn` don't match `params`.
  ValueError: if this is called via a subclass and if that class overrides
    a member of `Estimator`.
