郭等的多视图网络的keras实现。
multi-view-network的Python项目详细描述
Keras中的多视图网络
本包装基于郭宏宇,Colin Cherry和Jiang Su(2017)的End-to-End Multi-View Networks for Text Classification。多视图网络(mvn)的总体架构在本文中并没有详细说明,所以我不得不做一些猜测。
请随时联系我在aless@ndro.xyz与任何反馈。
基本用法
假设您将您的语料库准备为一个文档列表,每个文档都由一个嵌入列表(每个令牌一个)表示,那么您可以这样训练mvn:
importmulti_view_networkimportnumpyasnp# Very important: the documents in embedded_corpus **need** to have# the same number of embedded_tokens. If this is not the case# you can use multi_view_network.pad_embedded_corpus() to pad# the documents with 0-filled mock embeddings.data=np.array(embedded_corpus)# The output of the MVN is softmaxed so it's important to# make sure the labels are one-hot encoded.labels=np.array([[0,1],[0,1],[1,0],etc.])model=multi_view_network.BuildMultiViewNetwork(embeddings_dim=300,hidden_units=16,dropout_rate=0.2,output_units=2)model.compile(optimizer='sgd',loss='categorical_crossentropy')model.fit(data,labels,epochs=200,batch_size=32)
更复杂的体系结构
^ {
importmulti_view_networkembeddings_dim=300hidden_units=64output_units=2inputs=keras.layers.Input(shape=(None,embeddings_dim))s1=SelectionLayer(name='s1')(inputs)s2=SelectionLayer(name='s2')(inputs)s3=SelectionLayer(name='s3')(inputs)s4=SelectionLayer(name='s4')(inputs)s5=SelectionLayer(name='s5')(inputs)s6=SelectionLayer(name='s6')(inputs)s7=SelectionLayer(name='s7')(inputs)s8=SelectionLayer(name='s8')(inputs)v1=ViewLayer(view_index=1,name='v1')(s1)v2=ViewLayer(view_index=2,name='v2')([s1,s2])v3=ViewLayer(view_index=3,name='v3')([s1,s2,s3])v4=ViewLayer(view_index=4,name='v4')([s1,s2,s3,s4])v5=ViewLayer(view_index=5,name='v5')([s1,s2,s3,s4,s5])v6=ViewLayer(view_index=6,name='v6')([s1,s2,s3,s4,s5,s6])v7=ViewLayer(view_index=7,name='v7')([s1,s2,s3,s4,s5,s6,s7])v8=ViewLayer(view_index='Last',name='v8')(s8)concatenation=keras.layers.concatenate([v1,v2,v3,v4,v5,v6,v7,v8],name='concatenation')fully_connected=keras.layers.Dense(units=hidden_units,name='fully_connected')(concatenation)dropout=keras.layers.Dropout(rate=dropout_rate)(fully_connected)another_dense_layer=keras.layers.Dense(units=hidden_units,name='another_dense_layer')(dropout)softmax=keras.layers.Dense(units=output_units,activation='softmax',name='softmax')(dropout)model=keras.models.Model(inputs=inputs,outputs=softmax)
实用程序
utils.py
模块包含两个函数,这些函数在预处理输入时可以派上用场。如上所述,当您强制将嵌入的_文档列表设置为np.array()
时,所有文档都具有相同数量的嵌入的_令牌是很重要的。否则,生成的数组将有一个不正确的.shape
,这将导致Keras抛出错误(因为输入与预期的形状不匹配)。
有两个实用函数可以用来解决这个问题:pad_embedded_corpus()和cap_embedded_corpus()。第一个将0填充的mock embedded_标记添加到每个文档,直到所有文档具有相同的长度。第二种方法裁剪每个文档,以便只维护前x个标记,从而获得相同的结果。
例如:
importmulti_view_networkembedded_corpus=[[[0,0]],[[0,0],[1,1]],[[0,0],[1,1],[2,1]]]padded_corpus=multi_view_network.pad_embedded_corpus(embedded_corpus,embeddings_dim=2)padded_corpus_sizes=[len(lst)forlstinpadded_corpus]# padded_corps_sizes# >>> [3, 3, 3]capped_corpus=multi_view_network.cap_embedded_corpus(embedded_corpus)capped_corpus_sizes=[len(lst)forlstincapped_corpus]#capped_corpus_sizes# >>> [1, 1, 1]
在文档中添加0填充向量对mvn的输出和训练性能没有影响,因此建议确保所有嵌入的文档具有相同的长度。