pytorch中的dispatcher

前置知识: virtual table in C++

对于每个opr,dispatcher构建了一个vtable(c++多态性相关概念)。dispatcher的工作就是根据输入的tensor和其他一些meta信息,计算dispatch key,然后根据vtable跳转到相应的函数
c++ virtual table : https://www.cnblogs.com/ijpq/p/16291824.html

与c++ vtable区别:

c++ vatable pytorch vtable 解释
每个类有一个vtable 每个opr有一个vtable 在pytorch中,扩展一个已有的opr,只需要提供一个新的vtable。
只有*this指针重要 不仅考虑tensor,还有其他的meta信息
支持boxing和unboxing

  • 每个tensor可以决定一套dispatch key
  • local include set, 与tensor无关的一些key, 例如tracing(目前还不知道tracing干啥的)
  • global set,是一些必须要设定的key。之前autograd是位于这个global set的,但是现在放进tensor了
  • local exclude set, 放一些已经dispatch过的key,这样就不会再set中出现,避免重复dispatch

一个autograd的dispatch过程(早期位于global set时的过程)
起始状态时,autograd位于global,exclude是空的。
执行dispatch过程,找到优先级最高的autograd key,调用autograd handler。在autograd过程中,创建了RAII AutoNonVariableTypeMode,它的作用是把autograd放入exclude中。
继续dispatch,跳过autograd,找到cpu key。
local TLS仍然处于call tree中,使得后续的dispatch操作会跳过autograd。
最后,从函数返回,RAII将autograd从exclude中删除.

一个tracing的过程

一个backendselect过程

函数指针是如何进入virtual table的?

通过registeration api实现的: https://pytorch.org/tutorials/advanced/dispatcher.html

与registeration api交互的三个方式:定义schema(m.def),给一个dispatch key实现register(m.impl), fallback(m.fallback)(这三个表述比较模糊,见下面的图)

给一个dispatch key实现register如下,图示给CPU这个key实现register,即将cpu_mul这个kernel注册到cpu这个key上

比如还可以给所有的key都注册同样的kernel。不过这种方式应该是不建议的

或者是给所有的opr都注册同样的fallback

这三种操作存在优先级顺序,如下

boxing/unboxing

boxed数据表示,是指各种类型的数据有同样的layout。这样的话可以给不同type的数据写相同的调用方法
unboxed数据表示,是指各种类型的数据有适合自己的layout,在cpp中,由于是unboxed数据表示,所以可以借助模板来写相同的调用代码。

在torch中实现了boxed数据表示,//TODO

posted @ 2022-05-20 16:18  ijpq  阅读(147)  评论(0编辑  收藏  举报