BERT is probably one of the most exciting developments in NLP in the recent years. Just a few months back, even Google has announced that it is using BERT in its search, supposedly the “biggest leap forward” it did in understanding search in the past five years. That is a huge testament to come from Google. About Search! That’s just how significant BERT is.

(Image Ref)
Now there are some amazing resources to understand BERT, Transformers, and Attention networks in detail (Attention and Transformers are the building blocks of BERT). I am linking them in the footnote. This article is about understanding the architecture and parameters better, once you already understood BERT at a basic level.
The model is fortunately very easy to load in Python, using Keras (and keras_bert). Following code loads up the model, and print a summary of all the layers.
import keras
from keras_bert import get_base_dict, get_model, compile_model, gen_batch_inputs
# Build & train the model
model = get_model(
token_num=30000,
head_num=12,
transformer_num=12,
embed_dim=768,
feed_forward_dim=3072,
seq_len=512,
pos_num=512,
dropout_rate=0.05,
)
compile_model(model)
model.summary()
All the parameters I have used, including the token count, here are from BERT base model (BERT has two variants, a small variant called base, and another one called large). The parameters are (remember the notation, as we will be using them later):
Token number (T) = 30k. This is no:of distinct tokens, derived from WordPiece tokenization. This breaks down single words to component words, to improve coverage. Ex: playing is converted to (play ,##ing). So as long as the model knows the word “sleep”, it can infer the meaning of “sleeping” even if it is seeing the word for first time
head_num (A) = 12. Total 12 attention heads per Transformer layer
Transformer num (L) = 12
embed_dim (H) = Embedding length =768
Feed forward Dim (FFD) = H*4 =3072
seq_len (S)= Max no:of tokens that can be in an input sentence = 512
pos_num (P) = Positions to be encoded = S = 512
This should give you a very long summary of all the layers in BERT, which looks like this:
Layer (type) Output Shape Param # Connected to ==================================================================================================Input-Token (InputLayer) (None, 512) 0 __________________________________________________________________________________________________Input-Segment (InputLayer) (None, 512) 0 __________________________________________________________________________________________________Embedding-Token (TokenEmbedding [(None, 512, 768), ( 23040000 Input-Token[0][0] __________________________________________________________________________________________________Embedding-Segment (Embedding) (None, 512, 768) 1536 Input-Segment[0][0] __________________________________________________________________________________________________Embedding-Token-Segment (Add) (None, 512, 768) 0 Embedding-Token[0][0] Embedding-Segment[0][0] __________________________________________________________________________________________________Embedding-Position (PositionEmb (None, 512, 768) 393216 Embedding-Token-Segment[0][0] __________________________________________________________________________________________________Embedding-Dropout (Dropout) (None, 512, 768) 0 Embedding-Position[0][0] __________________________________________________________________________________________________Embedding-Norm (LayerNormalizat (None, 512, 768) 1536 Embedding-Dropout[0][0] __________________________________________________________________________________________________Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 2362368 Embedding-Norm[0][0] __________________________________________________________________________________________________Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-1-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 0 Embedding-Norm[0][0] Encoder-1-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-1-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-1-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-1-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-1-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-1-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-1-FeedForward[0][0] __________________________________________________________________________________________________Encoder-1-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-1-MultiHeadSelfAttention- Encoder-1-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-1-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-1-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-1-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-2-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-1-FeedForward-Norm[0][0] Encoder-2-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-2-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-2-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-2-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-2-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-2-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-2-FeedForward[0][0] __________________________________________________________________________________________________Encoder-2-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-2-MultiHeadSelfAttention- Encoder-2-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-2-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-2-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-2-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-3-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-2-FeedForward-Norm[0][0] Encoder-3-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-3-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-3-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-3-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-3-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-3-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-3-FeedForward[0][0] __________________________________________________________________________________________________Encoder-3-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-3-MultiHeadSelfAttention- Encoder-3-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-3-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-3-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-3-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-4-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-3-FeedForward-Norm[0][0] Encoder-4-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-4-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-4-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-4-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-4-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-4-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-4-FeedForward[0][0] __________________________________________________________________________________________________Encoder-4-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-4-MultiHeadSelfAttention- Encoder-4-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-4-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-4-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-4-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-5-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-4-FeedForward-Norm[0][0] Encoder-5-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-5-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-5-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-5-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-5-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-5-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-5-FeedForward[0][0] __________________________________________________________________________________________________Encoder-5-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-5-MultiHeadSelfAttention- Encoder-5-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-5-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-5-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-5-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-6-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-5-FeedForward-Norm[0][0] Encoder-6-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-6-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-6-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-6-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-6-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-6-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-6-FeedForward[0][0] __________________________________________________________________________________________________Encoder-6-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-6-MultiHeadSelfAttention- Encoder-6-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-6-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-6-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-6-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-7-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-6-FeedForward-Norm[0][0] Encoder-7-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-7-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-7-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-7-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-7-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-7-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-7-FeedForward[0][0] __________________________________________________________________________________________________Encoder-7-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-7-MultiHeadSelfAttention- Encoder-7-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-7-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-7-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-7-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-8-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-7-FeedForward-Norm[0][0] Encoder-8-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-8-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-8-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-8-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-8-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-8-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-8-FeedForward[0][0] __________________________________________________________________________________________________Encoder-8-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-8-MultiHeadSelfAttention- Encoder-8-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-8-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-8-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 2362368 Encoder-8-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-9-MultiHeadSelfAttention[__________________________________________________________________________________________________Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 0 Encoder-8-FeedForward-Norm[0][0] Encoder-9-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-9-MultiHeadSelfAttentio (None, 512, 768) 1536 Encoder-9-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-9-FeedForward (FeedForw (None, 512, 768) 4722432 Encoder-9-MultiHeadSelfAttention-__________________________________________________________________________________________________Encoder-9-FeedForward-Dropout ( (None, 512, 768) 0 Encoder-9-FeedForward[0][0] __________________________________________________________________________________________________Encoder-9-FeedForward-Add (Add) (None, 512, 768) 0 Encoder-9-MultiHeadSelfAttention- Encoder-9-FeedForward-Dropout[0][__________________________________________________________________________________________________Encoder-9-FeedForward-Norm (Lay (None, 512, 768) 1536 Encoder-9-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-9-FeedForward-Norm[0][0] __________________________________________________________________________________________________Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-10-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-9-FeedForward-Norm[0][0] Encoder-10-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-10-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-10-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-10-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-10-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-10-FeedForward-Dropout (None, 512, 768) 0 Encoder-10-FeedForward[0][0] __________________________________________________________________________________________________Encoder-10-FeedForward-Add (Add (None, 512, 768) 0 Encoder-10-MultiHeadSelfAttention Encoder-10-FeedForward-Dropout[0]__________________________________________________________________________________________________Encoder-10-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-10-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-10-FeedForward-Norm[0][0]__________________________________________________________________________________________________Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-11-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-10-FeedForward-Norm[0][0] Encoder-11-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-11-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-11-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-11-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-11-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-11-FeedForward-Dropout (None, 512, 768) 0 Encoder-11-FeedForward[0][0] __________________________________________________________________________________________________Encoder-11-FeedForward-Add (Add (None, 512, 768) 0 Encoder-11-MultiHeadSelfAttention Encoder-11-FeedForward-Dropout[0]__________________________________________________________________________________________________Encoder-11-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-11-FeedForward-Add[0][0] __________________________________________________________________________________________________Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 2362368 Encoder-11-FeedForward-Norm[0][0]__________________________________________________________________________________________________Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-12-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 0 Encoder-11-FeedForward-Norm[0][0] Encoder-12-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-12-MultiHeadSelfAttenti (None, 512, 768) 1536 Encoder-12-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-12-FeedForward (FeedFor (None, 512, 768) 4722432 Encoder-12-MultiHeadSelfAttention__________________________________________________________________________________________________Encoder-12-FeedForward-Dropout (None, 512, 768) 0 Encoder-12-FeedForward[0][0] __________________________________________________________________________________________________Encoder-12-FeedForward-Add (Add (None, 512, 768) 0 Encoder-12-MultiHeadSelfAttention Encoder-12-FeedForward-Dropout[0]__________________________________________________________________________________________________Encoder-12-FeedForward-Norm (La (None, 512, 768) 1536 Encoder-12-FeedForward-Add[0][0] __________________________________________________________________________________________________MLM-Dense (Dense) (None, 512, 768) 590592 Encoder-12-FeedForward-Norm[0][0]__________________________________________________________________________________________________MLM-Norm (LayerNormalization) (None, 512, 768) 1536 MLM-Dense[0][0] __________________________________________________________________________________________________Extract (Extract) (None, 768) 0 Encoder-12-FeedForward-Norm[0][0]__________________________________________________________________________________________________MLM-Sim (EmbeddingSimilarity) (None, 512, 30000) 30000 MLM-Norm[0][0] Embedding-Token[0][1] __________________________________________________________________________________________________Input-Masked (InputLayer) (None, 512) 0 __________________________________________________________________________________________________NSP-Dense (Dense) (None, 768) 590592 Extract[0][0] __________________________________________________________________________________________________MLM (Masked) (None, 512, 30000) 0 MLM-Sim[0][0] Input-Masked[0][0] __________________________________________________________________________________________________NSP (Dense) (None, 2) 1538 NSP-Dense[0][0] ==================================================================================================Total params: 109,705,010Trainable params: 109,705,010Non-trainable params: 0
That’s a daunting list. The total number of trainable parameters is ~110M, just like the BERT paper mentions. That’s reassuring that the model we loaded is the right one.
The same can also be visualized in an image which helps us understand the computation graph better:
from keras.utils import plot_model
plot_model(model, to_file='bert.png')


Here’s a brief of various steps in the model:
Two inputs: One from word tokens, one from segment-layer
These get added, summed over to a third embedding: position embedding, followed by dropout and a layer normalization
Then starts Multi-head Self Attention layers — each set of these have 9 steps (all cells starting with Encoder-1 in the above image), and there are 12 such layers. So 108 lines in this are just to capture these. If we understand these better, we understand the architecture almost completely
Following these 12 layers, there are two outputs — one for NSP (Next Sentence Prediction) and one for MLM (Masked Language Modeling)
Layer-wise accounting:
Going through layers from top to bottom, we can see following:
Inputs — Token and segment do not have any trainable parameters, as expected.
Token embeddings parameters= 23040000 (H * T) — because each of 30k (T) tokens needs a representation in dimension 768 (H)
Segment Embeddings parameters = 1536 (2*H) because we need two vectors each of length (H). The vectors represent Segment A and Segment B respectively
Token embeddings and segment embeddings are added to Position Embedding. Parameters = 393216 (H*P). This is because it needs to generate P vectors, each of length H, for the tokens starting 1 to 512 (P). The position embeddings in BERT are trained and not fixed as in Attention is all you need; There’s a dropout applied, and then Layer Normalization is done
Layer Normalization parameters = 1536 (2*H). Normalization has two parameters to learn — mean and standard deviation of each of the embedding position, hence 2*H
Encoder: MultiheadSelfAttention: MultiHeadAttention = 2362368
This needs a bit of explanation. This is what’s happening inside this step[ref]:


There are total 12 heads, with input of dimension 768. So each head generates embedding of length 768/12 = 64. There are three embeddings generated — Q, K, V. That’s toal: 768*64*3 parameters per head, or 12*768*64*3 for all heads. Adding biases for each of Q, K, V, there are another 768*3. Total =12 * 768 * 64 * 3 + 768 * 3; This is after concatenating all the heads. Then an additional weight (W0 towards right in above image) is applied. That is a fully connected dense layer, with output dimension = input dimension. Hence, parameters (with bias)= 768*768 + 768. So the total parameters in this step = A * D * (D/A) * 3+ D * 3 + D * D + D = 12 * 768 * 64 * 3 + 768 * 3 + 768*768 + 768 = 2362368
7. Another Layer Normalization, following same logic as #5
8. FeedForward: FeedForward. This is actually a FeedForward network, which has two fully connected feedforward layers. It transforms the input dimension (H) to FFD, and back to H with ReLu activation in between. So total parameters with biases = (H * FFD + FFD) + (FFD * H + H) = (768 * 3072 + 3072) + (3072 * 768 + 768) = 4722432; This follows another Dropout layer
9. Another Layer Normalization, following same logic as #5
Steps 6–9 covers a single Transformer Layer, and the same set repeats for 12(L) times
This follows two output objectives. One for MLM (Masked Language Modeling)and one for NSP (Next Sentence Prediction). Let’s observe their parameters:
10. MLM- Dense: This takes an embedding as input and tries to predict the masked word’s embedding. So parameters (with bias) = H * H + H = 768*768 + 768 = 590592
11. MLM-Norm: Normalization layer, with parameter count following same logic as #5
12. MLM-Sim: EmbeddingSimilarity: This is computing the similarity between the output of MLM-Norm, and the input masked token’s embedding. But this layer also learns token level bias. So that’s T (=30k) parameters in this layer (Intuitively, I undestand this as similar to token-level priors, but please correct me if I am wrong).
13. NSP-Dense:Dense: This converts the input D length embedding to another D length embedding. Parameters = D *D + D = 590592
14. NSP: Dense: The output D length embedding of previous layer then gets transformed to two vectors, each representing IsNext and NotNext respectively. Hence, parameters = 2*D + 2= 1538
This concludes overview of the whole network. By going through this, following questions got answered for me:
Where the Sequence and Position embeddings are coming from, and the fact that both are trainable
Detailed understanding of what’s happening inside the Transformer cell
Role of Layer Normalization
Propagation of both MLM and NSP tasks at once
Footnote:
PyTorch walkthrough implementation of Attention
Three types of Embeddings in BERT
Colab notebook to understand attention in BERT- This also has a cool interactive visualization that explains how the Q, K, V embeddings are interacting with each other to produce the attention distribution