TensorFlow alapozó 7.

TensorFlow.js, avagy neurális hálók futtatása és tanítása böngészőben és Node.js-en

TensorFlow alapozó 7.

TensorFlow.js, avagy neurális hálók futtatása és tanítása böngészőben és Node.js-en

Forrás: https://towardsdatascience.com/online-machine-learning-with-tensorflow-js-2ae232352901

A TensorFlow.js a TensorFlow JavaScript portja aminek segítségével böngészőben, vagy Node.js-en futtathatjuk és taníthatjuk a neurális hálóinkat. Habár a lehetőségek jóval korlátozottabbak mint a Python implementáció esetén, sok esetben jól jöhet a JavaScript megvalósítás. Ilyen lehet az, ha egy meglévő Node.js backendbe szeretnénk neurális hálókat csempészni vagy kliens oldalon futtatni egy objektum felismerő rendszert vagy bármi hasonlót. A teljesítményre sem lehet igazán panasz, hiszen a böngészős megvalósítás WebGL-t használ a GPU elérésére, a Node.js implementáció alatt pedig ugyanaz a natív réteg van, amit a Python implementáció is használ, így mindkét platformból kihozza a lehető legtöbbet. Ebben a részben a cikksorozat első részében megismert konvolúciós hálót fogjuk JavaScript-re portolni, majd böngészőben futtatva betanítani és használni. A kódok természetesen elérhetőek GitHub-on, sőt, mivel JavaScriptről van szó, innen bárki futtathatja is őket a böngészőjében.

Elsőként vessünk egy pillantást a HTML fájlra. A jsDelivr CDN szerverről behúzzuk a TensorFlow.js valamint a tfjs-vis programkönyvtárakat. Az utóbbi egy hasznos vizualizációs tool TensorFlow-hoz amivel logolhatjuk a történéseket és mindenféle szép grafikonokat rajzolhatunk a tanításról. Ezek azok a dolgok amiket Python esetén a matplotlibbel rajzolgattunk. A következő két sor behúzza a data.js-t, ami a CIFAR-10-es adatokat tölti be és teszi elérhetővé tenzorok formájában (lásd első rész), valamint a script_train.js-t, ami a tényleges kódot tartalmazza.

A script_train.js első pár sorában behúzzuk a CIFAR-10 adathalmazt, majd a tfvis segítségével megjelenítjük az első 10 képet. A következő rész a model létrehozása. Mivel a model pöccre ugyanaz, mint az első részben, így nincs értelme róla sokat beszélni, hisz ott elmondtam már róla mindent amit kell. Ugyancsak ismerős lehet a compile és a fit metódus, amivel a tanítást végezzük. Ami egyedül újdonság, az talán csak a fitCallbacks rész. Ez arra szolgál, hogy a tanítás alatt minden batch feldolgozása után meghívja a rendszer a tfvis-t, ami egy grafikonra rajzolja az aktuális állapotot, így szépen nyomonkövethető a tanítás. Aki gondolja, ki is próbálhatja a fenti link segítségével. Kell hozzá egy kis türelem, mert a CIFAR-10 minták betöltése kb. fél-egy percig tart, amíg nem történik semmi. Ha ezen átküszködte magát a böngésző, utána már látjuk a grafikont meg a kék vonalat ami szépen kúszik felfelé az idő teltével. Maga a tanítás is eltart némi ideig. Géptől függően sok 10 perc is lehet mire kidobja a végén a modelt. Ha valakinek nincs kedve ezt végigvárni, az használhatja a repo-ban lévő modellt, amit már előre betanítottam. A másik lehetőség, hogy Python-ban tanítjuk be a modellt (ez sokkal gyorsabban megvan), majd átkonvertáljuk Tensorflow.js által emészthető formára. Ez utóbbihoz telepítsük a tensorflowjs_converter-t:

pip install tensorflowjs

Majd futtassuk le a Python-ból mentett h5 modellünkre:

tensorflowjs_converter — input_format keras ./my_model.h5 ./tfjs_model

Az eredmény egy JSON és egy bin fájl lesz a megadott könyvtárban. Előbbi a modell leírását, míg utóbbi a súlyokat tartalmazza. Ezt lehetőség szerint érdemes így csinálni. Tehát szerver oldalon, erős GPU-val (vagy akár TPU-val), Pythonban betanítani a modellt, majd konvertálni Tensorflow.js formára a böngészőben való használathoz.

Most hogy megvan a modellünk, ideje hogy használjuk. A következő egyszerű kód betölti a modellt, majd véletlenszerűen előszed egy képet a teszthalmazból és megpróbálja megmondani, hogy mi van rajta.

A model betöltésére a tf.loadLayersModel szolgál, ami paraméterként a JSON fájl URL-jét várja. A következő pár sor a kirajzolja a véletlenszerűen előkapott képet, majd a jól megszokott predict függvénnyel futtatjuk a hálózatot. Ahogyan az első részből már tudhatjuk, a predict függvény egy 10 elemű vektort fog visszaadni, ami azt mutatja meg, hogy a hálózat szerint a kép melyik kategóriába mennyire tartozik bele. Azt, ahol ez az érték a legnagyobb, zölddel jelöljük meg. Ez a hálózat “tippje”.

Nagyon dióhéjban ennyit szerettem volna mondani a TensorFlow.js-ről. Ha valaki JavaScript fejlesztőként szeretne elkezdeni neurális hálókkal foglalkozni, annak mindenképp jó kiindulópont, de komolyabb dolgokhoz már mindenképp megéri megtanulni Pythonban fejelszteni, mert egyfelől sokkal több a lehetőség, másfelől minden példakód Pythonban érhető el. Persze egy-egy adott problémára elegendő lehet a JavaScript megvalósítás Node.js-el, illetve ha a cél neurális hálók böngészőben való futtatása, akkor nyilván a Tensorflow.js a(z egyik) megoldás.

Ha tetszett az írás, itt megtalálhatod az előző részeket:

A következő rész pedig itt érhető el: