Everyone is debating RAG vs fine-tuning, which I find rather pointless: both are tools to be used in the AI engineer's toolkit. In this post, I set out to determine whether fine-tuning an LLM on retrieval-augmented prompts from the training set works better than just fine-tuning on zero-shot prompts...
- I use the three popular fine-tuning datasets: SQL, functional representation and GSM8K and find multi-percentage point accuracy improvements across all three when fine-tuning on RAG-augmented samples versus zero-shot samples.
- For these tasks, the train dataset doubles as a RAG knowledge base during both train and eval.
- I made the repo tunetherag which can be used to augment any Huggingface dataset with retrieval-augmented examples.
This experiment builds on Meta's recent paper which essentially shows that augmenting prompts with RAG-examples on "knowledge-intensive tasks" leads to accuracy improvements. This post extends the paper's findings showing that the technique works well for non-knowledge-intensive tasks too where the training set doubles as a knowledge base.
Accuracy-by-task on eval set. Performance improvements occur for each task. Blue is when fine-tuned for zero-shot inference and green is when fine-tuned on retrieval augmented prompts (with retrieval-augmented examples coming from the train set). Augmenting with RAG examples leads to considerable accuracy improvements.
All experiments were conducted with a CodeLlama 7B model quantised to 8 bits and training only a Lora adapter. OpenAI's text-embedding-ada-002 is used to find the closest examples in the train set/knowledge base for generating retrieval-augmented prompts. (My other guide has more details on how to fine-tune CodeLlama and my assortment of messy training scripts used for this post can also be found here).
Generating retrieval-augmented datasets
An example SQL retrieval-augmented prompt is this:
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question. Given the following example: ### Input: How many artists do we have? ### Context: CREATE TABLE artist (Id VARCHAR) ### Response: SELECT count(*) FROM artist Please generate the SQL query that answers the following: ### Input: How many singers do we have? ### Context: CREATE TABLE singer (Id VARCHAR) ### Response:
Where "How many singers do we have?" is the embedded document and "How many artists do we have?" is the closest semantic match from the train set/knowledge base that we use as our retrieval-augmented example.
The repo tunetherag basically does the following:
- Given a HF dataset, embed each sample from the training set using an embedding model of choice into the RAG knowledge base. I use Chroma as the local embedding store.
- Then go through the dataset again (this time both train and test splits) and for each sample, embed it and find the n closest (non-identical) documents in the train set knowledge base.
- Given the n closest documents, shove them all together into a retrieval-augmented prompt like the example above where n=1.
- Save the augmented dataset to a file and fine-tune the model on it.
Investigating different embedding functions
Accuracy-by-embedding-model on the SQL eval set. This graph is for prompts augmented with 1 example from the train set/knowledge base. text-embedding-ada-002 works the best followed by using a random sample followed by all-MiniLM-L6-v2. Using a poor unrepresentative embedding model has a negative impact over the baseline random sample.
Investigating different number of examples
Accuracy-by-no-of-examples used to augment prompts on the SQL eval set. The best performance is achieved when the model is given 1 retrieval-augmented example. What's interesting is that if there was no fine-tuning involved, using more examples leads to better performance: 5 examples gives 67.3% execution accuracy while 1 example gives 63.4% execution accuracy. (All prompts are augmented with the all-MiniLM-L6-v2 model. And CodeLlama is fine-tuned to convergence on these retrieval-augmented datasets.)
Conflation with training set/very similar examples that aren’t helpful
Something that I was noticing with SQL was that examples being found for a given sample text had very similar words but were not necessarily the best functional example for the query at hand.
For example, if the input sample we want to augment is:
How many farms are there?
the semantically closest example in the train set is:
What are the maximum and minimum number of cows across all farms.
which clearly doesn't give the model much insight into how to do a counting task. It belongs to the same
farm dummy database from the dataset and doesn't give the model much useful functional information.
Because of this, I implemented a constraint that prevented retrieval-augmented examples being from the same database. Only the closest semantic samples from different dummy databases could be used.
For the "How many farms are there?" example, the closest example from a different database becomes:
How many entrepreneurs are there?
which is a much better example for the model to use to learn how do do this kind of count task. This constraint by itself bumps up eval accuracy from 77.6% to 79%.
Another interesting observation with using retrieval-augmented samples versus zero-shot samples can lead to great stability improvements across a fine-tuning job.
Both graphs are using a Llama 2 7B model at 8 bit quantisation using a Lora adapter:
Training loss on SQL dataset with zero-shot non-retrieval augmented prompts.
Training loss on SQL dataset with one-shot retrieval-augmented prompts.
This set of experiments shows that for any fine-tuning task, knowledge-intensive or not, allowing the model to leverage both in-context learning and fine-tuning leads to better performance than just fine-tuning for zero-shot inference.
Email me at sam[at]ragntune.com if you have any questions!