线性代数后端的通用接口

backends的Python项目详细描述


LAB

BuildCoverage StatusLatest Docs

线性代数后端的通用接口:编写一次代码,在任何 后端

注意:实验室需要TensorFlow 2。

安装

在安装包之前,请确保gccgfortran 可用。 在os x上,这两个都是用brew install gcc安装的; 水蟒的使用者可能会考虑conda install gcc。 然后简单地

pip install backends

基本用法

包的基本用例是编写自动 根据其参数的类型确定要使用的后端。

示例:

importlabasBimportlab.torch# Load the PyTorch extension.importlab.tensorflow# Load the TensorFlow extension.defobjective(matrix):outer_product=B.matmul(matrix,matrix,tr_b=True)returnB.mean(outer_product)

默认情况下,不会加载pytorch和tensorflow扩展来保存 启动时间。或者,可以直接import lab.torch as Bimport lab.tensorflow as B

用numpy和autograd运行它:

>>>importautograd.numpyasnp>>>objective(B.randn(np.float64,2,2))0.15772589216756833

用TensorFlow运行它:

>>>importtensorflowastf>>>objective(B.randn(tf.float64,2,2))<tf.Tensor'Mean:0'shape=()dtype=float64>

用pytorch运行它:

>>>importtorch>>>objective(B.randn(torch.float64,2,2))tensor(1.9557,dtype=torch.float64)

类型列表

本节列出了所有可用类型,可用于检查 对象或扩展函数。

示例:

>>>importlabasB>>>fromplumimportList,Tuple>>>importnumpyasnp>>>isinstance([1.,np.array([1.,2.])],List(B.NPNumeric))True>>>isinstance([1.,np.array([1.,2.])],List(B.TFNumeric))False>>>importtensorflowastf>>>importlab.tensorflow>>>isinstance((tf.constant(1.),tf.ones(5)),Tuple(B.TFNumeric))True

一般

Int          # Integers
Float        # Floating-point numbers
Bool         # Booleans
Number       # Numbers
Numeric      # Numerical objects, including booleans
DType        # Data type
Framework    # Anything accepted by supported frameworks

纽比

NPNumeric
NPDType
 
NP           # Anything NumPy

张量流

TFNumeric
TFDType
 
TF           # Anything TensorFlow

火把

TorchNumeric
TorchDType
 
Torch        # Anything PyTorch

方法列表

本节列出了所有可用的常量和方法。

  • 参数必须作为参数给出,关键字参数必须作为参数 作为关键字参数给定。 例如,sum(tensor, axis=1)是有效的,但是sum(tensor, 1)不是。

  • 参数的名称表示其函数:

    • abc表示一般张量。
    • dtype表示数据类型。例如,np.float32tf.float64;和 rand(np.float32)创建一个NumPy随机数,而 rand(tf.float64)创建tensorflow随机数。 数据类型总是作为第一个参数给出。
    • shape表示形状。 形状的尺寸总是作为 功能。 例如,reshape(tensor, 2, 2)是有效的,但是reshape(tensor, (2, 2)) 不是。
    • {
    • ref表示一个引用张量 将使用形状和数据类型。例如,zeros(tensor)创建 与tensor形状和数据类型相同的满零张量。

有关每个功能的详细说明,请参阅文档。

特殊变量

default_dtype  # Default data type.
epsilon        # Magnitude of diagonal to regularise matrices with.

常数

nan
pi
log_2_pi

isnan(a)

一般

zeros(dtype, *shape)
zeros(*shape)
zeros(ref)

ones(dtype, *shape)
ones(*shape)
ones(ref)

eye(dtype, *shape)
eye(*shape)
eye(ref)

linspace(dtype, a, b, num)
linspace(a, b, num)

range(dtype, start, stop, step)
range(dtype, stop)
range(dtype, start, stop)
range(start, stop, step)
range(start, stop)
range(stop)

cast(dtype, a)

identity(a)
abs(a)
sign(a)
sqrt(a)
exp(a)
log(a)
sin(a)
cos(a)
tan(a)
sigmoid(a)
softplus(a)
relu(a)

add(a, b)
subtract(a, b)
multiply(a, b)
divide(a, b)
power(a, b)
minimum(a, b)
maximum(a, b)
leaky_relu(a, alpha)

min(a, axis=None)
max(a, axis=None)
sum(a, axis=None)
mean(a, axis=None)
std(a, axis=None)
logsumexp(a, axis=None)

all(a, axis=None)
any(a, axis=None)

lt(a, b)
le(a, b)
gt(a, b)
ge(a, b)

bvn_cdf(a, b, c)

scan(f, xs, *init_state)

sort(a, axis=-1, descending=False)
argsort(a, axis=-1, descending=False)

线性代数

transpose(a, perm=None) (alias: t, T)
matmul(a, b, tr_a=False, tr_b=False) (alias: mm, dot)
trace(a, axis1=0, axis2=1)
kron(a, b)
svd(a, compute_uv=True)
solve(a, b)
inv(a)
det(a) 
logdet(a) 
cholesky(a) (alias: chol)

cholesky_solve(a, b)  (alias: cholsolve)
triangular_solve(a, b, lower_a=True) (alias: trisolve)
toeplitz_solve(a, b, c) (alias: toepsolve)
toeplitz_solve(a, c)

outer(a, b)
reg(a, diag=None, clip=True)

pw_dists2(a, b)
pw_dists2(a)
pw_dists(a, b)
pw_dists(a)

ew_dists2(a, b)
ew_dists2(a)
ew_dists(a, b)
ew_dists(a)

pw_sums2(a, b)
pw_sums2(a)
pw_sums(a, b)
pw_sums(a)

ew_sums2(a, b)
ew_sums2(a)
ew_sums(a, b)
ew_sums(a)

随机

set_random_seed(seed) 

rand(dtype, *shape)
rand(*shape)
rand(ref)

randn(dtype, *shape)
randn(*shape)
randn(ref)

choice(a, n)
choice(a)

成形

shape(a)
rank(a)
length(a) (alias: size)
isscalar(a)
expand_dims(a, axis=0)
squeeze(a)
uprank(a)

diag(a)
flatten(a)
vec_to_tril(a)
tril_to_vec(a)
stack(*elements, axis=0)
unstack(a, axis=0)
reshape(a, *shape)
concat(*elements, axis=0)
concat2d(*rows)
tile(a, *repeats)
take(a, indices, axis=0)

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java Clojure关键字在内存中的大小是多少?   Java中有固定长度的通用数组对象吗?   PostgreSQL:通过Java更新我的用户表   错误:使用java解析xml   java Json显示列表中对象的名称   java比较JodaTime时区   与JAVA中的API和包的区别?   java的int值在for循环中不改变   谷歌应用引擎中的java RSA   迁移到spring 5后出现java非法字符错误   java Websphere管理控制台不工作   JavaGSON如何始终在json中包含毫秒?   带有空格和双引号的windows Java ProcessBuilder命令参数失败   java错误:重复的zip条目[43.jar:org/apache/http/annotation/NotThreadSafe.class]