生成TensorFlow自定义运算符样板
tfopgen的Python项目详细描述
编写tensorflow运算符需要编写大量 样板C++和CUDA代码。此脚本为CPU生成代码 和GPU版本的TensorFlow运算符。更具体地说,给定 tensorflowinputs、outputs和attributes,它生成:
- 定义操作符类的C++Head文件,在设备上模板化。
- C++Head文件,定义操作员的CPU实现。
- 具有形状函数的C++源文件,RealStudioP和 注册内核生成器构造。
- 定义运算符的GPU实现的CUDA头, 包括一个cuda内核。
- 具有GPU寄存器内核生成器的CUDA源文件 接线员。
- python单元测试用例,它构造随机输入数据,并调用 接线员。
- 用于将运算符编译到共享库中的makefile,使用g++ 和NVCC。
要求
建造操作员所需的Tensorflow装置。
pip install tensorflow
安装
pip install tfopgen
用法
用户应提供定义操作员的yaml配置文件:
- 输入和可选的形状。
- 输出和可选的输出。
- 多态类型属性。
- 其他属性。
- 文件。
例如,我们可以在complex_phase.yml文件中定义ComplexPhase运算符的大纲。
---project:astronomylibrary:fouriername:ComplexPhasetype_attrs:-"FT:{float,double}=DT_FLOAT"-"CT:{complex64,complex128}=DT_COMPLEX64"inputs:-["uvw:FT",[null,null,3]]# (ntime, nbl, 3)-["frequency:FT",[null]]# (nchan, )-["lm:FT",[null,2]]# (nsrc, 2)outputs:-["complex_phase:CT",[null,null,null,null]]doc:>Given tensors(1) of (U, V, W) baseline coordinates with shape (ntime, nbl, 3)(2) of (L, M) sky coordinates with shape (nsrc, 2)(3) of frequencies,compute the complex phase with shape (nsrc, ntime, nbl, nchan)
然后我们可以运行:
$ tfopgen complex_phase.yml
创建以下目录结构和文件:
$ tree fourier/ fourier/ ├── complex_phase_op_cpu.cpp ├── complex_phase_op_cpu.h ├── complex_phase_op_gpu.cu ├── complex_phase_op_gpu.cuh ├── complex_phase_op.h ├── Makefile └── test_complex_phase.py
^ {TT6}$和^ {TT7}$选项在内部指定C++命名空间 创建了运算符。此外,makefile将创建 fourier.so可以用tf.load_op_library('fourier.so')加载的共享库。
应提供任何多态类型属性。发电机将 将类型属性上的运算符模板化。它还将产生 CPU和 使用类型属性(float, 下面是双人间、64间和128间:
type_attrs:-"FT:{float,double}=DT_FLOAT"-"CT:{complex64,complex128}=DT_COMPLEX64"
运算符输入及其可选形状应指定为 包含定义.Input指令的字符串的列表,以及 描述输入张量的形状。形状中的null值 将被转换为pythonNone。如果规定了具体尺寸, 将在与 接线员。
inputs:-["uvw:FT",[null,null,3]]# (ntime, nbl, 3)-["frequency:FT",[null]]# (nchan, )-["lm:FT",[null,2]]# (nsrc, 2)
运算符输出也应类似地定义。
outputs:-["complex_phase:CT",[null,null,null,null]]
考虑到这些输入和输出,cpu和gpu操作符是用 与输入和输出相对应的命名变量。此外,a 创建具有给定输入和输出的CUDA内核,以及 形状函数检查所提供输入的秩和维数。
可以指定其他属性(并将在 register-op)指令,但不能由 生成器代码作为属性行为的范围是复杂的。
op_other_attrs:-"iterations:int32>=2",
最后,还可提供操作员文件。
doc:>Given tensors(1) of (U, V, W) baseline coordinates with shape (ntime, nbl, 3)(2) of (L, M) sky coordinates with shape (nsrc, 2)(3) of frequencies,compute the complex phase with shape (nsrc, ntime, nbl, nchan)