Shortcuts

学习基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || 变换 || 构建模型 || 自动微分 || 优化 || 保存和加载模型

变换

Created On: Feb 09, 2021 | Last Updated: Aug 11, 2021 | Last Verified: Not Verified

数据并非总是以机器学习算法所需的最终处理形式出现。我们使用 变换 对数据进行一些操作,使其适合训练。

所有 TorchVision 数据集都有两个参数 -transform 用于修改特征,target_transform 用于修改标签 - 接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几个常用的变换。

FashionMNIST 特征为 PIL 图像格式,标签为整数。在训练中,我们需要将特征转为归一化张量,并将标签转为独热编码张量。为了完成这些转换,我们使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 32.8k/26.4M [00:00<02:30, 175kB/s]
  0%|          | 65.5k/26.4M [00:00<02:31, 174kB/s]
  0%|          | 98.3k/26.4M [00:00<02:31, 173kB/s]
  1%|          | 229k/26.4M [00:00<01:09, 378kB/s]
  2%|1         | 426k/26.4M [00:00<00:40, 644kB/s]
  3%|3         | 885k/26.4M [00:01<00:21, 1.21MB/s]
  7%|6         | 1.74M/26.4M [00:01<00:10, 2.37MB/s]
 13%|#3        | 3.47M/26.4M [00:01<00:05, 4.58MB/s]
 25%|##5       | 6.62M/26.4M [00:01<00:02, 8.39MB/s]
 37%|###6      | 9.76M/26.4M [00:01<00:01, 10.5MB/s]
 49%|####8     | 12.9M/26.4M [00:02<00:01, 12.3MB/s]
 61%|######    | 16.0M/26.4M [00:02<00:00, 14.0MB/s]
 72%|#######2  | 19.1M/26.4M [00:02<00:00, 14.7MB/s]
 84%|########3 | 22.1M/26.4M [00:02<00:00, 14.5MB/s]
 95%|#########5| 25.2M/26.4M [00:02<00:00, 15.0MB/s]
100%|##########| 26.4M/26.4M [00:02<00:00, 9.26MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 206kB/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 205kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|          | 32.8k/4.42M [00:00<00:27, 161kB/s]
  1%|1         | 65.5k/4.42M [00:00<00:27, 160kB/s]
  2%|2         | 98.3k/4.42M [00:00<00:27, 160kB/s]
  5%|5         | 229k/4.42M [00:00<00:12, 349kB/s]
 10%|9         | 426k/4.42M [00:01<00:07, 568kB/s]
 20%|##        | 885k/4.42M [00:01<00:03, 1.13MB/s]
 39%|###9      | 1.74M/4.42M [00:01<00:01, 2.12MB/s]
 79%|#######8  | 3.47M/4.42M [00:01<00:00, 4.14MB/s]
100%|##########| 4.42M/4.42M [00:01<00:00, 2.69MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 23.0MB/s]

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转为 FloatTensor,并将图像的像素强度值缩放到范围 [0., 1.]

Lambda 转换

Lambda 转换应用任何用户定义的 lambda 函数。这里我们定义一个函数,将整数转为独热编码张量。它首先创建一个大小为 10(我们数据集中标签数量)的零张量,并调用 scatter_,在标签 y 给出的索引上分配一个 value=1

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源