将MySQL结果集转换为NumPy数组的最高效方法是什么?
我正在使用MySQLdb和Python。我有一些基本的查询,比如这个:
c=db.cursor()
c.execute("SELECT id, rating from video")
results = c.fetchall()
我希望“结果”能变成一个NumPy数组,并且我想节省内存使用。逐行复制数据似乎会非常低效(这样会需要双倍的内存)。有没有更好的方法可以把MySQLdb的查询结果转换成NumPy数组格式呢?
我想使用NumPy数组格式的原因是,我希望能够轻松地对数据进行切片和处理,而在这方面,Python似乎对多维数组不太友好。
e.g. b = a[a[:,2]==1]
谢谢!
3 个回答
NumPy的fromiter方法在这里看起来是最合适的(就像Keith的回答中提到的那样)。
用fromiter把MySQLdb游标方法返回的结果集转换成NumPy数组很简单,但有几个细节可能值得一提。
import numpy as NP
import MySQLdb as SQL
cxn = SQL.connect('localhost', 'some_user', 'their_password', 'db_name')
c = cxn.cursor()
c.execute('SELECT id, ratings from video')
# fetchall() returns a nested tuple (one tuple for each table row)
results = cursor.fetchall()
# 'num_rows' needed to reshape the 1D NumPy array returend by 'fromiter'
# in other words, to restore original dimensions of the results set
num_rows = int(c.rowcount)
# recast this nested tuple to a python list and flatten it so it's a proper iterable:
x = map(list, list(results)) # change the type
x = sum(x, []) # flatten
# D is a 1D NumPy array
D = NP.fromiter(iterable=x, dtype=float, count=-1)
# 'restore' the original dimensions of the result set:
D = D.reshape(num_rows, -1)
需要注意的是,fromiter返回的是一个1D的NumPy数组,
(这很合理,因为你可以用fromiter只返回MySQL表中某一行的一部分,通过传递一个count参数来实现)。
不过,你需要把它恢复成2D的形状,所以要调用游标方法rowcount,然后在最后一行调用reshape。
最后,count参数的默认值是'-1',这意味着会获取整个可迭代对象。
这个解决方案使用了Kieth的fromiter技巧,但更直观地处理了SQL结果的二维表结构。同时,它还改进了Doug的方法,避免了在Python数据类型中进行所有的重塑和扁平化。通过使用结构化数组,我们几乎可以直接从MySQL结果读取到numpy中,几乎不需要使用Python的数据类型。我说“几乎”是因为fetchall迭代器仍然会生成Python元组。
不过有一个小注意事项,但这并不算大问题。你必须提前知道列的数据类型和行的数量。
知道列的类型应该很简单,因为你应该知道查询是什么,否则你可以使用curs.description和MySQLdb.FIELD_TYPE.*常量的映射。
知道行数意味着你需要使用客户端游标(这是默认的)。我对MySQLdb和MySQL客户端库的内部工作原理了解不多,但我理解的是,当使用客户端游标时,整个结果会被提取到客户端内存中,虽然我怀疑实际上会涉及一些缓冲和缓存。这意味着结果会占用双倍的内存,一次是游标的副本,一次是数组的副本,所以如果结果集很大,尽快关闭游标以释放内存可能是个好主意。
严格来说,你不必提前提供行数,但这样做意味着数组的内存会提前一次性分配,而不是随着迭代器中更多行的到来而不断调整大小,这样可以大大提高性能。
接下来是一些代码
import MySQLdb
import numpy
conn = MySQLdb.connect(host='localhost', user='bob', passwd='mypasswd', db='bigdb')
curs = conn.cursor() #Use a client side cursor so you can access curs.rowcount
numrows = curs.execute("SELECT id, rating FROM video")
#curs.fetchall() is the iterator as per Kieth's answer
#count=numrows means advance allocation
#dtype='i4,i4' means two columns, both 4 byte (32 bit) integers
A = numpy.fromiter(curs.fetchall(), count=numrows, dtype=('i4,i4'))
print A #output entire array
ids = A['f0'] #ids = an array of the first column
#(strictly speaking it's a field not column)
ratings = A['f1'] #ratings is an array of the second colum
有关dtype的详细信息以及如何指定列数据类型和列名称,请查看numpy文档和上面关于结构化数组的链接。
fetchall
方法实际上返回的是一个迭代器,而 numpy 有一个叫 fromiter 的方法,可以用来从迭代器创建一个数组。所以,根据表格里的数据,你可以很方便地把这两者结合起来,或者使用一个适配器生成器。