在一个操作中从十位数

2024-04-26 06:35:41 发布

您现在位置:Python中文网/ 问答频道 /正文

在最近的TensorFlow(1.132.0)中,有没有一种方法可以在一个过程中从张量中提取非连续切片?怎么做? 例如,使用以下张量:

1 2 3 4
5 6 7 8 

我想在一个操作中提取列1和列3以获得:

2 4
6 8

然而,似乎我不能在一个单一的行动与切片。 正确/最快/最优雅的方法是什么?你知道吗


Tags: 方法过程tensorflow切片
2条回答

1.使用tf.gather(tensor, columns, axis=1)TF1.xTF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]

print(tf.gather(tensor, columns, axis=1).numpy())
%timeit -n 10000 tf.gather(tensor, columns, axis=1)
# [[2. 4.]
#  [6. 8.]]
82.6 µs ± 5.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

2.带索引(TF1.xTF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # < columns you want to extract

transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <  TF2, TF1.x-eager

with tf.Session() as sess:  # <  TF1.x
    print(sess.run(stacked))
# [[2. 4.]
#  [6. 8.]]

将其包装到函数并在tf.__version__=='2.0.0-alpha0'中运行%timeit

154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

@tf.function装饰它要快2倍多:

import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # < columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
    transposed = tf.transpose(tensor)
    sliced = [transposed[c] for c in columns]
    stacked = tf.transpose(tf.stack(sliced, axis=0))
    return stacked

%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

3.一行代码用于急切执行TF2TF1.x-eager):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # < columns you want to extract

res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
                             if i in columns], 0))
print(res.numpy())
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0'中:

242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

4.使用tf.one_hot()指定行/列,然后tf.boolean_mask()提取这些行/列(TF1.xTF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # < columns you want to extract

mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <  TF2, TF1.x-eager

with tf.Session() as sess: # TF1.x
    print(sess.run(res))
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0'中:

494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

您可以使用整形和切片的组合获得所有奇数列:

N = 4
M = 10
input = tf.constant(np.random.rand(M, N))
slice_odd = tf.reshape(tf.reshape(input, (-1, 2))[:,1], (-1, int(N/2)))

相关问题 更多 >