Python中文
首页
教程
问答
标签
搜索
登录
注册
PyTorch如何在多个维度上进行聚集
回答此问题可获得
20
贡献值,回答如果被采纳可获得
50
分。
<p>我想找到一种不用for循环的方法</p> <p>假设我有一个多维张量<code>t0</code>:</p> <pre><code>bs = 4 seq = 10 v = 16 t0 = torch.rand((bs, seq, v)) </code></pre> <p>它的形状是:<code>torch.Size([4, 10, 16])</code></p> <p>我有另一个张量<code>labels</code>,它是<code>seq</code>维中5个随机指数的一批:</p> <pre><code>labels = torch.randint(0, seq, size=[bs, sample]) </code></pre> <p>这就是形状<code>torch.Size([4, 5])</code>。这用于索引<code>t0</code>的<code>seq</code>维度</p> <p>我想做的是使用<code>labels</code>张量在批处理维度上循环进行聚集。我的暴力解决方案是:</p> <pre><code>t1 = torch.empty((bs, sample, v)) for b in range(bs): for idx0, idx1 in enumerate(labels[b]): t1[b, idx0, :] = t0[b, idx1, :] </code></pre> <p>导致张量<code>t1</code>的形状:<code>torch.Size([4, 5, 16])</code></p> <p>在pytorch中有没有更惯用的方法</p>
0 条评论
分类:
Python问答
请先
登录
后评论
默认排序
时间排序
1 个回答
匿名
1天前
擅长:python、mysql、java
<p>你可以这样做:</p> <pre><code>t1 = t0[[[b] for b in range(bs)], labels] </code></pre> <p>或</p> <pre><code>t1 = torch.stack([t0[b, labels[b]] for b in range(bs)]) </code></pre>
请先
登录
后评论
针对此问题:
更多的回答
关注
89
关注
收藏
1
收藏,
216
浏览
网友 提问于 2天前
相关Python问题
Django:。是不是“超级用户”字段不起作用
6 回答
Django:'DeleteQuery'对象没有属性'add'
2 回答
Django:'ModelForm'对象没有属性
3 回答
Django:'python manage.py runserver'返回'TypeError:'WindowsPath'类型的对象没有len()
2 回答
Django:'Python管理.pysyncdb'不创建我的架构表
9 回答
Django:'Python管理.py迁移“耗时数小时(和其他奇怪的行为)
8 回答
Django:'readonly'属性在我的ModelForm上不起作用
10 回答
Django:'RegisterEmployeeView'对象没有属性'object'
3 回答
Django:'str'对象没有属性'get'
5 回答
Django:'创建' 不能被指定为Order模型表单中的值,因为它是一个不可编辑的字段
2 回答
Django:“'QuerySet'类型的对象不是JSON可序列化的”
10 回答
Django:“'utf8'编解码器无法解码位置19983中的字节0xe9:无效的连续字节”,加载临时文件时
8 回答
Django:“<…>”需要有一个字段“id”的值,然后才能使用这个manytomy关系
7 回答
Django:“AnonymousUser”对象没有“get_full_name”属性
7 回答
Django:“ascii”编解码器无法解码位置1035中的字节0xc3:序号不在范围内(128)
5 回答
Django:“BaseTable”对象不支持索引
5 回答
Django:“collections.OrderedDict”对象不可调用
1 回答
Django:“Country”对象没有属性“all”
4 回答
Django:“Data”对象没有属性“save”
4 回答
Django:“datetime”类型的对象不是JSON serializab
7 回答