BERT is a transformer-based model for NLP tasks. As an encoder-only model, it has a highly regular architecture. In this article, you will learn how to create and pretrain a BERT model from scratch using PyTorch.
Let’s get started.
Overview
This article is divided into three parts; they are:
- Creating a BERT Model the Easy Way
- Creating a BERT Model from Scratch with PyTorch
- Pre-training the BERT Model
Creating a BERT Model the Easy Way
If your goal is to create a BERT model so that you can train it on your own data, using the Hugging Face transformers library is the easiest way to get started. After installing the library:
You can then create a BERT model by using the BertModel class. For example, you can load a pretrained BERT model from the Hugging Face model hub with the following code:
from transformers import BertModel model = BertModel.from_pretrained("bert-base-uncased") |
This will download the BERT model from the Hugging Face model hub and load it into a PyTorch model object. You can also create a new BERT model with a different configuration by using the BertConfig class. For example, to create a BERT model with 12 layers, 768 hidden dimensions, and 12 attention heads, you can use the following code:
from transformers import BertConfig, BertModel config = BertConfig( num_hidden_layers=12, hidden_size=768, num_attention_heads=12 ) model = BertModel(config=config) |
This will create a new, untrained BERT model with the specified configuration.
Creating a BERT Model from Scratch with PyTorch
Using the transformers library is convenient, but you lose the flexibility to customize the model architecture. However, building a BERT model from scratch with PyTorch is not very difficult. Let’s revisit the architecture of BERT:

The BERT architecture
As you can see, BERT is a stack of transformer blocks. Each transformer block consists of a self-attention layer and a feed-forward layer with GeLU activation. Post-norm with LayerNorm is used in the blocks. You can implement one transformer block in PyTorch with the following code:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch import torch.nn as nn class BertBlock(nn.Module): def __init__(self, hidden_size, num_heads, dropout_prob): super().__init__() self.attention = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout_prob, batch_first=True) self.attn_norm = nn.LayerNorm(hidden_size) self.ff_norm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout_prob) self.feed_forward = nn.Sequential( nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), nn.Linear(4 * hidden_size, hidden_size) ) def forward(self, x, pad_mask): # self-attention with padding mask and post-norm attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask) x = self.attn_norm(x + attn_output) # feed-forward with GeLU activation and post-norm ff_output = self.feed_forward(x) x = self.ff_norm(x + self.dropout(ff_output)) return x |
The BERT model requires a pooler that transforms the hidden state of the [CLS] token for classification tasks. The [CLS] token is a special placeholder used to represent the entire sequence, so its representation should be distinguished from other token states. The pooler is simply a linear layer with a tanh activation function. You can implement it with the following code:
class BertPooler(nn.Module): def __init__(self, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, x): x = self.dense(x) x = self.activation(x) return x |
Now, you can implement the BERT model using the above building blocks. The BERT model takes a sequence of integer tokens as input, and these tokens must be converted to embedding vectors before the transformer blocks can process them. Moreover, the model applies a mask to the input tokens to prevent the model from attending to padding tokens.
The BERT model can be implemented as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
class BertModel(nn.Module): def __init__(self, config): super().__init__() # embedding layers self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_id) self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size) self.embeddings_norm = nn.LayerNorm(config.hidden_size) self.embeddings_dropout = nn.Dropout(config.dropout_prob) # transformer blocks self.blocks = nn.ModuleList([ BertBlock(config.hidden_size, config.num_heads, config.dropout_prob) for _ in range(config.num_layers) ]) # [CLS] pooler layer self.pooler = BertPooler(config.hidden_size) def forward(self, input_ids, token_type_ids, pad_id=0): # create attention mask for padding tokens pad_mask = input_ids == pad_id # convert integer tokens to embedding vectors batch_size, seq_len = input_ids.shape position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) position_embeddings = self.position_embeddings(position_ids) type_embeddings = self.type_embeddings(token_type_ids) token_embeddings = self.word_embeddings(input_ids) x = token_embeddings + type_embeddings + position_embeddings x = self.embeddings_norm(x) x = self.embeddings_dropout(x) # process the sequence with transformer blocks for block in self.blocks: x = block(x, pad_mask) # pool the hidden state of the `[CLS]` token pooled_output = self.pooler(x[:, 0, :]) return x, pooled_output |
The BERT model embeds not only the input tokens but also the token type. Moreover, BERT uses learned position embeddings. You need to add the three embeddings together and pass them to the transformer blocks. The normalization and dropout applied after the embeddings help regularize the model and stabilize training.
The model returns the hidden state of the entire sequence and the pooled output of the [CLS] token, which are useful for the MLM and NSP tasks, respectively.
Notice that the model is instantiated with a config object. This is helpful to avoid listing all the hyperparameters in the constructor of the BertModel class. The config object is simply defined as:
import dataclasses @dataclasses.dataclass class BertConfig: """Configuration for BERT model.""" vocab_size: int = 30522 num_layers: int = 12 hidden_size: int = 768 num_heads: int = 12 dropout_prob: float = 0.1 pad_id: int = 0 max_seq_len: int = 512 num_types: int = 2 |
The above code defines the BERT model backbone. When you pre-train the model, you need to add pretraining heads to generate predictions for the MLM and NSP tasks. Let’s implement the pretraining model that uses the BERT model backbone and adds the pretraining heads.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
class BertPretrainingModel(nn.Module): def __init__(self, config): super().__init__() self.bert = BertModel(config) self.mlm_head = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.vocab_size), ) self.nsp_head = nn.Linear(config.hidden_size, 2) def forward(self, input_ids, token_type_ids, pad_id=0): # Process the sequence with the BERT model backbone x, pooled_output = self.bert(input_ids, token_type_ids, pad_id) # Predict the masked tokens for the MLM task and the classification for the NSP task mlm_logits = self.mlm_head(x) nsp_logits = self.nsp_head(pooled_output) return mlm_logits, nsp_logits |
Pre-training the BERT Model
Pre-training the BERT model requires a labeled dataset. You can refer to the previous post for how to create the labeled dataset.
The first step is to create a data loader for the pretraining dataset. Like most other models, BERT operates on batches of data rather than individual samples. The data loader helps you shuffle and batch the data for training and allows you to customize the data to fit the training pipeline.
Let’s see how you can create a PyTorch DataLoader object with the labeled dataset.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import datasets import torch dataset = datasets.Dataset.from_parquet("wikitext-2_train_data.parquet") def collate_fn(batch): """Custom collate function to handle variable-length sequences in dataset.""" # always at max length: tokens, segment_ids; always singleton: is_random_next input_ids = torch.tensor([item["tokens"] for item in batch]) token_type_ids = torch.tensor([item["segment_ids"] for item in batch]).abs() is_random_next = torch.tensor([item["is_random_next"] for item in batch]).to(int) # variable length: masked_positions, masked_labels masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item["masked_positions"]] masked_labels = torch.tensor([label for item in batch for label in item["masked_labels"]]) return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8) |
The dataset object is created using the Hugging Face datasets library, which loads the parquet file created in the previous post. The dataset is preprocessed to create labels for the MLM and NSP tasks. Each sample in the dataset is a Python dictionary with the keys tokens (the integer tokens of the sequence), segment_ids (the segment labels), is_random_next (a Boolean label indicating whether the next sentence is from another document), masked_positions (a list of masked positions), and masked_labels (the original tokens at the masked positions).
PyTorch’s DataLoader can help you shuffle and batch the data. You should consider setting num_workers appropriately to utilize multiple CPU cores so that data loading is not a bottleneck in your training.
You can set a custom collate function in the DataLoader to transform the data into tensors that can be fed into the model. Note that the segment_ids in the previous post use -1 for padding tokens, but we did not set up any embedding for this value. Since the padding locations are ignored, you can simply set those values to 1 for convenience.
The masked positions and masked labels need to be handled differently. Each sample may have a different number of masked positions, so you cannot stack them into a single tensor. Instead, you keep the masked positions as a list of tuples denoting the batch index and positions. The masked labels are simply a flattened tensor of the original tokens at the masked positions. The reason will become clear when you implement the training loop, as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
... # Training parameters epochs = 10 learning_rate = 1e-4 batch_size = 32 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BertPretrainingModel(BertConfig()).to(device) model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) loss_fn = nn.CrossEntropyLoss() for epoch in range(epochs): pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: # get batched data input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch input_ids = input_ids.to(device) token_type_ids = token_type_ids.to(device) is_random_next = is_random_next.to(device) masked_labels = masked_labels.to(device) # extract output from model mlm_logits, nsp_logits = model(input_ids, token_type_ids) # MLM loss: masked_positions is a list of tuples of (B, S), extract the # corresponding logits from tensor mlm_logits of shape (B, S, V) batch_indices, token_positions = zip(*masked_pos) mlm_logits = mlm_logits[batch_indices, token_positions] mlm_loss = loss_fn(mlm_logits, masked_labels) # Compute the loss for the NSP task nsp_loss = loss_fn(nsp_logits, is_random_next) # backward with total loss total_loss = mlm_loss + nsp_loss pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item()) optimizer.zero_grad() total_loss.backward() optimizer.step() scheduler.step() pbar.update(1) pbar.close() # Save the model torch.save(model.state_dict(), "bert_pretraining_model.pth") torch.save(model.bert.state_dict(), "bert_model.pth") |
This is a standard training loop in PyTorch, but much simplified from the training procedure as described in section A.2 of the original BERT paper. You set up the optimizer, scheduler, and loss function, then iterate over the data loader and update the model parameters. The tqdm library is used to visualize training progress. The pre-training model outputs are the sequence of logits for the MLM task and the logits for the NSP task. Calculating the NSP loss is straightforward since the output is a tensor of shape (B, 2) and the target is a vector of either 0 or 1. The MLM loss is calculated only on 15% of the input tokens. You need to extract the logits corresponding to those masked positions as mlm_logits and then compare them with masked_labels. The overall loss is the sum of the MLM loss and the NSP loss.
This training loop runs for 10 epochs (the original paper suggested 40). Depending on your hardware, it may take an hour to complete even for the smaller WikiText-2 dataset. If you’re using the larger WikiText-103 dataset, it may take a day to complete. The trained model is saved to the file bert_pretraining_model.pth. However, you typically do not need the pretrained model since the pre-training heads are not useful for other tasks. You can simply extract the backbone BERT model and save that one alone.
For completeness, here is the complete code:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import dataclasses import datasets import torch import torch.nn as nn import tqdm @dataclasses.dataclass class BertConfig: """Configuration for BERT model.""" vocab_size: int = 30522 num_layers: int = 12 hidden_size: int = 768 num_heads: int = 12 dropout_prob: float = 0.1 pad_id: int = 0 max_seq_len: int = 512 num_types: int = 2 class BertBlock(nn.Module): """One transformer block in BERT.""" def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float): super().__init__() self.attention = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout_prob, batch_first=True) self.attn_norm = nn.LayerNorm(hidden_size) self.ff_norm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout_prob) self.feed_forward = nn.Sequential( nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), nn.Linear(4 * hidden_size, hidden_size), ) def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor: # self-attention with padding mask and post-norm attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask) x = self.attn_norm(x + attn_output) # feed-forward with GeLU activation and post-norm ff_output = self.feed_forward(x) x = self.ff_norm(x + self.dropout(ff_output)) return x class BertPooler(nn.Module): """Pooler layer for BERT to process the [CLS] token output.""" def __init__(self, hidden_size: int): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dense(x) x = self.activation(x) return x class BertModel(nn.Module): """Backbone of BERT model.""" def __init__(self, config: BertConfig): super().__init__() # embedding layers self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_id) self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size) self.embeddings_norm = nn.LayerNorm(config.hidden_size) self.embeddings_dropout = nn.Dropout(config.dropout_prob) # transformer blocks self.blocks = nn.ModuleList([ BertBlock(config.hidden_size, config.num_heads, config.dropout_prob) for _ in range(config.num_layers) ]) # [CLS] pooler layer self.pooler = BertPooler(config.hidden_size) def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: # create attention mask for padding tokens pad_mask = input_ids == pad_id # convert integer tokens to embedding vectors batch_size, seq_len = input_ids.shape position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) position_embeddings = self.position_embeddings(position_ids) type_embeddings = self.type_embeddings(token_type_ids) token_embeddings = self.word_embeddings(input_ids) x = token_embeddings + type_embeddings + position_embeddings x = self.embeddings_norm(x) x = self.embeddings_dropout(x) # process the sequence with transformer blocks for block in self.blocks: x = block(x, pad_mask) # pool the hidden state of the `[CLS]` token pooled_output = self.pooler(x[:, 0, :]) return x, pooled_output class BertPretrainingModel(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.bert = BertModel(config) self.mlm_head = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.vocab_size), ) self.nsp_head = nn.Linear(config.hidden_size, 2) def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: # Process the sequence with the BERT model backbone x, pooled_output = self.bert(input_ids, token_type_ids, pad_id) # Predict the masked tokens for the MLM task and the classification for the NSP task mlm_logits = self.mlm_head(x) nsp_logits = self.nsp_head(pooled_output) return mlm_logits, nsp_logits # Training parameters epochs = 10 learning_rate = 1e-4 batch_size = 32 # Load dataset and set up dataloader dataset = datasets.Dataset.from_parquet("wikitext-2_train_data.parquet") def collate_fn(batch: list[dict]): """Custom collate function to handle variable-length sequences in dataset.""" # always at max length: tokens, segment_ids; always singleton: is_random_next input_ids = torch.tensor([item["tokens"] for item in batch]) token_type_ids = torch.tensor([item["segment_ids"] for item in batch]).abs() is_random_next = torch.tensor([item["is_random_next"] for item in batch]).to(int) # variable length: masked_positions, masked_labels masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item["masked_positions"]] masked_labels = torch.tensor([label for item in batch for label in item["masked_labels"]]) return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8) # train the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BertPretrainingModel(BertConfig()).to(device) model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) loss_fn = nn.CrossEntropyLoss() for epoch in range(epochs): pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: # get batched data input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch input_ids = input_ids.to(device) token_type_ids = token_type_ids.to(device) is_random_next = is_random_next.to(device) masked_labels = masked_labels.to(device) # extract output from model mlm_logits, nsp_logits = model(input_ids, token_type_ids) # MLM loss: masked_positions is a list of tuples of (B, S), extract the # corresponding logits from tensor mlm_logits of shape (B, S, V) batch_indices, token_positions = zip(*masked_pos) mlm_logits = mlm_logits[batch_indices, token_positions] mlm_loss = loss_fn(mlm_logits, masked_labels) # Compute the loss for the NSP task nsp_loss = loss_fn(nsp_logits, is_random_next) # backward with total loss total_loss = mlm_loss + nsp_loss pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item()) optimizer.zero_grad() total_loss.backward() optimizer.step() scheduler.step() pbar.update(1) pbar.close() # Save the model torch.save(model.state_dict(), "bert_pretraining_model.pth") torch.save(model.bert.state_dict(), "bert_model.pth") |
Further Reading
Below are some resources that you may find useful:
- Devlin et al (2018) BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- Google’s BERT implementation
- Hugging Face’s BERT implementation in the transformers library
- Hugging Face’s BERT documentation
Summary
In this article, you learned how to create a BERT model from scratch using PyTorch. Specifically, you learned:
- How to create a BERT model from scratch using PyTorch
- How to use PyTorch DataLoader to batch data and handle variable-length sequences
- How to pre-train a BERT model using the MLM and NSP tasks