Pasar al contenido principal
OpenAI

28 de junio de 2022

Publicación

Mitigaciones previas al entrenamiento de DALL·E 2

Vista aérea de una multitud de personas con gorros y banderas que miran hacia otro lado

DALL·E

Cargando...

Para compartir la magia de DALL·E 2 con un público amplio, necesitamos disminuir los riesgos relacionados con los potentes modelos de generación de imágenes. Con este fin, implementamos diversas barreras de seguridad(se abre en una nueva ventana) para evitar que las imágenes generadas infrinjan nuestra política de contenido(se abre en una nueva ventana).

Esta publicación se enfoca en las mitigaciones previas al entrenamiento, un subconjunto de estas barreras de seguridad que modifican directamente los datos desde los cuales aprende DALL·E 2. En especial, DALL·E 2 se entrena con cientos de millones de imágenes con leyenda obtenidas de internet, algunas de las cuales eliminamos y reponderamos para cambiar lo que aprende el modelo.

Esta publicación está organizada en tres secciones, en las que se describen distintas mitigaciones previas al entrenamiento:

  • En la primera sección, describimos cómo filtramos las imágenes violentas y sexuales del conjunto de datos de entrenamiento de DALL·E 2. Sin esta mitigación, el modelo aprendería a producir imágenes gráficas o explícitas si se le solicitan, e incluso podría arrojar esas imágenes sin intención en respuesta a indicaciones que parecen inofensivas.
  • En la segunda sección, descubrimos que filtrar los datos de entrenamiento puede aumentar los sesgos y describimos nuestra técnica para mitigar este efecto. Por ejemplo, sin esta mitigación, observamos que los modelos entrenados con datos filtrados a veces generaban más imágenes que representaban a hombres y menos que representaban a mujeres, en comparación con el conjunto de datos original.
  • En la sección final, nos centramos en el problema de la memorización y vemos que a veces los modelos como DALL·E 2 pueden reproducir imágenes incluidas en su entrenamiento en lugar de crear imágenes nuevas. En la práctica, descubrimos que esta regurgitación de imágenes se debe a que hay imágenes replicadas muchas veces en el conjunto de datos, y mitigamos el problema al eliminar las imágenes que son visualmente similares a otras dentro del conjunto de datos.

Reducción de los datos de entrenamiento gráficos o explícitos

Dado que los datos de entrenamiento afectan las capacidades de cualquier modelo aprendido, filtrar datos es una herramienta eficaz para limitar las capacidades no deseables del modelo. Aplicamos este enfoque a dos categorías, imágenes de violencia gráfica y contenido sexual, mediante clasificadores para filtrar imágenes de estas categorías y eliminarlas del conjunto de datos antes de entrenar a DALL·E 2. Entrenamos estos clasificadores de imágenes de forma interna y seguimos estudiando los efectos de filtrar el conjunto de datos en nuestro modelo entrenado.

Para entrenar a nuestros clasificadores de imágenes, reutilizamos un enfoque que empleamos anteriormente para filtrar datos de entrenamiento de GLIDE(se abre en una nueva ventana). Los pasos básicos de este enfoque son los siguientes: en primer lugar, creamos una especificación para las categorías de imágenes que queremos etiquetar; en segundo lugar, reunimos cientos de ejemplos positivos y negativos para cada categoría; en tercer lugar, usamos un procedimiento de aprendizaje activo para recopilar más datos y mejorar la compensación entre precisión y recuperación; y por último, ejecutamos el clasificador resultante en el conjunto de datos completo con un umbral de clasificación conservador para favorecer la recuperación por sobre la precisión. Para establecer estos umbrales, dimos prioridad a filtrar todos los datos malos por sobre dejar todos los datos buenos. Esto se debe a que más adelante tendremos la oportunidad de hacer un ajuste de precisión de nuestro modelo con más datos para enseñarle cosas nuevas, pero es mucho más difícil lograr que el modelo olvide algo que ya aprendió.

Cargando...

Durante la etapa de aprendizaje activo, mejoramos iterativamente nuestros clasificadores al recopilar etiquetas humanas para imágenes que pueden ser difíciles o estar clasificadas de manera incorrecta. En particular, usamos dos técnicas de aprendizaje activo para elegir imágenes de nuestro conjunto de datos (que contiene cientos de millones de imágenes sin etiquetar) y presentarlas para el etiquetado humano. En primer lugar, para bajar la tasa de falsos positivos de nuestro clasificador (la frecuencia con la que clasifica incorrectamente una imagen buena como violenta o sexual), asignamos etiquetas humanas a imágenes que el modelo actual clasificó como positivas. Para que este paso funcionara bien, ajustamos nuestro umbral de clasificación para una recuperación cercana al 100 %, pero una tasa alta de falsos positivos. De esta manera, los encargados etiquetarían principalmente casos negativos reales. Aunque esta técnica permite reducir los falsos positivos y la necesidad de que los encargados de etiquetar vean imágenes posiblemente dañinas, no ayuda a encontrar casos más positivos que el modelo está omitiendo.

A fin de disminuir la tasa de falsos negativos del clasificador, utilizamos una segunda técnica de aprendizaje activo: la búsqueda del vecino más cercano. Específicamente, ejecutamos una validación cruzada de numerosas iteraciones para encontrar en nuestro actual conjunto de datos etiquetado muestras positivas que el modelo tendía a clasificar incorrectamente como negativas (para lograrlo, literalmente entrenamos cientos de versiones del clasificador con distintas divisiones de entrenamiento y validación). A continuación, escaneamos nuestra gran colección de imágenes no etiquetadas en busca de vecinos más cercanos de estas muestras en un espacio perceptual de características y asignamos etiquetas humanas a las imágenes descubiertas. Gracias a nuestra infraestructura de cálculo, fue fácil aumentar la escala de entrenamiento del clasificador y la búsqueda de vecino más cercano a muchas GPU, lo que permitió que el paso de aprendizaje activo tarde una cantidad de minutos en lugar de horas o días.

Para verificar la eficacia de nuestros filtros de datos, entrenamos dos modelos GLIDE con los mismos hiperparámetros: uno en datos no filtrados y otro en el conjunto de datos después de filtrarlo. Denominamos al primero como modelo no filtrado y al segundo como modelo filtrado. Como cabe esperar, descubrimos que el modelo filtrado generalmente produce menos contenido gráfico o explícito como respuesta a las solicitudes de ese tipo de contenido. Sin embargo, también encontramos un efecto secundario inesperado al filtrar los datos: creaba o amplificaba los sesgos del modelo hacia ciertas clasificaciones demográficas.

Cargando...

Corrección del sesgo introducido por los filtros de datos

Los modelos generativos intentan igualar la distribución de sus datos de entrenamiento, lo que incluye los sesgos que contienen. En consecuencia, filtrar los datos de entrenamiento tiene la capacidad de crear o amplificar los sesgos en los modelos de salida. En general, corregir los sesgos en el conjunto de datos original es una tarea sociotécnica difícil, que seguimos estudiando y que está más allá del alcance de esta publicación. El problema que abordamos aquí es la amplificación de sesgos provocada específicamente por filtrar los datos. Con nuestro enfoque, queremos evitar que el modelo filtrado tenga más sesgos que el modelo no filtrado, principalmente mediante la disminución del cambio de distribución generado al filtrar los datos.

Como ejemplo concreto de amplificación del sesgo debido al filtro, considere la indicación “CEO”. Cuando nuestro modelo no filtrado generó imágenes para esta indicación, tendió a producir más imágenes de hombres que de mujeres; esperamos que gran parte de este sesgo sea un reflejo de nuestros datos de entrenamiento actuales. No obstante, cuando ejecutamos la misma indicación en nuestro modelo filtrado, el sesgo se amplificó, ya que las generaciones eran casi exclusivamente imágenes de hombres.

Nuestra hipótesis es que este caso particular de amplificación se deriva de dos situaciones: primero, aunque hombres y mujeres tienen aproximadamente la misma representación en el conjunto de datos original, este tiende a presentar a las mujeres en contextos más sexualizados; y segundo, nuestros propios clasificadores pueden contener sesgos debido a la implementación o definición de clase, a pesar de nuestros esfuerzos para garantizar que esto no se diera durante las etapas de recopilación y validación de datos. Debido a estos dos efectos, nuestro filtro puede eliminar más imágenes de mujeres que de hombres, lo que cambia la proporción de género que el modelo observa en el entrenamiento.

Para investigar con mayor profundidad el sesgo creado por el filtro, necesitábamos una forma de medir cuánto afectaban nuestros filtros de datos los sesgos hacia diversos conceptos. Nuestros filtros de contenido violento y sexual se basan exclusivamente en imágenes, pero la naturaleza multimodal de nuestro conjunto de datos nos permite medir en forma directa los efectos de estos filtros en el texto. Dado que cada imagen va acompañada por una leyenda, pudimos observar la frecuencia relativa de palabras clave seleccionadas a mano en los conjuntos de datos filtrado y no filtrado, a fin de calcular cuánto afectaban los filtros un concepto determinado.

Para lograrlo, utilizamos Apache Spark para calcular las frecuencias de una serie de palabras clave (como “progenitor”, “mujer”, “niño”) en todas las leyendas de los conjuntos de datos filtrado y no filtrado. Aunque nuestro conjunto de datos contiene cientos de millones de pares de texto-imagen, calcular la frecuencia de estas palabras clave solo tomó unos minutos con nuestro clúster de cálculo.

Después de calcular la frecuencia de las palabras clave, pudimos confirmar que los filtros de nuestro conjunto de datos efectivamente sesgaban las frecuencias de ciertas palabras clave más que otras. Por ejemplo, los filtros disminuyeron la frecuencia de la palabra “mujer” en un 14 %, mientras que la frecuencia de la palabra “hombre” solo bajó un 6 %. Esto confirmó, a gran escala, lo que habíamos observado anecdóticamente mediante el muestreo de los modelos GLIDE entrenados con ambos conjuntos de datos.

Cargando...

Ahora que teníamos un indicador para medir el sesgo producido por el filtro, necesitábamos una forma de mitigarlo. Para enfrentar este problema, nos centramos en reponderar el conjunto de datos filtrado para que su distribución coincidiera mejor con la distribución de imágenes no filtradas. Como ejemplo para ilustrar la idea, supongamos que nuestro conjunto de datos está compuesto por un 50 % de fotos de gatos y un 50 % de fotos de perros, y nuestros filtros de datos eliminan el 75 % de los perros y solo el 50% de los gatos. El conjunto de datos final sería ⅔ de gatos y ⅓ de perros, y un modelo generativo basado en probabilidades entrenado con este conjunto de datos generaría más imágenes de gatos que de perros. Podemos corregir este desequilibrio al multiplicar la pérdida de entrenamiento de las imágenes de perro por 2, lo que simula el efecto de repetir dos veces cada imagen de perro. Podemos escalar este enfoque a nuestros conjuntos de datos y modelos reales de una manera en gran parte automática, es decir, no necesitamos seleccionar a mano las características que queremos reponderar.

Calculamos las ponderaciones de las imágenes en el conjunto de datos filtrado usando probabilidades de un clasificador especial, parecido al enfoque utilizado por Choi et al. (2019)(se abre en una nueva ventana). Para entrenar a este clasificador, sacamos de manera uniforme muestras de imágenes de ambos conjuntos de datos y hacemos una predicción sobre el conjunto de datos de origen de la imagen. Específicamente, este modelo predice P(sin filtro|imagen), dado un P(sin filtro) = 0,5 anterior. En la práctica, no queremos que este modelo sea tan potente, ya que podría aprender la función exacta implementada por nuestros filtros en primer lugar. En lugar de eso, deseamos que el modelo sea más fluido que nuestros filtros de datos originales, que capture categorías amplias afectadas por los filtros y a la vez no esté seguro si se filtrará o no una imagen específica. Para esto, entrenamos una sonda lineal junto con un pequeño modelo CLIP.

Cuando ya tenemos un clasificador que predice la probabilidad de que una imagen provenga del conjunto de datos no filtrado, necesitamos convertir esa predicción en una ponderación para la imagen. Por ejemplo, suponga que P(sin filtro|imagen) = 0,8. Esto significa que hay 4 veces más probabilidades de encontrar la muestra en los datos no filtrados que en los filtrados, y una ponderación de 4 debería corregir el desequilibrio. En forma más general, podemos usar la ponderación P(sin filtro|imagen)/P(sin filtro|imagen).A

¿Cuánto mitiga realmente el sesgo amplificado este esquema de reponderación? Cuando hicimos el ajuste de precisión en el modelo filtrado anterior con el nuevo esquema de ponderación, el comportamiento de dicho modelo igualó mucho más el modelo no filtrado en los ejemplos sesgados que habíamos encontrado antes. Aunque esto era alentador, también deseábamos evaluar con mayor profundidad esta mitigación con nuestra heurística de sesgo basada en palabras clave. Para medir la frecuencia de las palabras clave y a la vez tomar en cuenta nuestro nuevo esquema de ponderación, podemos simplemente ponderar cada instancia de una palabra clave en el conjunto de datos filtrado por la ponderación de la muestra que la contiene. Con esto obtenemos un nuevo grupo de frecuencias de palabras clave que refleja las ponderaciones de la muestra en el conjunto de datos filtrado.

En la mayoría de las palabras clave que revisamos, el esquema de reponderación disminuía el cambio de frecuencia inducido al filtrar. En nuestro ejemplo de “hombre” y “mujer”, la disminución de la frecuencia relativa llegó a 1 % y -1 %, en comparación con los valores anteriores de 14 % y 6 %, respectivamente. Aunque esta métrica es solo un indicador para el sesgo real del filtro, confirma que nuestro esquema de reponderación basado en imágenes en realidad mejora de forma significativa la métrica basada en texto.

Seguimos investigando los sesgos de DALL·E 2, en parte a través de evaluaciones mayores del comportamiento del modelo e investigaciones sobre la manera en que el filtro afecta el sesgo y el desarrollo de capacidades.

Prevención de la regurgitación de imágenes

Observamos que los predecesores internos de DALL·E 2 a veces reproducían exactamente las imágenes de entrenamiento. Este comportamiento no coincidía con nuestro objetivo, ya que queríamos que DALL·E 2 creara imágenes originales y únicas de manera predeterminada en lugar de “pegar” partes de imágenes existentes. Asimismo, reproducir exactamente las imágenes de entrenamiento puede generar cuestionamientos legales en relación con la infracción de derechos de autor, propiedad y privacidad (si hubiera fotos de personas en los datos de entrenamiento).

Para comprender mejor el tema de la regurgitación de imágenes, recopilamos un conjunto de datos de indicaciones que frecuentemente producían imágenes duplicadas. Con este fin, usamos un modelo entrenado para sacar muestras de imágenes para 50 000 indicaciones de nuestro conjunto de datos de entrenamiento y ordenamos las muestras por semejanza perceptual con la imagen de entrenamiento correspondiente. Finalmente, revisamos las coincidencias principales a mano y encontramos solo unos cientos de pares realmente duplicados, de un total de 50 000 indicaciones. Aunque la tasa de regurgitación parecía ser inferior al 1 %, sentimos que era necesario llevar la tasa a 0 por los motivos señalados anteriormente.

Cuando estudiamos nuestro conjunto de datos de imágenes regurgitadas, descubrimos dos patrones. Primero, casi todas las imágenes eran gráficos vectoriales simples, fáciles de memorizar por su bajo contenido de información. Segundo, y lo más importante, todas las imágenes tenían muchos casi duplicados en el conjunto de datos de entrenamiento. Por ejemplo, podría haber un gráfico vectorial que parece un reloj marcando la 1 p. m., pero luego descubríamos una muestra de entrenamiento que tenía el mismo reloj marcando las 2 p. m., las 3 p. m., etc. Cuando nos dimos cuenta de esto, usamos una búsqueda distribuida de vecino más cercano para verificar que, efectivamente, todas las imágenes regurgitadas tenían duplicados de semejanza perceptual en el conjunto de datos. En otros(se abre en una nueva ventana) trabajos(se abre en una nueva ventana) se ha observado un fenómeno parecido en los grandes modelos de lenguaje, donde la duplicación de datos está muy vinculada con la memorización.

