Understanding reinforcement learning for model training from scratch

25 min read Original article ↗

Rohit Patel

An intuitive treatment of RLHF, TRPO, PPO, GRPO, DPO and RLAIF. This article follows my paper here: https://arxiv.org/abs/2509.04501

In this article, we are going to talk about reinforcement learning for model training (RLMT) in the context of LLMs. We will start with pre-trained model capable of predicting next token, and walk through everything we need to train it into an instruction-tuned model. In the last article titled “Understanding LLMs from scratch using middle-school math” we went through and built a large language model from scratch using basic math operators. The article was fully self-contained, required knowledge only of middle school math, and was well read, becoming one of the medium’s top articles for 2024 (it was cool to be featured alongside authors like

). So I am going to assume that people liked the no-nonsense, no-jargon, understand everything from first-principles approach. We will follow the same approach here, but with two changes:

  • This article will be self-contained only with the earlier article, and not self contained in and of itself. What this means is I am not going to repeat the content of the first article and we are going to assume you know what LLMs are.
  • We are going to assume some small amount of high-school knowledge. Basically, we will assume that you know what a variable is, what a vector is, what probabilities are, and that the \Sigma symbol signifies addition.

Let’s get started. Where we left off was that we had a model that was trained to predict the next token in the English language. For simplicity, let’s think of a token as a word, and as such, what the model does is that if you give it some text, it will try to predict the most likely word as continuation of this text. These are called pre-trained models with examples being GPT and Llama.

While this might seem like glorified auto-complete, pre-trained models can in fact do remarkable things. For example, if I were to try to complete the sentence “The capital of Germany is”, the model will complete it with Berlin. Moreover you can recursively feed the text back to the model to get it to write longer text and articles as shown in Figure 1. For example, one might start giving the model the following text: “This article explores Florida’s alligators, covering their habitat, physical characteristics, diet, reproduction, interaction with humans, and fascinating facts.” Continuing this text will cause a good pre-trained model to complete the entire article on Florida’s alligators along with the aspects mentioned in the seed text.

Press enter or click to view image in full size

Figure 1: Recursively feeding generated tokens to LLMs for longer text generation

At this point, it is conceivable that anything you may want the model to do can be reframed as a completion objective, e.g. “This is a 1000 word essay on…” instead of “Write me a 1000 word essay on…”. So why do we need instruction tuned models?

Instruction tuning models

Meaningfully using pre-trained models for useful purposes would require change in behavior, which is difficult. Instruction tuned models can on the other hand can be used in a natural question answering format which can lead to ease of use and adoption. This is evident by the fact that LLMs really caught on when the first instruction tuned models were made available.

Being text completion models, pre-trained models are trying to predict the most likely next token based on the corpus of data that they are trained on. This means that pre-trained models may:

  • Respond to questions with additional questions since training data may contain lists of questions
  • Not do a great job of following instructions
  • Make up facts or names/places that sound plausible but don’t exist
  • Generate offensive or harmful content without much consideration

The fine-tuning process is often used to further train the models to address these shortcomings. The biggest visible change this leads to is that the models follow instructions in a question-answer format, and as such we will call these models instruction tuned models (IT models) e.g. . Ouyang et al. (2022), Thoppilan et al. (2022), Touvron et al. (2023).

Now, how do we get the model to answer questions? The simplest thing we can think of is if the model was trained on data where questions were followed by answers then it would learn to to do text completions in a way so as to complete questions with answers. In essence, nothing about the model changes except now it is conditioned more strongly to complete any text that’s a question with an answer. This is the key to instruction tuning.

Yet, this doesn’t get us all the way to where we would like to be. For example, if you enter “alligator and crocodile” in any LLM today, it will respond with a description of the animals and similarities and differences. If these are merely text completion, the expectation would be to continue the sentence in some reasonable way, rather than to treat it as a question and try to return an answer. How is this achieved? One simple approach would be to structure the training data in the model so that questions are preceded with “User” and answers with “Model”. Something like this:

  • User: What is an antelope? Model: An antelope is a type of mammal that belongs to the Bovidae family…
  • User: How many legs does a spider have? Model: A spider has eight legs…
  • User: Who wrote Romeo and Juliet? Model: Romeo and Juliet, the iconic tragic love story, was written by…

By structuring the question and answer in the training data in this manner, whenever the model sees “User:” followed by some text followed by “Model:”, it will be conditioned to treat the text as a question and complete the whole sequence with something that looks like a complete response related to that text. This is similar in spirit to chat markup language suggested in Ouyang et al. (2022).

