把tensorflow放回pytorch,放回tensorflow(可微的tensorflow pytorch适配器)。
tfpyth的Python项目详细描述
tfpyth
把tensorflow放回pytorch,放回tensorflow(可微的tensorflow pytorch适配器)。
A light-weight differentiable adapter library to make TensorFlow and PyTorch interact.
安装
pip install tfpyth
示例
importtensorflowastfimporttorchasthimportnumpyasnpimporttfpythsession=tf.Session()defget_torch_function():a=tf.placeholder(tf.float32,name='a')b=tf.placeholder(tf.float32,name='b')c=3*a+4*b*bf=tfpyth.torch_from_tensorflow(session,[a,b],c).applyreturnff=get_torch_function()a=th.tensor(1,dtype=th.float32,requires_grad=True)b=th.tensor(3,dtype=th.float32,requires_grad=True)x=f(a,b)assertx==39.x.backward()assertnp.allclose((a.grad,b.grad),(3.,24.))
它有什么
torch_from_tensorflow
通过计算给定输入占位符的TensorFlow输出张量,创建可微的Pythorch函数。
eager_tensorflow_from_torch
从pytorch函数创建一个eager tensorflow函数。
tensorflow_from_torch
从pytorch函数创建tensorflow op/tensor。
未来工作
- []支持jax
- []支持高阶导数