您的位置:首页 > 其它

利用torch.utils.data.Dataset自定义数据加载类

2020-04-24 10:16 351 查看
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

import torchvision.transforms as T
transforms = T.Compose([
  T.Resize(224),
  T.CenterCrop(224),
  T.ToTensor(),
  T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
  def __init__(self, root, transforms=None):

    imgs = os.listdir(root)
    self.imgs = [os.path.join(root, img) for img in imgs]

    self.transforms = transforms
  def __getitem__(self, index):
    label = 1 if dog else 0
    data = Image.open(self.imgs[index])
    if self.transform:
      data = self.transform(data)
    return data, label
  def __len__(self):
    return len(self.imgs)
  • 点赞 1
  • 收藏
  • 分享
  • 文章举报
枫叶 发布了21 篇原创文章 · 获赞 3 · 访问量 322 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