El resultado anterior sugería que si deduplicamos nuestro conjunto de datos, podríamos resolver el problema de la regurgitación. Para lograrlo, planeamos usar una red neuronal para identificar grupos de imágenes que se parecen y luego eliminar todas las imágenes de cada grupo, excepto una.B

Sin embargo, esto requeriría revisar, para cada imagen, si era un duplicado de otra imagen del conjunto de datos. Dado que nuestro conjunto de datos contiene cientos de millones de imágenes, necesitaríamos revisar cientos de miles de billones de pares de imágenes para encontrar todos los duplicados. Aunque esto está técnicamente fuera de alcance, especialmente en un clúster de cálculo de gran tamaño, encontramos una alternativa mucho más eficaz, que funciona casi tan bien por una fracción pequeña del costo. Considera qué ocurre si agrupamos nuestro conjunto de datos antes de hacer la deduplicación. Dado que las muestras cercanas generalmente se agrupan juntas, la mayor parte de los pares duplicados no cruzarían los límites de decisión del clúster. A continuación, podríamos deduplicar las muestras dentro de cada clúster sin buscar duplicados fuera de este, con lo que perderíamos solo una pequeña fracción de los pares de duplicados. Esto es mucho más rápido que el enfoque anterior, ya que no tenemos que revisar cada par de imágenes.C

Cuando probamos empíricamente este enfoque en un pequeño subconjunto de datos, encontró el 85 % de los pares duplicados al usar clústeres K=1024. Para mejorar el índice de éxito de este algoritmo, aprovechamos una observación clave: cuando se agrupan distintos subconjuntos al azar de un conjunto de datos, los límites de decisión del clúster a menudo son muy diferentes. Por lo tanto, si un par duplicado cruza un límite del clúster en una agrupación de datos, el par podría quedar dentro de un mismo clúster en una agrupación diferente. Mientras más agrupaciones intente, más probable es que descubra un par duplicado determinado. Usamos cinco agrupaciones, lo que significa que buscamos duplicados de cada imagen en la unión de cinco clústeres distintos. En la práctica, se encontró el 97 % de los pares de duplicados en un subconjunto de nuestros datos.

Sorprendentemente, casi un cuarto de nuestro conjunto de datos fue eliminado en la deduplicación. Al observar los pares de casi duplicados que encontramos, muchos de ellos presentaban cambios significativos. Recordemos el ejemplo anterior del reloj: el conjunto de datos incluiría muchas imágenes del mismo reloj en distintas horas. Aunque es probable que estas imágenes hagan que el modelo memorice la apariencia de ese reloj en particular, también pueden ayudar a que el modelo aprenda a distinguir las horas en un reloj. Dada la cantidad de datos eliminados, nos preocupó que borrar imágenes como esta perjudicara el desempeño del modelo.

A fin de probar el efecto de la deduplicación en nuestros modelos, entrenamos dos modelos con hiperparámetros idénticos: uno en el conjunto de datos completo y otro en la versión deduplicada del conjunto de datos. Comparamos los modelos con las mismas evaluaciones humanas que utilizamos en nuestro modelo GLIDE original. Sorprendentemente, descubrimos que los evaluadores humanos preferían levemente el modelo entrenado con datos deduplicados, lo que sugiere que la gran cantidad de imágenes redundantes del conjunto de datos de hecho afectaba el desempeño.

Una vez que tuvimos un modelo entrenado con datos deduplicados, volvimos a ejecutar la búsqueda de regurgitación que hicimos previamente con 50 000 indicaciones del conjunto de datos de entrenamiento. El nuevo modelo nunca regurgitaba una imagen de entrenamiento cuando se le daba la misma indicación que en el conjunto de datos de entrenamiento. Para avanzar otro paso en esta prueba, también realizamos una búsqueda de vecino más cercano en el conjunto de datos de entrenamiento completo para cada una de las 50 000 imágenes generadas. De esta manera, pensamos que podríamos descubrir que el modelo regurgitaba una imagen distinta a la relacionada con una indicación específica. Aún con esta revisión más profunda, nunca encontramos un caso de regurgitación de imágenes.

