Pretrain a BERT Model from Scratch - MachineLearningMastery.com

13 min read Original article ↗

BERT is a transformer-based model for NLP tasks. As an encoder-only model, it has a highly regular architecture. In this article, you will learn how to create and pretrain a BERT model from scratch using PyTorch.

Let’s get started.

Overview

This article is divided into three parts; they are:

  • Creating a BERT Model the Easy Way
  • Creating a BERT Model from Scratch with PyTorch
  • Pre-training the BERT Model

Creating a BERT Model the Easy Way

If your goal is to create a BERT model so that you can train it on your own data, using the Hugging Face transformers library is the easiest way to get started. After installing the library:

You can then create a BERT model by using the BertModel class. For example, you can load a pretrained BERT model from the Hugging Face model hub with the following code:

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")

This will download the BERT model from the Hugging Face model hub and load it into a PyTorch model object. You can also create a new BERT model with a different configuration by using the BertConfig class. For example, to create a BERT model with 12 layers, 768 hidden dimensions, and 12 attention heads, you can use the following code:

from transformers import BertConfig, BertModel

config = BertConfig(

    num_hidden_layers=12,

    hidden_size=768,

    num_attention_heads=12

)

model = BertModel(config=config)

This will create a new, untrained BERT model with the specified configuration.

Creating a BERT Model from Scratch with PyTorch

Using the transformers library is convenient, but you lose the flexibility to customize the model architecture. However, building a BERT model from scratch with PyTorch is not very difficult. Let’s revisit the architecture of BERT:

The BERT architecture

As you can see, BERT is a stack of transformer blocks. Each transformer block consists of a self-attention layer and a feed-forward layer with GeLU activation. Post-norm with LayerNorm is used in the blocks. You can implement one transformer block in PyTorch with the following code:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

import torch.nn as nn

class BertBlock(nn.Module):

    def __init__(self, hidden_size, num_heads, dropout_prob):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size)

        )

    def forward(self, x, pad_mask):

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

The BERT model requires a pooler that transforms the hidden state of the [CLS] token for classification tasks. The [CLS] token is a special placeholder used to represent the entire sequence, so its representation should be distinguished from other token states. The pooler is simply a linear layer with a tanh activation function. You can implement it with the following code:

class BertPooler(nn.Module):

    def __init__(self, hidden_size):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

    def forward(self, x):

        x = self.dense(x)

        x = self.activation(x)

        return x

Now, you can implement the BERT model using the above building blocks. The BERT model takes a sequence of integer tokens as input, and these tokens must be converted to embedding vectors before the transformer blocks can process them. Moreover, the model applies a mask to the input tokens to prevent the model from attending to padding tokens.

The BERT model can be implemented as follows:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

class BertModel(nn.Module):

    def __init__(self, config):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

    def forward(self, input_ids, token_type_ids, pad_id=0):

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

The BERT model embeds not only the input tokens but also the token type. Moreover, BERT uses learned position embeddings. You need to add the three embeddings together and pass them to the transformer blocks. The normalization and dropout applied after the embeddings help regularize the model and stabilize training.

The model returns the hidden state of the entire sequence and the pooled output of the [CLS] token, which are useful for the MLM and NSP tasks, respectively.

Notice that the model is instantiated with a config object. This is helpful to avoid listing all the hyperparameters in the constructor of the BertModel class. The config object is simply defined as:

import dataclasses

@dataclasses.dataclass

class BertConfig:

    """Configuration for BERT model."""

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

The above code defines the BERT model backbone. When you pre-train the model, you need to add pretraining heads to generate predictions for the MLM and NSP tasks. Let’s implement the pretraining model that uses the BERT model backbone and adds the pretraining heads.

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

