将MySQL结果集转换为NumPy数组的最高效方法是什么?

17 投票
3 回答
26881 浏览
提问于 2025-04-16 23:36

我正在使用MySQLdb和Python。我有一些基本的查询,比如这个:

c=db.cursor()
c.execute("SELECT id, rating from video")
results = c.fetchall()

我希望“结果”能变成一个NumPy数组,并且我想节省内存使用。逐行复制数据似乎会非常低效(这样会需要双倍的内存)。有没有更好的方法可以把MySQLdb的查询结果转换成NumPy数组格式呢?

我想使用NumPy数组格式的原因是,我希望能够轻松地对数据进行切片和处理,而在这方面,Python似乎对多维数组不太友好。

e.g. b = a[a[:,2]==1] 

谢谢!

3 个回答

7

NumPy的fromiter方法在这里看起来是最合适的(就像Keith的回答中提到的那样)。

fromiter把MySQLdb游标方法返回的结果集转换成NumPy数组很简单,但有几个细节可能值得一提。

import numpy as NP
import MySQLdb as SQL

cxn = SQL.connect('localhost', 'some_user', 'their_password', 'db_name')
c = cxn.cursor()
c.execute('SELECT id, ratings from video')

# fetchall() returns a nested tuple (one tuple for each table row)
results = cursor.fetchall()

# 'num_rows' needed to reshape the 1D NumPy array returend by 'fromiter' 
# in other words, to restore original dimensions of the results set
num_rows = int(c.rowcount)

# recast this nested tuple to a python list and flatten it so it's a proper iterable:
x = map(list, list(results))              # change the type
x = sum(x, [])                            # flatten

# D is a 1D NumPy array
D = NP.fromiter(iterable=x, dtype=float, count=-1)  

# 'restore' the original dimensions of the result set:
D = D.reshape(num_rows, -1)

需要注意的是,fromiter返回的是一个1D的NumPy数组,

(这很合理,因为你可以用fromiter只返回MySQL表中某一行的一部分,通过传递一个count参数来实现)。

不过,你需要把它恢复成2D的形状,所以要调用游标方法rowcount,然后在最后一行调用reshape

最后,count参数的默认值是'-1',这意味着会获取整个可迭代对象。

25

这个解决方案使用了Kieth的fromiter技巧,但更直观地处理了SQL结果的二维表结构。同时,它还改进了Doug的方法,避免了在Python数据类型中进行所有的重塑和扁平化。通过使用结构化数组,我们几乎可以直接从MySQL结果读取到numpy中,几乎不需要使用Python的数据类型。我说“几乎”是因为fetchall迭代器仍然会生成Python元组。

不过有一个小注意事项,但这并不算大问题。你必须提前知道列的数据类型和行的数量。

知道列的类型应该很简单,因为你应该知道查询是什么,否则你可以使用curs.description和MySQLdb.FIELD_TYPE.*常量的映射。

知道行数意味着你需要使用客户端游标(这是默认的)。我对MySQLdb和MySQL客户端库的内部工作原理了解不多,但我理解的是,当使用客户端游标时,整个结果会被提取到客户端内存中,虽然我怀疑实际上会涉及一些缓冲和缓存。这意味着结果会占用双倍的内存,一次是游标的副本,一次是数组的副本,所以如果结果集很大,尽快关闭游标以释放内存可能是个好主意。

严格来说,你不必提前提供行数,但这样做意味着数组的内存会提前一次性分配,而不是随着迭代器中更多行的到来而不断调整大小,这样可以大大提高性能。

接下来是一些代码

import MySQLdb
import numpy

conn = MySQLdb.connect(host='localhost', user='bob', passwd='mypasswd', db='bigdb')
curs = conn.cursor() #Use a client side cursor so you can access curs.rowcount
numrows = curs.execute("SELECT id, rating FROM video")

#curs.fetchall() is the iterator as per Kieth's answer
#count=numrows means advance allocation
#dtype='i4,i4' means two columns, both 4 byte (32 bit) integers
A = numpy.fromiter(curs.fetchall(), count=numrows, dtype=('i4,i4'))

print A #output entire array
ids = A['f0'] #ids = an array of the first column
              #(strictly speaking it's a field not column)
ratings = A['f1'] #ratings is an array of the second colum

有关dtype的详细信息以及如何指定列数据类型和列名称,请查看numpy文档和上面关于结构化数组的链接。

16

fetchall 方法实际上返回的是一个迭代器,而 numpy 有一个叫 fromiter 的方法,可以用来从迭代器创建一个数组。所以,根据表格里的数据,你可以很方便地把这两者结合起来,或者使用一个适配器生成器。

撰写回答