Google Research Proposes MaskGIT: A New Deep Learning Technique Based on Bi-Directional Generative Transformers For High-Quality and Fast Image Synthesis
Generative Adversarial Networks (GANs), with their capacity of producing high-quality images, have been the leading technology in image generation for the past couple of years. Nevertheless, their minimax learning mechanism brought out different limits, such as training instability and mode collapse (i.e., when all the produced samples belong to a small set of samples).
Recently, Generative Transformer models are beginning to match, or even surpass, the performances of GANs. The simple idea is to learn a function to encode the input image into a quantized sequence and then train an autoregressive Transformer on a sequence prediction task (i.e., predict an image token, given all the previous image tokens). This learning is based on maximum likelihood and thus not affected by the same issues as GANs.
Nevertheless, the main problem of the state-of-the-art Generative Transformer, such as VQ-GAN, is the fact that images are treated as 1-D sequence in a very anti-intuitive way, following a left-to-right order. As is cited in the paper: “Imagine how an artwork is created. A painter starts with a sketch and then progressively refines it by filling or tweaking the details, which is in clear contrast to the line-by-line printing used in previous work”.
For this reason, the Google Research Team introduced Masked Generative Image Transformer or MaskGIT, a new bidirectional Transformer for image synthesis. The first phase of MaskGIT is the same as VQ-GAN, while the second phase is inspired by BERT (Bidirectional Encoder Representations from Transformers) bidirectional masked mechanism. More precisely, the images are not generated token-by-token in a sequential way as in autoregressive Transformers, but with a masking mechanism that allows predicting all the tokens parallelly, then keep just the more confident one and repeat the process until reaching acceptable confidence for all the tokens, saving an insane amount of time. The comparison between these two approaches is shown in the image below.
The first phase of training: Tokenization
As already said, the authors did not focus on the tokenization stage of the approach but relied on VQ-GAN. In this first stage of training, the encoder converts an image into the latent vector, and then a codebook is used to convert it to a discrete quantized latent vector. In the conversion process, you start from a token, search the closest entry into the codebook and represent that token as a codeword. The quantized vector (“visual tokens” in the image below) is passed to the decoder, which tries to reconstruct the original image. The encoder-decoder structure is trained through a perceptual loss (i.e., a Mean Squared Error performed on some internal representation) and an adversarial loss. At the same time, the codebook is learned with two alignment losses, whose goal is to get the codebook vector as close to the encoder output as possible.
The second phase of training: Masked Visual Token Modeling (MVTM)
The encoder and the decoder are frozen in the second phase, and a bidirectional Transformer is trained. An image is passed to the encoder, and the Visual Tokens are generated through the codebook. This array of Visual Tokens is then masked (figure above), meaning that a scheduling function is used to replace some tokens (the sampling of the tokens depends on the function) with a general [MASK] token. The aim is to learn how to predict the Visual Tokens starting from the Masked Tokens. The authors used the idea of BERT, where a masked word in a sentence is predicted “reading” the sentence from left to right and from right to left. The loss used to predict the probabilities for each masked token is the negative log-likelihood.
Inference
After the two training steps, to generate a new image, the algorithm starts with a blank canvas with all the tokens masked out. For the subsequent t iterations 1) the model firstly predicts the probabilities for all the masked tokens in parallel (if the canvas is blank, it will generate all the tokens), 2) at each masked location, the token is sampled over all possible tokens in the codebook, and its prediction score is used as a confidence score, 3) the number of tokens to mask is computed through the selected scheduling function, 4) the next Masked Token is obtained by masking the token which lower confidence score and keeping the most confident ones. The mask ratio is decreased until all tokens are generated.
Scheduling Function
The scheduling function in BERT uses a fixed ratio of 15%, meaning that it always masked 15% of the tokens. This can not be applied to this problem, so the authors focused solely on experimenting with continuous and decreasing function bounded between 0 and 1 (0 means that no token is masked, 1 means that all tokens are masked). After experimenting with different functions, the best solution was found to be a cosine function.
Results
In the ImageNet benchmark, MaskGIT was demonstrated to be 64x time faster and obtained a higher classification accuracy score (CAS) and FID than state-of-the-art models both in Generative Transformer (VQ-GAN), classic GANs (BigGAN), and Diffusion Models (ADM).
Some very cool applications of MaskGIT are shown below; it is particularly worth noting a new application of class-conditional image editing, (b) in the figure, where the content inside a bounding box is re-generated changing the class while keeping the context. This application would be unfeasible with classic autoregressive models, while it is very trivial for MaskGIT.
Conclusion
MaskGIT defined a new paradigm for image generation, outperforming the state-of-the-art in almost any task. As it is the first implementation of this idea, we can’t wait to see what other synthesis tasks will be mastered by improving bidirectional generative Transformer in the following years!
Paper: https://arxiv.org/pdf/2202.04200v1.pdf
Suggested
Credit: Source link
Comments are closed.