Tiny Agents - Training Small LLMs to Use Tools with DPO and Synthetic Data
TL;DR: I've created a synthetic dataset for training LLMs to use tools, and am training a model using DPO to improve accuracy. The dataset is available here, and I'll be releasing trained models soon.
Why tool usage?
When ChatGPT emerged, its initial utility as a search engine replacement and knowledge engine. However, over time, it became clear that it, and models of its class, possessed far more interesting skills: some degree of reasoning ability, and even the ability to perform tasks by outputting structured text (i.e. tool calls). This sparked excitement about its potential as an "agent" capable of interacting with the world. Demos followed, from buying things on the internet to more general tasks like making $100k(!). For a brief period after ChatGPT's launch, it seemed that all we needed to do was find the right agent framework, the correct system prompt, and boom, we'd have AGI.
Over a year later, this clearly hasn't happened yet. While impressive demos continue to be released - and continue to be increasingly impressive and perform an increasingly wide range of tasks - each of them thus far has the same fatal flaw: reliability. On the latest benchmarks, such as AgentBench, even the best models only perform a task correctly 50-60% of the time, and often fail in ways that are hard to predict. That being said, such systems hold a lot of promise; it's painful to see that such tasks are technically possible in many cases, but not reliable enough to work. Adding to this, training models to perform specific tasks typically requires RLHF to some degree - a technique which require lots of expensive ranked pairwise data, and are notoriously hard to train.
DPO and Synthetic Data
However, in the last few months, two techniques have become popular that might massively help with these limitations: DPO and the rise of synthetic datasets. Simply put, DPO uses negative-positive pairs like RLHF, but doesn't require training a separate reward model; instead, the loss is calculated directly using the log probability ratios between the positive and negative examples for each estimate. This massively reduces the compute requirements required to train a DPO model; a common trick is to load the trained model as a LoRA to get the positive logprobs, and then unload them to get the baseline logprobs. In some cases, you can tune a model with DPO on a single 3090.
Synthetic datasets are exactly what the name implies: using LLMs to generate data to use for fine-tuning. A recent example is Cosmopedia, by Huggingface - which uses a crawled corpus as the seed, and then re-phrases the data to generate a synthetic dataset that follows a format that is more effective for fine-tuning purposes. This was demonstrated for general LLM alignment in the LIMA paper, which used a synthetic dataset of only 1000 samples; more recently, the Yi technical report showed that a synthetic dataset of 10k samples was enough to align a model for instruction following with DPO and only 10k samples.
Putting these together, I saw an opportunity to train a model for tool use, entirely with synthetic data, using DPO to improve accuracy over traditional SFT approaches. Furthermore, I'd limit the amount of distinct tools used; given lots of examples on a small set of tools, the hope is that I'd be able to reach a level of accuracy in the tool use that meets or exceed the current state of the art.
Designing the dataset
Then the question becomes: How do we guarantee that we have one objectively 'correct' answer, and one objectively 'incorrect' answer? This is the trickiest part; we could generate N different calls and score them, as in the WizardLM paper but that requres N * (no. of rows) completions, which can be prohibitively expensive. A similar approach was used in Anthropic's RLAIF paper - using an LLM to rate the completions, and then using the ratings to train a reward model. However, this is also expensive, and requires a lot of data to train the reward model. However
The solution I landed on is to basically force the model to generate negative completions, using a form of 'dropout' - similar to dropping out labels in classification tasks, we can 'drop out' part of the tool definiton, or remove it entirely. This massively increases the chance that the model will generate a negative completion. In total, we have 4 types of dropout:
- Remove the tools' descriptions, but keep the tool name and parameters.
- Reorder the tools' parameters.
- Remove the tools' parameters, but keep the tool name / description.
- Completely remove the tool from the prompt, so that the model thinks it doesn't have one.
The full approach looks like this:
- Generate a list of seed use cases for tools present; for example, given a calculator tool, prompt the LLM to generate tasks like "add 2 and 3", "subtract 5 from 7", etc.
- Generate sample conversations from a user, of the user asking for help with a tool, and the agent responding with the correct answer.
- Generate negative pairs for each of the seed use cases, using the dropout techniques described above.
Not too complicated! However - this took lots of tweaks to get right in practice. The dropouts would cause too few failures; or the model would generate duplicate completions for either the negative samples or the positive ones. You can see the code used to generate this dataset here.
I'm still in the progress of training models on this dataset; of course, generating the dataset is only half of the problem, but I'm happy enough with the quality that I'd like to share it more broadly. I'll update this post as I train the models and get results.
Takeaways
- Forgive errors when possible, and validate logs constantly. I had to stop the dataset every 1k rows or so to check for errors in formatting, etc; early on I hadn't asked the model to generate a detailed enough tool call, leading to a high percentage of completions that were just "I don't know" or "I can't help with that". Other issues are more subtle - such as slight re-phrasings leading to duplicate completions for totally different use cases.
- For most inference providers, you can find their throughput cap, and progressively batch, rather than using a fixed batch size. This really helped with speeding up generation; I added a basic falloff system to avoid getting rate limited as well.
Example rows
Seed dataset (found here):
Tool | User Question | Tool Call | Agent Response |
---|---|---|---|
{'name': 'find_timezone', 'description': '', 'parameters': {...} | Find the timezone of Tokyo | {'location': 'Tokyo, Japan'} | The timezone of Tokyo, Japan is Asia/Tokyo. |
{'name': 'sunrise_sunset_times', 'description': '', 'parameters': {...} | What time will the sun rise and set in Los Angeles on July 15, 2022? | {'location': 'Los Angeles', 'date': '2022-07-15'} | On July 15, 2022, in Los Angeles, the sun is estimated to rise at 5:45 AM and set at 8:15 PM. |
DPO pairs (found here):
Tool | User Question | Tool Call | Negative Tool Call | Agent Response | Negative Response |
---|---|---|---|---|---|
{'name': 'find_timezone', 'description': '', 'parameters': {...} | Find the timezone of Tokyo | {'location': 'Tokyo, Japan'} | N/A | The timezone of Tokyo, Japan is Asia/Tokyo. | I'm sorry, I couldn't find the timezone for Tokyo. |
{'name': 'sunrise_sunset_times', 'description': '', 'parameters': {...} | What time will the sun rise and set in Los Angeles on July 15, 2022? | {'position': 'Los Angeles' | N/A | N/A | I'm sorry, the request made failed and I couldn't return the time for sunrise and sunset in Los Angeles on July 15, 2022. |
Note the incorrect parameters in the second negative tool call, and the lack of a tool call in the first negative tool call.
Disclaimer: I work at Stability AI, but all work mentioned here is a personal side project and isn't affiliated with Stability in any way.