Shortcuts

(原型) MaskedTensor高级语义

Created On: Oct 28, 2022 | Last Updated: Oct 28, 2022 | Last Verified: Not Verified

开始学习本教程之前,请确保已查看我们的`MaskedTensor概述教程 <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`。

本教程的目的是帮助用户了解一些高级语义的工作原理及其生成过程。我们会重点介绍其中两种语义:

*. MaskedTensor与`NumPy&apos;s MaskedArray <https://numpy.org/doc/stable/reference/maskedarray.html>`__之间的差异 *. 归约语义

准备工作

import torch
from torch.masked import masked_tensor
import numpy as np
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

MaskedTensor与NumPy&apos;s MaskedArray

NumPy&apos;s ``MaskedArray``与MaskedTensor存在一些基本语义差异。

*. 它们的工厂函数和基本定义会反转掩码(类似于``torch.nn.MHA``);也就是说,MaskedTensor

使用``True``表示“指定”而``False``表示“未指定”或“有效”/“无效”,而NumPy则相反。我们认为我们的掩码定义不仅更直观,而且与PyTorch整体上的现有语义更加一致。

*. 交集语义。在NumPy中,如果其中一个元素被掩盖,结果元素将

同样被掩盖——实际上,它们`应用逻辑或操作符 <https://github.com/numpy/numpy/blob/68299575d8595d904aff6f28e12d21bf6428a4ba/numpy/ma/core.py#L1016-L1024>`__。

data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])
npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())

print("npm0:\n", npm0)
print("npm1:\n", npm1)
print("npm0 + npm1:\n", npm0 + npm1)

与此同时,MaskedTensor不支持掩码不匹配时的加法或二元操作符——要理解原因,请查看:ref:归约章节<reduction-semantics>

mt0 = masked_tensor(data, mask)
mt1 = masked_tensor(data, ~mask)
print("mt0:\n", mt0)
print("mt1:\n", mt1)

try:
    mt0 + mt1
except ValueError as e:
    print ("mt0 + mt1 failed. Error: ", e)

不过,如果需要这种行为,MaskedTensor确实支持这些语义,可以访问数据和掩码,同时通过使用:func:`to_tensor`方便地将MaskedTensor转换为用掩码值填充的Tensor。例如:

t0 = mt0.to_tensor(0)
t1 = mt1.to_tensor(0)
mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask())

print("t0:\n", t0)
print("t1:\n", t1)
print("mt2 (t0 + t1):\n", mt2)

注意,掩码是`mt0.get_mask() & mt1.get_mask()`,因为:class:`MaskedTensor`的掩码是NumPy&apos;s的反转版本。

归约语义

回想在`MaskedTensor&apos;s概述教程 <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`__中我们讨论过的“实现缺失的torch.nan*操作”。这些是归约的例子——运算符从Tensor中移除一个(或多个)维度并聚合结果。在本节中,我们将使用归约语义来说明我们之前严格要求匹配掩码的原因。

基本上,:class:`MaskedTensor`在忽略掩盖(未指定)值的情况下执行相同的归约操作。以一个例子说明:

data = torch.arange(12, dtype=torch.float).reshape(3, 4)
mask = torch.randint(2, (3, 4), dtype=torch.bool)
mt = masked_tensor(data, mask)

print("data:\n", data)
print("mask:\n", mask)
print("mt:\n", mt)

现在,不同的归约(所有在dim=1上):

print("torch.sum:\n", torch.sum(mt, 1))
print("torch.mean:\n", torch.mean(mt, 1))
print("torch.prod:\n", torch.prod(mt, 1))
print("torch.amin:\n", torch.amin(mt, 1))
print("torch.amax:\n", torch.amax(mt, 1))

需要注意的是,被掩盖的元素下的值未必保证具有任何特定的值,特别是当整行或整列被完全掩盖时(归一化也是如此)。关于掩码语义的更多细节,可以查阅此`RFC <https://github.com/pytorch/rfcs/pull/27>`__。

现在,我们可以重新审视这个问题:为什么我们会在二元操作上强制执行掩码必须匹配的约束?换句话说,为什么我们不使用``np.ma.masked_array``的相同语义?请看以下示例:

data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])
npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())

print("npm0:", npm0)
print("npm1:", npm1)

现在,我们尝试加法操作:

print("(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("npm0.sum(0) + npm1.sum(0):\n", npm0.sum(0) + npm1.sum(0))

求和和加法显然应该具有结合性,但使用NumPy&apos;s的语义时,却不具有,这可能会给用户带来困惑。

MaskedTensor`则不会允许这种操作,因为`mask0 != mask1。话虽如此,如果用户愿意,可以通过一些方式规避这一点(例如,如下图所示,使用:func:`to_tensor`将MaskedTensor的未定义元素填充为0值),但用户现在必须更加明确自己的意图。

mt0 = masked_tensor(data0, ~mask0)
mt1 = masked_tensor(data1, ~mask1)

(mt0.to_tensor(0) + mt1.to_tensor(0)).sum(0)

总结

在本教程中,我们了解了MaskedTensor与NumPy&apos;s MaskedArray背后的不同设计决策,以及归约语义。总体而言,MaskedTensor旨在避免歧义和令人困惑的语义(例如,我们尝试在二元操作之间保留结合性属性),这反过来可能会使用户有时需要更加有意图地编写代码,但我们认为这是更好的选择。如果您对此有任何想法,请`告诉我们 <https://github.com/pytorch/pytorch/issues>`__!

脚本总运行时间: (0分钟 0.000秒)

画廊由Sphinx-Gallery生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源