Shortcuts

(原型) MaskedTensor概述

Created On: Oct 28, 2022 | Last Updated: Oct 28, 2022 | Last Verified: Not Verified

本教程旨在作为使用MaskedTensor的起点,并讨论其掩码语义。

MaskedTensor作为:class:`torch.Tensor`的扩展,为用户提供以下能力:

  • 使用任何掩码语义(例如,变量长度张量、nan*操作等)

  • 区分0和NaN梯度

  • 各种稀疏应用(请参阅下方教程)

有关MaskedTensor更详细的介绍,请参阅`torch.masked文档 <https://pytorch.org/docs/master/masked.html>`__。

使用MaskedTensor

在本节中,我们讨论如何使用MaskedTensor,包括如何构造、访问数据和掩码,以及如何进行索引和切片。

准备工作

我们将首先为教程做必要的设置:

import torch
from torch.masked import masked_tensor, as_masked_tensor
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

构造

构造MaskedTensor有几种不同的方法:

  • 第一种方法是直接调用MaskedTensor类

  • 第二种(也是我们推荐的方法)是使用:func:`masked.masked_tensor`和:func:`masked.as_masked_tensor`工厂函数,这与:func:`torch.tensor`和:func:`torch.as_tensor`类似

在本教程中,我们将假设如下的导入行:from torch.masked import masked_tensor

访问数据和掩码

MaskedTensor中的底层字段可以通过以下方式访问:

  • :meth:`MaskedTensor.get_data`函数

  • :meth:`MaskedTensor.get_mask`函数。请记住,``True``表示“指定”或“有效”,而``False``表示“未指定”或“无效”。

总的来说,返回的底层数据在未指定的条目中可能无效,因此我们建议当用户需要一个没有任何掩码条目的Tensor时,使用:meth:`MaskedTensor.to_tensor`(如上所示)返回一个填充值的Tensor。

索引和切片

:class:`MaskedTensor`是Tensor的子类,这意味着它继承了与:class:`torch.Tensor`相同的索引和切片语义。以下是一些常见的索引和切片模式的示例:

data = torch.arange(24).reshape(2, 3, 4)
mask = data % 2 == 0

print("data:\n", data)
print("mask:\n", mask)
# float is used for cleaner visualization when being printed
mt = masked_tensor(data.float(), mask)

print("mt[0]:\n", mt[0])
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])

为什么MaskedTensor有用?

因为:class:`MaskedTensor``对指定和未指定值的处理是作为一等公民而不是事后补充(通过填充值、nans等),它能够解决常规Tensor无法解决的多个不足之处;事实上,:class:`MaskedTensor`的诞生很大程度上是由于这些重复性问题。

以下我们将讨论PyTorch中至今未解决的一些常见问题,并说明如何用:class:`MaskedTensor`解决这些问题。

区分0和NaN梯度

一个问题是 torch.Tensor 面临无法区分梯度是未定义(NaN)还是实际为 0 的困境。由于 PyTorch 没有一种方法标记一个值是已指定/有效还是未指定/无效,它不得不依赖 NaN 或 0(取决于具体情况),导致不可靠的语义,因为许多操作不能正确处理 NaN 值。更混乱的是,有时梯度会因操作顺序不同而有所变化(例如,取决于链操作中 NaN 值出现的早晚)。

MaskedTensor 是解决这个问题的完美方案!

torch.where

Issue 10729 中,我们注意到使用 torch.where() 时操作顺序可能会产生影响,因为我们无法区分 0 是实际的 0 还是来自未定义梯度的结果。因此,我们保持一致并屏蔽结果:

当前结果:

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
y.sum().backward()
x.grad

MaskedTensor 结果:

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
y.sum().backward()
mx.grad

这里的梯度仅提供给选定的子集。实际上,这会更改 where 的梯度以屏蔽掉元素,而不是将它们设置为零。

另一个 torch.where

Issue 52248 是另一个例子。

当前结果:

a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))

MaskedTensor 结果:

a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))

此问题类似(甚至链接到下面的下一个问题),因为它表达了对意外行为的挫败感,原因是无法区分“无梯度”和“零梯度”,这反过来使处理其他操作变得难以理解。

使用掩码时,x/0 产生 NaN 梯度

Issue 4132 中,用户提出 x.grad 应为 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过完全屏蔽梯度使这一点变得非常清晰。

当前结果:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0)  # => mask is [0, 1]
y[mask].backward()
x.grad

MaskedTensor 结果:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
loss.sum().backward()
x.grad

torch.nansum()torch.nanmean()

Issue 67180 中,梯度未正确计算(一个长期存在的问题),而 MaskedTensor 正确处理了它。

当前结果:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1

MaskedTensor 结果:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
mt = masked_tensor(a, ~torch.isnan(a))
c = mt * b
c1 = torch.sum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1

安全 Softmax

安全 softmax 是另一个经常出现的很好的例子,问题详见 issue。简而言之,如果整个批次被“屏蔽”或完全由填充组成(在 softmax 情况下,表示设置为 -inf),这会导致 NaN,进而可能导致训练发散。

幸运的是,MaskedTensor 已解决该问题。考虑以下设置:

data = torch.randn(3, 3)
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
x = data.masked_fill(~mask, float('-inf'))
mt = masked_tensor(data, mask)
print("x:\n", x)
print("mt:\n", mt)

例如,我们希望沿着 dim=0 计算 softmax。注意第二列是“不安全的”(即完全被屏蔽),因此计算 softmax 时结果会产生 0/0 = nan,因为 exp(-inf) = 0。然而,我们真正希望的是屏蔽梯度,因为它们未被指定且对训练无效。

PyTorch 结果:

x.softmax(0)

MaskedTensor 结果:

mt.softmax(0)

实现缺失的 torch.nan* 操作符

Issue 61474 中,有添加操作符以覆盖各种 torch.nan* 应用的请求,例如 torch.nanmaxtorch.nanmin 等。

一般来说,这些问题更自然地适配屏蔽语义,因此与其引入额外的操作符,我们建议使用 MaskedTensor。由于 nanmean 已经实现,我们可以将其用作比较点:

x = torch.arange(16).float()
y = x * x.fmod(4)
z = y.masked_fill(y == 0, float('nan'))  # we want to get the mean of y when ignoring the zeros
print("y:\n", y)
# z is just y with the zeros replaced with nan's
print("z:\n", z)
print("y.mean():\n", y.mean())
print("z.nanmean():\n", z.nanmean())
# MaskedTensor successfully ignores the 0's
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))

在上述示例中,我们构造了一个 y,并希望计算序列的平均值,同时忽略零值。torch.nanmean 可以用于该操作,但其余的 torch.nan* 操作缺乏实现。MaskedTensor 通过能够使用基础操作解决了这个问题,我们已经支持该问题中列出的其他操作。例如:

torch.argmin(masked_tensor(y, y != 0))

确实,当忽略零值时,最低参数的索引为索引 1 的 1。

MaskedTensor 还可以支持当数据完全被屏蔽时的归约操作,相当于上述情况,当数据张量完全为 nan 时。nanmean 会返回 ``nan``(一个模糊的返回值),而 MaskedTensor 更准确地指示一个屏蔽结果。

x = torch.empty(16).fill_(float('nan'))
print("x:\n", x)
print("torch.nanmean(x):\n", torch.nanmean(x))
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))

这是一个类似于安全 softmax 的问题,其中 0/0 = nan 而我们真正需要的是一个未定义的值。

总结

在本教程中,我们介绍了 MaskedTensors 是什么,演示了如何使用它们,并通过一系列示例和它们帮助解决的问题阐明了它们的价值。

进一步阅读

要继续学习更多内容,可以查看我们的 MaskedTensor Sparsity 教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。

脚本总运行时间: (0分钟 0.000秒)

画廊由Sphinx-Gallery生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源