The vanilla VAE exhibits distinct clusters whereas the CVAE has a extra homogeneous distribution. Vanilla VAE encodes class and sophistication variation into the latent house since there isn’t a supplied conditional sign. Nevertheless, the CVAE doesn’t have to be taught class distinction and the latent house can give attention to the variation inside courses. Due to this fact, a CVAE can probably be taught extra info because it doesn’t depend on having to be taught primary class conditioning.
Two mannequin architectures had been created to check picture era. The primary structure was a convolutional CVAE with a concatenating conditional method. All networks had been constructed for Vogue-MNIST photos of measurement 28×28 (784 whole pixels).
class ConcatConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
tremendous().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, 32)
# Latent house (with concatenated situation)
self.fc_mu = nn.Linear(self.flatten_size + 32, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 32, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + 32, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Concatenate situation with encoded enter
x = torch.cat([x, c], dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c = self.label_embedding(c)
# Concatenate situation with latent vector
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The CVAE encoder consists of three convolutional layers every adopted by a ReLU non-linearity. The output of the encoder is then flattened. The category quantity is then handed by an embedding layer and added to the encoder output. The reparameterization trick is then used with 2 linear layers to acquire a μ and σ within the latent house. As soon as sampled, the output of the reparameterized latent house is handed to the decoder now concatenated with the category quantity embedding layer output. The decoder consists of three transposed convolutional layers. The primary two include a ReLU non-linearity with the final layer containing a sigmoid non-linearity. The output of the decoder is a 28×28 generated picture.
The opposite mannequin structure follows the identical method however with including the conditional enter as an alternative of concatenating. A significant query was if including or concatenating will result in higher reconstruction or era outcomes.
class AdditiveConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
tremendous().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, self.flatten_size)
# Latent house (with out concatenation)
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_var = nn.Linear(self.flatten_size, latent_dim)
# Decoder situation embedding
self.decoder_label_embedding = nn.Embedding(num_classes, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Add situation to encoded enter
x = x + c
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
# Add situation to latent vector
c = self.decoder_label_embedding(c)
z = z + c
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The identical loss operate is used for all CVAEs from the equation proven above.
def loss_function(recon_x, x, mu, logvar):
"""Computes the loss = -ELBO = Detrimental Log-Chance + KL Divergence.
Args:
recon_x: Decoder output.
x: Floor reality.
mu: Imply of Z
logvar: Log-Variance of Z
"""
BCE = F.binary_cross_entropy(recon_x, x, discount='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
As a way to assess model-generated photos, 3 quantitative metrics are generally used. Imply Squared Error (MSE) was calculated by summing the squares of the distinction between the generated picture and a floor reality picture pixel-wise. Structural Similarity Index Measure (SSIM) is a metric that evaluates picture high quality by evaluating two photos based mostly on structural info, luminance, and distinction [3]. SSIM can be utilized to check photos of any measurement whereas MSE is relative to pixel measurement. SSIM rating ranges from -1 to 1, the place 1 signifies an identical photos. Frechet inception distance (FID) is a metric for quantifying the realism and variety of photos generated. As FID is a distance measure, decrease scores are indicative of a greater reconstruction of a set of photos.
Earlier than scaling as much as full textual content to picture, CVAEs picture reconstruction and era on Vogue-MNIST. Vogue-MNIST is an MNIST-like dataset consisting of a coaching set of 60,000 examples and a check set of 10,000 examples. Every instance is a 28×28 grayscale picture, related to a label from 10 courses [4].
Preprocessing capabilities had been created to extract the related key phrase containing the category identify from the enter short-text common expression matching. Additional descriptors (synonyms) had been used for many courses to account for related style gadgets included in every class (e.g. Coat & Jacket).
courses = {
'Shirt':0,
'Prime':0,
'Trouser':1,
'Pants':1,
'Pullover':2,
'Sweater':2,
'Hoodie':2,
'Gown':3,
'Coat':4,
'Jacket':4,
'Sandal':5,
'Shirt':6,
'Sneaker':7,
'Shoe':7,
'Bag':8,
'Ankle boot':9,
'Boot':9
}def word_to_text(input_str, courses, mannequin, system):
label = class_embedding(input_str, courses)
if label == -1: return Exception("No legitimate label")
samples = sample_images(mannequin, num_samples=4, label=label, system=system)
plot_samples(samples, input_str, torch.tensor([label]))
return
def class_embedding(input_str, courses):
for key in checklist(courses.keys()):
template = f'(?i)b{key}b'
output = re.search(template, input_str)
if output: return courses[key]
return -1
The category identify was then transformed to its class quantity and used because the conditional enter to the CVAE alongside. As a way to generate a picture, the category label extracted from the quick textual content description is handed into the decoder with random samples from a Gaussian distribution to enter the variable from the latent house.
Earlier than testing era, picture reconstruction is examined to make sure the performance of the CVAE. Because of making a convolutional community with 28×28 photos, the community might be skilled in lower than an hour with lower than 100 epochs.
Reconstructions include the overall form of the bottom reality photos, however sharp, excessive frequency options are lacking from the picture. Any textual content or intricate design patterns are blurred within the mannequin output. Inputting any quick textual content containing a category of Vogue-MNIST provides generated outputs resembling reconstructed photos.
The generated photos have an MSE of 11 and a SSIM of 0.76. These represent good generations signifying that in easy, small photos, CVAEs can generate high quality photos. GANs and DDPMs will produce increased high quality photos with advanced options, however CVAEs can deal with easy instances.
When scaling as much as picture era to textual content of any size, extra sturdy strategies can be wanted moreover common expression matching. To do that, Open AI’s CLIP is used to transform textual content right into a excessive dimensional embedding vector. The embedding mannequin is utilized in its ViT-B/32 configuration, which outputs embeddings of size 512. A limitation of the CLIP mannequin is that it has a most token size of 77, with research exhibiting an excellent smaller efficient size of 20 [5]. Thus, in situations the place the enter textual content accommodates a number of sentences, the textual content is break up up by sentence and handed by the CLIP encoder. The ensuing embeddings are averaged collectively to create the ultimate output embedding.
An extended textual content mannequin requires way more sophisticated coaching information than Vogue-MNIST, so COCO dataset was used. COCO dataset has annotations (that aren’t utterly sturdy however that shall be mentioned later) that may be handed into CLIP to get embeddings. Nevertheless, COCO photos are of measurement 640×480, that means that even with cropping transforms, a bigger community is required. Including and concatenating conditional inputs architectures are each examined for lengthy textual content to picture era, however the concatenating method is proven right here:
class cVAE(nn.Module):
def __init__(self, latent_dim=128):
tremendous().__init__()system = torch.system("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model, _ = clip.load("ViT-B/32", system=system)
self.clip_model.eval()
for param in self.clip_model.parameters():
param.requires_grad = False
self.latent_dim = latent_dim
# Modified encoder for 128x128 enter
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4x4
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 512 * 4 * 4 # Flattened measurement from encoder
# Course of CLIP embeddings for encoder
self.condition_processor_encoder = nn.Sequential(
nn.Linear(512, 1024)
)
self.fc_mu = nn.Linear(self.flatten_size + 1024, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 1024, latent_dim)
self.decoder_input = nn.Linear(latent_dim + 512, 512 * 4 * 4)
# Modified decoder for 128x128 output
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 128x128
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 3, 3, stride=1, padding=1), # 128x128
nn.Sigmoid()
)
def encode_condition(self, textual content):
with torch.no_grad():
embeddings = []
for sentence in textual content:
embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).sort(torch.float32))
return torch.imply(torch.stack(embeddings), dim=0)
def encode(self, x, c):
x = self.encoder(x)
c = self.condition_processor_encoder(c)
x = torch.cat([x, c], dim=1)
return self.fc_mu(x), self.fc_var(x)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 512, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
One other main level of investigation was picture era and reconstruction on photos of various sizes. Particularly, modifying COCO photos to be of measurement 64×64, 128×128, and 256×256. After coaching the community, reconstruction outcomes ought to first be examined.
All picture sizes result in reconstructed background with some characteristic outlines and proper colours. Nevertheless, as picture measurement will increase, extra options are in a position to be recovered. This is smart as though it would take lots longer to coach a mannequin with a bigger picture measurement, there’s extra info that may be captured and discovered by the mannequin.
With picture era, this can be very troublesome to generate top quality photos. Most photos have backgrounds to some extent and blurred options within the picture. This may be anticipated for picture era from a CVAE. This happens in each concatenation and addition for the conditional enter, however the concatenated method performs higher. That is probably as a result of concatenated conditional inputs is not going to intrude with necessary options and ensures info is preserved distinctly. Circumstances might be ignored if they’re irrelevant. Nevertheless, additive conditional inputs can intrude with current options and utterly mess up the community when updating weights throughout backpropagation.
The entire COCO generated photos have a far decrease SSIM of about 0.4 in comparison with the SSIM on Vogue-MNIST. MSE is proportional to picture measurement, so it’s troublesome to quanity variations. FID for COCO picture generations are within the 200s for additional proof that COCO CVAE generated photos will not be sturdy.
The largest limitation in attempting to make use of CVAEs for picture era is, nicely, the CVAE. The quantity of data that may be contained and reconstructed/generated is extraordinarily depending on the dimensions of the latent house. A latent house that’s too small gained’t seize any significant info and is proportional to the dimensions of the output picture. A 28×28 picture wants a much smaller latent house than a 64×64 picture (because it proportionally squares from picture measurement). Nevertheless, a latent house greater than the precise picture provides pointless information and at that time simply create a 1-to-1 mapping. For the COCO dataset, a latent house of not less than 512 is required to seize some options. And whereas CVAEs are generative fashions, a convolutional encoder and decoder is a moderately rudimentary community. The coaching type of a GAN or the advanced denoising strategy of a DDPM permits for much extra sophisticated picture era.
One other main limitation in picture era is the dataset skilled on. Though the COCO dataset has annotations, the annotations will not be extensively detailed. As a way to prepare advanced generative fashions, a special dataset needs to be used for coaching. COCO doesn’t present areas or extra info for background particulars. A fancy characteristic vector from the CLIP encoder can’t be successfully utilized to a CVAE on COCO.
Though CVAEs and picture era on COCO have their limitations, it creates a workable picture era mannequin. Extra code and particulars might be supplied simply attain out!
[1] Kingma, Diederik P, et. al. “Auto-encoding variational bayes.” arXiv:1312.6114 (2013).
[2] Sohn, Kihyuk, et. al. “Studying Structured Output Illustration utilizing Deep Conditional Generative Fashions.” NeurIPS Proceedings (2015).
[3] Nilsson, J., et. al. “Understanding ssim.” arXiv:2102.12037 (2020).
[4] Xiao, Han, et. al. “Vogue-mnist: a novel picture dataset for benchmarking machine studying algorithms.” arXiv:2403.15378 (2024) (MIT license).
[5] Zhang, B., et. al. “Lengthy-clip: Unlocking the long-text functionality of clip.” arXiv:2403.15378 (2024).
A reference to my group challenge companions Jake Hession (Deloitte Guide), Ashley Hong (Google SWE), and Julian Kuppel (Quant)!