This works relatively well, but there can be issues. For example, what if the string User: or Model: occurs naturally in the question or the answer, for example “How is the delimiter ‘User:’ typically used in model training?”, which is a valid question to ask an LLM. Can we improve this system? Let’s go back to the basics of how text is fed into these neural networks. Remember neural networks can only ingest and output numbers. As such, all text is broken up into subword “tokens” and then fed into the network. Each of these tokens is represented by a vector which is its “embedding”. So, in essence, when you are feeding text to an LLM what you are really feeding is a sequence of vectors each representing a token. If this feels unfamiliar, you may want to recap Understanding LLMs from Scratch Using Middle School Math.

One nice way to delineate a question from an answer would be to simply mint two new tokens (and corresponding embedding vectors) that are placed in the position of User: and Model:. This is similar to perhaps inventing a new character for delimiting files because all the existing characters may already exist somewhere. Let’s call these newly minted tokens <USER> and <MODEL> for sake of convenience. This does not mean that the strings <USER> and <MODEL> have special meaning to the model — we are just using them here on paper to denote the special tokens because we need some way of denoting them. In reality, if someone were to feed these strings to model they would have their own tokenization same as anything else e.g. <THISSTRING> or <THATSTRING> etc. These tokens simply represent delimiters (we could have called them <DELIM1> and <DELIM2> and it would make absolutely no difference anywhere except in this discussion). With this new scheme, an embedding can be trained, and you can now have the model always follow the special tokens with an answer-type completion. Modern inttruct tuned models use some variation of this scheme which was suggested in Askell et al. (2021).

Now we have a pretty good scheme for instruction tuning a model and making sure a text completion model starts to behave more and more like a question answering model. If we can design a good corpus of questions and answers, then doing additional training on the model using the very same next token prediction scheme will give us what we need. Moreover, this stage can now be used to enforce other values on the model that are of interest. For example we can:

  • Put questions in the training data with specific instructions and answers that follow those instructions so the model can be better at instruciton following in its answers
  • Have questions that can lead to harmful or offensive content in responses and then answers that are refusals so that the model will learn to refuse answers to potentially harmful questions (e.g. respond with “Sorry I cannot provide that information” when asked a question “How to make a dirty bomb?”)

These are not the only things one can address during training. Moreover, the above-mentioned objectives of instruction following and responsible AI are often achieved using a mixture of approaches and not simply what is described here.

Supervised fine tuning

Supervised fine-tuning is no different from pre-training. In both cases, you are doing the exact same thing, training the model to predict the next token (other objectives exist, such as masked language modeling used in Devlin et al. (2019), but are not nearly as prevalent and outside the scope of this article). The key difference is during pre-training you are using a much larger corpus such as the common crawl whereas for SFT you have a much smaller and higher quality dataset.

Let’s actually write down the loss function for the next token prediction objective since we did not do so in our previous article. Let’s assume that we have an LLM with a vocabulary size of 32,000. Now this means that the output layer of the LLM will contain 32000 numbers. In the previous article, we talked about the softmax function and the motivation behind it. After softmax is applied, we get 32k numbers that are all between 0 and 1 and they all add up to 1 and as such this can be treated as the vector of probabilities for the 32k tokens. Let’s name these 32k numbers with indexed variables such that the first number is p_1 the second one is p_2 and so on such that the iᵗʰ number is p_i going all the way to p_32000.

Now, in the next token prediction case we have the training data. Let’s say the sentence “The quick brown fox jumps over the lazy dog” is part of our training data. So what we’re doing here is that if we give the model part of the text, say “The quick brown fox” then the model should predict “jumps”. What this means is that of all the 32k tokens in the vocabulary, the probability of the token “jumps” should be the highest. In practice, tokens are not the same as words. Words are often broken down into one or more tokens. For example the word jumps could be broken down into two tokens. For sake of comprehensibility, we will assume that our language models are using words as tokens. Let’s say jumps is the iᵗʰ token then what we want to do is maximize p_i.

Now, the value of each of the p_i will be different depending on which previous tokens were fed. For example, if we now feed “The quick brown fox jumps” to the network we want to maximize another p_j where the jᵗʰ token is “over”. We need some concise way of representing this. Basically what we want to capture is Probability of “jumps” in the condition that “The quick brown fox” was fed to the model. And we want to be able to represent that easily for many words. Let’s label all the tokens in the sentence tok_1, tok_2, tok_3 and so on.

Let’s use the vector text to denote the vector of tokens which contains T total tokens tok_1, tok_2…tok_T. It gets tiresome to write all the tokens repeatedly so let’s use the notation a_t to denote tok_t and s_t to denote tok_1, tok_2,…,tok_{t-1} which will make it easier for us to express things. Note s_t counts first t-1 tokens not first t tokens. Let’s use π to denote the probability of an event. And let’s use ‘|’ to denote “conditional on”. So we can write the above probability as:

Now if we want the model to learn the whole sentence, we would be predicting these probabilities recursively. If we assume that they are independent, we can multiply these probabilities to get the total probability of the sequence being generated by the model. As such the objective is:

Press enter or click to view image in full size

These probability numbers are extremely small. For each one of these, the probabilities are one of 32k numbers that all add up to 1. So you are really multiplying and maximizing really tiny numbers. If all probabilities were equal, they would be 1/32000 = 0.00003125, and so many numbers will be smaller. Multiplying them for even a small sequence such as this one could mean that the number you are trying to maximize is around (1/32000)⁹ = 2.84 × 10⁻⁴¹ which has 40 zeros after the decimal and before the digits 284! That is a difficult number to work with. Fortunately, if we take a log, then log(1/32000) = -10.37, which seems to be a much more manageable number. Moreover, log has this nice property that log(a·b) = log(a)+log(b) and so if we take log of these expressions, instead of multiplying really really tiny numbers and getting tinier numbers still, we can simply add up a few reasonably sized numbers. Since log is a nice monotonic function maximizing x and maximizing log(x) means the same thing. So why don’t we write our objective function in a more manageable way:

Now we have nice numbers to maximize. Keep in mind all the log numbers are going to be always negative since log of anything under 1 is a negative number and all probabilities are going to be under 1. So maximizing a negative number means you are minimizing its magnitude since -5 is greater than -10 on the real line. One nice and clean option is to rewrite the problem as a minimization problem with a negative sign instead. That way, you have positive numbers that you are trying to minimize. Moreover, not all sequences are length 9 so let’s make this a slightly more general T length of the sequence.

Here, what you are really doing is changing the parameters of the model while minimizing this loss, so these probabilities will change. If we want to be really clear we should add somewhere in the loss function the clarification that it is coming from a specific model. One way to do it is simply to put the model in the condition, such that what you are really saying is something like Probability of “jumps” condition on “The quick brown fox” being fed to the model where the model is M. Let’s use NLL(text,M) to denote the loss. This would make the loss look something like this:

All the mathematical notation may make it look complicated, but as we know, it is rather quite simple. We are just trying to maximize the probability of the specific tokens from the model. This formulation is also called “negative log-likelihood” because the probability of the tokens under the model can also be thought of as the likelihood of the model given the tokens. The model that gives us better values of probabilities is more likely to be a better model. So you can say you are maximizing the likelihood function and trying to find the best model that gives you the highest value of the likelihood function. We have already talked about how we can use gradient descent to minimize the loss once you have a loss function. Supervised fine-tuning does just that with the negative log-likelihood as the loss function. This loss is what pre-training uses as well, at a much larger scale. The NLL loss is also referred to as the “cross entropy loss” in deep learning literature, which happens to be a more general concept which we discuss in An intuitive treatment of Negative log-likelihood, Cross entropy, KL divergence, and Importance sampling. Fisher (1922), Kullback and Leibler (1951), Shannon (1948), Hopfield (1987) and Bengio et al. (2003) lay some of these foundations.

One last thing to note here is that the loss written above is for a single example. Usually when you are training you have a lot of training examples to run supervised fine tuning on and you take average loss over all those examples. Let’s say you denote the set of all training samples by S and suppose the set S has S samples; then the loss looks something like this:

One thing to note is that text is the concatenation of the question and the answer for a particular sample. This also means that if we want the model to simply learn how to generate answers to questions, we could calculate the NLL starting from token t where t is the first token of the answer. This does not change anything we discussed above, other than the slight change to NLL definition.

So we’re ready to instruction tune the model using SFT, but the issue now is finding a corpus of question-answer pairs that covers wide topics. The written world is full of books, articles, papers etc. but not nearly enough text exists in the form of Q&A. One option is to have humans write question-answers, or source questions from somewhere (e.g. users asking questions online) and have humans write answers to supplement whatever Q&A datasets one could find online. All of these are time consuming and expensive approaches. Nonetheless, SFT remains a critical first step in training models to be instruction tuned and getting them to a somewhat respectable place of generating answers.

Rejection Sampling

