您的位置:首页 > 其它

关于Pytorch中Dataset和Dataloader的理解

2019-06-23 17:05 696 查看

使用Pytorch自定义读取数据时步骤如下:
1)创建Dataset对象
2)将Dataset对象作为参数传递到Dataloader中

详述步骤1)创建Dataset对象:
需要编写继承Dataset的类,并且覆写__getitem__和__len__方法,代码如下

class dataset(Dataset):
def process(self):
#对数据进行处理
pass

def __getitem__(self,index):
pass

def __len__(self):
#返回数据的长度
pass

(1)其中__getitem__函数的作用是根据索引index遍历数据
(2)__len__函数的作用是返回数据集的长度
(3)在创建的dataset类中可根据自己的需求对数据进行处理。可编写独立的数据处理函数,在__getitem__函数中进行调用,例如上述代码片段中的process函数;或者直接将数据处理方法写在__getitem__函数中。

详述步骤2)创建将Dataset对象作为参数传递到Dataloader中:
只需要将步骤1)创建的Dataset对象作为参数传递到Dataloader中,代码如下:

#创建对象
dataset_object = dataset()
#将dataset_object传递到Dataloader中
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate,
pin_memory=False,
drop_last=False)

需要注意的是,Dataloader中存在一个默认的collate_fn函数,需要根据自己的需求重写collate_fn函数:
(1)该函数的作用是将数据整理成一个batch,即根据batch_size的大小一次性在数据集中取出batch_size个数据。例如数据集中有4条数据,batch_size的值为2,则每次在4条数据中取出2条数据。
(2)collate_fn函数的输入是一个list,list中的每个元素为自己编写的dataset类中__getitem__函数的返回值。
(3)Dataloader中drop_last代表将不足一个batch_size的数据是否保留,即假如有4条数据,batch_size的值为3,将取出一个batch_size之后剩余的1条数据是否仍然作为训练数据。

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: