使用NumPy高效返回分数分量插入点索引

2024-04-24 04:29:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我希望有效地计算指数,其中元素应插入数组中以维持秩序,但包括表示数组中两个最近点之间“距离”的分数分量。你知道吗

应该可以使用索引和分数返回原始值。在实践中,以及性能之所以重要的原因,我将需要对大量数据点执行此操作。你知道吗

为了证明我的意思,我已经通过np.searchsorted和一些if语句提出了一些工作逻辑,但是还不能用NumPy对逻辑进行矢量化。我也很高兴看到一个有效的解决方案,利用numba和有相当或更好的性能与NumPy。甚至是一个现成的解决方案在NumPy,Scipy等,我不知道。你知道吗

我还包括一些基准代码如下。你知道吗

import numpy as np

np.random.seed(0)

datapoint = np.random.random() * np.random.choice([1, -1]) * 500  # -274.4067
line = np.linspace(-500, 500, 101)  # [-500, -490, ... , 0, ..., 490, 500] - an ordered array, may not be linspace

def get_position(line, point):
    position = np.searchsorted(line, point, side='right')
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main, fraction

idx, frac = get_position(line, datapoint)              # (22, 0.55932480363376269)
print(line[idx] + frac * (line[idx + 1] - line[idx]))  # -274.4067; test to see if you get back original value

def run_multiple(line, data):
    out = np.empty((data.shape[0], 3))
    for i in range(data.shape[0]):
        idx, frac = get_position(line, data[i])
        out[i, 0] = data[i]
        out[i, 1] = idx
        out[i, 2] = frac
    return out

基准测试

# Python 3.6.0, NumPy 1.11.3, Numba 0.30.1
# Note: Numba 0.30.1 does not support "side" argument of np.searchsorted; not able to upgrade

n = 10**5  # Actual n will be larger
res = run_multiple(line, np.random.random(n) * np.random.choice([1, -1], n) * 500)  # 901 ms per loop

# array([[ -4.22132874e+02,   7.00000000e+00,   7.86712571e-01],
#        [ -4.28972809e+02,   7.00000000e+00,   1.02719119e-01],
#        [  4.23625869e+02,   9.20000000e+01,   3.62586939e-01],
#        ..., 
#        [ -1.88627877e+02,   3.10000000e+01,   1.37212282e-01],
#        [  4.98162640e+01,   5.40000000e+01,   9.81626397e-01],
#        [  1.35777097e+02,   6.30000000e+01,   5.77709684e-01]])

Tags: numpydatagetifmainnplinenot
2条回答

如果Numba(或您正在使用的版本)不支持某个函数,那么查看Numba source code并查看已有的函数总是一个好主意。 通常,至少一部分问题已经实现了。你知道吗

代码

import numpy as np
import numba as nb

#almost copied from Numba source
#https://github.com/numba/numba/blob/master/numba/targets/arraymath.py
"""Copyright (c) 2012, Anaconda, Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.

Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@nb.njit()
def searchsorted_right(a, v):
    n = len(a)
    if np.isnan(v):
        # Find the first nan (i.e. the last from the end of a,
        # since there shouldn't be many of them in practice)
        for i in range(n, 0, -1):
            if not np.isnan(a[i - 1]):
                return i
        return 0
    lo = 0
    hi = n
    while hi > lo:
        mid = (lo + hi) >> 1
        if a[mid]<= v:
            # mid is too low => go up
            lo = mid + 1
        else:
            # mid is too high, or is a NaN => go down
            hi = mid
    return lo

@nb.njit()
def get_position(line, point):
    position = searchsorted_right(line, point)
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main, fraction

@nb.njit(parallel=True)
def run_multiple(line, data):
    out = np.empty((data.shape[0], 3))
    for i in nb.prange(data.shape[0]):
        idx, frac = get_position(line, data[i])
        out[i, 0] = data[i]
        out[i, 1] = idx
        out[i, 2] = frac
    return out

计时

n = 10**5
line = np.linspace(-500, 500, 101)
points = np.random.random(n) * np.random.choice([1, -1], n) * 500

%timeit run_multiple(line, points)
#1.08 ms ± 14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

#@user3483203
%timeit frac(line, points)
#8.65 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

要将其矢量化,我会屏蔽边缘情况,并在最后担心它们。不管怎样,您只需要考虑position == size条件,因为low条件在相应的列中仅为零,out数组已经满足了。你知道吗

def frac(line, points):
    pos = np.searchsorted(line, points, side='right')
    low = pos == 0
    high = pos == line.shape[0]
    m = ~(low | high)
    ii = points[m]
    jj = pos[m]
    frac = (ii - line[jj-1]) / (line[jj] - line[jj-1])
    out = np.zeros((points.shape[0], 3))
    out[:, 0] = points
    out[m, 1] = jj - 1
    out[m, 2] = frac
    out[high, 1] = line.shape[0] - 1
    return out

基准

n = 10**5
line = np.linspace(-500, 500, 101)
points = np.random.random(n) * np.random.choice([1, -1], n) * 500

In [5]: %timeit run_multiple(line, points)
1.23 s ± 53.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [7]: %timeit frac(line, points)
13.4 ms ± 290 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [8]: np.allclose(frac(line, points), run_multiple(line, points))
Out[8]: True

相关问题 更多 >