如何实现"where"(numpy.where(...))?
我刚接触函数式编程,想知道怎么在Python、Scala或Haskell中实现numpy.where()这个功能。如果能有个简单明了的解释就太好了。
3 个回答
这里有两种使用场景。第一种是你有两个数组,第二种是你只有一个数组。
在第一种情况下,使用 numpy.where(cond)
,你会得到一个索引列表,这些索引对应的条件数组是满足条件的。在Scala中,你通常会这样做:
(cond, cond.indices).zipped.filter((c,_) => c)._2
这显然没有那么简洁,但在Scala中,这不是人们通常使用的基本操作(因为它的构建方式不同,比如不太强调索引)。
在第二种情况下,使用 numpy.where(cond,x,y)
,你会根据条件 cond
的真假来得到 x
或 y
。如果 cond
为真,就得到 x
,如果为假,就得到 y
。在Scala中,
(cond, x, y).zipped.map((c,tx,ty) => if (c) tx else ty)
也能完成同样的操作(虽然不那么简洁,但同样,这通常也不是基本操作)。需要注意的是,在Scala中,你可以更方便地将 cond
设为一个方法,这个方法会测试 x
和 y
,并返回真或假,然后你就可以这样做:
(x, y).zipped.map((tx,ty) => if (c(tx,ty)) tx else ty)
(虽然通常即使简洁起见,你也会把数组命名为 xs
和 ys
,而单个元素命名为 x
和 y
)。
在Python中,可以通过 numpy.where.__doc__
来查看相关文档:
If `x` and `y` are given and input arrays are 1-D, `where` is
equivalent to::
[xv if c else yv for (c,xv,yv) in zip(condition,x,y)]
在Haskell中,要处理n维列表,就像NumPy那样,需要比较复杂的类型类构建,但一维的情况就简单多了:
select :: [Bool] -> [a] -> [a] -> [a]
select [] [] [] = []
select (True:bs) (x:xs) (_:ys) = x : select bs xs ys
select (False:bs) (_:xs) (y:ys) = y : select bs xs ys
这只是一个简单的递归过程,逐个检查每个列表中的每个元素,当所有列表都到达末尾时,返回一个空列表。(注意,这里说的是列表,不是数组。)
这里有一个更简单但不太明显的一维列表实现,翻译自NumPy文档(感谢joaquin的提醒):
select :: [Bool] -> [a] -> [a] -> [a]
select bs xs ys = zipWith3 select' bs xs ys
where select' True x _ = x
select' False _ y = y
为了实现两个参数的情况(返回条件为真的所有索引;感谢Rex Kerr的提醒),可以使用列表推导式:
trueIndices :: [Bool] -> [Int]
trueIndices bs = [i | (i,True) <- zip [0..] bs]
也可以用现有的select
来写,虽然这样做意义不大:
trueIndices :: [Bool] -> [Int]
trueIndices bs = catMaybes $ select bs (map Just [0..]) (repeat Nothing)
这里是n维列表的三参数版本:
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
class Select bs as where
select :: bs -> as -> as -> as
instance Select Bool a where
select True x _ = x
select False _ y = y
instance (Select bs as) => Select [bs] [as] where
select = zipWith3 select
下面是一个例子:
GHCi> select [[True, False], [False, True]] [[0,1],[2,3]] [[4,5],[6,7]]
[[0,5],[6,3]]
不过在实际应用中,你可能更想使用一个合适的n维数组类型。如果你只是想在一个特定的n上使用select
,luqui在这个答案的评论中给出的建议更好:
在实际操作中,我会用
(zipWith3.zipWith3.zipWith3) select' bs xs ys
(针对三维情况)。
(随着n的增加,可以添加更多的zipWith3
组合。)