get_layer_and_variable_from tf.keras.Model

def get_layers_and_variables_from_model(model: tf.keras.Model, scope_name=None):
    layer_dict = {}
    if scope_name is not None:
        base_name = scope_name
    else:
        base_name = model.name

    # get Layers
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            sub_model_layer_dict = get_layers_and_variables_from_model(
                layer, "{}/{}".format(base_name, layer.name)
            )
            layer_dict.update(sub_model_layer_dict)
        elif isinstance(layer, tf.keras.layers.Layer):
            layer_dict["{}/{}".format(base_name, layer.name)] = layer

    # get Variables
    for attr_name, attr_value in model.__dict__.items():
        # NOTE: _train_counter, _test_counter, _predict_counter are
        # built-in variables of tf.keras.Model
        if attr_name not in [
            "_train_counter",
            "_test_counter",
            "_predict_counter",
        ] and isinstance(attr_value, tf.Variable):
            layer_dict["{}/{}".format(base_name, attr_value.name)] = attr_value
    return layer_dict

  

递归:因为keras.Model中可能包含另一个keras.Model

版本:tf2.5.1

-----2022.10.08补充------

这个方法有一个问题:如果定义keras.Model的时候,名字是相同的(或者不命名,默认名字则全部是Variable(名字不会自动unique,有可能是tf2中的eager模式认为tensor name不重要了)),例如:

self.v = tf.Variable(0.6, trainable=False, name="aaa")
self.v1 = tf.Variable(0.6, trainable=False, name="aaa")
 self.v2 = tf.Variable(0.6, trainable=True, name="aaa")

这种情况使用上面的函数将无法区分不同的"aaa",所以使用的时候要求给Variable手动设置unique的name.

------2022.10.31补充-------

如果把Variable存到了某个dict(成员变量),上面的代码获取Variable会有问题,attr_value会是一个dict而不是Variable. 需要进一步从dict中抽出Variable

posted @ 2022-09-19 10:52  灰太狼锅锅  阅读(30)  评论(0编辑  收藏  举报