class BertPretrainingModel(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.bert = BertModel(config)

        self.mlm_head = nn.Sequential(

            nn.Linear(config.hidden_size, config.hidden_size),

            nn.GELU(),

            nn.LayerNorm(config.hidden_size),

            nn.Linear(config.hidden_size, config.vocab_size),

        )

        self.nsp_head = nn.Linear(config.hidden_size, 2)

    def forward(self, input_ids, token_type_ids, pad_id=0):

        # Process the sequence with the BERT model backbone

        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

        # Predict the masked tokens for the MLM task and the classification for the NSP task

        mlm_logits = self.mlm_head(x)

        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits

Pre-training the BERT Model

Pre-training the BERT model requires a labeled dataset. You can refer to the previous post for how to create the labeled dataset.

The first step is to create a data loader for the pretraining dataset. Like most other models, BERT operates on batches of data rather than individual samples. The data loader helps you shuffle and batch the data for training and allows you to customize the data to fit the training pipeline.

Let’s see how you can create a PyTorch DataLoader object with the labeled dataset.

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

import datasets

import torch

dataset = datasets.Dataset.from_parquet("wikitext-2_train_data.parquet")

def collate_fn(batch):

    """Custom collate function to handle variable-length sequences in dataset."""

    # always at max length: tokens, segment_ids; always singleton: is_random_next

    input_ids = torch.tensor([item["tokens"] for item in batch])

    token_type_ids = torch.tensor([item["segment_ids"] for item in batch]).abs()

    is_random_next = torch.tensor([item["is_random_next"] for item in batch]).to(int)

    # variable length: masked_positions, masked_labels

    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item["masked_positions"]]

    masked_labels = torch.tensor([label for item in batch for label in item["masked_labels"]])

    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

                                         collate_fn=collate_fn, num_workers=8)

The dataset object is created using the Hugging Face datasets library, which loads the parquet file created in the previous post. The dataset is preprocessed to create labels for the MLM and NSP tasks. Each sample in the dataset is a Python dictionary with the keys tokens (the integer tokens of the sequence), segment_ids (the segment labels), is_random_next (a Boolean label indicating whether the next sentence is from another document), masked_positions (a list of masked positions), and masked_labels (the original tokens at the masked positions).

PyTorch’s DataLoader can help you shuffle and batch the data. You should consider setting num_workers appropriately to utilize multiple CPU cores so that data loading is not a bottleneck in your training.

You can set a custom collate function in the DataLoader to transform the data into tensors that can be fed into the model. Note that the segment_ids in the previous post use -1 for padding tokens, but we did not set up any embedding for this value. Since the padding locations are ignored, you can simply set those values to 1 for convenience.

The masked positions and masked labels need to be handled differently. Each sample may have a different number of masked positions, so you cannot stack them into a single tensor. Instead, you keep the masked positions as a list of tuples denoting the batch index and positions. The masked labels are simply a flattened tensor of the original tokens at the masked positions. The reason will become clear when you implement the training loop, as follows:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

...

# Training parameters

epochs = 10

learning_rate = 1e-4

batch_size = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

    for batch in pbar:

        # get batched data

        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

        input_ids = input_ids.to(device)

        token_type_ids = token_type_ids.to(device)

        is_random_next = is_random_next.to(device)

        masked_labels = masked_labels.to(device)

        # extract output from model

        mlm_logits, nsp_logits = model(input_ids, token_type_ids)

        # MLM loss: masked_positions is a list of tuples of (B, S), extract the

        # corresponding logits from tensor mlm_logits of shape (B, S, V)

        batch_indices, token_positions = zip(*masked_pos)

        mlm_logits = mlm_logits[batch_indices, token_positions]

        mlm_loss = loss_fn(mlm_logits, masked_labels)

        # Compute the loss for the NSP task

        nsp_loss = loss_fn(nsp_logits, is_random_next)

        # backward with total loss

        total_loss = mlm_loss + nsp_loss

        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

        optimizer.zero_grad()

        total_loss.backward()

        optimizer.step()

        scheduler.step()

        pbar.update(1)

    pbar.close()

# Save the model

torch.save(model.state_dict(), "bert_pretraining_model.pth")

torch.save(model.bert.state_dict(), "bert_model.pth")

This is a standard training loop in PyTorch, but much simplified from the training procedure as described in section A.2 of the original BERT paper. You set up the optimizer, scheduler, and loss function, then iterate over the data loader and update the model parameters. The tqdm library is used to visualize training progress. The pre-training model outputs are the sequence of logits for the MLM task and the logits for the NSP task. Calculating the NSP loss is straightforward since the output is a tensor of shape (B, 2) and the target is a vector of either 0 or 1. The MLM loss is calculated only on 15% of the input tokens. You need to extract the logits corresponding to those masked positions as mlm_logits and then compare them with masked_labels. The overall loss is the sum of the MLM loss and the NSP loss.

