区域关注的实施
area-attention的Python项目详细描述
区域注意事项
【PyTorch实施注意事项】。 这个模块允许关注内存区域,每个区域包含一组空间上或时间上相邻的项。 TensorFlow实现可以找到here。在
设置
$ pip install area_attention
使用
单头区域注意:
^{pr2}$多头区域注意:
importtorchfromarea_attentionimportAreaAttention,MultiHeadAreaAttentionarea_attention=AreaAttention(key_query_size=32,area_key_mode='max',area_value_mode='mean',max_area_height=2,max_area_width=2,memory_height=4,memory_width=4,dropout_rate=0.2,top_k_areas=0)multi_head_area_attention=MultiHeadAreaAttention(area_attention=area_attention,num_heads=2,key_query_size=32,key_query_size_hidden=32,value_size=64,value_size_hidden=64)q=torch.rand(4,8,32)k=torch.rand(4,16,32)v=torch.rand(4,16,64)x=multi_head_area_attention(q,k,v)x# torch.Tensor with shape (8, 64)
单元测试
$ python -m pytest tests
书目
[1]李阳等.区域注意.机器学习国际会议。2019年PMLR。在
引用
@inproceedings{li2019area,title={Area attention},author={Li, Yang and Kaiser, Lukasz and Bengio, Samy and Si, Si},booktitle={International Conference on Machine Learning},pages={3846--3855},year={2019},organization={PMLR}}
- 项目
标签: