正文

处理数据样本的代码可能会逐渐变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset 允许我们使用预加载的数据集以及自定义数据。 Dataset存储样本及其对应的标签,DataLoader封装了一个迭代器用于遍历Dataset,以便轻松访问样本数据。

PyTorch 领域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集继承自torch.utils.data.Dataset并实现了特定于特定数据的功能。它们可用于对您的模型进行原型设计和基准测试。你可以在这里找到它们:图像数据集、 文本数据集和 音频数据集

1. 加载数据集

下面是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 文章图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载FashionMNIST 数据集:

  • root是存储训练/测试数据的路径,
  • train指定训练或测试数据集,
  • download=True如果数据不可用,则从 Internet 下载数据root
  • transformtarget_transform指定特征和标签转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:26, 303914.51it/s]
  0%|          | 65536/26421880 [00:00<01:27, 301769.74it/s]
  0%|          | 131072/26421880 [00:00<01:00, 437795.76it/s]
  1%|          | 229376/26421880 [00:00<00:42, 621347.43it/s]
  2%|1         | 491520/26421880 [00:00<00:20, 1259673.64it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2264911.11it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4467299.81it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 8587616.55it/s]
 26%|##6       | 6881280/26421880 [00:00<00:01, 14633777.99it/s]
 37%|###7      | 9830400/26421880 [00:01<00:00, 18150145.01it/s]
 49%|####8     | 12910592/26421880 [00:01<00:00, 21161097.17it/s]
 61%|######    | 16023552/26421880 [00:01<00:00, 23366004.89it/s]
 72%|#######2  | 19136512/26421880 [00:01<00:00, 24967488.10it/s]
 84%|########4 | 22249472/26421880 [00:01<00:00, 26016258.24it/s]
 95%|#########5| 25231360/26421880 [00:01<00:00, 26218488.24it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 15984902.80it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 268356.24it/s]
100%|##########| 29515/29515 [00:00<00:00, 266767.69it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:14, 302027.13it/s]
  1%|1         | 65536/4422102 [00:00<00:14, 300501.69it/s]
  3%|2         | 131072/4422102 [00:00<00:09, 436941.45it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 619517.19it/s]
 10%|9         | 425984/4422102 [00:00<00:03, 1044158.55it/s]
 20%|##        | 884736/4422102 [00:00<00:01, 2114396.73it/s]
 40%|####      | 1769472/4422102 [00:00<00:00, 4067080.68it/s]
 80%|########  | 3538944/4422102 [00:00<00:00, 7919346.09it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5036535.17it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 22168662.21it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

2. 迭代和可视化数据集

我们可以像python 列表一样索引Datasets,比如:

training_data[index]

我们用matplotlib来可视化训练数据中的一些样本。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows   1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

3.创建自定义数据集

自定义 Dataset 类必须实现三个函数:initlen__和__getitem

比如: FashionMNIST 图像存储在一个目录img_dir中,它们的标签分别存储在一个 CSV 文件annotations_file中。

在接下来的部分中,我们将分析每个函数中发生的事情。

import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(self.img_labels)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image,/uploads/20230220/63daddd909e747db7c388155edcefa4f.jpg,/uploads/20230220/afca6341cf19dfc06534123e87f9ae97.jpg,/uploads/20230220/54386c70565510f85a61db765d5cad1c.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

3.2 __len__

len 函数返回我们数据集中的样本数。

例子:

def __len__(self):
    return len(self.img_labels)

3.3 __getitem__

getitem 函数从给定索引处的数据集中加载并返回一个样本idx。基于索引,它识别图像在磁盘上的位置,使用 将其转换为张量read_image,从 csv 数据中检索相应的标签self.img_labels,调用它们的转换函数(如果适用),并返回张量图像和相应的标签一个元组。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

4. 使用 DataLoaders 为训练准备数据

Dataset一次加载一个样本数据和其对应的label。在训练模型时,我们通常希望以minibatches“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过拟合,并使用 Pythonmultiprocessing加速数据检索。

DataLoader是一个可迭代对象,它封装了复杂性并暴漏了简单的API。

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5.遍历 DataLoader

我们已将该数据集加载到 DataLoader中,并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱(为了更细粒度地控制数据加载顺序,请查看Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 4

以上就是python机器学习pytorch自定义数据加载器的详细内容,更多关于python pytorch自定义数据加载器的资料请关注Devmax其它相关文章!

python机器学习pytorch自定义数据加载器的更多相关文章

  1. XCode 3.2 Ruby和Python模板

    在xcode3.2下,我的ObjectiveCPython/Ruby项目仍然可以打开更新和编译,但是你无法创建新项目.鉴于xcode3.2中缺少ruby和python的所有痕迹(即创建项目并添加新的ruby/python文件),是否有一种简单的方法可以再次安装模板?我发现了一些关于将它们复制到某个文件夹的信息,但我似乎无法让它工作,我怀疑文件夹的位置已经改变为3.2.解决方法3.2中的应用程序模板

  2. Swift基本使用-函数和闭包(三)

    声明函数和其他脚本语言有相似的地方,比较明显的地方是声明函数的关键字swift也出现了Python中的组元,可以通过一个组元返回多个值。传递可变参数,函数以数组的形式获取参数swift中函数可以嵌套,被嵌套的函数可以访问外部函数的变量。可以通过函数的潜逃来重构过长或者太复杂的函数。

  3. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  4. Swift、Go、Julia与R能否挑战 Python 的王者地位

    本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请发送邮件至dio@foxmail.com举报,一经查实,本站将立刻删除。

  5. 红薯因 Swift 重写开源中国失败,貌似欲改用 Python

    本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请发送邮件至dio@foxmail.com举报,一经查实,本站将立刻删除。

  6. 你没看错:Swift可以直接调用Python函数库

    上周Perfect又推出了新一轮服务器端Swift增强函数库:Perfect-Python。对,你没看错,在服务器端Swift其实可以轻松从其他语种的函数库中直接拿来调用,不需要修改任何内容。以如下python脚本为例:Perfect-Python可以用下列方法封装并调用以上函数,您所需要注意的仅仅是其函数名称以及参数。

  7. Swift中的列表解析

    在Swift中完成这个的最简单的方法是什么?我在寻找类似的东西:从Swift2.x开始,有一些与你的Python样式列表解析相当的东西。(在这个意义上,它更像是Python的xrange。如果你想保持集合懒惰一路通过,只是这样说:与Python中的列表解析语法不同,Swift中的这些操作遵循与其他操作相同的语法。

  8. swift抛出终端的python错误

    每当我尝试启动与python相关的swift时,我都会收到错误.我该如何解决?

  9. 在Android上用Java嵌入Python

    解决方法看看this,它适用于J2SE,你可以尝试在Android上运行.

  10. Android中的自然语言处理API

    我正在尝试制作类似于thiswebsite的Android应用程序.问题是我对自然语言处理领域很陌生.我不希望实现太多,只是提供用户与应用程序的一些交互,给他一种感觉,他确实在与某人聊天.基本上,我只是捕获用户输入的文本并将其发送到API并显示从API检索的结果.我遇到了http://opennlp.apache.org/和http://gate.ac.uk/,但不知道如何在我的Android应用

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部