Our model can now answer questions, somewhat. What can we do to scale our training and make it better? One idea is that we can use the model itself to generate more question-answering data. If we could source questions from somewhere and have the model generate answers, we would have more question-answer pairs. This would substantially reduce the work of creating such dataset for fine-tuning. The issue however is that since the model isn’t well trained, it may not generate the best data. How can we get over this?

Why don’t we generate many responses for each prompt from the model and see if some of them are good responses to the prompts, effectively rejecting all the other responses. This should allow us to curate a model generated dataset that is of higher quality than the average model response. Moreover, it is a lot easier for people to select the best response from a group rather than type up a good response to a question. The process looks something like this:

  • Start with a prompt
  • Use the current model to generate G different completions (responses) for that prompt, by sampling with enough randomness (temperature) to get a variety.
  • Rank the G responses.
  • Select the top response (or top few responses) and discard the rest.
  • Fine-tune the model on the prompt paired with that selected best response treating it as SFT data.

We now have a way of using the model to generate the dataset for supervised fine-tuning. It is a model generated dataset but nonetheless we are going to use it to train the model using SFT.

How can we further improve things? While it is easier for humans to find good responses from a sample, it is still a manual process. What if we had a model pick better responses? We could train a “preference model” that is capable of selecting the better of two responses when provided with two options. We would still need to have human labeled data for training this preference model, but after the initial set of human labeling work, we would be able to bootstrap an automatic process using the preference model. To train this preference model, we would thus need data that has the prompt and two generated responses where a human has selected a preferred response. In essence, you have two text_i, text_j where each one is a combination of the prompt and one of the generated responses. Let hp denote the actual human preference recorded in data such that hp can take values zero and one, and if hp=1 then text_i is the human preferred text. The model takes the two text inputs and it gives as output the probability of each of the inputs being preferred by a human (i.e. it returns a single probability of the first input being preferred by human, let’s call it P(text_i), since the other probability will simply be P(text_j) = 1-P(text_i)). Now what we want to do is maximize P(text_i) when hp=1 and maximize P(text_j) when hp=0. We can write the loss using the same trick above to convert probabilities to negative log-likelihood and turn it into a minimization problem:

And so when human preference hp=1 then P(text_i) is being maximized and when hp=0 then P(text_j) = 1-P(text_i) is being maximized. This is the same as before, and in the same way, to get the average loss we simply take an average over all prompts in the training dataset. This loss function is called “binary cross entropy loss”. The general concept of rejection sampling was suggested byy von Neumann (1951), whereas Zelikman et al. (2022) and Yuan et al. (2023) demonstrate usage in LLM training. We can also denote text_w as the winning text and simplify the expression:

Rejection sampling followed by supervised fine-tuning in a near-automated loop sounds like a great way to improve models. But it has certain issues and areas for improvement:

  • Discrete learning: When we are rejection sampling, we are learning in discrete steps where we sample a lot of responses on a lot of prompts and then run a round of SFT from improvement. This means that multiple parameter updates are performed using the rejection-sampled data.
  • No Learning from Mistakes: When we throw away the bad outputs, the model doesn’t learn why they were bad. It only gets signal from the one best answer that it was good. If the model repeatedly produces a certain kind of error in the rejected samples, this method doesn’t explicitly penalize that error. The feedback is purely positive (on the chosen answer) and not negative on the others. In essence, the model isn’t told what not to do, only what to do more of. This could limit the improvement or require many examples for the model to implicitly figure out the boundaries of bad responses.
  • High Computational Cost: Generating many samples per prompt is expensive, especially for large models. If we generate 10 candidates for each of 100k prompts, that’s 1 million model forward passes to sift out 100k best samples. This is much more work than a single pass per prompt. This can sometimes be mitigated by parallel generation or clever sampling, but it’s still a factor to consider.
  • Model Collapse and Bias: This is the most important issue with rejection sampling. If we always pick the single highest-scoring answer according to a fixed criterion, we risk over-optimizing the model on that criterion. The model might start giving very narrow, optimized responses that score well but lack diversity or even coherence. For example, if the reward model or human annotator inadvertently prefers verbose answers, the model might converge to always giving overly long answers. In extreme cases, the model could exploit weaknesses in the scoring system, a phenomenon akin to reward hacking. Without any counterbalance, repeatedly fine-tuning on only top outputs can drive the model distribution to collapse around patterns that the scorer loves, even if those patterns are unnatural. We may be pushing the model into areas the scoring model is not well calibrated, causing garbage outputs that the scorer mistakenly rates high. This issue is made worse by the fact that “No learning from mistakes” necessitates multiple rounds of rejection sampling followed by SFT increasing the chances of model collapse.

