Training a reasoning model for free (almost)

In Stanford ACM's AI Clinic's February workshop, we discussed the "s1: Simple test-time scaling" paper by Muennighoff, Yang, Shi, Li, and others. Test-time compute is a method where a model receives additional computational resources during its inference phase. For example, if a model needs to solve a complex problem, a test-time compute strategy may have the model introspect and refine its answers before giving a final answer — in refining the answer, the model uses more compute [1].

The s1 paper provides another way to improve a model's reasoning capability. The method consists of two steps. First, perform supervised fine-tuning (SFT) on a pretrained model using a dataset that contains another model's reasoning traces (the text that a model outputs while it reasons). Concretely, the dataset contains many text examples of Gemini's reasoning output, and the paper finds that with 1000 curated examples, a 32-billion parameter (32B) Qwem model can be adjusted to model that reasoning behavior.

Second, during inference, we apply budget forcing to the model. Budget forcing is a method to have the model reason for a set number of tokens. Specifically, the model will demark its reasoning output in between think tags, and the model will continue outputting tokens in the think region until its output occupies the aforementioned limit. If the reasoning output does not meet the reasoning budget, we append "Wait" to the model's output and then have the model continue generating tokens. Appending "Wait" to the finetuned model's output can lead it to continue reasoning and double-check its answer.

These results look amazing, so the AI Clinic held a workshop to cover three topics:

  1. Implement s1 ourselves
  2. Apply techniques to make the most of a limited GPU
  3. Learn about the HuggingFace environment

Our code has been released on our Github.

Technical notes

The initial version of this project was implemented on a Google Colab notebook using their free GPU. While it is possible to implement s1 and the notes below with a 1.5B Qwem model (the largest one that can fit on Colab's GPU), Colab's timeout mechanisms may interfere with the finetuning process — if you aren't vigilant, the notebook will get shut down while training occurs.

In order to work in a more suitable environment, we rented a GPU on the Vast.ai platform. For $5, we rented one 48 GB Nvidia A40 GPU for 10 hours. In this time, we trained a 3B Qwem model and reviewed its reasoning capability. You can follow along with the aforementioned Github repository.

Here are the key notes for applying the s1 methodology on a limited GPU:

  1. We train the Qwem model on a pared down dataset. First, we take the s1 authors' full dataset of 59,000 reasoning traces. Then, we filter the dataset for reasoning traces that consist of no more than 2300 tokens. This limit is imposed because the model cannot be trained to output a significantly larger number of tokens on a limited GPU. On a smaller GPU (such as Collab's free GPU), this limit likely needs to be decreased.
  2. In order to fit and train a 3B model, we quantize the model and perform LoRA updates. Quantizing is a "model compression technique that converts the weights and activations within an LLM from a high-precision data representation to a lower-precision data representation" [2]. For example, a model typically contains many 32-bit floats (decimal numbers), and a quantized model may truncate those decimals to 4-bit integers.

    LoRA refers to a technique to efficiently finetune models by modeling weight changes as low-rank matrices. Importantly, a low-rank matrix can be represented in a more compact format, reducing the amount of GPU memory required to adapt a model to a new dataset.
  3. Finally, we run SFT using HuggingFace's TRL library. During SFT, the model is trained on the provided reasoning traces, enabling it to learn patterns and structures that underly reasoning.

Reasoning Observations

We only appended "Wait" once to the model's output, so in this workshop, we did not implement budget forcing. However, it was still interesting to see how SFT affected the model's behavior. Specifically, after SFT, the model knew how to use think tags and it also reasoned for many more tokens.

The base Qwem model does not make effective use of the think tag. See the example below. The initial query ends at <|im_start|>think, and instead of starting a new line after think, the model output completes the query to thinker.

<|im_start|>user
How many r's are in raspberry?<|im_end|>
<|im_start|>assistant
<|im_start|>thinker
To find out how many 'r's' are in the word "raspberry,"

After SFT, the model properly handles think:

<|im_start|>user
How many r's are in raspberry?<|im_end|>
<|im_start|>assistant
<|im_start|>think
Let's break down the process of counting the number of 'r' letters in the 
word "raspberry".

The second observation is more informal, where in a few examples, we observed the base model arriving at a final answer in fewer tokens than the finetuned model. Of course, with budget forcing, it would lead the base model to continue reasoning until it reaches the provided budget.

Now, one really interesting output is that when the finetuned model is asked to arrange the letters in "LISTEN" into a related word, it attempts (albeit unsuccessfully) to rearrange the letters itself. The base model did not show this behavior.

Back to the initial thought of rearranging [abbrv.] the letters:

*   LISETN (no word)
*   LITSEN (no word)
*   LITSEN (no word)
*   LITSEN (no word)

Final notes

In our initial tests, the finetuned 3B Qwem model did not give correct answers to questions like "How many r's are in raspberry?" or "Rearrange the letters in 'LISTEN' to form another related word." Perhaps this behavior could be corrected by implementing budget forcing.

We note that this behavior can be further explained by this project's limitations.

  1. The model is finetuned on 3B parameters, not 32B.
  2. The reasoning traces are limited to those which are under 2300 tokens, which reduces the quality of our dataset

The workshop still teaches the AI Clinic a lot about topics like efficient training, SFT, and using the HuggingFace environment, in addition to revealing the interesting behaviors from the finetuned model that we have discussed above.