如何实现"where"(numpy.where(...))?

2 投票
3 回答
893 浏览
提问于 2025-04-17 08:36

我刚接触函数式编程,想知道怎么在Python、Scala或Haskell中实现numpy.where()这个功能。如果能有个简单明了的解释就太好了。

3 个回答

3

这里有两种使用场景。第一种是你有两个数组,第二种是你只有一个数组。

在第一种情况下,使用 numpy.where(cond),你会得到一个索引列表,这些索引对应的条件数组是满足条件的。在Scala中,你通常会这样做:

(cond, cond.indices).zipped.filter((c,_) => c)._2

这显然没有那么简洁,但在Scala中,这不是人们通常使用的基本操作(因为它的构建方式不同,比如不太强调索引)。

在第二种情况下,使用 numpy.where(cond,x,y),你会根据条件 cond 的真假来得到 xy。如果 cond 为真,就得到 x,如果为假,就得到 y。在Scala中,

(cond, x, y).zipped.map((c,tx,ty) => if (c) tx else ty)

也能完成同样的操作(虽然不那么简洁,但同样,这通常也不是基本操作)。需要注意的是,在Scala中,你可以更方便地将 cond 设为一个方法,这个方法会测试 xy,并返回真或假,然后你就可以这样做:

(x, y).zipped.map((tx,ty) => if (c(tx,ty)) tx else ty)

(虽然通常即使简洁起见,你也会把数组命名为 xsys,而单个元素命名为 xy)。

5

在Python中,可以通过 numpy.where.__doc__ 来查看相关文档:

If `x` and `y` are given and input arrays are 1-D, `where` is
equivalent to::

    [xv if c else yv for (c,xv,yv) in zip(condition,x,y)]
6

在Haskell中,要处理n维列表,就像NumPy那样,需要比较复杂的类型类构建,但一维的情况就简单多了:

select :: [Bool] -> [a] -> [a] -> [a]
select [] [] [] = []
select (True:bs) (x:xs) (_:ys) = x : select bs xs ys
select (False:bs) (_:xs) (y:ys) = y : select bs xs ys

这只是一个简单的递归过程,逐个检查每个列表中的每个元素,当所有列表都到达末尾时,返回一个空列表。(注意,这里说的是列表,不是数组。)

这里有一个更简单但不太明显的一维列表实现,翻译自NumPy文档(感谢joaquin的提醒):

select :: [Bool] -> [a] -> [a] -> [a]
select bs xs ys = zipWith3 select' bs xs ys
  where select' True x _ = x
        select' False _ y = y

为了实现两个参数的情况(返回条件为真的所有索引;感谢Rex Kerr的提醒),可以使用列表推导式:

trueIndices :: [Bool] -> [Int]
trueIndices bs = [i | (i,True) <- zip [0..] bs]

也可以用现有的select来写,虽然这样做意义不大:

trueIndices :: [Bool] -> [Int]
trueIndices bs = catMaybes $ select bs (map Just [0..]) (repeat Nothing)

这里是n维列表的三参数版本:

{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}

class Select bs as where
  select :: bs -> as -> as -> as

instance Select Bool a where
  select True x _ = x
  select False _ y = y

instance (Select bs as) => Select [bs] [as] where
  select = zipWith3 select

下面是一个例子:

GHCi> select [[True, False], [False, True]] [[0,1],[2,3]] [[4,5],[6,7]]
[[0,5],[6,3]]

不过在实际应用中,你可能更想使用一个合适的n维数组类型。如果你只是想在一个特定的n上使用select,luqui在这个答案的评论中给出的建议更好:

在实际操作中,我会用(zipWith3.zipWith3.zipWith3) select' bs xs ys(针对三维情况)。

(随着n的增加,可以添加更多的zipWith3组合。)

撰写回答