Próximos pasos

Aunque todas las mitigaciones analizadas anteriormente representan un avance considerable hacia nuestra meta de disminuir los riesgos relacionados con DALL·E 2, cada mitigación puede mejorar:

  • Mejores filtros previos al entrenamiento nos permitirían entrenar a DALL·E 2 con más datos y tal vez disminuir aún más el sesgo del modelo. Nuestros filtros actuales están ajustados para una tasa baja de pérdida a cambio de muchos falsos positivos. Por ese motivo, filtramos aproximadamente el 5 % del conjunto de datos completo, a pesar de que la mayoría de esas imágenes no infringía para nada nuestra política de contenido. Mejorar nuestros filtros permitiría recuperar algunos de estos datos de entrenamiento.
  • El sesgo es introducido y posiblemente amplificado en muchas etapas de desarrollo e implementación del sistema. Evaluar y mitigar el sesgo y el daño que este provoca en sistemas como DALL·E 2 es un problema interdisciplinario importante que seguimos estudiando en OpenAI como parte de nuestra misión general. Nuestro trabajo en este ámbito incluye crear evaluaciones para comprender mejor el problema, seleccionar nuevos conjuntos de datos y aplicar técnicas como el aporte humano y el ajuste de precisión para desarrollar tecnologías más sólidas y representativas.
  • También es fundamental que sigamos estudiando la memorización y generalización en los sistemas de aprendizaje profundo. Aunque la deduplicación es un buen paso para comenzar a prevenir la memorización, no nos explica detalladamente por qué o cómo los modelos como DALL·E 2 memorizan los datos de entrenamiento.

Notas al pie

  1. Cuando parametrizamos P(sin filtro|imagen) como sigmoide(f(x)), la ponderación es exp(f(x)). Esto se puede derivar usando la definición del sigmoide:

1/(1+ef(x))/(11/(1+ef(x))) 1/(1+e^−f(x))/(1−1/(1+e^−f(x))) =1/(1+ef(x))/((1+ef(x)1)/(1+ef(x))) = 1/(1+e^{-f(x)}) / ((1+e^{-f(x)} - 1)/(1+e^{-f(x)})) =1/(1+ef(x))/((ef(x))/(1+ef(x))) = 1/(1+e^{-f(x)}) / ((e^{-f(x)})/(1+e^{-f(x)})) =(1+ef(x))/(1+ef(x))/(ef(x)) = (1+e^-f(x))/(1+e^-f(x)) / (e^-f(x)) =1/(ef(x))=ef(x) = 1 / (e^{-f(x)}) = e^{f(x)}

  1. B

    Para lograrlo, podemos calcular un vector de característica viv_i para cada imagen de entrenamiento ii, y luego eliminar todas las imágenes jj de manera que exista un i<ji < j donde vivj||v_i - v_j|| <umbral. Para resolver este problema en forma ingenua, necesitaríamos calcular cada distancia por pares vivj||v_i - v_j||, una tarea que aumenta cuadráticamente con el tamaño de nuestro conjunto de datos.

  2. C

    Si K K representa el número de clústeres y N N el tamaño del conjunto de datos, este enfoque solo necesita cálculos de distancia por pares O(K(N/K)2)=O(N2/K) O(K*(N/K)^2) = O(N^2/K) , en lugar de O(N2) O(N^2) completo. Entretanto, aún está garantizado que ninguna imagen tendrá más de K K casi duplicados en el peor de los casos.

Contribuidores

Alex Nichol, Aditya Ramesh, Pamela Mishkin, Prafulla Dariwal, Joanne Jang, Mark Chen

Contribuciones escritas de

Greg Brockman, Aditya Ramesh, Pamela Mishkin, Mark Chen, Pranav Shyam, Casey Chu, Che Chang, Miles Brundage