水平堆叠层的包装层

keras-multi-head的Python项目详细描述


路缘石多头

TravisCoverageVersionDownloadsLicense

用于水平堆叠层的包装层。

安装

pip install keras-multi-head

用法

重复层

如果仅提供一个层,则将复制该层。参数layer_num控制最终将复制多少层。

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=100,output_dim=20,name='Embedding'))model.add(MultiHead(keras.layers.LSTM(units=32),layer_num=5,name='Multi-LSTMs'))model.add(keras.layers.Flatten(name='Flatten'))model.add(keras.layers.Dense(units=4,activation='softmax',name='Dense'))model.build()model.summary()

使用多层

第一个参数也可以是具有不同配置的层的列表,但是,它们必须具有相同的输出形状。

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=100,output_dim=20,name='Embedding'))model.add(MultiHead([keras.layers.Conv1D(filters=32,kernel_size=3,padding='same'),keras.layers.Conv1D(filters=32,kernel_size=5,padding='same'),keras.layers.Conv1D(filters=32,kernel_size=7,padding='same'),],name='Multi-CNNs'))model.build()model.summary()

线性变换

当给定hidden_dim时,输入数据将映射到每个层的相同形状的不同值。

正则化

当您希望从平行层中提取不同的特征时,将使用正则化。可以自定义层中权重的索引,间隔表示权重的部分和正则化因子。

例如,双向lstm层默认有6个权重,前3个属于前向层。前向层中的第二个权重(递归核)控制递归连接的门的计算。计算单元状态的核心是递归核的x 2到x 3个单元。我们可以对内核使用正则化:

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=5,output_dim=3,name='Embed'))model.add(MultiHead(layer=keras.layers.Bidirectional(keras.layers.LSTM(units=16),name='LSTM'),layer_num=5,reg_index=[1,4],reg_slice=(slice(None,None),slice(32,48)),reg_factor=0.1,name='Multi-Head-Attention',))model.add(keras.layers.Flatten(name='Flatten'))model.add(keras.layers.Dense(units=2,activation='softmax',name='Dense'))model.build()
  • reg_index:指数layer.get_weights(),单个整数或整数列表。
  • reg_sliceslices或slices的元组或以前选择的列表。如果在reg_index中提供了多个索引,并且reg_slice不是列表,则假定reg_slice等于所有索引。如果将此参数保留为None,则将使用整个数组。
  • reg_factor:正则化因子,浮点数或浮点数列表。

多头注意力

提供了一个更具体的多头层(因为普通层更难使用)。该层使用缩放的点积注意层作为其子层,只需要head_num

importkerasfromkeras_multi_headimportMultiHeadAttentioninput_layer=keras.layers.Input(shape=(2,3),name='Input',)att_layer=MultiHeadAttention(head_num=3,name='Multi-Head',)(input_layer)model=keras.models.Model(inputs=input_layer,outputs=att_layer)model.compile(optimizer='adam',loss='mse',metrics={},)model.summary()

当输入只有一层时,输入张量和输出张量的形状是相同的。当给定列表时,输入层将被视为查询、键和值:

importkerasfromkeras_multi_headimportMultiHeadAttentioninput_query=keras.layers.Input(shape=(2,3),name='Input-Q',)input_key=keras.layers.Input(shape=(4,5),name='Input-K',)input_value=keras.layers.Input(shape=(4,6),name='Input-V',)att_layer=MultiHeadAttention(head_num=3,name='Multi-Head',)([input_query,input_key,input_value])model=keras.models.Model(inputs=[input_query,input_key,input_value],outputs=att_layer)model.compile(optimizer='adam',loss='mse',metrics={},)model.summary()

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
垃圾收集Java将最大堆大小(Xmx)设置为物理内存的一部分   java调用getFragmentManager()而不扩展片段   Java将SQL db中的日期格式字符串转换为日期对象以进行比较   java如何通过输入月份和年份的整数来查找特定月份?   java Struts2推送通知/WebSocket交互最佳实践   java为什么spring jdbcTemplate batchUpdate逐行插入   java句柄FileUploadBase。带弹簧2的SizeLimitExceedeException   使用Google App Engine(GAE)生成哈希代码的java   javaspringmysql.com。mysql。希杰。jdbc。例外情况。通信异常:通信链路故障   java当我想构建我的循环视图项目时遇到了一些问题   java小部件列表视图加载视图   安卓 java。伊奥。IOException:损坏的文件描述符   java在从库中拾取文件时设置方向   SFTP中每个文件的java线程   读取数据库时出现安卓 Java内存不足异常   java Swagger不够聪明,无法处理匿名类型(如地图)   java需要了解安卓视图中onClick回调的实现   java从管理面板退出后,如何返回主菜单并进入任何面板?   java高效增长的原子阵列