Verified Commit c791a17b authored by Yaoyao Liu's avatar Yaoyao Liu
Browse files

Update dataloader

parent b93f4817
import os.path as osp
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from os.path import expanduser
class DatasetLoader(Dataset):
"""The class to load the dataset"""
def __init__(self, setname, args=None):
if args.datadir is None:
home_dir = expanduser("~")
DATASET_DIR=os.path.join(home_dir,'dataset/cifar_fs')
else:
DATASET_DIR=os.path.join(args.datadir,'cifar_fs')
print('****input with 84******')
# Set the path according to train, val and test
if setname=='train':
THE_PATH = osp.join(DATASET_DIR, 'meta-train')
label_list = os.listdir(THE_PATH)
elif setname=='test':
THE_PATH = osp.join(DATASET_DIR, 'meta-test')
label_list = os.listdir(THE_PATH)
elif setname=='val':
THE_PATH = osp.join(DATASET_DIR, 'meta-val')
label_list = os.listdir(THE_PATH)
else:
raise ValueError('Wrong setname.')
# Generate empty list for data and label
data = []
label = []
# Get folders' name
folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))]
# Get the images' paths and labels
for idx, this_folder in enumerate(folders):
this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path))
label.append(idx)
# Set data, label and class number to be accessable from outside
self.data = data
self.label = label
self.num_class = len(set(label))
# Transformation
if setname == 'train':
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))])
#transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
# np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([92,92]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))])
#transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
# np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, label
if __name__=='__main__':
dataset=DatasetLoader('train')
print ('num image in this set:',dataset.__len__())
print ('num class in this set:',np.unique(dataset.label).__len__())
#test 160 class,206209 image
#val 97 class, 124261 image
#train 351 class ,448695 image
a=dataset.__getitem__(0)
\ No newline at end of file
...@@ -20,8 +20,7 @@ class DatasetLoader(Dataset): ...@@ -20,8 +20,7 @@ class DatasetLoader(Dataset):
TRAIN_PATH = 'data/FC100/train' TRAIN_PATH = 'data/FC100/train'
VAL_PATH = 'data/FC100/val' VAL_PATH = 'data/FC100/val'
TEST_PATH = 'data/FC100/test' TEST_PATH = 'data/FC100/test'
# Set the path according to train, val and test
if setname=='train': if setname=='train':
THE_PATH = TRAIN_PATH THE_PATH = TRAIN_PATH
label_list = os.listdir(THE_PATH) label_list = os.listdir(THE_PATH)
...@@ -33,27 +32,22 @@ class DatasetLoader(Dataset): ...@@ -33,27 +32,22 @@ class DatasetLoader(Dataset):
label_list = os.listdir(THE_PATH) label_list = os.listdir(THE_PATH)
else: else:
raise ValueError('Wrong setname.') raise ValueError('Wrong setname.')
# Generate empty list for data and label
data = [] data = []
label = [] label = []
# Get folders' name
folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))] folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))]
# Get the images' paths and labels
for idx, this_folder in enumerate(folders): for idx, this_folder in enumerate(folders):
this_folder_images = os.listdir(this_folder) this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images: for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path)) data.append(osp.join(this_folder, image_path))
label.append(idx) label.append(idx)
# Set data, label and class number to be accessable from outside
self.data = data self.data = data
self.label = label self.label = label
self.num_class = len(set(label)) self.num_class = len(set(label))
# Transformation
if setname == 'train': if setname == 'train':
image_size = 84 image_size = 84
self.transform = transforms.Compose([ self.transform = transforms.Compose([
...@@ -79,13 +73,3 @@ class DatasetLoader(Dataset): ...@@ -79,13 +73,3 @@ class DatasetLoader(Dataset):
path, label = self.data[i], self.label[i] path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB')) image = self.transform(Image.open(path).convert('RGB'))
return image, label return image, label
if __name__=='__main__':
dataset=DatasetLoader('train')
print ('num image in this set:',dataset.__len__())
print ('num class in this set:',np.unique(dataset.label).__len__())
#test 160 class,206209 image
#val 97 class, 124261 image
#train 351 class ,448695 image
a=dataset.__getitem__(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment