Arithmetic ability is close to full marks! The National University of Singapore released Goat, which killed GPT-4 with only 7 billion parameters, and initially supported 16-digit multiplicat

Although large-scale language models have demonstrated superior performance in various natural language processing tasks, arithmetic problems are still a major difficulty. Even the most powerful GPT-4 is still difficult to deal with basic computing problems.

 

Recently, researchers from the National University of Singapore proposed a model goat Goat dedicated to arithmetic. After fine-tuning on the basis of the LLaMA model, it achieved significantly better arithmetic ability than GPT-4.

 

Paper link:

https://arxiv.org/pdf/2305.14201.pdf

 

By fine-tuning on synthetic arithmetic datasets, Goat achieves state-of-the-art performance on BIG-bench arithmetic subtasks,

 

Goat can achieve near-perfect accuracy in addition and subtraction of large numbers only through supervised fine-tuning, surpassing all previous pre-trained language models, such as Bloom, OPT, GPT-NeoX, etc., among which Goat-7B with zero samples has achieved The accuracy even exceeds the PaLM-540 after few-shot learning

 

The researchers attribute Goat's superior performance to LLaMA's consistent word segmentation of numbers.

 

To solve more challenging tasks, such as large number multiplication and division, the researchers also proposed a method to classify tasks according to the learnability of arithmetic, and then use basic arithmetic principles to classify non-learnable tasks (such as multiple digit multiplication and division) into a sequence of learnable tasks.

 

After comprehensive experimental verification, the decomposition step proposed in this paper can effectively improve the arithmetic performance.

 

And Goat-7 B can be efficiently trained using LoRA on a 24 GB VRAM GPU, and other researchers can easily repeat the experiment. The model, data set, and python script for generating the data set will soon be open-sourced.

 

A language model that can count

language model

LLaMA is a set of open-source pre-trained language models trained on trillions of tokens using publicly available datasets and achieves state-of-the-art performance on multiple benchmarks.

 

Previous research results have shown that tokenization is important to the arithmetic ability of LLM, but commonly used tokenization techniques cannot represent numbers well, for example, numbers with too many digits may be segmented.

LLaMA chooses to divide numbers into multiple tokens to ensure the consistency of number representation. The researchers believe that the extraordinary arithmetic ability shown in the experimental results is mainly due to the consistent word segmentation of numbers by LLaMA.

 

In experiments, other fine-tuned language models, such as Bloom, OPT, GPT-NeoX, and Pythia, could not match the arithmetic capabilities of LLaMA.

 

Learnability of Arithmetic Tasks

 

Previous theoretical analysis of using intermediate supervision to solve composite tasks showed that such tasks are not learnable, but can be decomposed into a polynomial number of simple subtasks.

 

That is, non-learnable compound problems can be learned by using intermediate supervision or step-by-step chain of thought (CoT).

 

Based on this analysis, the researchers first experimentally categorized learnable and non-learnable tasks.

 

In the context of arithmetic computing, learnable tasks are typically those for which a model can be successfully trained to directly generate answers, achieving a sufficiently high accuracy within a predefined number of training epochs.

 

Non-learnable tasks are those for which the model has difficulty learning correctly and generating direct answers even after extensive training.

 

Although the exact reasons behind the variation in task learnability are not fully understood, it can be hypothesized that this is related to the complexity of the underlying schema and the size of working memory required to complete the task.

The researchers experimentally examine the learnability of these tasks by fine-tuning the model specifically for each task in a simplified synthetic environment.

Learnable and non-learnable tasks

 

The result of task classification is also the same as human perception. With practice, humans can calculate the addition and subtraction of two large numbers in their minds, and can go directly from left (most significant digit) to right (least significant digit) without hand calculations. numbers) to write the final numerical answer.

 

But mental arithmetic solving large number multiplication and division is a challenging task.

 

It can also be observed that the above classification results for the tasks are also consistent with the performance of GPT-4. In particular, GPT-4 is good at generating direct answers for large number addition and subtraction. When it comes to multi-digit multiplication and division tasks, the accuracy will Decreased significantly.

 

The inability of models as powerful as GPT-4 to directly solve non-learnable tasks may also indicate that generating direct answers for these tasks is extremely challenging, even after extensive training.

 

It is worth noting that tasks that are learnable for LLaMA may not necessarily be learnable for other LLMs.

 

Furthermore, not all tasks classified as non-learnable are completely impossible for the model to learn.

 

For example, multiplying 2-digit numbers by 2-digit numbers is considered a non-learnable task, but the model can still directly generate answers by overfitting the training set if the training set contains all possible 2-digit multiplication enumeration data.

 

However, the whole process requires nearly 10 epochs to achieve an accuracy rate of about 90%.

 

And by inserting the CoT proposed in the paper before the final answer, the model can achieve quite good accuracy in double-digit multiplication after 1 epoch of training, which is also consistent with previous research conclusions, that is, the existence of intermediate supervision. aid in the learning process.

 

Addition and Subtraction

These two arithmetic operations are learnable, and with only supervised fine-tuning, the model demonstrates a remarkable ability to accurately generate direct numerical answers.

 

Even though the model was only trained on a very limited subset of the additive data, it managed to capture the fundamental patterns of arithmetic operations, as evidenced by the near-perfect accuracy of the model on an unseen test set, And without using CoT

 

multiplication

The researchers experimentally verified that n-digit multiplication by 1-digit multiplication is learnable, while multi-digit multiplication cannot be learned.

 

To overcome this problem, the researchers choose to fine-tune the LLM to generate CoT before generating the answer, decomposing the multi-digit multiplication into 5 learnable subtasks:

1. Extraction, extracting arithmetic expressions from natural language instructions

2. Split (split), split the smaller number of the two into place values

3. Expansion, summation based on distributive expansion

4. Product (product), calculate each product at the same time

5. Adding term by term, add the first two terms, copy the remaining terms, and get the final sum

Each of these tasks is learnable.

 

division

Similarly, it can be observed experimentally that n-digit division by 1-digit number is learnable, while multi-digit division is not.

 

Using the recursive equation that improves slow division, the researchers devised a new chain of thought prompt.

The main idea is to subtract multiples of the divisor from the dividend until the remainder is less than the divisor.

data set

The experiment designed in the article is the addition and subtraction of two positive integers, each positive integer contains up to 16 digits, and the result of the subtraction operation may be negative.

 

In order to limit the maximum sequence length generated, the result of multiplication is a positive integer within 12 digits; in the division of two positive integers, the dividend is less than 12 digits, and the quotient value is within 6 digits.

 

The researchers synthesized a data set using a Python script that generated about 1 million question-answer pairs. The answers contained the proposed CoT and the final digital output. All numbers were randomly generated to ensure that the probability of repeated instances is very low, but small. Numbers may be sampled multiple times.

 

fine-tuning

To enable the model to solve arithmetic problems based on instructions and facilitate natural language question answering, the researchers used ChatGPT to generate hundreds of instruction templates.

 

During instruction tuning, a template is randomly selected from the training set for each arithmetic input and fine-tuned for LLaMA-7B, similar to the approach used in Alpaca.

Goat-7B can be fine-tuned using LoRA on a 24GB VRAM GPU, and it takes only about 1.5 hours to fine-tune 100,000 samples on an A100 GPU with near-perfect accuracy.

 

Experimental results

It seems unfair to compare the performance of Goat and GPT-4 on large number of multiplications and divisions, because GPT-4 will directly generate the answer, while Goat relies on the thought chain designed, so when GPT-4 evaluates, it is still in the middle of each prompt. Add "Solve it step by step" at the end

However, it can be observed that although the intermediate steps of long multiplication and division of GPT-4 are wrong in some cases, the final answer is still correct, which means that GPT-4 does not use the intermediate supervision of the chain of thinking to improve final output.

 

Finally, the following 3 common mistakes were identified from the solution of GPT-4:

  • 1. Alignment of corresponding numbers
  • 2. Repeat numbers
  • 3. The intermediate result of multiplying n digits by 1 digit is wrong

 

From the experimental results, it can be seen that GPT-4 performs quite well on 8D+8D and 16D+16D tasks, but the calculation results on most 16D+8D tasks are wrong, although intuitively, 16D+ 8D should be relatively easier than 16D+16D.

 

While the exact reason for this is unknown, one possible factor could be GPT-4's inconsistent digit segmentation process, making alignment between two digits difficult.

 

References:

https://huggingface.co/papers/2305.14201

 

 

Comments

You must be logged in to post a comment.

About Author

This guy is lazy and left nothing behind.