The Complete ViT Structure
At this level all ViT parts have efficiently been created. Therefore, we are able to now use them to assemble the complete Imaginative and prescient Transformer structure. Take a look at the Codeblock 17 under to see how I do it.
# Codeblock 17
class ViT(nn.Module):
def __init__(self):
tremendous().__init__()#self.patcher = PatcherUnfold()
self.patcher = PatcherConv() #(1)
self.pos_embedding = PosEmbedding()
self.transformer_encoders = nn.Sequential(
*[TransformerEncoder() for _ in range(NUM_ENCODERS)] #(2)
)
self.mlp_head = MLPHead()
def ahead(self, x):
x = self.patcher(x)
x = self.pos_embedding(x)
x = self.transformer_encoders(x)
x = x[:, 0] #(3)
x = self.mlp_head(x)
return x
There are a number of issues I need to emphasize concerning the above code. First, at line #(1)
we are able to use both PatcherUnfold()
or PatcherConv()
as they each have the identical position, i.e., to do the patch flattening and linear projection step. On this case, I take advantage of the latter for no particular motive. Secondly, the Transformer Encoder block might be repeated NUM_ENCODER
(12) occasions (#(2)
) since we’re going to implement ViT-Base as said in Determine 3. Lastly, don’t neglect to slice the tensor outputted by the Transformer Encoder since our MLP head will solely course of the category token a part of the output (#(3)
).
We will take a look at whether or not our ViT mannequin works correctly utilizing the next code.
# Codeblock 18
vit = ViT().to(gadget)
x = torch.randn(1, 3, 224, 224).to(gadget)
print(vit(x).dimension())
You may see right here that the enter which the dimension is 1×3×224×224 has been transformed to 1×10, which signifies that our mannequin works as anticipated.
Be aware: you should remark out all of the prints to make the output appears extra concise like this.
# Codeblock 18 output
torch.Dimension([1, 10])
Moreover, we are able to additionally see the detailed construction of the community utilizing the abstract()
perform we imported at the start of the code. You may observe that the entire variety of parameters is round 86 million, which matches the quantity said in Determine 3.
# Codeblock 19
abstract(vit, input_size=(1,3,224,224))
# Codeblock 19 output
==========================================================================================
Layer (sort:depth-idx) Output Form Param #
==========================================================================================
ViT [1, 10] --
├─PatcherConv: 1-1 [1, 196, 768] --
│ └─Conv2d: 2-1 [1, 768, 14, 14] 590,592
│ └─Flatten: 2-2 [1, 768, 196] --
├─PosEmbedding: 1-2 [1, 197, 768] 152,064
│ └─Dropout: 2-3 [1, 197, 768] --
├─Sequential: 1-3 [1, 197, 768] --
│ └─TransformerEncoder: 2-4 [1, 197, 768] --
│ │ └─LayerNorm: 3-1 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-2 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-3 [1, 197, 768] 1,536
│ │ └─Sequential: 3-4 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-5 [1, 197, 768] --
│ │ └─LayerNorm: 3-5 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-6 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-7 [1, 197, 768] 1,536
│ │ └─Sequential: 3-8 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-6 [1, 197, 768] --
│ │ └─LayerNorm: 3-9 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-10 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-11 [1, 197, 768] 1,536
│ │ └─Sequential: 3-12 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-7 [1, 197, 768] --
│ │ └─LayerNorm: 3-13 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-14 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-15 [1, 197, 768] 1,536
│ │ └─Sequential: 3-16 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-8 [1, 197, 768] --
│ │ └─LayerNorm: 3-17 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-18 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-19 [1, 197, 768] 1,536
│ │ └─Sequential: 3-20 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-9 [1, 197, 768] --
│ │ └─LayerNorm: 3-21 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-22 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-23 [1, 197, 768] 1,536
│ │ └─Sequential: 3-24 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-10 [1, 197, 768] --
│ │ └─LayerNorm: 3-25 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-26 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-27 [1, 197, 768] 1,536
│ │ └─Sequential: 3-28 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-11 [1, 197, 768] --
│ │ └─LayerNorm: 3-29 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-30 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-31 [1, 197, 768] 1,536
│ │ └─Sequential: 3-32 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-12 [1, 197, 768] --
│ │ └─LayerNorm: 3-33 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-34 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-35 [1, 197, 768] 1,536
│ │ └─Sequential: 3-36 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-13 [1, 197, 768] --
│ │ └─LayerNorm: 3-37 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-38 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-39 [1, 197, 768] 1,536
│ │ └─Sequential: 3-40 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-14 [1, 197, 768] --
│ │ └─LayerNorm: 3-41 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-42 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-43 [1, 197, 768] 1,536
│ │ └─Sequential: 3-44 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-15 [1, 197, 768] --
│ │ └─LayerNorm: 3-45 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-46 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-47 [1, 197, 768] 1,536
│ │ └─Sequential: 3-48 [1, 197, 768] 4,722,432
├─MLPHead: 1-4 [1, 10] --
│ └─LayerNorm: 2-16 [1, 768] 1,536
│ └─Linear: 2-17 [1, 768] 590,592
│ └─GELU: 2-18 [1, 768] --
│ └─Linear: 2-19 [1, 10] 7,690
==========================================================================================
Whole params: 86,396,938
Trainable params: 86,396,938
Non-trainable params: 0
Whole mult-adds (Items.MEGABYTES): 173.06
==========================================================================================
Enter dimension (MB): 0.60
Ahead/backward go dimension (MB): 102.89
Params dimension (MB): 231.59
Estimated Whole Dimension (MB): 335.08
==========================================================================================