Treevalue(0x02)——函数树化详细解析(上篇)

本文将对 func_treelize 这一treevalue库中的核心功能进行详细的原理解析。

关于treevalue的概述,可以参考之前的文章:Treevalue(0x01)——功能概述

树化函数基本原理

在treevalue库中, func_treelize 是核心特性之一,可以将普通的函数快速作用于树对象上。而这一“作用”的原理是什么呢,我们来一起看看——首先准备一个普通的函数,并加上 func_treelize 装饰器,就像这样

from treevalue import func_treelize


@func_treelize()
def gcd(a, b):  # GCD calculation
    print('gcd', a, b)
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

函数的部分是一个最大公因数的计算,并且和之前文章(Treevalue(0x01)——功能概述)中的区别在于,添加了一行 print 输出,用于体现函数内部在整个计算过程中是如何被调用的。基于这一函数,我们进行如下的调用,可以得到对应的输出结果

from treevalue import FastTreeValue

gcd(9, 12)
# gcd 9 12
# 3

t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})
gcd(t1, t2)
# gcd 30 48
# gcd 9 54
# gcd 4 6
# gcd 2 4
# <TreeValue 0x7f12950e3be0>
# ├── a --> 2
# ├── b --> 6
# └── x --> <TreeValue 0x7f1296732310>
#     ├── c --> 2
#     └── d --> 9

根据输出语句,不难发现——经过func_treelize装饰后的函数,在被传入TreeValue类型的时候,会自动基于其结构将内部的数值一一对应传入原函数,并在执行计算后组装成与原来相同的树结构
基于以上基本特性,func_treelize这一过程也被称为函数的树化,经过树化后的函数将满足以下基本特性:

  1. 当所有传入参数均为非树对象时,函数行为与返回值与原函数保持严格一致,即树化后的函数依然可以像原函数一样地使用
  2. 树化的函数本身不会对传入的树对象内部结构有显式的限制,在函数的树化逻辑中将基于传入树参数的结构生成最终的返回值结构。
  3. 函数的树化逻辑部分不会对树对象内部的值进行任何的判定与检测,只是作为一个中继器将对应的值传入原函数并获取运算结果

树化函数运行机制

通过开头章节的简单例子展示,相信各位已经对函数的树化有了基本的概念和了解。在本章中,将对函数的树化过程进行更加详细的机制分析。

机制概述

在开头章节的例子中,展现的只是两种最为理想化的情况:

  1. 传入的参数均为非树对象
  2. 传入的参数均为结构完全一致的树对象

然而实际上,基于对“树”这一数据结构的基本了解,不难发现实际上需要作出处理的情况依然有很多,包括但不限于:

  • 键值缺少——参与计算的某个树对象在对应的位置上缺少了对应的键值,这样的情况如何处理?例如下图中, t2.x.d 缺失,这样的情况该如何处理?

  • 键值类型不匹配——参与计算的某几个树对象对应位置上,有些是叶子节点值,有些是非叶子节点子树,形成“值-子树”之间的直接运算,这样的情况如何定义?例如下图中, t1.b 为子树但是 t2.b 为值,这样的情况如何定义?

  • 计算模式多样性——当参与计算的树对象之间的结构存在较多较大差异性时,如何设计计算策略使之能支持更多样化的计算?例如下列的场景,如何组织对如此结构各异的树之间的运算?

  • 数据格式多样性——当参与计算的叶子节点值格式存在不统一时,如何处理?例如下面的场景,如何对 t1t2 下显然不同尺寸的 torch.Tensor 进行处理?

因此,基于这些很现实的问题,我们为树化函数定义了如下的选项:

  • 模式选项(mode)——决定树化函数的整体运行机制。
  • 继承选项(inherit)——对键值类型不匹配的情况进行了定义,并提供了处理机制。
  • 缺省选项(missing)——为键值缺少的情况提供了缺省值补全机制。

模式选项(mode)

模式选项是树化函数中最为重要的选项,其将直接决定树化函数的主体计算逻辑。目前定义了四种常用模式:

  • 严格模式(STRICT)
  • 内共同模式(INNER)
  • 外共有模式(OUTER)
  • 左优先模式(LEFT)

接下来的子章节中会结合例子进行逐一介绍。

严格模式(STRICT)

严格模式是最常用的模式选项,意味着当且仅当所有树参数在当前子树位置上的键一一对应时,会将其键值进行一一对应地代入计算,否则抛出异常。代码实现如下,与开头的例子等价,模式选项的默认值即为严格模式

