多级筛选:
比如结构是2*2*3,只想选第三维的最大的
tx[index, best_n, g_y_center, g_x_center]
index=[01],best_n=[0,1]
最后只取两个值,第一行,第1列,第二行,第2列的。
筛选第3维最大的值,下面的代码不对,解决方法:查询max源码
也可以把3维用view降到2维再计算就可以了。
import torch
anch_ious = torch.Tensor([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]])
print('b shape',anch_ious.shape)
b = torch.max(anch_ious, 2)
print(b[0])
print(b[1])
b = b[1].squeeze(1)
print(b)
print(anch_ious[list(range(anch_ious.size(0))),list(range(anch_ious.size(1))), b])
通过值筛选:
import torch
x = torch.linspace(1, 8, steps=8).view(4, 2)
#筛选第一维和第二维都
本文探讨了在PyTorch中如何进行多级筛选,特别是针对三维张量时如何选取第三维的最大值。介绍了错误代码示例并提出了解决方案,包括查看`max`函数源码以及将三维张量转换为二维来简化计算。此外,还提到如何筛选二维张量中最大值不满足条件的元素。
订阅专栏 解锁全文
1756

被折叠的 条评论
为什么被折叠?



