torchcluster是一个用于集群分析的python包。
torchcluster的Python项目详细描述
torchcluster是一个用于集群分析的python包。的速度 利用pytorch对聚类算法进行了有效的改进 后端。我们也在研究测试数据集和可视化工具。 相关工作将在下一个版本中发布。
系统要求
火炬组应该在
- 所有Linux发行版不早于Ubuntu16.04
- MacOS X
- Windows 10
torchcluster还需要python 3.5或更高版本。python 2支持是 来了。
现在,torchcluster在PyTorch上工作 0.4.1条。
安装
使用pip
pip install torchcluster
使用水蟒
conda install -c tczhangzhi torchcluster
火炬团的样子
定义数据集生成器并生成数据集:
from torchcluster.dataset.simple import SimpleDataset dataset_factory = SimpleDataset(2, feature=2, sigma=2, device=device) dataset = dataset_factory(100)
配置群集算法并获取结果:
from torchcluster.zoo.spectrum import SpectrumClustering cluster = SpectrumClustering(2) result, _ = cluster(dataset)
您还可以对自己的数据集进行集群。数据集应该是张量 n乘以m,其中n是数据集中的数据点数量,m是 每个数据点的尺寸:
dataset = torch.cat([torch.randn(500,2) + torch.Tensor([-2,-3]), torch.randn(500,2) + torch.Tensor([2,1])])
使用光谱聚类得到以下结果:
tensor([0, 0, ..., 1, 1])