M - PHATE

m-phate的Python项目详细描述


m-phate

Latest PyPi versionTravis CI BuildCoverage StatusarXiv PreprintTwitterGitHub stars

Demonstration M-PHATE plot

多层phate(m-phate)是一种用于时间演化数据可视化的降维算法。要了解更多关于m-phate的知识,您可以阅读我们关于arxiv的预印本,其中我们将其应用于训练过程中神经网络的进化。上面我们展示了一个m-phate应用于300多个训练阶段的3层mlp的演示,由epoch(左)、hidden layer(中)和最强烈激活每个隐藏单元的数字标签(右)着色。在下面,你可以看到同样的网络,辍学应用在三维的训练中,也由最活跃的单位着色。

3D rotating gif

目录

工作原理

多层phate(m-phate)结合了一种新的多层核结构和PHATE visualization。我们的内核捕捉了进化图结构的动态,当可视化时,它给出了关于系统进化的独特直觉;在arxiv的预印本中,我们展示了在训练和再训练过程中将其应用于神经网络的情况。我们将m-phate与其他降维技术进行了比较,结果表明,多层核的组合构造和phate的使用为可视化提供了显著的改进。在两个小插曲中,我们演示了在连续学习中对已建立的训练任务和学习方法使用m-phate,以及在通常用于提高泛化性能的正则化技术中使用m-phate。

m-phate中使用的多层核包括在数据的时间片(如神经网络训练中的时间段)上建立图形,然后通过在时间上连接每个点到自身,并根据其相似性加权来连接这些片。其结果是一个高度稀疏、结构化的内核,它可以洞察数据的演化结构。

Example of multislice graph

Example of multislice kernel

安装

pypi

安装
pip install --user m-phate

从源安装

pip install --user git+https://github.com/scottgigante/m-phate.git

用法

基本用法示例

下面我们将m-phate应用于50个随机运动点的模拟数据。

import numpy as np
import m_phate
import scprep

# create fake data
n_time_steps = 100
n_points = 50
n_dim = 25
np.random.seed(42)
data = np.cumsum(np.random.normal(0, 1, (n_time_steps, n_points, n_dim)), axis=0)

# embedding
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(data)

# plot
time = np.repeat(np.arange(n_time_steps), n_points)
scprep.plot.scatter2d(m_phate_data, c=time, ticks=False, label_prefix="M-PHATE")

Example embedding

网络培训

为了将m-phate应用到神经网络中,我们在训练过程中提供了帮助类来存储网络中的样本。要使用这些,必须安装^{}^{}

import numpy as np

import keras
import scprep

import m_phate
import m_phate.train
import m_phate.data

# load data
x_train, x_test, y_train, y_test = m_phate.data.load_mnist()

# select trace examples
trace_idx = [np.random.choice(np.argwhere(y_test[:, i] == 1).flatten(),
                              10, replace=False)
             for i in range(10)]
trace_data = x_test[np.concatenate(trace_idx)]

# build neural network
lrelu = keras.layers.LeakyReLU(alpha=0.1)
inputs = keras.layers.Input(
    shape=(x_train.shape[1],), dtype='float32', name='inputs')
h1 = keras.layers.Dense(128, activation=lrelu, name='h1')(inputs)
h2 = keras.layers.Dense(128, activation=lrelu, name='h2')(h1)
h3 = keras.layers.Dense(128, activation=lrelu, name='h3')(h2)
outputs = keras.layers.Dense(10, activation='softmax', name='output_all')(h3)

# build trace model helper
model_trace = keras.models.Model(inputs=inputs, outputs=[h1, h2, h3])
trace = m_phate.train.TraceHistory(trace_data, model_trace)

# compile network
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy',
              metrics=['categorical_accuracy', 'categorical_crossentropy'])

# train network
model.fit(x_train, y_train, batch_size=128, epochs=200,
          verbose=1, callbacks=[trace],
          validation_data=(x_test,
                           y_test))

# extract trace data
trace_data = np.array(trace.trace)
epoch = np.repeat(np.arange(trace_data.shape[0]), trace_data.shape[1])

# apply M-PHATE
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(trace_data)

# plot the result
scprep.plot.scatter2d(m_phate_data, c=epoch, ticks=False,
                      label_prefix="M-PHATE")

有关详细示例,请参见keras中的示例笔记本和^{}中的tensorflow

参数调整

调整m-phate参数的关键在于平衡片间连接和片内连接之间的平衡。这主要是通过interslice_knnintraslice_knn实现的。您可以在this notebook中看到参数调整效果的示例。

图形复制

我们提供脚本来复制预印本中的所有经验数据。

运行它们:

git clone https://github.com/scottgigante/m-phate
cd m-phate
pip install --user .
DATA_DIR=~/data/checkpoints/m_phate # change this if you want to store the data elsewhere

chmod +x scripts/generalization/generalization_train.sh
chmod +x scripts/task_switching/classifier_mnist_task_switch_train.sh

./scripts/generalization/generalization_train.sh $DATA_DIR
./scripts/task_switching/classifier_mnist_task_switch_train.sh $DATA_DIR

python scripts/demonstration_plot.py $DATA_DIR
python scripts/comparison_plot.py $DATA_DIR
python scripts/generalization_plot.py $DATA_DIR
python scripts/task_switch_plot.py $DATA_DIR

# generalization plot using training data
./scripts/generalization/generalization_train.sh ${DATA_DIR}/train_data --sample-train-data
mkdir train_data; cd train_data; python -i ../scripts/generalization_plot.py ${DATA_DIR}/train_data; cd ..

待办事项

  • 为Pythorch提供支持
  • 笔记本示例:
    • 分类,pytorch
    • 自动编码器,pytorch
  • 建立“已读文档”页
  • 更新arxiv链接

帮助

如果您有任何问题,请随时open an issue

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java数组。按字符串排序   如何使用Netbeans设置Java打印的页面大小   java有没有一种方法可以获取sparkjava/嵌入式jetty服务器的主线程执行器?   正则表达式Java正则表达式:需要更简单的解决方案   无法使用java解析XML   MySQL Java JDBC:如何获取自动递增列的名称?   java错误:“限定符必须是表达式”Android Studio   Spring+java。lang.NoClassDefFoundError:weblogic/logging/LogEntryFormatter   java将JList插入GridLayout   listview中的java Get selected复选框   使用CriteriaBuilder的java JPA左外部联接会导致错误:不允许部分对象查询维护缓存或进行编辑   java循环双链接列表addToHead和print   java更好地检测三角形按钮(libgdx)   java ConcurrentHashMap迭代保证人   java如何获取控制台。通过webdriver记录信息?   java Javafx阶段为空   java如何使用apachetika从excel文件中访问空白单元格   java使用SQlite数据库列填充AutoCompleteTextView   java如何在不使用idea构建整个maven项目的情况下运行主方法?