备注
点击 这里 下载完整示例代码
学习基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || 变换 || 构建模型 || 自动微分 || 优化 || 保存和加载模型
变换¶
Created On: Feb 09, 2021 | Last Updated: Aug 11, 2021 | Last Verified: Not Verified
数据并非总是以机器学习算法所需的最终处理形式出现。我们使用 变换 对数据进行一些操作,使其适合训练。
所有 TorchVision 数据集都有两个参数 -transform
用于修改特征,target_transform
用于修改标签 - 接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几个常用的变换。
FashionMNIST 特征为 PIL 图像格式,标签为整数。在训练中,我们需要将特征转为归一化张量,并将标签转为独热编码张量。为了完成这些转换,我们使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 32.8k/26.4M [00:00<02:30, 175kB/s]
0%| | 65.5k/26.4M [00:00<02:31, 174kB/s]
0%| | 98.3k/26.4M [00:00<02:31, 173kB/s]
1%| | 229k/26.4M [00:00<01:09, 378kB/s]
2%|1 | 426k/26.4M [00:00<00:40, 644kB/s]
3%|3 | 885k/26.4M [00:01<00:21, 1.21MB/s]
7%|6 | 1.74M/26.4M [00:01<00:10, 2.37MB/s]
13%|#3 | 3.47M/26.4M [00:01<00:05, 4.58MB/s]
25%|##5 | 6.62M/26.4M [00:01<00:02, 8.39MB/s]
37%|###6 | 9.76M/26.4M [00:01<00:01, 10.5MB/s]
49%|####8 | 12.9M/26.4M [00:02<00:01, 12.3MB/s]
61%|###### | 16.0M/26.4M [00:02<00:00, 14.0MB/s]
72%|#######2 | 19.1M/26.4M [00:02<00:00, 14.7MB/s]
84%|########3 | 22.1M/26.4M [00:02<00:00, 14.5MB/s]
95%|#########5| 25.2M/26.4M [00:02<00:00, 15.0MB/s]
100%|##########| 26.4M/26.4M [00:02<00:00, 9.26MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 206kB/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 205kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%| | 32.8k/4.42M [00:00<00:27, 161kB/s]
1%|1 | 65.5k/4.42M [00:00<00:27, 160kB/s]
2%|2 | 98.3k/4.42M [00:00<00:27, 160kB/s]
5%|5 | 229k/4.42M [00:00<00:12, 349kB/s]
10%|9 | 426k/4.42M [00:01<00:07, 568kB/s]
20%|## | 885k/4.42M [00:01<00:03, 1.13MB/s]
39%|###9 | 1.74M/4.42M [00:01<00:01, 2.12MB/s]
79%|#######8 | 3.47M/4.42M [00:01<00:00, 4.14MB/s]
100%|##########| 4.42M/4.42M [00:01<00:00, 2.69MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 23.0MB/s]
Lambda 转换¶
Lambda 转换应用任何用户定义的 lambda 函数。这里我们定义一个函数,将整数转为独热编码张量。它首先创建一个大小为 10(我们数据集中标签数量)的零张量,并调用 scatter_,在标签 y
给出的索引上分配一个 value=1
。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))