从文件加载数据集以用于sklearn/numpy,包括标签

7 投票
1 回答
23032 浏览
提问于 2025-04-17 17:17

我看到在sklearn这个库里,我们可以使用一些预定义的数据集,比如用 mydataset = datasets.load_digits() 这个命令,就能得到一个数组(可能是numpy数组?)存放数据集的内容,使用 mydataset.data 可以获取这个数组,而 mydataset.target 则是对应的标签数组。不过我想加载我自己的数据集,以便能在sklearn中使用。那我该怎么做,数据需要什么格式呢?我的文件格式是这样的(每一行代表一个数据点):

-0.2080,0.3480,0.3280,0.5040,0.9320,1.0000,label1
-0.2864,0.1992,0.2822,0.4398,0.7012,0.7800,label3
...
...
-0.2348,0.3826,0.6142,0.7492,0.0546,-0.4020,label2
-0.1856,0.3592,0.7126,0.7366,0.3414,0.1018,label1

1 个回答

14

你可以使用numpy的genfromtxt函数从文件中获取数据(http://docs.scipy.org/doc/numpy/reference/generated/numpy.genfromtxt.html

import numpy as np
mydata = np.genfromtxt(filename, delimiter=",")

不过,如果你的数据里有文本列,使用genfromtxt就会比较麻烦,因为你需要指定数据的类型。

用优秀的Pandas库来处理这些数据会简单得多(http://pandas.pydata.org/

import pandas as pd
mydata = pd.read_csv(filename)
target = mydata["Label"]  #provided your csv has header row, and the label column is named "Label"

#select all but the last column as data
data = mydata.ix[:,:-1]

撰写回答