图书馆神经网络.BernoulliRBM”创建高度相关的特征

2024-04-19 13:54:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用sklearn库生成数据集的新功能
使用受限玻耳兹曼机(RBM,
神经网络伯努利先生)。我使用以下环境:

python 3.5.0版 numpy==1.11.1 scikit学习==0.18

我已经尝试了大量的迭代(n\u iter=6000)和
所有训练数据(373个样本)的学习率低(0.0001)。但是,
RBM生成的新功能都非常重要
相关的。有人能解释为什么会这样吗?你知道吗

以下是MWE:

import numpy as np
import csv
from sklearn.neural_network import BernoulliRBM

# train data
train_data = np.array(
[[0.0326086956522,0.0,0.0,0.0200400801603,0.0674157303371,0.000805152979066,0.00200803212851,0.243243243243,0.0123456790123,0.55,0.0233428760185,0.0,0.0,0.0,0.444444444,0.0,0.0,0.157556270138,0.0188679245283,0.0983652512615],
[0.0108695652174,0.2,0.0,0.00200400801603,0.0112359550562,0.0,0.0,0.027027027027,0.0123456790123,1.0,0.00154151068047,0.0,0.0,1.0,1.0,0.0,0.0,0.0289389067571,0.0,0.0],
[0.0869565217391,0.0,0.152542372881,0.0260521042084,0.0749063670412,0.00322061191626,0.0180722891566,0.108108108108,0.0987654320988,0.4,0.022241796961,0.2,0.0909090909091,0.0,0.40625,0.0,0.0,0.053054662388,0.0188679245283,0.129097937384],
[0.0326086956522,0.2,0.0847457627119,0.0140280561122,0.0149812734082,0.000268384326355,0.0120481927711,0.027027027027,0.0246913580247,0.25,0.00352345298392,1.0,0.0,0.75,0.555555556,0.0,0.0,0.0192926045047,0.0188679245283,0.0983652512615],
[0.0978260869565,0.0,0.0,0.0100200400802,0.0711610486891,0.00214707461084,0.00803212851406,0.027027027027,0.111111111111,0.265625,0.0262056815679,1.0,0.0,0.0,0.518518519,0.0,0.0,0.0568060021635,0.0566037735849,0.213107498008],
[0.0760869565217,0.8,0.0,0.0180360721443,0.0936329588015,0.0,0.0120481927711,0.0810810810811,0.0864197530864,0.3333333335,0.0561550319313,0.0,0.0,0.863636364,0.342857143,0.5,0.333333333333,0.168121267841,0.169811320755,0.463705037033],
[0.0978260869565,1.0,0.0,0.0100200400802,0.063670411985,0.00697799248524,0.0,0.135135135135,0.0740740740741,0.4166666665,0.0156353226162,0.0,0.0,0.949367089,0.333333333,0.25,0.266666666667,0.0316184351626,0.0566037735849,0.163932249402],
[0.0326086956522,0.2,0.0,0.0380761523046,0.0374531835206,0.000805152979066,0.0281124497992,0.135135135135,0.037037037037,1.0,0.00836820083682,0.0,0.0,0.923076923,0.583333333,0.0,0.0,0.0562700964881,0.0188679245283,0.0491752486057],
[0.0108695652174,0.0,0.0,0.0200400801603,0.00374531835206,0.0,0.0160642570281,0.0540540540541,0.0123456790123,1.0,0.000220215811495,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.217391304348,0.0,0.0,0.0140280561122,0.295880149813,0.0365002683843,0.0100401606426,0.135135135135,0.123456790123,0.4487534625,0.183880202599,1.0,0.0909090909091,0.0,0.19375,0.0,0.0,0.191961414822,0.188679245283,0.287703974741],
[0.0652173913043,0.0,0.0,0.0160320641283,0.0224719101124,0.00402576489533,0.0140562248996,0.027027027027,0.0740740740741,1.0,0.00132129486897,0.0,0.0,0.0,0.444444444,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.0326086956522,0.6,0.0,0.0100200400802,0.0411985018727,0.000268384326355,0.00200803212851,0.108108108108,0.0123456790123,0.25,0.00902884827131,1.0,0.0909090909091,0.971428571,0.75,0.25,0.133333333333,0.0594855305401,0.0566037735849,0.147540499867],
[0.119565217391,0.2,0.0,0.0140280561122,0.0973782771536,0.0,0.0100401606426,0.0540540540541,0.135802469136,0.29,0.0398590618806,1.0,0.0,0.529411765,0.409090909,0.0,0.0,0.0723472668927,0.0188679245283,0.107306205553],
[0.0326086956522,0.2,0.0,0.0100200400802,0.0262172284644,0.000268384326355,0.00200803212851,0.108108108108,0.037037037037,0.25,0.00638625853336,1.0,0.0,0.818181818,0.666666667,0.0,0.0,0.0401929260499,0.0188679245283,0.0983652512615],
[0.173913043478,0.4,0.0,0.0300601202405,0.243445692884,0.020397208803,0.0,0.405405405405,0.16049382716,0.46,0.106364236952,1.0,0.0,0.725490196,0.311111111,0.0,0.0,0.136254019315,0.169811320755,0.230532031043],
[0.163043478261,0.4,0.0,0.0180360721443,0.153558052434,0.0,0.0,0.243243243243,0.185185185185,0.3392857145,0.044924025545,1.0,0.0909090909091,0.725490196,0.225,0.25,0.133333333333,0.0594855305401,0.0377358490566,0.226223848446],
[0.152173913043,0.6,0.0508474576271,0.0220440881764,0.10861423221,0.0228126677402,0.00602409638554,0.216216216216,0.135802469136,0.2884615385,0.0237833076415,1.0,0.0909090909091,0.759259259,0.321428571,0.0,0.0,0.0316949931128,0.0754716981132,0.189692820679],
[0.29347826087,0.4,0.0,0.0160320641283,0.378277153558,0.0421363392378,0.0100401606426,0.0810810810811,0.185185185185,0.4123931625,0.283197533583,0.888888889,0.0909090909091,0.294117647,0.183760684,0.25,0.466666666667,0.220078599537,0.0754716981132,0.163932249402],
[0.0326086956522,0.0,0.0,0.00400801603206,0.0112359550562,0.000805152979066,0.00401606425703,0.0,0.037037037037,0.75,0.000880863245981,0.0,0.0,0.0,0.666666667,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.597826086957,0.4,0.135593220339,0.0400801603206,0.397003745318,0.352388620505,0.0160642570281,0.324324324324,0.111111111111,0.4782763535,0.249504514424,1.0,0.181818181818,0.406593407,0.195454545,0.0,0.0,0.0922537270084,0.188679245283,0.273613857004]]
)


# define the RBM model
random_state = 200
model = BernoulliRBM(n_components=10,n_iter=10,random_state=random_state)

# building RBM and creating RBM features
# Each column means one feature, each row means one line of the train data.
RBM_feature_data = model.fit_transform(train_data)

print(RBM_feature_data)

谢谢你!你知道吗


Tags: 数据import功能numpydatamodelnptrain