Capire le reti neurali attraverso circuiti sparsi
Abbiamo addestrato i modelli a pensare in passaggi più semplici e tracciabili, in modo da capire meglio come funzionano.
Le reti neurali alimentano i sistemi di IA attualmente più capaci, ma rimangono difficili da comprendere. Non scriviamo questi modelli con istruzioni esplicite e passo dopo passo. Infatti imparano regolando miliardi di connessioni interne, o “pesi”, fino a quando non conoscono a fondo una specifica attività. Progettiamo le regole dell'addestramento, ma non i comportamenti specifici che emergono, e il risultato è una fitta rete di connessioni che nessun essere umano può decifrare facilmente.
Man mano che i sistemi di IA diventano più capaci e hanno un impatto reale sulle decisioni in fatto di scienza, istruzione e sanità, capire come funzionano è essenziale. L'interpretabilità si riferisce ai metodi che ci aiutano a comprendere perché un modello ha generato un determinato output. Sono molti i modi in cui potremmo riuscirci.
Ad esempio, i modelli di ragionamento sono incentivati a spiegare come lavorano mentre si avvicinano a una risposta finale. L'interpretabilità della catena di pensiero sfrutta tali spiegazioni per monitorare il comportamento del modello. L'utilità è immediata: le catene di pensiero degli attuali modelli di ragionamento sembrano essere informative rispetto a comportamenti preoccupanti come l'inganno. Tuttavia, fare completo affidamento su tale proprietà è una strategia fragile e potrebbe fallire nel tempo.
D'altra parte, l'interpretabilità meccanicistica, che è al centro di questo lavoro, cerca di analizzare i calcoli di un modello con un approccio di ingegneria inversa. Finora nell'immediato non è risultato molto utile, ma in linea di principio potrebbe offrire una spiegazione più completa del comportamento del modello. Cercando di spiegare il comportamento del modello a livello più dettagliato, l'interpretabilità meccanicistica può generare meno ipotesi e trasmettere maggiore fiducia. Tuttavia il percorso dai dettagli di basso livello alle spiegazioni di comportamenti complessi è molto più lungo e complesso.
L'interpretabilità permette di raggiungere diversi obiettivi chiave, come ad esempio consentire una migliore supervisione e fornire segnali precoci di comportamenti non sicuri o strategicamente non allineati. Inoltre integra le nostre altre iniziative in materia di sicurezza, come la supervisione scalabile, l'addestramento avversario e la simulazione di attacco.
In questo lavoro, mostriamo che spesso possiamo addestrare i modelli in modi che li rendono più facili da interpretare. Consideriamo il nostro lavoro come una promettente integrazione dell'analisi post-hoc delle reti dense.
Questa è una scommessa molto ambiziosa: per arrivare alla piena comprensione dei comportamenti complessi dei nostri modelli più potenti, la strada è lunga. Tuttavia, per i comportamenti semplici, vediamo che i modelli sparsi addestrati con il nostro metodo contengono piccoli circuiti disaccoppiati che sono sia comprensibili sia sufficienti per eseguire il comportamento. Questo suggerisce che potrebbe esserci un percorso praticabile per l'addestramento di sistemi più grandi di cui possiamo comprendere i meccanismi.
I precedenti lavori sull'interpretabilità meccanicistica sono iniziati da reti dense e intricate, cercando di districarle. In queste reti, ogni singolo neurone è connesso a migliaia di altri neuroni. La maggior parte dei neuroni sembra svolgere molte funzioni distinte, la cui comprensione sembra apparentemente impossibile.
Ma cosa succederebbe se addestrassimo reti neurali non intrecciate, con molti più neuroni, ma dove ogni neurone ha solo poche decine di connessioni? Allora forse la rete risultante sarebbe più semplice e più facile da comprendere. Questa è la scommessa centrale della nostra ricerca.
Tenendo presente questo principio, abbiamo addestrato modelli linguistici con un'architettura molto simile a quella dei modelli linguistici esistenti come GPT‑2, con una piccola modifica: forziamo la stragrande maggioranza dei pesi del modello a essere zero. Questo ha vincolato il modello a utilizzare solo pochissime delle possibili connessioni tra i suoi neuroni. Questa è una semplice modifica che, a nostro avviso, semplifica sostanzialmente i calcoli interni del modello.
Nelle normali reti neurali dense, ogni neurone è connesso a ciascun neurone nello strato successivo. Nei nostri modelli sparsi, ogni neurone si collega solo ad alcuni neuroni nello strato successivo. Speriamo che questo renda i neuroni e la rete nel suo insieme più facili da comprendere.
Desideriamo misurare fino a che punto le computazioni dei nostri modelli sparsi sono disaccoppiate. Abbiamo considerato vari comportamenti semplici del modello e verificato se potessimo isolare le parti del modello responsabili di ciascun comportamento, da noi chiamate circuiti.
Abbiamo selezionato manualmente una serie di semplici attività algoritmiche. Per ciascuna, abbiamo ridotto il modello al circuito più piccolo che può ancora svolgere l'attività ed esaminato quanto sia semplice tale circuito. (Per maggiori dettagli, consulta il nostro articolo(si apre in una nuova finestra).) Abbiamo scoperto che con l'addestramento di modelli più grandi e più sparsi, potevamo produrre modelli sempre più capaci con circuiti sempre più semplici.
Tracciamo l'interpretabilità rispetto alla capacità nei modelli (in basso a sinistra è migliore). Per un modello di dimensioni fisse e sparse, aumentando la sparsità, ossia impostando più pesi a zero, si riduce la capacità ma si aumenta l'interpretabilità. L'aumento delle dimensioni del modello sposta questa frontiera verso l'esterno, suggerendo che possiamo costruire modelli più grandi che siano sia capaci che interpretabili.
Per rendere concreto questo concetto, considera un'attività in cui un modello addestrato sul codice Python deve completare una stringa con il tipo corretto di virgolette. In Python, ‘hello’ deve terminare con un apostrofo singolo, mentre “hello” deve terminare con un apostrofo doppio. Il modello può svolgere il compito ricordando quale tipo di virgolette ha aperto la stringa e riproducendole alla fine.
I nostri modelli più interpretabili sembrano contenere circuiti disaccoppiati che implementano esattamente quell'algoritmo.

Esempio di circuito in un transformer sparso che prevede se terminare una stringa con virgolette singole o doppie. Questo circuito utilizza solo cinque canali residui (linee grigie verticali), due neuroni MLP nello strato 0 e un canale di query-chiave dell'attenzione e un canale di valore nello strato 10. Il modello (1) codifica le virgolette singole in un canale residuo e le virgolette doppie in un altro; il (2) utilizza uno strato MLP per convertirle in un canale che rileva qualsiasi tipo di virgolette e un altro che discerne tra virgolette singole e doppie; il (3) utilizza un'operazione di attenzione per ignorare i token intermedi, trovare la virgoletta precedente e copiarne il tipo nel token finale; e il (4) prevede la virgoletta di chiusura corrispondente.
Nella nostra definizione, le connessioni esatte mostrate sopra sono sufficienti per svolgere l'attività: se rimuoviamo il resto del modello, questo piccolo circuito funziona ancora. Inoltre sono necessarie: eliminare questi pochi archi causa il fallimento del modello.
Abbiamo anche esaminato alcuni comportamenti più complessi. I nostri circuiti per questi comportamenti (ad esempio il binding delle variabili mostrato di seguito) sono più difficili da spiegare completamente. Anche in questo caso, possiamo comunque ottenere spiegazioni parziali relativamente semplici che sono predittive del comportamento del modello.
Un altro circuito di esempio, con meno dettagli. Per determinare il tipo di una variabile chiamata current, un'operazione di attenzione copia il nome della variabile nel token set() quando viene definito, e un'altra operazione avanti copia il tipo dal token set() in un uso successivo della variabile, permettendo al modello di dedurre il token successivo corretto.
Questo lavoro è un primo passo verso un obiettivo più grande: rendere i calcoli del modello più facili da comprendere. Ma la strada da percorrere è ancora lunga. I nostri modelli sparsi sono molto più piccoli dei modelli di frontiera e gran parte del loro calcolo non viene ancora interpretata.
Successivamente, speriamo di scalare le nostre tecniche a modelli più grandi e di spiegare meglio il comportamento dei modelli. Enumerando i motivi circuitali che sottendono ragionamenti più complessi in modelli sparsi capaci, potremmo sviluppare una comprensione che ci aiuti a indirizzare meglio le indagini sui modelli di frontiera.
Per superare l'inefficienza dell'addestramento di modelli sparsi, vediamo due strade da percorrere. Una possibilità è estrarre circuiti sparsi da modelli densi esistenti, piuttosto che effettuare l'addestramento di modelli sparsi da zero. I modelli densi sono fondamentalmente più efficienti da implementare rispetto ai modelli sparsi. L'altra strada è sviluppare tecniche più efficienti per addestrare i modelli per l'interpretabilità, che potrebbero essere più facili da mettere in produzione.
Sottolineiamo che i nostri risultati qui non garantiscono che questo approccio si estenda a sistemi più capaci, ma i primi risultati sono promettenti. Il nostro obiettivo è espandere gradualmente quanto di un modello possiamo interpretare in modo affidabile e costruire strumenti che semplifichino l'analisi, il debug e la valutazione di sistemi futuri.
Autori
Leo Gao, Achyuta Rajaram, Jacob Coxon, Soham V. Govande, Bowen Baker e Dan Mossing


