用python排序20个新闻组数据集

2024-04-20 10:26:05 发布

您现在位置:Python中文网/ 问答频道 /正文

在下面的代码中,我尝试对20个新闻组数据集进行排序。但它提供了错误。奇怪的是这个数组的维数(1131410107)。有人知道为什么吗?你知道吗

import numpy as np

import tensorflow as tf

from pprint import pprint

from sklearn.datasets import fetch_20newsgroups_vectorized

data_train= fetch_20newsgroups_vectorized(subset='train')

temp= np.sort(data_train.data, axis=1)

Tags: 数据代码fromimportdata排序as错误
1条回答
网友
1楼 · 发布于 2024-04-20 10:26:05

正如fetch_20newsgroup_vectorized所描述的,它返回一个csr\u矩阵,而不是np.矩阵地址:

Returns bunch : Bunch object bunch.data: sparse matrix, shape [n_samples, n_features] bunch.target: array, shape [n_samples] bunch.target_names: list, length [n_classes]

你需要把它转到np.矩阵使用todense

np.sort(data_train.data.todense(), axis=1)

相关问题 更多 >