5. Fine-Tuning for Classification

What is

Fine-tuning is the process of taking a pre-trained model that has learned general language patterns from vast amounts of data and adapting it to perform a specific task or to understand domain-specific language. This is achieved by continuing the training of the model on a smaller, task-specific dataset, allowing it to adjust its parameters to better suit the nuances of the new data while leveraging the broad knowledge it has already acquired. Fine-tuning enables the model to deliver more accurate and relevant results in specialized applications without the need to train a new model from scratch.

As pre-training a LLM that "understands" the text is pretty expensive it's usually easier and cheaper to to fine-tune open source pre-trained models to perform a specific task we want it to perform.

Preparing the data set

Data set size

Of course, in order to fine-tune a model you need some structured data to use to specialise your LLM. In the example proposed in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 is fine tuned to detect if an email is spam or not using the data from https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.

This data set contains much more examples of "not spam" that of "spam", therefore the book suggest to only use as many examples of "not spam" as of "spam" (therefore, removing from the training data all the extra examples). In this case, this was 747 examples of each.

Then, 70% of the data set is used for training, 10% for validation and 20% for testing.

  • The validation set is used during the training phase to fine-tune the model's hyperparameters and make decisions about model architecture, effectively helping to prevent overfitting by providing feedback on how the model performs on unseen data. It allows for iterative improvements without biasing the final evaluation.

  • In contrast, the test set is used only after the model has been fully trained and all adjustments are complete; it provides an unbiased assessment of the model's ability to generalize to new, unseen data. This final evaluation on the test set gives a realistic indication of how the model is expected to perform in real-world applications.

Entries length

As the training example expects entries (emails text in this case) of the same length, it was decided to make every entry as large as the largest one by adding the ids of <|endoftext|> as padding.

Initialize the model

Using the open-source pre-trained weights initialize the model to train. We have already done this before and follow the instructions of https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb you can easily do it.

Classification head

In this specific example (predicting if a text is spam or not), we are not interested in fine tune according to the complete vocabulary of GPT2 but we only want the new model to say if the email is spam (1) or not (0). Therefore, we are going to modify the final layer that gives the probabilities per token of the vocabulary for one that only gives the probabilities of being spam or not (so like a vocabulary of 2 words).

# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
    in_features=BASE_CONFIG["emb_dim"],
    out_features=num_classes
)

Parameters to tune

In order to fine tune fast it's easier to not fine tune all the parameters but only some final ones. This is because it's known that the lower layers generally capture basic language structures and semantics applicable. So, just fine tuning the last layers is usually enough and faster.

# This code makes all the parameters of the model unrtainable
for param in model.parameters():
    param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
    param.requires_grad = True

Entries to use for training

In previos sections the LLM was trained reducing the loss of every predicted token, even though almost all the predicted tokens were in the input sentence (only 1 at the end was really predicted) in order for the model to understand better the language.

In this case we only care on the model being able to predict if the model is spam or not, so we only care about the last token predicted. Therefore, it's needed to modify out previous training loss functions to only take into account that token.

This is implemented in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb as:

def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]  # Logits of last output token
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]  # Logits of last output token
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss

Note how for each batch we are only interested in the logits of the last token predicted.

Complete GPT2 fine-tune classification code

You can find all the code to fine-tune GPT2 to be a spam classifier in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb

Last updated