tf2自定义优化器

# -*- coding: utf-8 -*-

from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops


class Adammom(optimizer_v2.OptimizerV2):
    """Adammom Optimizer

    w: trainable weights
    d2sum = 0.0
    ada_decay_rate = 0.9999
    ada_epsilon = 1e-8
    learning_rate = 0.0001
    mom_decay_rate = 0.99

    d2sum = d2sum * ada_decay_rate + 1
    for i in range(len(w)):
        g2sum = g2sum[i] * ada_decay_rate + grad[i] * grad[i]
        scale = sqrt((1.0 + ada_epsilon)/(g2sum/d2sum + ada_epsilon))
        velocity[i] = mom_decay_rate * velocity[i] + (1 - mom_decay_rate) * grad[i]
        w[i] = w[i] - learning_rate * velocity[i] * scale

    :args:
    ada_decay_rate: (float) The decay rate to control g2sum's decay. Defaults to be 0.9999.
    ada_epsilon: (float) A super small value to correct the scale. Defaults to be 1e-08.
    learning_rate: (float) The learning rate of AdamMom. Defaults to be 0.0001.
    mom_decay_rate: (float) The decay rate of moment. Defaults to be 0.99.
    """

    _HAS_AGGREGATE_GRAD = True

    def __init__(
        self,
        learning_rate=0.0001,
        ada_decay_rate=0.9999,
        ada_epsilon=1e-08,
        mom_decay_rate=0.99,
        name="Adammom",
        **kwargs
    ):
        super(Adammom, self).__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper("decay", self._initial_decay)
        self._set_hyper("ada_decay_rate", ada_decay_rate)
        self._set_hyper("mom_decay_rate", mom_decay_rate)
        self.ada_epsilon = ada_epsilon

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        # Separate for-loops to respect the ordering of slot variables from v1.
        for var in var_list:
            self.add_slot(var, "d2sum")
        for var in var_list:
            self.add_slot(var, "g2sum")
        for var in var_list:
            self.add_slot(var, "velocity")

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(Adammom, self)._prepare_local(var_device, var_dtype, apply_state)

        ada_decay_rate_t = array_ops.identity(
            self._get_hyper("ada_decay_rate", var_dtype)
        )
        mom_decay_rate_t = array_ops.identity(
            self._get_hyper("mom_decay_rate", var_dtype)
        )

        apply_state[(var_device, var_dtype)].update(
            dict(
                ada_epsilon=ops.convert_to_tensor_v2_with_dispatch(
                    self.ada_epsilon, var_dtype
                ),
                ada_decay_rate_t=ada_decay_rate_t,
                mom_decay_rate_t=mom_decay_rate_t,
            )
        )

    @def_function.function(jit_compile=True)
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)
        ) or self._fallback_apply_state(var_device, var_dtype)
        # TODO(lebronzheng): The following calculations should be fused into a c++ kernel
        d2sum = self.get_slot(var, "d2sum")
        g2sum = self.get_slot(var, "g2sum")
        ada_decay_rate = coefficients["ada_decay_rate_t"]
        # d2sum = d2sum * ada_decay_rate + 1
        d2sum.assign(d2sum * ada_decay_rate + 1)
        # g2sum = g2sum[i] * ada_decay_rate + grad[i] * grad[i]
        g2sum.assign(g2sum * ada_decay_rate + math_ops.square(grad))
        # scale = sqrt((1.0 + ada_epsilon)/(g2sum/d2sum + ada_epsilon))
        ada_epsilon = coefficients["ada_epsilon"]
        scale = math_ops.sqrt((1 + ada_epsilon) / (g2sum / d2sum + ada_epsilon))
        # velocity = mom_decay_rate * velocity + (1 - mom_decay_rate) * grad
        mom_decay_rate = coefficients["mom_decay_rate_t"]
        velocity = self.get_slot(var, "velocity")
        velocity.assign(mom_decay_rate * velocity + (1 - mom_decay_rate) * grad)
        # w = w - learning_rate * velocity * scale
        var.assign_sub(coefficients["lr_t"] * velocity * scale)

    @def_function.function(jit_compile=True)
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        raise NotImplemented("Not implemented currently")

    def get_config(self):
        config = super(Adammom, self).get_config()
        config.update(
            {
                "learning_rate": self._serialize_hyperparameter("learning_rate"),
                "decay": self._initial_decay,
                "ada_decay_rate": self._serialize_hyperparameter("ada_decay_rate"),
                "mom_decay_rate": self._serialize_hyperparameter("mom_decay_rate"),
                "ada_epsilon": self.ada_epsilon,
            }
        )
        return config

  

1. _resource_apply_sparse主要为稀疏场景设计,例如实现LazyAdam,可以选取指定的行更新,其他行的不更新.
2.self._iterations表示优化器更新的次数,在一些使用的time step的优化器中有用,例如adam中计算β的t次方中的t.
但是这个iterations是优化器级别的,也就是说优化器中的所有variable共用一个iterations.
如果每轮迭代是全部参数都进行更新,那没任何问题,但是如果每轮只更新部分参数,那么其他参数的t等价于也被+1了. 会导致计算的公式不是adam原始公式中的结果.
当然这个未必一定会影响效果. 需要实验测试. 如果需要实现一个参数级别的iteration,只需要把iteration这个variable在_create_slot中创建,然后每次apply的时候自动加1.
3._create_slot相当于定义训练参数之外的优化器参数,例如:momentum,energy等
posted @ 2022-09-19 11:45  灰太狼锅锅  阅读(165)  评论(0编辑  收藏  举报