“Generative Pretraining from Pixels” Summarized
https://proceedings.icml.cc/static/paper_files/icml/2020/6022-Paper.pdf (2020/06)
1. History of Unsupervised Pre-training in Computer Vision
- Mid~Late 2000’s: Approaches such as Deep Belief Network and Denoising Autoencoder were used to learn beneficial features for the subsequent supervised modeling.
- Early~Mid 2010’s: Piecewise linear activation functions, improved initializations and normalization strategies removed need for pre-training.
- Mid 2010’s: Generative models such as GANs and VAEs have been developed and evaluated for their representation learning capabilities.
- Mid~Late 2010’s: In NLP, unsupervised pre-training such as BERT and GPT pushed SOTA forward.
- Late 2010’s: New training objectives, more recent architectures, and increased model capacity has allowed self-supervised pre-training to achieve SOTA in low data settings and some transfer learning settings. Especially, contrastive losses comparing various views and transformations of input images have recently driven significant progress in self-supervised learning.
Many self-supervised approaches in computer vision focused on designing auxiliary objectives which support the learning of useful representations without attempting to directly model the input data. In contrast, the authors studied generative pre-training of images with transformer decoder. We call the model Image-GPT (iGPT).
2. Pre-training Objectives
Two objectives were studied; auto-regressive and BERT.
1. Auto-regressive objective
We model the density $p(x)$ auto-regressively, and minimize the negative log-likelihood of the data.
\[p(x) = \prod_{i=1}^n p(x_i|x_1,...,x_{i-1}, \theta)\] \[L_{AR} = \mathbb{E}_{x\sim X}[-\log p(x)]\]2. BERT objective
We sample a sub-sequence $M$ such that each index $i$ independently has probability of 0.15 of appearing in $M$. Then we train the model by minimizing the negative log-likelihood of the “masked” elements $x_M$ conditioned on the “unmasked” ones.
\[L_{BERT} = \mathbb{E}_{x\sim X} \mathbb{E}_M \sum_{i\in M} [-\log p(x_i|x_{[1,n](unmasked)})]\]3. Architecture
The authors use GPT-2 formulation of the transformer decoder block.
When training AR objective, we apply the standard upper triangular mask to the $n \times n$ matrix of attention logits.
When training BERT objective, no attention logit masking is required: after applying content embeddings to the input sequence, we zero out the positions in M.
We should note that BERT model has to learn any spatial relationships between positions at train time. AR model partially uses spatial relationships via the ordering of the conditionals. Nevertheless, transformer decoder architecture strongly contrasts convolutional neural networks which incorporate the inductive bias that features should arise from spatially proximate elements.
The authors trained 4 models with different sizes.
Model | L (# layers) | d (embedding size) | # parameters |
---|---|---|---|
iGPT-XL | 60 | 3072 | 6.8B |
iGPT-L | 48 | 1536 | 1.4B |
iGPT-M | 36 | 1024 | 455M |
iGPT-S | 24 | 512 | 76M |
As models grow larger, validation loss on the auto-regressive objective decreases, and linear probe accuracy(we’ll see what it is right below) increases. So, we can say bigger generative models learn better representations.
4. Transfer Methods
The authors used two approaches to measure representation quality; fine-tune and linear probing.
1. Fine-tuning
We average pool across the sequence dimension to extract a $d$-dimensional vector of features per example.
While fine-tuning on cross entropy loss yields reasonable downstream performance, the authors found empirically that adding generative loss ($L_{AR}$ or $L_{BERT}$) and training on joint objective works even better.
2. Linear Probing
We use the pre-trained model as a feature extractor. In order to extract features, average pooling is applied to certain layer. Then a linear classifier is trained on the extracted features and targets for downstream task.
The authors found that the best features often lie in the middle of the network.
This behavior suggests that these generative models operate in two phases.
- Each position gathers information from its surrounding context in order to build a more global image representation.
- This contextualized input is used to solve the conditional next pixel prediction task.
Linear probing captures the intuition that good features should linearly separate the classes of transfer tasks. Furthermore, linear probes help disentangle feature quality from model architecture.
5. Dataset
- ImageNet as a proxy for a large unlabeled corpus
- ImageNet, CIFAR-10, CIFAR-100, STL-10 as proxies for downstream tasks
Augmentations
- ImageNet: When training, randomly resize an image such that the shorter sidelength is in the range [256, 384] and then take a random 224 x 224 crop. When evaluating, resize the image such that the shorter sidelength is 224, and use the single 224 x 224 center crop.
- CIFAR-10, CIFAR-100: When fine-tuning, 4 pixels are reflection padded on each side, and a 32 x 32 crop is randomly sampled from the padded image or its horizontal flip.
Once optimal hyperparameters are found on validation set, the authors folded validation set back into the training set, retrained the model, and reported numbers on the respective test set.
6. Context Reduction
The memory requirements of the transformer decoder scale quadratically with context length when using dense attention. Thus we need to lower the input resolution (IR). Since human performance on image classification begins to drop rapidly below size of $32^2 \times 3$, the authors used IR of either $32^2 \times 3$, $48^2 \times 3$, or $64^2 \times 3$.
After resizing input image, to the above resolution, the authors went further to reduce channel from 3 to 1. They created 9-bit color palette by clustering (R, G, B) pixel values using $k$-means with $k=512$. The resulting context length is $32^2$ or $48^2$ or $64^2$ which we call model resolution $MR$.
7. Training Details
Pre-training
- batch size: 64(iGPT-XL), 128(others)
- n iterations: 2M(iGPT-XL), 1M(others)
- optimizer: Adam with $\beta_1=0.9$ and $\beta_2=0.95$
- learning rate: tried 0.01, 0.003, 0.001, … , stopping once the final validation loss starts increasing
- learning rate scheduler: warm up for 1 epoch, decays to 0 following a cosine schedule
Fine-tuning
- didn’t apply cosine-scheduler
Linear probe
- optimizer: SGD with momentum 0.9 and a high learning rate (30, 10, 3, … in the manner describe above)
8. Experiment Results
1. Linear Probe
On CIFAR and STL-10
iGPT achieved SOTA across the entire spectrum of pre-training approaches.
On ImageNet
iGPT which uses 32x32 resolution fall behind other benchmarks which use original high-resolution of ImageNet images. Note that this is not due to architectural or methodological inferiority.
2. Fine-tuning
In linear probe, attaching classification head to the middle layer yielded best representations. However, in fine-tuning, attaching head at the end outperforms.
On CIFAR, iGPT outperforms GPipe, the best model which pre-trains using ImageNet labels.
On ImageNet, iGPT achieves 66.3% accuracy after fine-tuning at MR $32^2$. If we train the network without pre-training, the accuracy goes down 19.4%. iGPT still slightly underperform Isometric Neural Nets.
3. Low-Data CIFAR-10 Classification
Evaluations of unsupervised representations often reuse supervised learning datasets which have thousands to millions of labeled examples. However, a representation which has robustly encoded a semantic concept should be exceedingly data efficient.
Large models are difficult to fin-tune in the ultra-low data regime, since it quickly memorizes training set. So with iGPT on low-data evaluation, the authors used only linear probing.
4. BERT Objective
Linear probe accuracy at every layer is worse than that of the auto-regressive model. However, during fine-tuning, BERT makes up much of this gap.
Because inputs to the BERT model are masked at training time, we must also mask them at evaluation time to keep inputs in-distribution. So we sample 5 independent masks for each input and take the modal prediction(ensembling). This technique improved ImageNet accuracy.
Leave a Comment