Verified Commit 12350df5 authored by Yaoyao Liu's avatar Yaoyao Liu
Browse files

Update tranductive branch

parent fc05190d
*.pyc
*.npy
*.tar
run.sh
bashrc
miniimagenet
data/miniImageNet
ckpts/miniImageNet
cache
logs
data
pretrain_model
pretrain_model_1
*.pyc
__pycache__
## An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning
# Transductive Experiments
[![LICENSE](https://img.shields.io/github/license/yaoyao-liu/E3BM?style=flat-square)](https://github.com/yaoyao-liu/E3BM/blob/master/LICENSE)
[![Python](https://img.shields.io/badge/python-3.6-blue.svg?style=flat-square)](https://www.python.org/)
[![PyTorch](https://img.shields.io/badge/pytorch-1.2.0-%237732a8?style=flat-square)](https://pytorch.org/)
[![CodeFactor](https://img.shields.io/codefactor/grade/github/yaoyao-liu/E3BM/inductive?style=flat-square)](https://www.codefactor.io/repository/github/yaoyao-liu/e3bm)
### Running Experiments
This repository contains the PyTorch implementation for the Paper "[An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning](https://arxiv.org/pdf/1904.08479)". If you have any questions on this repository or the related paper, feel free to [create an issue](https://github.com/yaoyao-liu/E3BM/issues/new) or send me an email.
<br>
Email address: yaoyao.liu (at) mpi-inf.mpg.de
#### Summary
* [Introduction](#introduction)
* [Installation](#installation)
* [Inductive Experiments](#Inductive-Experiments)
* [Transductive Experiments](#transductive-Experiments)
* [Citation](#citation)
* [Acknowledgements](#acknowledgements)
## Introduction
Few-shot learning aims to train efficient predictive models with a few examples. The lack of training data leads to poor models that perform high-variance or low-confidence predictions. In this paper, we propose to meta-learn the ensemble of epoch-wise empirical Bayes models (E<sup>3</sup>BM) to achieve robust predictions. "Epoch-wise" means that each training epoch has a Bayes model whose parameters are specifically learned and deployed. "Empirical" means that the hyperparameters, e.g., used for learning and ensembling the epoch-wise models, are generated by hyperprior learners conditional on task-specific data. We introduce four kinds of hyperprior learners by considering inductive vs. transductive, and epoch-dependent vs. epoch-independent, in the paradigm of meta-learning. We conduct extensive experiments for five-class few-shot tasks on three challenging benchmarks: miniImageNet, tieredImageNet, and FC100, and achieve top performance using the epoch-dependent transductive hyperprior learner, which captures the richest information. Our ablation study shows that both "epoch-wise ensemble" and "empirical" encourage high efficiency and robustness in the model performance.
<p align="center">
<img src="https://yyliu.net/images/misc/e3bm.png" width="800"/>
</p>
> Figure: Conceptual illustrations of the model adaptation on the blue, red and yellow tasks. (a) MAML is the classical inductive method that meta-learns a network initialization θ that is used to learn a single base-learner on each task. (b) SIB is a transductive method that formulates a variational posterior as a function of both labeled training data T(tr) and unlabeled test data x(te). It also uses a single base-learner and optimizes the learner by running several synthetic gradient steps on x(te). (c) Our E<sup>3</sup>BM is a generic method that learns to combine the epoch-wise base-learners, and to generate task-specific learningcrates α and combination weights v that encourage robust adaptation.
### Installation
In order to run this repository, we advise you to install python 3.6 and PyTorch 1.2.0 with Anaconda.
You may download Anaconda and read the installation instruction on their official website:
<https://www.anaconda.com/download/>
Create a new environment and install PyTorch and torchvision on it:
```bash
conda create --name e3bm-pytorch python=3.6
conda activate e3bm-pytorch
conda install pytorch=1.2.0
conda install torchvision -c pytorch
```
Install other requirements:
Run meta-training with default settings (data and pre-trained model will be downloaded automatically):
```bash
pip install -r requirements.txt
python main.py --phase_sib=meta_train
```
### Inductive Experiments
#### Performance (ResNet-12)
Experiment results for 5-way few-shot classification on ResNet-12 (same as [this repository](https://github.com/kjunelee/MetaOptNet)).
| Method | Backbone |𝑚𝑖𝑛𝑖 1-shot | 𝑚𝑖𝑛𝑖 5-shot | 𝒕𝒊𝒆𝒓𝒆𝒅 1-shot | 𝒕𝒊𝒆𝒓𝒆𝒅 5-shot |
| -------------- |---------- | ---------- | ---------- |------------ | ------------ |
| [`ProtoNet`](https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch) | ResNet-12 |`60.37 ± 0.83` | `78.02 ± 0.57` | `65.65 ± 0.92` | `83.40 ± 0.65` |
| [`MatchNet`](https://github.com/gitabcworld/MatchingNetworks) | ResNet-12 |`63.08 ± 0.80` | `75.99 ± 0.60` | `68.50 ± 0.92` | `80.60 ± 0.71` |
| [`MetaOptNet`](https://github.com/kjunelee/MetaOptNet) | ResNet-12 |`62.64 ± 0.61` | `78.63 ± 0.46` | `65.99 ± 0.72` | `81.56 ± 0.53` |
| [`Meta-Baseline`](https://github.com/cyvius96/few-shot-meta-baseline) | ResNet-12 |`63.17 ± 0.23` | `79.26 ± 0.17` | `68.62 ± 0.27` | `83.29 ± 0.18` |
| [`CAN`](https://github.com/blue-blue272/fewshot-CAN) | ResNet-12 |`63.85 ± 0.48` | `79.44 ± 0.34` | `69.89 ± 0.51` | `84.93 ± 0.38` |
| `E3BM (Ours)` | ResNet-12 |`64.09 ± 0.37` | `80.29 ± 0.25` | `71.34 ± 0.41` | `85.82 ± 0.29` |
#### Running experiments
Run meta-training with default settings:
Run meta-test with our checkpoint (data and the checkpoint will be downloaded automatically):
```bash
python main.py -backbone resnet12 -shot 1 -way 5 -mode meta_train -dataset miniimagenet
python main.py -backbone resnet12 -shot 5 -way 5 -mode meta_train -dataset miniimagenet
python main.py -backbone resnet12 -shot 1 -way 5 -mode meta_train -dataset tieredimagenet
python main.py -backbone resnet12 -shot 5 -way 5 -mode meta_train -dataset tieredimagenet
python main.py --phase_sib=meta_eval
```
Run pre-training with default settings:
Run meta-test with other checkpoints:
```bash
python main.py -backbone resnet12 -mode pre_train -dataset miniimagenet
python main.py -backbone resnet12 -mode pre_train -dataset tieredimagenet
python main.py --phase_sib=meta_eval --meta_eval_load_path=<your_ckpt_dir>
```
#### Download resources
All the datasets and pre-trained models will be downloaded automatically.
You may also download the resources on Google Drive using the following links:
<br>
Dataset: [miniImageNet](https://drive.google.com/file/d/1vv3m14kusJcRpCsG-brG_Xk9MnetY9Bt/view?usp=sharing), and [tieredImageNet](https://drive.google.com/file/d/1T-4NVTSa5T6CXKSRbymYLnWp_OrtF-mo/view?usp=sharing)
<br>
Pre-trained models: [Google Drive](https://drive.google.com/file/d/13pzlvn9s4psbZlGpIsYCi9fwQnWeSIkP/view?usp=sharing)
<br>
Meta-trained checkpoints: [Google Drive](https://drive.google.com/drive/folders/17qTMpovfgEV6mRi8M4FkMYLIfBm3smgc?usp=sharing)
### Transductive Experiments
See the transductive setting experiments in this branch: <https://github.com/yaoyao-liu/E3BM/tree/transductive>.
### Citation
Please cite our paper if it is helpful to your work:
```bibtex
@inproceedings{Liu2020E3BM,
author = {Yaoyao Liu and
Bernt Schiele and
Qianru Sun},
title = {An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning},
booktitle = {European Conference on Computer Vision (ECCV)},
year = {2020}
}
```
### Acknowledgements
Our implementations use the source code from the following repositories:
* [Learning Embedding Adaptation for Few-Shot Learning](https://github.com/Sha-Lab/FEAT)
* [Empirical Bayes Transductive Meta-Learning with Synthetic Gradients](https://github.com/hushell/sib_meta_learn)
* [DeepEMD: Differentiable Earth Mover's Distance for Few-Shot Learning](https://github.com/icoz69/DeepEMD)
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1bL_l3OtOM-bYlIDpsODVmbfhbXhA2i6X' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1bL_l3OtOM-bYlIDpsODVmbfhbXhA2i6X" -O netFeatBest.pth && rm -rf /tmp/cookies.txt
mkdir ckpts
mkdir ckpts/miniImageNet
mv netFeatBest.pth ckpts/miniImageNet/
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vEFTrtcShLQcm8AqEf-a6EifOXY6LM_5' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1vEFTrtcShLQcm8AqEf-a6EifOXY6LM_5" -O e3bm_ckpt.pth && rm -rf /tmp/cookies.txt
mv e3bm_ckpt.pth ckpts/miniImageNet/
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1FzkXVCIA8VbhcOoKQPJKX5OVe7Q0aFKp' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1FzkXVCIA8VbhcOoKQPJKX5OVe7Q0aFKp" -O Mini-ImageNet.zip && rm -rf /tmp/cookies.txt
unzip Mini-ImageNet.zip
rm Mini-ImageNet.zip
rm -r Mini-ImageNet/train_val Mini-ImageNet/train_test
mv Mini-ImageNet/train_train Mini-ImageNet/train
mv Mini-ImageNet data/miniImageNet
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ia1mUe9OVIVfs_ggjiycBS6ttQd8Ke1i' -O val1000Episode_5_way_5_shot.json
mv val1000Episode_5_way_5_shot.json data/miniImageNet/
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1G206V-8_Ls5QH05KKFtuWXmJ3iv-MRQn' -O val1000Episode_5_way_1_shot.json
mv val1000Episode_5_way_1_shot.json data/miniImageNet/
# Copyright (c) 2020 Yaoyao Liu. All Rights Reserved.
# Some files of this repository are modified from https://github.com/hushell/sib_meta_learn
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
\ No newline at end of file
import os.path as osp
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from os.path import expanduser
class DatasetLoader(Dataset):
"""The class to load the dataset"""
def __init__(self, setname, args=None):
if args.datadir is None:
home_dir = expanduser("~")
DATASET_DIR=os.path.join(home_dir,'dataset/cifar_fs')
else:
DATASET_DIR=os.path.join(args.datadir,'cifar_fs')
print('****input with 84******')
# Set the path according to train, val and test
if setname=='train':
THE_PATH = osp.join(DATASET_DIR, 'meta-train')
label_list = os.listdir(THE_PATH)
elif setname=='test':
THE_PATH = osp.join(DATASET_DIR, 'meta-test')
label_list = os.listdir(THE_PATH)
elif setname=='val':
THE_PATH = osp.join(DATASET_DIR, 'meta-val')
label_list = os.listdir(THE_PATH)
else:
raise ValueError('Wrong setname.')
# Generate empty list for data and label
data = []
label = []
# Get folders' name
folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))]
# Get the images' paths and labels
for idx, this_folder in enumerate(folders):
this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path))
label.append(idx)
# Set data, label and class number to be accessable from outside
self.data = data
self.label = label
self.num_class = len(set(label))
# Transformation
if setname == 'train':
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))])
#transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
# np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([92,92]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))])
#transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
# np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, label
if __name__=='__main__':
dataset=DatasetLoader('train')
print ('num image in this set:',dataset.__len__())
print ('num class in this set:',np.unique(dataset.label).__len__())
#test 160 class,206209 image
#val 97 class, 124261 image
#train 351 class ,448695 image
a=dataset.__getitem__(0)
\ No newline at end of file
# Copyright (c) 2020 Yaoyao Liu. All Rights Reserved.
# Some files of this repository are modified from https://github.com/hushell/sib_meta_learn
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
import os.path as osp
import PIL
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
THIS_PATH = osp.dirname(__file__)
ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..'))
IMAGE_PATH = osp.join(ROOT_PATH, 'data/cub/images')
SPLIT_PATH = osp.join(ROOT_PATH, 'data/cub/split')
class CUB(Dataset):
def __init__(self, setname, args):
txt_path = osp.join(SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:]
data = []
label = []
lb = -1
self.wnids = []
for l in lines:
context = l.split(',')
name = context[0]
wnid = context[1]
path = osp.join(IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
data.append(path)
label.append(lb)
self.data = data
self.label = label
self.num_class = np.unique(np.array(label)).shape[0]
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.transform = transforms.Compose([
transforms.Resize(84, interpolation = PIL.Image.BICUBIC),
transforms.CenterCrop(84),
transforms.ToTensor(),
normalize])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, label
# Copyright (c) 2020 Yaoyao Liu. All Rights Reserved.
# Some files of this repository are modified from https://github.com/hushell/sib_meta_learn
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
import numpy as np
import torchvision.transforms as transforms
def dataset_setting(dataset, nSupport):
if dataset == 'MiniImageNet':
mean = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
std = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
normalize = transforms.Normalize(mean=mean, std=std)
trainTransform = transforms.Compose([transforms.RandomCrop(80, padding=8),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize
])
valTransform = transforms.Compose([transforms.CenterCrop(80),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize])
inputW, inputH, nbCls = 80, 80, 64
trainDir = './data/miniImageNet/train/'
valDir = './data/miniImageNet/val/'
testDir = './data/miniImageNet/test/'
episodeJson = './data/miniImageNet/val1000Episode_5_way_1_shot.json' if nSupport == 1 \
else './data/miniImageNet/val1000Episode_5_way_5_shot.json'
else:
raise ValueError('Do not support other datasets yet.')
return trainTransform, valTransform, inputW, inputH, trainDir, valDir, testDir, episodeJson, nbCls
# Copyright (c) 2020 Yaoyao Liu. All Rights Reserved.
# Some files of this repository are modified from https://github.com/hushell/sib_meta_learn
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
import os.path as osp
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from os.path import expanduser
class DatasetLoader(Dataset):
"""The class to load the dataset"""
def __init__(self, setname, args=None):
data_base_dir = 'data/FC100'
if os.path.exists(data_base_dir):
pass
else:
print ('Download FC100 from Google Drive.')
os.makedirs(data_base_dir)
os.system('sh scripts/download_fc100.sh')
TRAIN_PATH = 'data/FC100/train'
VAL_PATH = 'data/FC100/val'
TEST_PATH = 'data/FC100/test'
THIS_PATH = osp.dirname(__file__)
ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..'))
TRAIN_PATH = osp.join(ROOT_PATH, 'data/FC100/meta-train/train')
VAL_PATH = osp.join(ROOT_PATH, 'data/FC100/meta-train/val')
TEST_PATH = osp.join(ROOT_PATH, 'data/FC100/meta-train/test')
# Set the path according to train, val and test
class FewshotCifar(Dataset):
""" Usage:
"""
def __init__(self, setname, args, train_aug=False):
if setname=='train':
THE_PATH = TRAIN_PATH
label_list = os.listdir(THE_PATH)
elif setname=='test':
THE_PATH = TEST_PATH
label_list = os.listdir(THE_PATH)
elif setname=='val':
THE_PATH = VAL_PATH
label_list = os.listdir(THE_PATH)
else:
raise ValueError('Wrong setname.')
# Generate empty list for data and label
raise ValueError('Wrong setname.')
data = []
label = []
folders = [osp.join(THE_PATH, label) for label in os.listdir(THE_PATH) if os.path.isdir(osp.join(THE_PATH, label))]
# Get folders' name
folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))]
# Get the images' paths and labels
for idx, this_folder in enumerate(folders):
for idx in range(len(folders)):
this_folder = folders[idx]
this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path))
label.append(idx)
# Set data, label and class number to be accessable from outside
self.data = data
self.label = label
self.num_class = len(set(label))
# Transformation
if setname == 'train':
image_size = 84
if train_aug:
image_size = 80
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
else:
image_size = 84
image_size = 80
self.transform = transforms.Compose([
transforms.Resize([92,92]),
transforms.Resize(92),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
def __len__(self):
return len(self.data)
......@@ -80,12 +80,3 @@ class DatasetLoader(Dataset):
image = self.transform(Image.open(path).convert('RGB'))
return image, label
if __name__=='__main__':
dataset=DatasetLoader('train')
print ('num image in this set:',dataset.__len__())
print ('num class in this set:',np.unique(dataset.label).__len__())
#test 160 class,206209 image
#val 97 class, 124261 image
#train 351 class ,448695 image
a=dataset.__getitem__(0)
# Copyright (c) 2020 Yaoyao Liu. All Rights Reserved.
# Some files of this repository are modified from https://github.com/hushell/sib_meta_learn
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
......@@ -13,69 +14,74 @@
# ==============================================================================
import os.path as osp
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from os.path import expanduser
import os
class MiniImageNet(Dataset):
def __init__(self, setname, args):
data_base_dir = 'data/miniimagenet'
if os.path.exists(data_base_dir):
pass
else:
print ('Download miniImageNet from Google Drive.')
os.makedirs(data_base_dir)
os.system('sh scripts/download_miniimagenet.sh')
IMAGE_PATH = 'data/miniimagenet/images'
SPLIT_PATH = 'data/miniimagenet/split'
csv_path = osp.join(SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
THIS_PATH = osp.dirname(__file__)
ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..'))
TRAIN_PATH = osp.join(ROOT_PATH, 'data/miniImageNet/train')
VAL_PATH = osp.join(ROOT_PATH, 'data/miniImageNet/val')
TEST_PATH = osp.join(ROOT_PATH, 'data/miniImageNet/test')
class MiniImageNet2(Dataset):
def __init__(self, setname, args):
if setname=='train':
THE_PATH = TRAIN_PATH
label_list = os.listdir(TRAIN_PATH)
elif setname=='test':
THE_PATH = TEST_PATH
label_list = os.listdir(THE_PATH)
elif setname=='val':
THE_PATH = VAL_PATH
label_list = os.listdir(THE_PATH)
else:
raise ValueError('Wrong setname.')
data = []
label = []
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
data.append(path)
label.append(lb)
folders = [osp.join(THE_PATH, label) for label in label_list if os.path.isdir(osp.join(THE_PATH, label))]
for idx in range(len(folders)):
this_folder = folders[idx]
this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path))
label.append(idx)
self.data = data
self.label = label
self.num_class = len(set(label))
if setname == 'val' or setname == 'test':
image_size = 84
if args.pre_augmentation:
image_size