from treevalue import func_treelize


@func_treelize(mode='strict')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

在上述的树化gcd函数中,完整的计算机制如下图1所示, tr 为树化gcd的运算结果


(图1,t1、t2内的键值可以形成一一对应)

但是当出现如下所示的参数时,则应抛出异常,因为部分键存在缺失,无法形成一一对应。


(图2,t1.b与t1.x.c缺失,无法形成一一对应)

严格模式是一种最为常见的计算逻辑,适用于大部分常见情况,也是在业务逻辑上最为顺理成章的一种模式。但是对非规则结构下的计算则不能兼容,因此另外三种模式选项分别针对不同的情况来支持非规则结构下的计算。

内共同模式(INNER)

内共同模式下,仅会对全部树参数当前子树位置上均存在此键时,才会对将其键值进行一一对应地代入计算,而当此键值在某一树参数当前子树位置上存在缺失情况是,则会直接忽略该组键值。代码实现如下,将 mode 设置为 inner 即可

from treevalue import func_treelize


@func_treelize(mode='inner')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如对图2所示的例子,在内共同模式下可以正常计算,如图3所示


(图3,t1.x.c和t2.b因为t2.x.c和t1.b的缺失而被忽略)

内共同模式会忽略无法形成对应的多余值,可以确保在几乎所有情况下均能得出计算结果而不会产生错误。但是会不可避免地造成部分信息丢失,而在一部分情况下这是不可接受的,因此请根据实际需求进行选择。

外共有模式(OUTER)

外共有模式下,只要在任意一个树参数的当前子树位置上存在此键值,则会将其进行代入计算。而对于缺失的值,则会使用缺省选项中设置的值或生成器进行获取并代入。代码实现如下,将 mode 设置为 outer 即可,并将缺省选项设置为值 1

from treevalue import func_treelize


@func_treelize(mode='outer', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如对图2所示的例子,在外共有模式下可以正常计算,如图4所示


(图4,t1.b和t1.x.c缺失,将使用缺省选项指定的默认值1)

外共有模式将会让所有的数值参与运算,但是在绝大部分情况下均依赖缺省选项的设置,因此在使用前请确保缺省选项的正确配置,以及业务逻辑上的自洽。

左优先模式(LEFT)

左优先模式下,参与运算的键值将以全部树参数中最左的一项为参考。其中最左的一项定义为,在python函数调用的位置参数(postional argument)中,如果存在树参数,则取最左的一项;如果不存在,则在函数调用的键值参数(key-word argument)红,取字典序最小的一项。代码实现如下,将 mode 设置为 left 即可,并将缺省选项设置为值 1

from treevalue import func_treelize


@func_treelize(mode='left', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如对于图2所示的 gcd(t1, t2) 例子中,在左优先模式下计算结果如下,如图5所示


(图5,t2.b因t1.b的缺失而被忽略,而t2.x.c取缺省值1)

而在 gcd(t2, t1) 例子中,左优先计算结果如下,如图6所示


(图6,t1.x.c因t2.x.c的缺失而被忽略,而t1.b取缺省值1)

左优先模式会按照最左树参数的结构来进行计算,生成的计算结果也将和最左的参数保持一致。但是与外共有模式类似,左优先模式在绝大部分情况下依赖缺省选项的配置,需要确保配置准确无误且自洽。此外,对于原本满足交换律的运算,经过左优先模式的树化后将会失去原有的交换律性质,这一点请务必留意。

继承选项(inherit)

继承选项可以通过普通值的继承机制,让树化函数在实际应用中使用起来更加简洁,也让树参数可以和普通参数在树化后的函数中被混用。在默认情况下,继承选项是处于开启状态的,即等价于如下的代码

from treevalue import func_treelize


@func_treelize(inherit=True)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

因此,有如下的例子 gcd(t1, t2) ,其计算结果如图7所示


(图7,t2.x.c和t2.x.d继承t2.x的值6)

此外显而易见的是,也可以直接将非树值直接传入,和树参数混用,例如下面的例子 gcd(100, t1) ,其计算结果如图x所示


(图8,值100被完全继承并作为第一棵树的全部值)

而当继承选项被关闭时,则上述两个例子均会抛出异常,因为存在值和子树混用的情况。

从业务逻辑的角度来看,继承选项可以良好地适应大部分真实存在的值复用情况,且值和子树混用在大多数业务逻辑上也是有明确意义的。但是当混用在业务逻辑角度上意义不明且需要被显式地检测时,则建议关闭继承选项

缺省选项(missing)

缺省选项可以为部分键值存在缺失的情况提供一个值的补充,主要作用于外共有模式和左优先模式。我们可以通过 missing 参数直接提供值,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=0)
def total(*args):
    return sum(args)

上述的加法函数计算例子如下, total(t1, t2, t3) 计算结果如下图9所示


(图9,缺省值0被全面用于填补空缺,并最终计算出了有效的总和)

此外考虑到有些情况下,直接使用值作为缺省值可能会存在公用同一个对象导致错误的情况,因此我们提供了通过传入生成函数来产生默认值的用法。可以通过 missing 参数传入值生成器,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=lambda: [])
def append(arr: list, *args):
    for item in args:
        if item:
            arr.append(item)
    return arr