Reinforcement learning

How can we do better than rejection sampling? Let’s try to fix the issues we listed above. Much of this section will deal with reinforcement learning methods, pioneered by Brown (1951), Bellman (1957), Barto et al. (1983), Sutton (1988), Watkins (1989), Littman (1994), B¨orgers and Sarin (1997), Hu and Wellman (2003), and many others. However, we will discuss them in the context of LLM training. In the reinforcement learning literature, the set of rules that determine the value of a_t given the current s_t is called a “policy”, and a_t is referred to as “action” at time t whereas s_t is referred to as the “state” at time t. In essence, a policy is something that provides a probability distribution over the set of actions (in our case, the set of actions is the model vocabulary, and the output of the final softmax layer is the probability distribution) given the current state s_t (in our case the current state is the string, i.e. the tokens, up to point t-1). You then sample the action a_t from this probability distribution. For example, you could simply take the token with highest probability as a_t (and that would be called greedy decoding). In our case, the model is the policy, and as such the terms model and policy are interchangeable for our purposes.

REINFORCE

One of the things we can do is that instead of treating the model generated samples as SFT data and running multiple parameter updates, we create new samples after each parameter update (or more practically a small number of updates, let’s say less than five) thereby leading to a more continuous improvement in model and reducing the possibility of the model being overfit on a particular set of output. However, this would increase the compute burden many-fold. We already have to do a lot of generations to get a single training sample. So we first have to solve our issue of not using all the generated samples before we could do this.

How about using all the generated training samples in the loss by simply weighing them by their goodness? The current SFT formulation from Equation above simply averages the loss over all samples, effectively giving every rejection sampled example a weight of one. What if we had a function, let’s call it a “reward function”, that would give us a score for each sample in terms of how good or bad it is. This way, we could use all generated samples during training and we would simply weigh the less good samples appropriately — or even negatively — in the loss function. It would look something like this:

Where R(text_i) is the reward from text sample text_i and L(text_i, M) is as defined in the previously mentioned equation. This is called the REINFORCE algorithm and is due to Williams (1992). In practice, this loss function can lead to high variance in the gradients. To mitigate that, a baseline is subtracted from the reward. The baseline can be anything, but what is commonly used is something that depends on the input text being fed to the model for next token prediction (i.e. s_t). Let’s use V_M(s_t) to denote the baseline at time t. As such, we cannot separate the reward function from the sum in the equation and we will need to expand the entire term. This is what it looks like:

Where text_i is the i-th sample text in the dataset and a_it denotes the t-th token of this text i and s_it denotes the sequence of tokens of this text i from 1 to t-1 (and as such, a whole text of length T can also be denoted by s_{T+1}). The baseline function V_M(s_t) depends on the text tokens we have up to t, and the model M. Finally, π_M(a_t | s_t) is the probability of the model M giving the token a_t as the next token when the sequence of tokens s_t (which is of length t-1 tokens) is input into the model, as discussed before.

Recall that each of the training sample text_i is essentially a combination of a prompt and a generated completion (in case of instruct-tuning, the completion would be a response). This whole sequence of prompt+response constitutes a single training example. If multiple responses were generated for one prompt, we would have multiple text sequences. One interesting thing to note here now is that once we have a reward function, we no longer necessarily need to do multiple generations for the same prompt. We can simply do one generation, improve the model, and do another generation with the improved model. Every generated response can be used for training and it makes the process a lot more efficient than rejection sampling. Another thing to note is that same as in the SFT case, since we want the model to only learn the answers to the questions and not the questions themselves, we start the inner sum of the term from a value of t such that it corresponds to the first token of the string text_i where the answer starts, i.e. t for each i such that text_it is the first token of the answer following the question string. This can be done for all reinforcement learning algorithms, and as such we will not revisit this note again.

The process of taking actions, getting rewards, and learning from those rewards in a loop to improve future decisions is called reinforcement learning. REINFORCE is one of the simplest reinforcement learning algorithms, and it addresses the first three points that we made about rejection sampling. Let’s finish our discussion of REINFORCE by understanding where these rewards come from — something that is relevant to many reinforcement learning algorithms.

Value function

Converting Latex to medium is tiring. The rest of the discussion on Value function, Advantage function, TRPO, PPO, GRPO, DPO and everything else can be found in a much better formatted, freely available, pdf here: https://arxiv.org/abs/2509.04501