Pytorch中自定义Dataset和Dataloader总结

  2019-4-5 


Dataset自定义数据集与DataLoader 数据集加载器

Dataset定义了数据集的存在,其中包含了每个数据的路径和标签,dataset类型不是拿来保存的!它只是相当于一个数据集索引器!并没有加载数据或封装数据!

Dataloader则定义了指定数据集的加载方法

在对Dataloader进行迭代的时候,Dataset指向的数据集才正式被加载

Dataset

Dataset 定义并加载自己的数据集
继承自Dataset必须重写三个函数
__init__(self, ):将任何需要的初始化,比如读取全部数据路径等,不限制参数个数
__getitem__(self,index):根据每次调用时的index返回对应元素和标签
__len__(self):负责返回数据集中的元素个数

以加载图片数据集为例

from torch.utils.data import Dataset
#结构
class myDataset(Dataset):
    #self, 代表可以补充其他自定参数
    def __init__(self,  ):
    pass
    def __len__(self):
        #返回最大长度
        return len
    def __getitem__(self,index):
        #返回每次应读取的单个数据
        return data,label

#例子
class myDataset(Dataset):
    def __init__(self,root,transform=None):
        # 所有图片的绝对路径
        imgs=os.listdir(root)
        #这句话可以使用glob快速加载 见66.
        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform

    def __getitem__(self, index):
        img_path = self.imgs[index]
        pil_img = Image.open(img_path)
        pil_img = pil_img.convert("RGB")
    if self.transforms:
            data = self.transforms(pil_img)
        else:
        pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
            label = xxxxx(这里省略,总之是得到这个图的标签)
        return data,label

    def __len__(self):
        return len(self.imgs)

#创建数据集实例并初始化
dataSet=FlameSet('./test',transform = transform)
#依然用Dataloader加载数据集
data = torch.utils.data.DataLoader(myDataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)

DataLoader

DataLoader将数据集对象和不同的取样器联合,如SequentitalSamplerRandomSampler,并使用单进程或多线程的迭代器,为我们提供批量数据。取样器是为算法提供数据的不同策略。
dataloader = torch.utils.data.DataLoader(trainSet,batch_size=BATCH_SIZE,shuffle=True,num_workers=0,collate_fn=fn)
Dataloader的储存数据形式是一个batch一个batch存,取也是一个batch一个batch取,每组数据的内容分别为一组batch的input和该组batch每个数据对应的label

for data in phase:
    inputs,label = data

注意: datalodaer.__len__()得到的是batch(分组)总数而不是数据总数
dataset.__len__得到的是数据总数

自定义DataLoader:

collate_fn是非常重要的,如nlp中经常在里面做padding
collate_fn是自定义函数来设计数据收集的方式,意思是已经通过上面的Dataset类中的getitem函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数
collate_fn的输入是一个list,list的长度是一个batch_size,list中的每个u岸数都是getitem得到的结果

以我一次项目中加载NLP数据集写的代码为例:

# 生成datasets
# 不进行转置,直接在网络中输入的时候调用RNN的函数进行转置
class textDataset(dataset.Dataset):
    def __init__(self,data,label):
        self.data = data
        self.label = label
    def __len__(self):
        return len(self.data)
    #在取数据的时候调用
    def __getitem__(self,index):
        data = torch.Tensor(self.data[index])
        # 这里是在getitem的时候才会转化为long
        #注意这里在转tensor的时候必须得是list格式,然后转成了之后再取出第0项即为单tensor元素
        label = torch.LongTensor([self.label[index]])[0]
        return data,label
    
def preprocess(batch_data):
    max_seq_size = 30 #最大句长 (看清楚这里处理的句子是,分割or。分割)
    vec_size = 200
    #经过dataloader调用__getitem__方法从dataset取得的数据是这样的形式(dataset[0],dataset[1],xxxx)一个batch长度,
    #也就是相当于((data0,label0),(data1,label1),xxx),所以用zip分别提取出来形成(data,label)
    #zip(*zipped) 是逆zip    
    datas,labels = zip(*batch_data)
    
    corpus_vec_resize = list()
    #遍历每一行(每一个seq)
    for corpus_vec in datas:
        #这里对每个data中的句子做padding   #【也可以使用掩码的方式做padding,生成掩码矩阵 应该有更简单的方法或工具】
        #将处理后的全部转为Tensor,如果转成narray再弄会出错
        #若长度过小,则从句尾开始填充
        if len(corpus_vec) < max_seq_size:
            pad_num = max_seq_size-len(corpus_vec) 
            padding = list([[0]*vec_size]*pad_num)
            seq_padded = np.append(corpus_vec,padding,axis=0)
            seq_padded = torch.Tensor(seq_padded)
            corpus_vec_resize.append(seq_padded)
            #若长度过长,则依次从句头和句末开始修剪
        if len(corpus_vec) > max_seq_size:
            cut_num = len(corpus_vec)-max_seq_size
            #句头剪cut_num//2 句尾剪cut-num-cut_num//2
            seq_cut = corpus_vec[cut_num//2:len(corpus_vec)-(cut_num-cut_num//2)] 
            seq_cut = torch.Tensor(seq_cut)
            corpus_vec_resize.append(seq_cut)
        if len(corpus_vec) == max_seq_size:
            corpus_vec = torch.Tensor(corpus_vec)
            corpus_vec_resize.append(corpus_vec)   
    #将datas转变为tensor   
    datas = torch.stack(corpus_vec_resize) 

    #合并处理labels
    label_res = list()
    for i in labels:
        #处理打错的标签(非0和1)
        if i.item()!=0 and i.item()!=1:
            label_res.append(1)
        else:
            label_res.append(i.item())
    labels = torch.LongTensor(label_res)
#     print(labels)
return datas,labels
#封装与加载
corpus_dataset = textDataset(data=corpus_vec,label=labels)
#这里定义了加载的具体方法:preprocess
textDataloader = dataloader.DataLoader(corpus_dataset,batch_size=batch_size,shuffle=False,num_workers=1,collate_fn=preprocess)

dataloader参数的更多详细资料还可以参考下面资料:

https://www.jianshu.com/p/8ea7fba72673
https://www.jianshu.com/p/bb90bff9f6e5

https://blog.csdn.net/weixin_30241919/article/details/95184794

dataloader的num_work子线程尽量数量开多点!不要让大量时间用去读数据.一般子线程设和核的个数相同

getitem()

如果类中定义了getitem()方法,那么它的实例对象(假设为P)就可以P[Key]这样取值。当实例对象做P[Key]运算的时候就会调用getitem()方法,返回的就是这个方法的return

dataset里有这个方法,当dataloader加载dataset的时候就会自动调用


且听风吟