This training loop runs for 10 epochs (the original paper suggested 40). Depending on your hardware, it may take an hour to complete even for the smaller WikiText-2 dataset. If you’re using the larger WikiText-103 dataset, it may take a day to complete. The trained model is saved to the file bert_pretraining_model.pth. However, you typically do not need the pretrained model since the pre-training heads are not useful for other tasks. You can simply extract the backbone BERT model and save that one alone.

For completeness, here is the complete code:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

import dataclasses

import datasets

import torch

import torch.nn as nn

import tqdm

@dataclasses.dataclass

class BertConfig:

    """Configuration for BERT model."""

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

class BertBlock(nn.Module):

    """One transformer block in BERT."""

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

class BertPooler(nn.Module):

    """Pooler layer for BERT to process the [CLS] token output."""

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

class BertModel(nn.Module):

    """Backbone of BERT model."""

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

class BertPretrainingModel(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()

        self.bert = BertModel(config)

        self.mlm_head = nn.Sequential(

            nn.Linear(config.hidden_size, config.hidden_size),

            nn.GELU(),

            nn.LayerNorm(config.hidden_size),

            nn.Linear(config.hidden_size, config.vocab_size),

        )

        self.nsp_head = nn.Linear(config.hidden_size, 2)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # Process the sequence with the BERT model backbone

        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

        # Predict the masked tokens for the MLM task and the classification for the NSP task

        mlm_logits = self.mlm_head(x)

        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits

# Training parameters

epochs = 10

learning_rate = 1e-4

batch_size = 32

# Load dataset and set up dataloader

dataset = datasets.Dataset.from_parquet("wikitext-2_train_data.parquet")

def collate_fn(batch: list[dict]):

    """Custom collate function to handle variable-length sequences in dataset."""

    # always at max length: tokens, segment_ids; always singleton: is_random_next

    input_ids = torch.tensor([item["tokens"] for item in batch])

    token_type_ids = torch.tensor([item["segment_ids"] for item in batch]).abs()

    is_random_next = torch.tensor([item["is_random_next"] for item in batch]).to(int)

    # variable length: masked_positions, masked_labels

    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item["masked_positions"]]

    masked_labels = torch.tensor([label for item in batch for label in item["masked_labels"]])

    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

                                         collate_fn=collate_fn, num_workers=8)

# train the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

    for batch in pbar:

        # get batched data

        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

        input_ids = input_ids.to(device)

        token_type_ids = token_type_ids.to(device)

        is_random_next = is_random_next.to(device)

        masked_labels = masked_labels.to(device)

        # extract output from model

        mlm_logits, nsp_logits = model(input_ids, token_type_ids)

        # MLM loss: masked_positions is a list of tuples of (B, S), extract the

        # corresponding logits from tensor mlm_logits of shape (B, S, V)

        batch_indices, token_positions = zip(*masked_pos)

        mlm_logits = mlm_logits[batch_indices, token_positions]

        mlm_loss = loss_fn(mlm_logits, masked_labels)

        # Compute the loss for the NSP task

        nsp_loss = loss_fn(nsp_logits, is_random_next)

        # backward with total loss

        total_loss = mlm_loss + nsp_loss

        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

        optimizer.zero_grad()

        total_loss.backward()

        optimizer.step()

        scheduler.step()

        pbar.update(1)

    pbar.close()

# Save the model

torch.save(model.state_dict(), "bert_pretraining_model.pth")

torch.save(model.bert.state_dict(), "bert_model.pth")

Further Reading

Below are some resources that you may find useful:

Summary

In this article, you learned how to create a BERT model from scratch using PyTorch. Specifically, you learned:

  • How to create a BERT model from scratch using PyTorch
  • How to use PyTorch DataLoader to batch data and handle variable-length sequences
  • How to pre-train a BERT model using the MLM and NSP tasks

No comments yet.