上述的列表追加值计算例子如下, append(t0, t1, t2, t3) 运算结果如下图10所示


(图10,每次缺省均会生成新的空列表)

通过缺省选项的有效配置,结合外共有模式和左优先模式,可以有效扩展树化函数对值缺省情况的处理能力。不过值得注意的是,缺省选项在严格模式下无法生效,因为当检测到键缺失时将会直接抛出异常;以及缺省模式在内共同模式下永远无法实质上生效,因此树化函数会针对这一情况抛出一个警告信息。

上升、下沉选项

除了上述的基本机制选项之外,树化函数还提供了上升(rise)和下沉(subside)选项,以简化对结构化数据的处理。两者的功能分别为:

  • 下沉(subside)——尝试将参数中顶层结构非树的对象,提取结构后将结构下沉至树内,使原函数在运行过程中可以接收到。关于下沉函数的具体细节可以参考之前文章
  • 上升(rise)——尝试从返回结果树的叶子节点值中提取共同结构,向上升至树外,使返回值的逻辑结构可以被外部直接访问。关于上升函数的具体细节可以参考之前文章

因此我们可以在需要的时候打开这两个选项,代码如下,实现的效果是从列表 arr 中查找首个满足条件值的位置( position ),并统计共有多少个满足条件的值( cnt

from treevalue import func_treelize, FastTreeValue


@func_treelize(subside=True, rise=True)
def check(arr: list, target):
    position = None
    cnt = 0
    for i, item in enumerate(arr):
        if target(item):
            if position is None:
                position = i
            cnt += 1

    return position, cnt


t1 = FastTreeValue({'a': 2, 'b': 4, 'x': {'c': 7, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 2, 'd': 53}})
t3 = FastTreeValue({'a': 9, 'b': -12, 'x': {'c': 3, 'd': 7}})

tr1, tr2 = check([t1, t2, t3], lambda x: x % 2 == 0)

代码中可以看到三棵树 t1t2t3 可以直接用列表装载,在原函数 check 中可以接收到对应位置上的值列表。并且由于 rise 选项的开启,位置和数量所构成的二元组也会被提取出来,形成两棵树,即 tr1tr2 ,如下图11所示


(图11,[t1, t2, t3]作为列表参数,tr1, tr2作为返回值树)

此外,上升和下沉选项一个更加有效的使用例子是对 torch.splittorch.stack 函数进行装饰,代码如下所示

import torch

from treevalue import func_treelize, TreeValue

stack = func_treelize(subside=True)(torch.stack)
split = func_treelize(rise=True)(torch.split)

trees = [TreeValue({
    'a': torch.randn(2, 4),
    'b': torch.randn(3, 4),
    'x': {'c': torch.randn(2, 1, 3)}
}) for _ in range(10)]

st = stack(trees)  # stack all the trees together
splitted = split(st, [1] * 10)  # split back to trees

# splitted should be equal to trees

其中 st 即为合并后的树,而 splitted 为再次拆分后的树, splittedtrees 等价。

后续预告

本文主要针对treevalue的核心特性——树化函数,基于其自身进行了详细的原理解析,受限于篇幅,本次只着重讲述了原生树化函数本身的原理、特性以及例子。在下一篇中将会针对更多衍生场景进行分析与展示,敬请期待。

同时欢迎了解其他OpenDILab的开源项目:https://github.com/opendilab

posted @ 2021-11-05 21:26  HansBug  阅读(178)  评论(0编辑  收藏  举报