JAX برای یادگیری ماشین: چگونه کار میکند و چرا باید آن را یاد بگیریم
JAX در دنیای هوش مصنوعی و دیپ لرنینگ فریم ورک خیلی جدیدی است که وعده میدهد که برنامهنویسی ML را بصریتر، ساختارمندتر و تمیزتر کند. این فریمورک میتواند جایگزین ابزارهایی مانند Tensorflow و PyTorch شود، اگرچه که در هسته، بسیار با آنها متفاوت است. این مقاله، به بررسی JAX و دلیلی که باید از آن به جای کتابخانههای دیگر استفاده کرد خوهد پرداخت.
JAX چیست؟
JAX یک کتابخانه پایتون است که برای تحقیقات یادگیریماشین با کارایی بالا طراحی شده است. به عبارت ساده چیزی جز یک کتابخانه محاسباتی عددی نظیر Numpy نیست. اما با برخی ویژگیهای کلیدی پیشرفته. این ابزار توسط گوگل توسعه داده شده و به صورت داخلی توسط تیمهای گوگل و Deep-mind استفاده میشود.
نصب JAX
قبل از اینکه در مورد مزایای اصلی JAX صحبت کنیم، به شما پیشنهاد میکنم JAX را در محیط پایتون یا در Google colab نصب کنید تا بتوانید خودتان کدها را دنبال کرده و اجرا کنید. البته لینک کد کامل را در انتهای مطلب میگذارم.
برای نصب JAX، به سادگی میتوانیم در خط فرمان (command line) از pip استفاده کنیم:
$ pip install --upgrade jax jaxlib
توجه داشته باشید که این نسخه فقط از اجرا در CPU پشتیبانی میکند. اگر میخواهید از GPU نیز پشتیبانی کنید، ابتدا به CUDA و cuDNN نیاز دارید و سپس دستور زیر را اجرا کنید (حتماً نسخه jaxlib را با نسخه CUDA خود منطبق کنید):
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
برای عیبیابی، دستورالعملهای رسمی Github را بررسی کنید.
حالا بیایید JAX را در کنار Numpy استفاده کنیم. ما از Numpy برای مقایسهی موارد استفادهی مختلف بهره خواهیم برد.
import jax import jax.numpy as jnp import numpy as np
مبانی JAX
بیایید با اصول اولیه شروع کنیم. همانطور که قبلاً گفتیم، تنها هدف اصلی JAX انجام عملیات عددی به روشی قابل بیان و با کارایی بالا است. این بدان معنی است که سینتکس آن تقریباً مشابه Numpy است. به عنوان مثال، اگر بخواهیم یک آرایه از صفرها ایجاد کنیم، خواهیم داشت:
x = np.zeros(10) y= jnp.zeros(10)
تفاوتها در پشت صحنه قرار دارند.
DeviceArray
یکی از مزیتهای اصلی JAX این است که میتوانیم همان برنامه را بدون هیچ تغییری در شتاب دهندههای سخت افزاری مانند GPU و TPU اجرا کنیم.
این کار توسط یک ساختار زیربنایی به نام DeviceArray انجام میشود که اساساً جایگزین آرایه استاندارد Numpy میشود.
DeviceArrayها تنبل یا lazy هستند، به این معنی که مقادیر را در شتابدهنده نگه میدارند و فقط در صورت نیاز آنها را pull میکنند.
x # array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) y # DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
همانطور که از آرایههای استاندارد استفاده میکنیم، میتوانیم از DeviceArrays هم استفاده کنیم و آن را به کتابخانههای دیگر هم ارسال کنیم، نمودارهایی را بر اساس آنها ترسیم کنیم، مشتق بگیریم و کارهایی از این دست. همچنین توجه داشته باشید که اکثر APIهای Numpy (توابع و عملیاتها) توسط JAX پشتیبانی میشوند، بنابراین کد JAX شما تقریباً مشابه Numpy خواهد بود.
نکته مهم دیگر سرعت است. JAX سریعتر است؛ خیلی سریعتر. بیایید به یک مثال ساده نگاه کنیم. دو آرایه با سایز (۱۰۰۰، ۱۰۰۰) یکی با Numpy و دیگری با JAX ایجاد میکنیم و حاصلضرب داخلی با خودش را محاسبه میکنیم.
این دو عملیات را با استفاده از دستور timeit از نظر زمانی مقایسه میکنیم:
x = np.random.rand(1000,1000) y = jnp.array(x) %timeit -n 1 -r 1 np.dot(x,x) # 1 loop, best of 1: 52.6 ms per loop %timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready() # 1 loop, best of 1: 1.47 ms per loop
تاثیرگذار است. درست است؟ انتظار هم همین است. محاسبات در GPU سریعتر است. همچنین به تابع block_until_ready() توجه کنید. از آنجایی که JAX به صورت asynchronous کار میکند، باید منتظر بمانیم تا اجرا کامل شود تا زمان را به درستی اندازهگیری کنیم.
شما که نمیتوانید باور کنید که این تمام چیزی است که JAX برای ارائه دارد؟
اکنون به قسمتهای جذابتر میرسیم!
چرا JAX؟
اگر سرعت و پشتیبانی خودکار از پردازندههای گرافیکی برای شما کافی نیست، شما را سرزنش نمیکنم. به نظر میرسد که هر کتابخانه دیگری میتواند این موارد را مدیریت کند. برای درک بیشتر مزایای JAX، باید عمیقتر غواصی کنیم. JAX را میتوان به عنوان مجموعهای از تبدیلهای توابع پایتون و Numpy معمولی دید.
نمونهای از این تحولات مشتقگیری است. آیا JAX از مشتقگیری خودکار پشتیبانی میکند؟
مطمئناً درست حدس زدید.
مشتقگیری خودکار با استفاده از تابع grad
JAX قادر است از انواع توابع پایتون و NumPy از جمله حلقهها، شاخهها، بازگشتها و موارد دیگر مشتقگیری کند.
این برای برنامههای Deep Learning بسیار مفید است، زیرا میتوانیم back propagation را تقریباً بدون زحمت اجرا کنیم. تابع اصلی برای انجام این کار grad() نام دارد. به عنوان مثال، ما یک تابع درجه دوم ساده تعریف میکنیم و مشتق آن را در نقطه 1.0 محاسبه میکنیم.
برای اینکه ثابت کنیم نتیجه درست است، مشتق را به صورت دستی هم محاسبه میکنیم.
from jax import grad def f(x): return 3*x**2 + 2*x + 5 def f_prime(x): return 6*x +2 grad(f)(1.0) # DeviceArray(8., dtype=float32) f_prime(1.0) # 8.0
یک نکتهی بسیار شگفتانگیز برای من این بود که JAX در واقع به جای استفاده از تکنیکهای فانتزی دیگر، حل تحلیلی گرادیان را در لایههای زیرین انجام میدهد. به سادگی شکل تابع را میگیرد و قانون زنجیره را اجرا میکند. از آنجایی که مشتقگیری خودکار بسیار پیچیدهتر از این است، برای درک کاملتر به شدت توصیه میکنم به راهنمای رسمی نگاه کنید.
جبر خطی تسریعشده (کامپایلر XLA)
یکی از عواملی که JAX را بسیار سریع میکند، شتابدهندهی جبر خطی یا XLA است.
XLA یک کامپایلر مخصوص برای جبر خطی است که به طور گسترده توسط TensorFlow استفاده شده است.
به منظور انجام هر چه سریعتر عملیات ماتریسها، کد در مجموعهای از هستههای محاسباتی کامپایل میشود که میتوانند به طور گسترده بر اساس ماهیت کد بهینهسازی شوند.
نمونه ای از این بهینهسازیها عبارتند از:
- ترکیب عملیات: نتایج میانی در حافظه ذخیره نمیشوند
- طرحبندی بهینهشده: بهینهسازی “شکل” یک آرایه در حافظه نمایش داده شده است
کامپایل در لحظه/just in time compilation یا jit
کامپایل در لحظه، دست در دست XLA وارد میشود. برای استفاده از قدرت XLA، کد باید در هستههای XLA کامپایل شود. اینجاست که jit وارد عمل میشود.
jit روشی برای اجرای کدهای کامپیوتری است که شامل کامپایل در طول اجرای یک برنامه (در زمان اجرا) به جای قبل از اجرا است.
برای استفاده از XLA و jit میتوان از تابع jit() یا حاشیه نویسی @jit استفاده کرد.
from jax import jit x = np.random.rand(1000,1000) y = jnp.array(x) def f(x): for _ in range(10): x = 0.5*x + 0.1* jnp.sin(x) return x g = jit(f) %timeit -n 5 -r 5 f(y).block_until_ready() # 5 loops, best of 5: 10.8 ms per loop %timeit -n 5 -r 5 g(y).block_until_ready() # 5 loops, best of 5: 341 µs per loop
یک بار دیگر بهبود در زمان اجرا آشکار میشود. jit را میتوان با تابع grad (یا هر تبدیل دیگری) نیز ترکیب کرد که باعث میشود back propagation بسیار سریع باشد.
البته، توجه داشته باشید که jit دارای کاستیهایی است: به عنوان مثال، اگر نتواند عملکرد را به طور دقیق نشان دهد (که معمولاً با شاخههای “if” اتفاق میافتد)، احتمالاً از کار خواهد افتاد. با این حال، برای اکثر موارد استفاده مربوط به یادگیری عمیق، فوقالعاده مفید است.
محاسبات را در دستگاههای مختلف با pmap تکرار کنید
pmap تبدیل دیگری است که ما را قادر میسازد محاسبات را در چندین هسته یا دستگاه تکرار کنیم و آنها را به صورت موازی اجرا کنیم (p در pmap مخفف موازی است).
این تبدیل به طور خودکار محاسبات را در تمام دستگاههای فعلی توزیع میکند و تمام ارتباطات بین آنها را مدیریت میکند. برای بررسی دستگاههای موجود، میتوانید jax.devices() را اجرا کنید.
from jax import pmap def f(x): return jnp.sin(x) + x**2 f(np.arange(4)) #DeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32) pmap(f)(np.arange(4)) #ShardedDeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
توجه داشته باشید که DeviceArray اکنون به ShardedDeviceArray تبدیل شده است، که ساختاری است که اجرای موازی را مدیریت میکند.
یکی دیگر از کارهای جالبی که JAX به ما اجازه میدهد انجام دهیم، ارتباط جمعی بین دستگاهها است. فرض کنید که میخواهیم یک عملیات «کاهش» یا reduce بین مقادیر در همه دستگاهها انجام دهیم (مثلاً جمع ببندیم). برای انجام این کار، باید تمام دادهها را از همه دستگاهها جمعآوری کنیم و مجموع را محاسبه کنیم. این کار به راحتی به شرح زیر قابل انجام است:
from functools import partial from jax.lax import psum @partial(pmap, axis_name="i") def normalize(x): return x/ psum(x,'i') normalize(np.arange(8.))
کد بالا بردار x را در تمام دستگاهها ترسیم میکند و یک عملیات ارتباط جمعی را برای اجرای psum (مجموع موازی) اجرا میکند. به عبارت دیگر، تمام «x» را از دستگاهها جمعآوری میکند، آنها را خلاصه میکند و نتیجه را به هر دستگاه برمیگرداند تا به محاسبات موازی ادامه دهد. من مثال بالا را از این سخنرانی عالی Matthew Johnson در جریان GTC 2020 قرض گرفتم.
همچنین میتوانید تصور کنید که با pmap میتوانیم الگوهای محاسباتی خود را تعریف کنیم و از دستگاههای خود به بهترین شکل ممکن بهرهبرداری کنیم. درست مانند کاری که معمولاً با CUDA برای هستههای جداگانه انجام میدهیم، اما این بار برای دستگاههای جداگانه!
تبدیل برداری خودکار با vmap
vmap همانطور که از نام آن پیداست، تبدیل تابعی است که به ما امکان میدهد توابع را بر روی بردارها اعمال کنیم (v مخفف vector است!).
میتوانیم تابعی را بگیریم که روی یک نقطه داده عمل میکند و آن را برداری کنیم تا بتواند دستهای از این نقاط داده (یا بردار) با اندازه دلخواه را بپذیرد. به عنوان مثال:
from jax import vmap def f(x): return jnp.square(x) f(jnp.arange(10)) #DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32) vmap(f)(jnp.arange(10)) #DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
ممکن است تعجب کنید که در اینجا چه چیزی به دست آمد. برای درک این موضوع، بیایید نگاهی بیاندازیم به اینکه وقتی f(x) بدون vmap اجرا می شود چه اتفاقی میافتد:
- یک لیست خروجی مقداردهی اولیه میشود.
- مربع 0 محاسبه شده و برگردانده میشود.
- نتیجه 0 به لیست اضافه میشود.
- مربع 1 محاسبه شده و برگردانده میشود.
- نتیجه 1 به لیست اضافه میشود.
- مربع 2 محاسبه شده و برگردانده میشود.
- نتیجه 4 به لیست اضافه میشود.
- و الی آخر …
کاری که vmap انجام میدهد این است که عملیات مربع را فقط یک بار انجام میدهد، زیرا تمام مقادیر را با هم دستهبندی میکند و آنها را از تابع عبور میدهد. و این باعث افزایش سرعت و مصرف حافظه میشود.
تحولات ذکر شده مواردی هستند که قطعاً باید بدانید، در ادامه میخواهم به چند مورد دیگر اشاره کنم که در طول سفر JAX من را شگفتزده کرد.
مولد اعداد شبه تصادفی
مولد اعداد تصادفی JAX کمی متفاوت از Numpy عمل میکند. به جای اینکه یک مولد اعداد شبه تصادفی (PRNGs) استاندارد باشد، مانند Numpy و Scipy، همهی توابع تصادفی JAX نیاز به یک حالت PRNG صریح دارند که به عنوان آرگومان اول ارسال شود.
یک مولد اعداد تصادفی تنها یک حالت/state دارد. عدد “تصادفی” بعدی تابعی از عدد قبلی و seed/state است. دنباله مقادیر تصادفی محدود است و تکرار میشود.
نکته مهمی که باید به آن توجه کرد این است که PRNGها هم از نظر برداری و هم از نظر محاسبات موازی بین دستگاهها به خوبی کار میکنند.
from jax import random key = random.PRNGKey(5) random.uniform(key)
ارسال asynchronous
یکی دیگر از جنبههای JAX که من را تحت تأثیر قرار داد این است که از ارسال asynchronous استفاده میکند. این بدان معناست که قبل از بازگرداندن کنترل به برنامه پایتون منتظر نمیماند تا عملیات تکمیل شود. در عوض، DeviceArray را برمیگرداند که یک future است (درست مانند Completable future در Java)
future مقداری است که در آینده در یک دستگاه شتابدهنده تولید میشود، اما لزوماً فوراً در دسترس نیست.
future را میتوان بدون انتظار برای تکمیل محاسبات به سایر عملیاتها منتقل کرد. به این ترتیب JAX به کد پایتون اجازه میدهد جلوتر از شتابدهنده اجرا شود و اطمینان حاصل شود که میتواند بدون نیاز به صبر کردن، عملیاتها را برای شتابدهنده سختافزاری (مثلاً GPU) در صف قرار دهد.
پروفایل کردن JAX و حافظهی دستگاه
آخرین ویژگیای که میخواهم به آن اشاره کنم، پروفایل کردن است. از اینکه Tensoboard از پروفایل JAX پشتیبانی میکند خوشحال خواهید شد.
همین امر در مورد Nsight در Nvidia نیز صدق میکند، که برای اشکالزدایی و پروفایل کردن کد GPU استفاده میشود. علاوه بر این، میتوان از ابزار پروفایلکردن حافظه داخلی JAX نیز استفاده کرد که نحوه اجرای کد JAX در GPU و TPU را نشان میدهد. در اینجا یک قطعه از راهنما آورده شده است:
import jax import jax.numpy as jnp import jax.profiler def func1(x): return jnp.tile(x, 10) * 0.5 def func2(x): y = func1(x) return y, jnp.tile(x, 10) + 1 x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) y, z = func2(x) z.block_until_ready() jax.profiler.save_device_memory_profile("memory.prof")
اگر pprof را که یک کتابخانه گوگل است نصب کردهاید، میتوانید دستور زیر را اجرا کنید که یک پنجره مرورگر با تمام اطلاعات لازم باز میشود.
$ pprof --web memory.prof
این شگفتانگیز نیست؟
پیشنهاد میکنیم که حتما ویژگیهای مختلف آن را امتحان کنید.
نتیجهگیری
در این پست، سعی کردم مروری بر مزایای JAX نسبت به سایر کتابخانهها داشته باشم و تکههای کد ساده را برای یادگیری سینتکس اولیه و پیچیدگیهای آن ارائه کنم. به هر حال، شما می توانید کد کامل را در این نوتبوک colab یا در github پیدا کنید.
در مقالات بعدی، ما آن را یک گام فراتر خواهیم برد و چگونگی ساخت و آموزش شبکههای عصبی عمیق با JAX و همچنین نگاهی به چارچوبهای مختلف ساختهشده در بالای آن را بررسی خواهیم کرد.
اگر این مقاله برای شما جالب بود، فراموش نکنید که آن را در شبکه های اجتماعی به اشتراک بگذارید.
منبع:
https://theaisummer.com/jax/
دیدگاهتان را بنویسید