Created
April 25, 2024 14:25
-
-
Save jyingl3/66d8bc3893a9b50a04bf208175e2d617 to your computer and use it in GitHub Desktop.
PJRT plugin tutorial
# install jax | |
~$ pip install -U "jax[cpu]" | |
# Build the .so file | |
~$ git clone https://github.com/openxla/xla | |
# Optional, checkout the branch that adds vlog | |
~/xla$ git checkout remotes/origin/test_626168031 | |
# build cpu plugin | |
~/xla$ bazel build xla/pjrt/c:pjrt_c_api_cpu_plugin.so | |
# Check the method exposed. It should contain `T GetPjrtApi@@VERS_1.0` on the top | |
~/xla$ nm -gD bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so | grep GetPjrt | |
# Use this plugin in JAX by setting PJRT_NAMES_AND_LIBRARY_PATHS | |
~/xla$ PJRT_NAMES_AND_LIBRARY_PATHS=cpu_plugin:bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so ENABLE_PJRT_COMPATIBILITY=1 TF_CPP_VMODULE=cpu_client=3,pjrt_c_api_wrapper_impl=3 TF_CPP_MIN_LOG_LEVEL=0 python | |
>>> import jax | |
>>> from jax._src import xla_bridge | |
>>> jax.config.update("jax_platform_name", "cpu_plugin") | |
>>> client = xla_bridge.get_backend() | |
Platform 'cpu_plugin' is experimental and not all JAX functionality may be correctly supported! | |
I0000 00:00:1712356514.251375 99055 cpu_client.cc:424] TfrtCpuClient created. | |
2024-04-22 17:42:10.410579: I external/xla/xla/pjrt/pjrt_c_api_client.cc:134] PjRtCApiClient created. | |
>>> xla_bridge.backend_pjrt_c_api_version() | |
(0, 49) | |
>>> client.platform | |
'cpu' | |
>>> client.devices() | |
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)] | |
>>> import numpy as np | |
>>> x = np.arange(12.).reshape((3, 4)).astype("float32") | |
>>> device_x = jax.device_put(x) | |
>>> @jax.jit | |
... def fn(x): | |
... return x * x | |
>>> x_shape = jax.ShapeDtypeStruct(shape=(16, 16), dtype=jax.numpy.dtype('float32')) | |
>>> lowered = fn.lower(x_shape) | |
>>> executable = lowered.compile()._executable | |
>>> executable.as_text() | |
'HloModule jit_fn, entry_computation_layout={(f32[16,16]{1,0})->f32[16,16]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}\n\nENTRY %main.3 (Arg_0.1: f32[16,16]) -> f32[16,16] {\n %Arg_0.1 = f32[16,16]{1,0} parameter(0)\n ROOT %multiply.2 = f32[16,16]{1,0} multiply(f32[16,16]{1,0} %Arg_0.1, f32[16,16]{1,0} %Arg_0.1), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<stdin>" source_line=3}\n}\n\n' | |
# JAX 1+1 | |
>>> jax.numpy.add(1, 1) | |
Array(2, dtype=int32, weak_type=True) | |
# jit | |
>>> jax.jit(lambda x: x * 2)(1.) | |
Array(2., dtype=float32, weak_type=True) | |
# pmap (4 devices in this example) | |
>>> arr = jax.numpy.arange(jax.device_count()) | |
>>> jax.pmap(lambda x: x + jax.lax.psum(x, 'i'), axis_name='i')(arr) | |
Array([6, 7, 8, 9], dtype=int32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment