7.2. Fine-Tuning to follow instructions

The goal of this section is to show how to fine-tune an already pre-trained model to follow instructions rather than just generating text, for example, responding to tasks as a chat bot.

Dataset

I order to fine tune a LLM to follow instructions it's needed to have a dataset with instructions and responses to fine tune the LLM. There are different formats to train a LLM into follow instructions, for example:

  • The Apply Alpaca prompt style example:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Calculate the area of a circle with a radius of 5 units.

### Response:
The area of a circle is calculated using the formula \( A = \pi r^2 \). Plugging in the radius of 5 units:

\( A = \pi (5)^2 = \pi \times 25 = 25\pi \) square units.
  • Phi-3 Prompt Style Example:

<|User|>
Can you explain what gravity is in simple terms?

<|Assistant|>
Absolutely! Gravity is a force that pulls objects toward each other.

Training a LLM with these kind of data sets instead of just raw text help the LLM understand that he needs to give specific responses to the questions is receives.

Therefore, one of the first things to do with a dataset that contains requests and answers is to model that date in the desired prompt format, like:

# Code from https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/01_main-chapter-code/ch07.ipynb
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""

    return instruction_text + input_text

model_input = format_input(data[50])
desired_response = f"\n\n### Response:\n{data[50]['output']}"
print(model_input + desired_response)

Then, as always, it's needed to separate the dataset in sets for training, validation and testing.

Batching & Data Loaders

Then, it's needed to batch all the inputs and expected outputs for the training. For this, it's needed to:

  • Tokenize the texts

  • Pad all the samples to the same length (usually the length will be as big as the context length used to pre-train the LLM)

  • Create the expected tokens by shifting 1 the input in a custom collate function

  • Replace some padding tokens with -100 to exclude them from the training loss: After the first endoftext token, substitute all the other endoftext tokens by -100 (because using cross_entropy(...,ignore_index=-100) means that it'll ignore targets with -100)

  • [Optional] Mask using -100 also all the tokens belonging to the question so the LLM learns only how to generate the answer. In the Apply Alpaca style this will mean to mask everything until ### Response:

With this created, it's time to crate the data loaders for each dataset (training, validation and test).

Load pre-trained LLM & Fine tune & Loss Checking

It's needed to load a pre-trained LLM to fine tune it. This was already discussed in other pages. Then, it's possible to use the previously used training function to fine tune the LLM.

During the training it's also possible to see how the training loss and validation loss varies during the epochs to see if the loss is getting reduced and if overfitting is ocurring. Remember that overfitting occurs when the training loss is getting reduced but the validation loss is not being reduced or even increasing. To avoid this, the simplest thing to do is to stop the training at the epoch where this behaviour start.

Response Quality

As this is not a classification fine-tune were it's possible to trust more the loss variations, it's also important to check the quality of the responses in the testing set. Therefore, it's recommended to gather the generated responses from all the testing sets and check their quality manually to see if there are wrong answers (note that it's possible for the LLM to create correctly the format and syntax of the response sentence but gives a completely wrong response. The loss variation won't reflect this behaviour). Note that it's also possible to perform this review by passing the generated responses and the expected responses to other LLMs and ask them to evaluate the responses.

Other test to run to verify the quality of the responses:

  1. Measuring Massive Multitask Language Understanding (MMLU): MMLU evaluates a model's knowledge and problem-solving abilities across 57 subjects, including humanities, sciences, and more. It uses multiple-choice questions to assess understanding at various difficulty levels, from elementary to advanced professional.

  2. LMSYS Chatbot Arena: This platform allows users to compare responses from different chatbots side by side. Users input a prompt, and multiple chatbots generate responses that can be directly compared.

  3. AlpacaEval: AlpacaEval is an automated evaluation framework where an advanced LLM like GPT-4 assesses the responses of other models to various prompts.

  4. General Language Understanding Evaluation (GLUE): GLUE is a collection of nine natural language understanding tasks, including sentiment analysis, textual entailment, and question answering.

  5. SuperGLUE: Building upon GLUE, SuperGLUE includes more challenging tasks designed to be difficult for current models.

  6. Beyond the Imitation Game Benchmark (BIG-bench): BIG-bench is a large-scale benchmark with over 200 tasks that test a model's abilities in areas like reasoning, translation, and question answering.

  7. Holistic Evaluation of Language Models (HELM): HELM provides a comprehensive evaluation across various metrics like accuracy, robustness, and fairness.

  8. OpenAI Evals: An open-source evaluation framework by OpenAI that allows for the testing of AI models on custom and standardized tasks.

  9. HumanEval: A collection of programming problems used to evaluate code generation abilities of language models.

  10. Stanford Question Answering Dataset (SQuAD): SQuAD consists of questions about Wikipedia articles, where models must comprehend the text to answer accurately.

  11. TriviaQA: A large-scale dataset of trivia questions and answers, along with evidence documents.

and many many more

Follow instructions fine-tuning code

You can find an example of the code to perform this fine tuning in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/01_main-chapter-code/gpt_instruction_finetuning.py

References

Last updated