
Hem entrenat un sistema que resol problemes de matemàtiques de primària amb gairebé el doble de precisió que un model GPT‑3 ajustat finament. Resol aproximadament el 90% dels problemes que resolen nens reals: una petita mostra de nens de 9 a 12 anys va obtenir un 60% en una prova del nostre conjunt de dades, mentre que el nostre sistema va obtenir un 55% en aquests mateixos problemes.
Per què és important
Això és important perquè la IA actual encara és força feble en el raonament de sentit comú de diversos passos, que és fàcil fins i tot per a nens de primària. Hem aconseguit aquests resultats entrenant el nostre model perquè reconegui els seus errors, de manera que pugui intentar-ho repetidament fins que trobi una solució que funcioni.
Els grans models de llenguatge com GPT‑3 tenen moltes habilitats impressionants, inclosa la seva capacitat d’imitar molts estils d’escriptura i el seu ampli coneixement factual. Tanmateix, tenen dificultats per dur a terme tasques que requereixen un raonament precís de diversos passos, com ara resoldre problemes de matemàtiques redactats de primària. Tot i que el model pot imitar el ritme de solucions correctes, produeix regularment errors crítics de lògica.
Per igualar el rendiment humà en dominis lògics complexos, els nostres models han d’aprendre a reconèixer els seus errors i a triar amb cura els seus passos. Amb aquest objectiu, entrenem verificadors perquè avaluïn si una solució proposada és correcta o no. Per resoldre un problema nou, fem servir verificadors per seleccionar la millor entre moltes solucions proposades. Hem recopilat el nou conjunt de dades GSM8K per avaluar els nostres mètodes, i publiquem aquest conjunt de dades per facilitar la recerca.
En els deu exemples següents mostrem solucions generades pel nostre nou mètode, la verificació, i pel nostre mètode de referència, l’ajust fi.
GSM8K consta de 8,5 mil problemes de matemàtiques redactats de primària d’alta qualitat. Cada problema requereix entre 2 i 8 passos per resoldre’s, i les solucions impliquen principalment fer una seqüència de càlculs elementals amb operacions aritmètiques bàsiques (+ − × ÷) per arribar a la resposta final. Els models de llenguatge d’última generació ajustats finament obtenen mals resultats en aquest conjunt de dades, principalment a causa de l’alta diversitat dels problemes. Alhora, les solucions de GSM8K només depenen de conceptes elementals, de manera que aconseguir un alt rendiment en la prova és un objectiu assequible.
Les solucions de GSM8K estan escrites en llenguatge natural en lloc d’expressions purament matemàtiques. En mantenir-nos en el llenguatge natural, les solucions generades pel model són més fàcils d’interpretar per als humans, i els nostres mètodes continuen sent relativament independents del domini.
Un repte important del raonament matemàtic és l’alta sensibilitat als errors individuals. Els models autoregressius, que generen cada solució segment a segment, no tenen cap mecanisme per corregir els seus propis errors. Les solucions que es desvien del camí correcte aviat esdevenen irrecuperables, com es pot veure en els exemples proporcionats.
Abordem aquest problema entrenant verificadors perquè avaluïn la correcció de les solucions generades pel model. Als verificadors se’ls donen moltes solucions possibles, totes escrites pel mateix model, i s’entrenen per decidir quines, si n’hi ha cap, són correctes.
Per resoldre un problema nou en el moment de la prova, generem 100 solucions candidates i després seleccionem la que el verificador classifica més amunt. Els verificadors es beneficien d’aquesta optionalitat inherent, així com del fet que la verificació sovint és una tasca més senzilla que la generació.
Hem observat que la verificació aporta una millora notable del rendiment, sempre que el conjunt de dades sigui prou gran. Amb conjunts de dades massa petits, creiem que els verificadors sobreajusten memoritzant les respostes finals del conjunt d’entrenament, en lloc d’aprendre propietats més útils del raonament matemàtic.
Sobre el conjunt d’entrenament complet, la verificació amb 6B paràmetres supera lleugerament un model de 175B paràmetres ajustat finament, cosa que proporciona una millora del rendiment aproximadament equivalent a multiplicar per 30 la mida del model. A més, sembla que la verificació escala de manera més eficaç amb dades addicionals, si extrapolem a partir dels resultats actuals.
Produir arguments correctes i reconèixer-ne d’incorrectes són reptes clau en el desenvolupament d’una IA més general. Les matemàtiques de primària són un banc de proves ideal per a aquestes capacitats. Els problemes de GSM8K són conceptualment senzills, però un sol error subtil n’hi ha prou per fer descarrilar tota una solució. Identificar i evitar aquests errors és una habilitat crucial que els nostres models han de desenvolupar. En entrenar verificadors, ensenyem als nostres models a separar les bones solucions de les que no han acabat de funcionar. Esperem que aquestes habilitats siguin cada cop més rellevants a mesura que intentem aplicar els nostres models a dominis lògicament més complexos.
Autors
Agraïments
Gràcies a l’equip de Surge AI per dur a terme la recopilació de dades de GSM8K.
Gràcies als coautors del nostre article: Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser, Matthias Plappert, Jerry Tworek, Jacob Hilton, Reiichiro Nakano i Christopher Hesse.
Gràcies a les persones que van aportar comentaris sobre aquesta publicació: Dan Hendrycks, Leo Gao, Alec Radford, Giambattista Parascandolo, Harri Edwards, Yura Burda, Nick Ryder, Ilya Sutskever, Mira Murati, Sam Altman, Aris Konstantinidis, Andrew Mayne, Hannah Wong i Steve Dowling.
Gràcies als estudiants que es van oferir voluntaris per fer la nostra prova!


