Code for our ICLR 2021 paper Wandering Within a World: Online Contextualized Few-Shot Learning [arxiv]
Although our code base is MIT licensed, the RoamingRooms dataset is not since it is derived from the Matterport3D dataset.
To download the RoamingRooms dataset, you first need to sign the agreement for a non-commercial license here.
Then, you need to submit a request here. We will manually approve your request afterwards.
For inquiries, please email: roaming-rooms@googlegroups.com
The whole dataset is around 60 GB. It has 1.2M video frames with 7k unique object instance classes. Please refer to our paper for more statistics of the dataset.
Our code is tested on Ubuntu 18.04 with GPU capability. We provide docker files for reproducible environments. We recommend at least 20GB CPU memory and 11GB GPU memory. 2-4 GPUs are required for multi-GPU experiments. Our code is based on TensorFlow 2.
Install protoc from here.
Run make to build proto buffer configuration files.
Install docker and nvidia-docker.
Build the docker container using ./build_docker.sh.
Modify the environment paths. You need to change DATA_DIR and OURPUT_DIR in setup_environ.sh. DATA_DIR is the main folder where datasets are placed and OUTPUT_DIR is the main folder where training models are saved.
Install protoc from here.
Run make to build proto buffer configuration files.
Modify the environment paths. You need to change DATA_DIR and OURPUT_DIR in setup_environ.sh. DATA_DIR is the main folder where datasets are placed and OUTPUT_DIR is the main folder where training models are saved.
Create a conda environment:
conda create -n oc-fewshot python=3.6
conda activate oc-fewshot
conda install pip
Install CUDA 10.1
Install OpenMPI 4.0.0
Install NCCL 2.6.4 for CUDA 10.1
Modify installation paths in install.sh
Run install.sh
To set up the Omniglot dataset, run script/download_omniglot.sh. This script will download the Omniglot dataset to DATA_DIR.
To set up the Uppsala texture dataset (for spatiotemporal cue experiments), run script/download_uppsala.sh. This script will download the Uppsala texture dataset to DATA_DIR.
To run training on your own, use the following command.
./run_docker.sh {GPU_ID} python -m fewshot.experiments.oc_fewshot \
--config {MODEL_CONFIG_PROTOTXT} \
--data {EPISODE_CONFIG_PROTOTXT} \
--env configs/environ/roaming-omniglot-docker.prototxt \
--tag {TAG} \
[--eval]
MODEL_CONFIG_PROTOTXT can be found in configs/models.EPISODE_CONIFG_PROTOTXT can be found in configs/episodes.TAG is the name of the saved checkpoint folder.--eval flag to evaluate.For example, to train CPM on the semisupervised benchmark:
./run_docker.sh 0 python -m fewshot.experiments.oc_fewshot \
--config configs/models/roaming-omniglot/cpm.prototxt \
--data configs/episodes/roaming-omniglot/roaming-omniglot-150-ssl.prototxt \
--env configs/environ/roaming-omniglot-docker.prototxt \
--tag roaming-omniglot-ssl-cpm
All of our code is tested using GTX 1080 Ti with 11GB GPU memory. Note that the above command uses a single GPU. Our original experiments in the paper is performed using two GPUs, with twice the batch size and doubled learning rate. To run that setting, use the following command:
./run_docker_hvd_01.sh python -m fewshot.experiments.oc_fewshot_hvd \
--config {MODEL_CONFIG_PROTOTXT} \
--data {EPISODE_CONFIG_PROTOTXT} \
--env configs/environ/roaming-omniglot-docker.prototxt \
--tag {TAG}
Below we include command to run experiments on RoamingRooms. Our original experiments in the paper is performed using four GPUs, with batch size to be 8. To run that setting, use the following command:
./run_docker_hvd_0123.sh python -m fewshot.experiments.oc_fewshot_hvd \
--config {MODEL_CONFIG_PROTOTXT} \
--data {EPISODE_CONFIG_PROTOTXT} \
--env configs/environ/roaming-rooms-docker.prototxt \
--tag {TAG}
When evaluate, use --eval --usebest to pick the checkpoint with the highest validation performance.
Table 1: RoamingOmniglot Results (Supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 64.34 | 61.00 ± 0.22 | 81.85 ± 0.21 | link |
| DNC | 81.30 | 78.87 ± 0.19 | 91.01 ± 0.15 | link |
| OML-U | 77.38 | 70.98 ± 0.21 | 89.13 ± 0.16 | link |
| OML-U++ | 86.85 | 88.43 ± 0.14 | 92.07 ± 0.14 | link |
| Online MatchingNet | 88.69 | 84.82 ± 0.15 | 95.55 ± 0.11 | link |
| Online IMP | 90.15 | 85.74 ± 0.15 | 96.66 ± 0.09 | link |
| Online ProtoNet | 90.49 | 85.68 ± 0.15 | 96.95 ± 0.09 | link |
| CPM (Ours) | 94.17 | 91.99 ± 0.11 | 97.74 ± 0.08 | link |
Table 2: RoamingOmniglot Results (Semi-supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 54.34 | 68.30 ± 0.20 | 76.38 ± 0.49 | link |
| DNC | 81.37 | 88.56 ± 0.12 | 93.81 ± 0.26 | link |
| OML-U | 66.70 | 74.65 ± 0.19 | 90.81 ± 0.34 | link |
| OML-U++ | 81.39 | 89.07 ± 0.19 | 89.40 ± 0.18 | link |
| Online MatchingNet | 84.39 | 88.77 ± 0.13 | 97.28 ± 0.17 | link |
| Online IMP | 81.62 | 88.68 ± 0.13 | 97.09 ± 0.19 | link |
| Online ProtoNet | 84.61 | 88.71 ± 0.13 | 97.61 ± 0.17 | link |
| CPM (Ours) | 90.42 | 93.18 ± 0.16 | 97.89 ± 0.15 | link |
Table 3: RoamingRooms Results (Supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 45.67 | 59.90 ± 0.40 | 61.85 ± 0.45 | link |
| DNC | 80.86 | 82.15 ± 0.32 | 87.30 ± 0.30 | link |
| OML-U | 76.27 | 73.91 ± 0.37 | 83.99 ± 0.33 | link |
| OML-U++ | 88.03 | 88.32 ± 0.27 | 89.61 ± 0.29 | link |
| Online MatchingNet | 85.91 | 82.82 ± 0.32 | 89.99 ± 0.26 | link |
| Online IMP | 87.33 | 85.28 ± 0.31 | 90.83 ± 0.25 | link |
| Online ProtoNet | 86.01 | 84.89 ± 0.31 | 89.58 ± 0.28 | link |
| CPM (Ours) | 89.14 | 88.39 ± 0.27 | 91.31 ± 0.26 | link |
Table 4: RoamingRooms Results (Semi-supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 33.32 | 52.71 ± 0.38 | 55.83 ± 0.76 | link |
| DNC | 73.49 | 80.27 ± 0.33 | 87.87 ± 0.49 | link |
| OML-U | 63.40 | 70.67 ± 0.38 | 85.25 ± 0.56 | link |
| OML-U++ | 81.90 | 84.79 ± 0.31 | 89.80 ± 0.47 | link |
| Online MatchingNet | 78.99 | 80.08 ± 0.34 | 92.43 ± 0.41 | link |
| Online IMP | 75.36 | 84.57 ± 0.31 | 91.17 ± 0.43 | link |
| Online ProtoNet | 76.36 | 80.67 ± 0.34 | 88.83 ± 0.49 | link |
| CPM (Ours) | 84.12 | 86.17 ± 0.30 | 91.16 ± 0.44 | link |
Table 5: RoamingImageNet Results (Supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 7.73 | 11.60 ± 0.12 | 43.93 ± 0.27 | link |
LSTM* |
22.54 | 28.14 ± 0.20 | 52.07 ± 0.27 | link |
| DNC | 7.20 | 10.55 ± 0.11 | 42.22 ± 0.27 | link |
DNC* |
26.80 | 33.45 ± 0.19 | 55.78 ± 0.27 | link |
| OML-U | 21.89 | 15.06 ± 0.14 | 52.52 ± 0.27 | link |
| OML-U Cos | 10.87 | 24.45 ± 0.18 | 30.89 ± 0.24 | link |
| Online MatchingNet | 13.05 | 20.61 ± 0.15 | 38.73 ± 0.24 | link |
| Online IMP | 14.25 | 22.92 ± 0.16 | 41.01 ± 0.25 | link |
| Online ProtoNet | 15.51 | 22.95 ± 0.17 | 44.98 ± 0.25 | link |
Online ProtoNet* |
23.10 | 32.82 ± 0.19 | 49.98 ± 0.25 | link |
| CPM (Ours) | 34.43 | 40.40 ± 0.21 | 60.29 ± 0.26 | link |
* denotes using pretrained CNN.
Table 6: RoamingImageNet Results (Semi-supervised)
| Method | AP | 1-shot Acc. | 3-shot Acc. | Checkpoint |
|---|---|---|---|---|
| LSTM | 4.03 | 22.53 ± 0.18 | 41.34 ± 0.55 | link |
LSTM* |
13.50 | 30.02 ± 0.20 | 46.95 ± 0.56 | link |
| DNC | 3.66 | 22.37 ± 0.18 | 37.83 ± 0.54 | link |
DNC* |
16.50 | 39.53 ± 0.19 | 54.10 ± 0.54 | link |
| OML-U | 10.16 | 22.74 ± 0.17 | 55.81 ± 0.55 | link |
| OML-U Cos | 5.65 | 23.37 ± 0.16 | 32.79 ± 0.50 | link |
| Online MatchingNet | 9.32 | 25.96 ± 0.16 | 55.32 ± 0.51 | link |
| Online IMP | 4.55 | 20.70 ± 0.15 | 51.23 ± 0.53 | link |
| Online ProtoNet | 7.10 | 26.87 ± 0.16 | 42.40 ± 0.52 | link |
Online ProtoNet* |
15.76 | 36.69 ± 0.18 | 55.47 ± 0.53 | link |
| CPM (Ours) | 24.75 | 44.58 ± 0.21 | 58.72 ± 0.53 | link |
* denotes using pretrained CNN.
If you use our code, please consider cite the following:
@inproceedings{ren21ocfewshot,
author = {Mengye Ren and
Michael L. Iuzzolino and
Michael C. Mozer and
Richard S. Zemel},
title = {Wandering Within a World: Online Contextualized Few-Shot Learning},
booktitle = {9th International Conference on Learning Representations, {ICLR}},
year = {2021}
}