Trainieren Sie Ihr neuronales Netz zehnmal schneller mit Jax auf einer TPU
Alle coolen Kids scheinen dieser Tage von JAX zu schwärmen. Deepmind nutzt es ausgiebig für seine Forschung und baut sogar sein eigenes Ökosystem darauf auf. Boris Dayma und sein Team haben DALL-E Mini in kürzester Zeit mit JAX und TPUs gebaut. Es lohnt sich, einen Blick auf Hugging Face zu werfen, wo Sie bereits über 5000 in JAX geschriebene Modelle finden. Aber was genau ist JAX und warum ist es so besonders? Laut ihrer Website bietet JAX automatische Differenzierung, Vektorisierung und Just-in-Time-Kompilierung für GPUs und TPUs über komponierbare Transformationen. Klingt kompliziert? Keine Sorge, in diesem Blogpost nehmen wir Sie mit auf eine Tour und zeigen Ihnen, wie JAX funktioniert, wie es sich von Tensorflow/Pytorch unterscheidet und warum wir denken, dass es ein super interessantes Framework ist.
JAX ist ein hochleistungsfähiges Framework für numerische Berechnungen und maschinelles Lernen von Google Research, das superschnell auf GPUs und TPUs läuft, ohne dass man sich um Low-Level-Details kümmern muss. Das Ziel von JAX war es, ein Framework zu entwickeln, das Hochleistung mit der Ausdruckskraft und Benutzerfreundlichkeit von Python kombiniert, damit Forscher mit neuen Modellen und Techniken experimentieren können, ohne sich um hochoptimierte Low-Level-C/C++-Implementierungen kümmern zu müssen. Dieses Ziel wird durch die Verwendung des XLA-Compilers (Accelerated Linear Algebra) von Google erreicht, der effizienten Maschinencode erzeugt, anstatt vorkompilierte Kernel zu verwenden. Das Tolle an JAX ist, dass es beschleunigerunabhängig ist, was bedeutet, dass derselbe Python-Code sowohl auf GPUs als auch auf TPUs effizient ausgeführt werden kann.
JAX arbeitet mit zusammensetzbaren Funktionstransformationen, d. h. JAX nimmt eine Funktion und erzeugt eine neue Funktion, die anders interpretiert wird, und es können mehrere Transformationen aneinandergereiht werden. Die automatische Differenzierung ist beispielsweise eine Transformation, die die Ableitung einer Funktion erzeugt, während die automatische Vektorisierung eine Funktion, die auf einen einzelnen Datenpunkt wirkt, in eine Funktion umwandelt, die auf einen Stapel von Datenpunkten wirkt. Durch diese Transformationen ermöglicht JAX dem Programmierer, in der High-Level-Python-Welt zu bleiben und den Compiler die harte Arbeit machen zu lassen, indem er den hocheffizienten Code generiert, der zum Trainieren komplexer Modelle benötigt wird. Wir werden diese Transformationen durchgehen und sie in einem Beispiel anwenden, in dem wir ein einfaches mehrschichtiges Perzeptron erstellen.
JAX ist ein compilerorientiertes Framework, was bedeutet, dass ein Compiler für die Umwandlung der Python-Funktionen in effizienten Maschinencode verantwortlich ist. Tensorflow und Pytorch hingegen haben vorkompilierte GPU- und TPU-Kernel für jede Operation. Während der Ausführung eines TensorFlow-Programms wird jede Operation einzeln abgearbeitet. Während die Operationen selbst sehr gut optimiert sind, erfordert ihre Zusammenführung eine Menge Speicheroperationen, was zu einem Engpass in der Leistung führt. Der XLA-Compiler kann Code für die gesamte Funktion erzeugen. Er kann all diese Informationen nutzen, um Operationen zu verschmelzen und eine Menge Speichervorgänge einzusparen und so insgesamt schnelleren Code zu erzeugen.
JAX ist auch leichtgewichtiger als Tensorflow und Pytorch, da nicht jede Operation, Funktion oder jedes Modell separat implementiert werden muss. Stattdessen implementiert JAX die NumPy-API mit einfacheren und Low-Level-Operationen, die als Bausteine verwendet und vom Compiler zu komplexen Modellen und Funktionen zusammengefügt werden können.
Das compilerorientierte Design ist weitaus leistungsfähiger, als man auf den ersten Blick vermuten könnte. Mit dem Compiler besteht keine Notwendigkeit mehr, Low-Level-Beschleunigercode zu implementieren. Dadurch können Forscher ihre Produktivität erheblich steigern und mit neuen Modellarchitekturen experimentieren. Forscher können sogar mit GPUs und TPUs experimentieren, ohne dass sie ihren Code neu schreiben müssen. Aber wie funktioniert das?
JAX kompiliert nicht direkt in Maschinencode, sondern in eine Zwischendarstellung, die vom High-Level-Python-Code und dem Maschinencode unabhängig ist. Der Compiler ist in ein Frontend, das Python-Funktionen in die IR kompiliert, und ein Backend, das die IR in plattformspezifischen Maschinencode kompiliert, aufgeteilt. Dieses Design ist nicht neu, ein Beispiel für einen Compiler, der ebenfalls diesem Design folgt, ist LLVM. Es gibt Frontends sowohl für C als auch für Rust, die High-Level-Code in die LLVM-IR übersetzen. Das Backend kann dann Maschinencode für eine Vielzahl von unterstützten Maschinentypen erzeugen, unabhängig davon, ob der ursprüngliche Code in C oder Rust geschrieben wurde.
Das ist von großer Bedeutung, denn dank dieses flexiblen Designs kann man einen neuen Beschleuniger bauen, ein XLA-Backend dafür schreiben und den JAX-Code, der zuvor auf GPUs/TPUs lief, auf dem neuen Beschleuniger ausführen. Andererseits könnte man auch ein Framework in einer anderen Programmiersprache erstellen, das mit der JAX-IR kompiliert wird und dank XLA GPUs und TPUs nutzen kann.
Wenn dieser compilerbasierte Ansatz so viel besser funktioniert als vorkompilierte Kernel, warum haben Tensorflow und Pytorch nicht von Anfang an davon Gebrauch gemacht? Die Antwort ist ziemlich einfach: Es ist wirklich schwer, einen guten numerischen Compiler zu entwickeln. Mit seiner automatischen Differenzierung, Vektorisierung und Jit-Kompilierung hat JAX einige wirklich mächtige Werkzeuge in petto. Allerdings ist JAX auch nicht die Wunderwaffe, denn all diese Vorzüge haben ihren Preis: Sie müssen ein paar neue Tricks und Konzepte der funktionalen Programmierung lernen.
JAX kann nicht jede beliebige Python-Funktion transformieren, sondern nur reine Funktionen. Eine reine Funktion kann als eine Funktion definiert werden, die nur von ihren Eingaben abhängt, was bedeutet, dass sie für eine gegebene Eingabe x immer dieselbe Ausgabe y zurückgibt und dass sie keine Nebeneffekte wie IO-Operationen oder Mutation globaler Variablen erzeugt. Die Dynamik von Python bedeutet, dass sich das Verhalten einer Funktion in Abhängigkeit von den Typen ihrer Eingaben ändert, und JAX will sich diese Dynamik zunutze machen, indem es Funktionen zur Laufzeit transformiert. Zu Beginn einer Transformation prüft JAX, wie sich die Funktion bei einer Reihe von Eingaben verhält, und transformiert die Funktion auf der Grundlage dieser Informationen. Unter der Haube verfolgt JAX die Funktion, genau wie der Python-Interpreter. Indem nur reine Funktionen zugelassen werden, wird die Umwandlung von Funktionen zur rechten Zeit viel einfacher und schneller.
Stellen Sie sich vor, dass der Tracer mit Seiteneffekten wie IO zu tun hat. Das bedeutet, dass unerwartetes Verhalten auftreten kann, z. B. wenn ein Benutzer ungültige Daten eingibt, was die Erstellung von effizientem Code erheblich erschwert, insbesondere wenn Beschleuniger im Spiel sind. Globale Variablen können sich zwischen zwei Funktionsaufrufen ändern und so das Verhalten der Funktion, in der sie verwendet werden, komplett verändern, wodurch eine transformierte Funktion ungültig wird. Wenn Sie sich für Compiler und die Feinheiten der JAX-Verfolgung interessieren, empfehlen wir Ihnen einen Blick in die Dokumentation, um mehr über die Funktionsweise zu erfahren.
Der einzige Nachteil von JAX ist, dass es nicht überprüfen kann, ob eine Funktion eine reine Funktion ist. Es obliegt dem Programmierer, dafür zu sorgen, dass er reine Funktionen schreibt, da JAX sonst die Funktion mit einem unerwarteten Verhalten umwandelt.
Die Arbeit mit reinen Funktionen wirkt sich auch darauf aus, wie Datenstrukturen verwendet werden. In anderen Frameworks werden Modelle des maschinellen Lernens oft zustandsorientiert dargestellt, was jedoch mit dem Paradigma der funktionalen Programmierung kollidiert, da es sich um die Mutation eines globalen Zustands handelt. Um dieses Problem zu überwinden, führt JAX Pytrees ein, baumartige Strukturen, die aus containerartigen Python-Objekten aufgebaut sind. ContainerähnlicheKlassen können in der pytree-Registry registriert werden, die standardmäßig Listen, Tupel und Dicts enthält. Pytrees können andere Pytrees enthalten und Klassen, die nicht in der pytree-Registry registriert sind, werden als Leafs bezeichnet. Leafs können als unveränderliche Eingaben für eine reine Funktion betrachtet werden. Für jede Klasse in der pytree-Registrierung gibt es eine Funktion, die einen pytree in ein Tupel mit seinen Kindern und optionalen Metadaten konvertiert, sowie eine Funktion, die Kinder und Metadaten zurück in einen containerähnlichen Typ konvertiert. Diese Funktionen können verwendet werden, um das Modell oder andere zustandsabhängige Objekte zu aktualisieren, die Sie verwenden.
Bevor wir in unser MLP-Beispiel eintauchen, zeigen wir die wichtigsten Transformationen in JAX.
Die erste Transformation ist die automatische Differenzierung, bei der wir eine Python-Funktion als Eingabe nehmen und eine Funktion zurückgeben, die den Gradienten dieser Funktion darstellt. Das Tolle an JAXs Autodiff ist, dass es Python-Funktionen differenzieren kann, die Python-Container, Konditionale, Schleifen usw. verwenden. Im folgenden Beispiel erstellen wir eine Funktion, die den Gradienten der Funktion "tanh" darstellt. Da JAX-Transformationen zusammensetzbar sind, können wir n verschachtelte Aufrufe der Funktion grad verwenden, um die n-teAbleitung zu berechnen.
Die automatische Differenzierung von JAX ist ein leistungsfähiges und umfangreiches Werkzeug. Wenn Sie mehr darüber erfahren möchten, wie es funktioniert, empfehlen wir Ihnen die Lektüre des JAX Autodiff Cookbook.
Beim Trainieren eines Modells wird in der Regel ein Stapel von Trainingsstichproben durch das Modell geleitet. Bei der Implementierung eines Modells müssen Sie sich also Ihre Vorhersagefunktion als eine Funktion vorstellen, die einen Stapel von Stichproben aufnimmt und eine Vorhersage für jede Stichprobe zurückgibt. Dies kann jedoch den Schwierigkeitsgrad der Implementierung erheblich erhöhen und die Lesbarkeit der Funktion im Vergleich zu einer Funktion, die mit einer einzelnen Stichprobe arbeitet, verringern. Nun kommt die zweite Transformation: die automatische Vektorisierung. Wir schreiben unsere Funktion so, als ob wir nur eine einzelne Probe verarbeiten würden, und vmap wandelt sie dann in eine vektorisierte Version um.
Am Anfang kann vmap etwas schwierig sein, besonders wenn man mit höheren Dimensionen arbeitet, aber es ist eine wirklich leistungsfähige Transformation. Wir empfehlen Ihnen, sich einige Beispiele in der Dokumentation anzuschauen, um ihr Potenzial vollständig zu verstehen.
Die dritte Funktionstransformation ist die Just-in-Time-Kompilierung. Das Ziel dieser Umwandlung ist es, die Leistung zu verbessern, den Code zu parallelisieren und ihn auf einem Beschleuniger auszuführen. JAX kompiliert nicht direkt in Maschinencode, sondern in eine Zwischendarstellung. Diese Zwischendarstellung ist unabhängig vom Python-Code und dem Maschinencode des Beschleunigers. Der XLA-Compiler kompiliert dann die Zwischendarstellung in effizienten Maschinencode.
Es ist nicht immer einfach zu entscheiden, wann und welchen Code Sie kompilieren sollten. Um den Compiler optimal zu nutzen, empfehlen wir Ihnen, sich die Dokumentation anzusehen. Später in diesem Blog werden wir etwas tiefer in das Design des Compilers einsteigen und erklären, warum JAX ein so leistungsfähiges Framework ist.
Nachdem wir nun die wichtigsten Transformationen kennengelernt haben, sind wir bereit, dieses Wissen in die Praxis umzusetzen. Wir werden ein MLP von Grund auf implementieren, um MNIST-Bilder zu klassifizieren und es superschnell auf einer TPU zu trainieren. Unser neuronales Netzwerk wird eine Eingabeschicht mit 728 Eingabevariablen haben, gefolgt von zwei versteckten Schichten mit 512 bzw. 256 Neuronen und einer Ausgabeschicht mit einem Knoten für jede Klasse.
Als erstes müssen wir eine Struktur erstellen, die unser Modell darstellt. Als Eingabe für unsere Initialisierungsfunktion haben wir eine Liste mit der Anzahl der Knoten in jeder Schicht unseres neuronalen Netzes. Wir haben eine Eingabeschicht, die der Anzahl der Pixel eines Bildes entspricht, gefolgt von zwei versteckten Schichten mit 512 bzw. 256 Neuronen und einer Ausgabeschicht, die der Anzahl der Klassen entspricht. Wir verwenden JAX Numpy-Arrays, um das Modell auf dem Beschleuniger zu initialisieren und vermeiden es, die Daten manuell zu kopieren.
Beachten Sie, dass die Erzeugung von Zufallszahlen etwas anders ist als bei Numpy. Wir wollen in der Lage sein, Zufallszahlen auf parallelen Beschleunigern zu erzeugen, und wir brauchen einen Zufallszahlengenerator, der gut mit dem Paradigma der funktionalen Programmierung funktioniert. Der Algorithmus von Numpy zur Erzeugung von Zufallszahlen ist für diese Zwecke nicht besonders geeignet. Weitere Informationen finden Sie in den JAX Design Notes und der Dokumentation.
Unser nächster Schritt besteht darin, eine Vorhersagefunktion zu schreiben, die einem Stapel von Bildern Kennzeichnungen zuweist. Wir werden die automatische Vektorisierung nutzen, um eine Funktion, die ein einzelnes Bild als Eingabe erhält und eine Bezeichnung ausgibt, in eine Funktion umzuwandeln, die Bezeichnungen für einen Stapel von Eingaben vorhersagt. Das Schreiben einer Vorhersagefunktion ist nicht besonders schwierig. Wir durchlaufen die versteckten Schichten des Netzes und wenden Gewichte und Vorspannungen über eine Matrixmultiplikation und Vektoraddition an und wenden die RELU-Aktivierungsfunktion an. Am Ende berechnen wir mit der RealSoftMax-Funktion die Ausgabebezeichnung. Sobald wir unsere Funktion zur Beschriftung eines einzelnen Bildes haben, können wir sie mit der Funktion vmap umwandeln, damit sie einen Stapel von Eingaben verarbeiten kann.
Die Verlustfunktion nimmt einen Stapel von Bildern und berechnet den mittleren absoluten Fehler. Wir rufen unsere gestapelten Vorhersagen auf und berechnen die Kennzeichnung für jedes Bild, vergleichen diese mit den kodierten Kennzeichnungen der Grundwahrheit und berechnen die mittlere Anzahl der Fehler.
Da wir nun unsere Vorhersage- und Verlustfunktion haben, werden wir eine Aktualisierungsfunktion implementieren, um unser Modell in jedem Trainingsschritt iterativ zu aktualisieren. Unsere Aktualisierungsfunktion nimmt einen Stapel von Bildern und ihre Grundwahrheitsbezeichnungen zusammen mit dem aktuellen Modell und einer Lernrate auf. Wir berechnen sowohl den Wert des Verlustes als auch den Wert seines Gradienten. Wir aktualisieren das Modell unter Verwendung der Lernrate und der Verlustgradienten. Da wir diese Funktion kompilieren wollen, müssen wir das aktualisierte Modell in einen pytree umwandeln. Wir geben auch den Wert des Verlustes zurück, um die Genauigkeit zu überwachen.
Jetzt, da wir die Aktualisierungsfunktion haben, werden wir sie kompilieren, damit sie auf einer TPU ausgeführt werden kann und ihre Leistung erheblich verbessert wird. Die verschachtelten Funktionen, die in der Aktualisierungsfunktion aufgerufen werden, werden ebenfalls kompiliert und optimiert. Der Grund, warum wir die Kompilierungstransformation nur auf die Aktualisierungsfunktion und nicht auf jede einzelne Funktion anwenden, ist, dass wir dem Compiler möglichst viele Informationen zur Verfügung stellen wollen, damit er den Code so weit wie möglich optimieren kann.
Wir können eine Genauigkeitsfunktion (und optional andere Metriken) definieren und eine Trainingsschleife mit unserer Aktualisierungsfunktion und dem Ausgangsmodell als Eingabe erstellen. Wir sind nun bereit, unser Modell mit einer TPU oder GPU zu trainieren.
Puh, wir haben heute eine Menge gelernt. Zuerst haben wir damit begonnen, JAX als ein Rahmenwerk mit zusammensetzbaren Funktionstransformationen zu beschreiben. Die vier Kerntransformationen sind automatische Vektorisierung, automatische Parallelisierung über mehrere Beschleuniger, automatische Differenzierung von Python-Funktionen und JIT-Kompilierung von Funktionen, um sie auf Beschleunigern auszuführen. Wir vertieften uns in das Innenleben von JAX und erfuhren, wie es in der Lage ist, so effiziente Funktionen zu erstellen, die sowohl auf GPUs als auch auf TPUs funktionieren, indem sie in eine IR kompiliert werden, die dann in XLA-Aufrufe umgewandelt wird. Dieser Ansatz ermöglicht es Forschern, mit neuen Techniken des maschinellen Lernens zu experimentieren, ohne sich um eine hoch optimierte Low-Level-Version ihres Codes kümmern zu müssen. Wir hoffen, dass auch Software-Ingenieure begeistert sind, so dass neue Bibliotheken auf der Grundlage von JAX entwickelt und potenzielle Beschleuniger schnell eingeführt werden können.