jax | Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and m | Machine Learning library

 by   google Python Version: 0.4.20 License: Apache-2.0

kandi X-RAY | jax Summary

kandi X-RAY | jax Summary

jax is a Python library typically used in Artificial Intelligence, Machine Learning, Deep Learning, Numpy applications. jax has no bugs, it has no vulnerabilities, it has build file available, it has a Permissive License and it has high support. You can install using 'pip install jax' or download it from GitHub, PyPI.

JAX is Autograd and XLA, brought together for high-performance machine learning research. With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python. You can even program multiple GPUs or TPU cores at once using pmap, and differentiate through the whole thing. Dig a little deeper, and you'll see that JAX is really an extensible system for composable function transformations. Both grad and jit are instances of such transformations. Others are vmap for automatic vectorization and pmap for single-program multiple-data (SPMD) parallel programming of multiple accelerators, with more to come.

            kandi-support Support

              jax has a highly active ecosystem.
              It has 23518 star(s) with 2197 fork(s). There are 312 watchers for this library.
              There were 9 major release(s) in the last 6 months.
              There are 1161 open issues and 3107 have been closed. On average issues are closed in 99 days. There are 305 open pull requests and 0 closed requests.
              It has a positive sentiment in the developer community.
              The latest version of jax is 0.4.20

            kandi-Quality Quality

              jax has 0 bugs and 0 code smells.

            kandi-Security Security

              jax has no vulnerabilities reported, and its dependent libraries have no vulnerabilities reported.
              jax code analysis shows 0 unresolved vulnerabilities.
              There are 0 security hotspots that need review.

            kandi-License License

              jax is licensed under the Apache-2.0 License. This license is Permissive.
              Permissive licenses have the least restrictions, and you can use them in most projects.

            kandi-Reuse Reuse

              jax releases are available to install and integrate.
              Deployable package is available in PyPI.
              Build file is available. You can build the component from source.
              Installation instructions, examples and code snippets are available.
              jax saves you 49636 person hours of effort in developing the same functionality from scratch.
              It has 111930 lines of code, 11371 functions and 373 files.
              It has high code complexity. Code complexity directly impacts maintainability of the code.

            Top functions reviewed by kandi - BETA

            kandi has reviewed jax and discovered the below as its top functions. This is intended to give you an instant insight into jax implemented functionality, and help decide if they suit your requirements.
            • Apply a function to a function .
            • Convert a function to a function .
            • Apply a function to each axis .
            • Wrapper around pjit .
            • Compute an XLA computation .
            • Gather the gather index .
            • Helper function for matplotlib .
            • Turn a jaxpr expression into a Function Dialect .
            • Rewrite the expression .
            • Applies a function to each axis .
            Get all kandi verified functions for this library.

            jax Key Features

            No Key Features are available at this moment for jax.

            jax Examples and Code Snippets

            Autodidax: JAX core from scratch-Part 4: linearize and vjp (and grad!)-linearize
            Pythondot img1Lines of Code : 362dot img1License : Permissive (Apache-2.0)
            copy iconCopy
            y, f_lin = linearize(f, x)
            y_dot = f_lin(x_dot)
            y, y_dot = jvp(f, (x,), (x_dot,))
            jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
            def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
              assert not len(lst)   
            copy iconCopy
            def jit(f):
              def f_jitted(*args):
                avals_in = [raise_to_shaped(get_aval(x)) for x in args]
                jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
                outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
                return tree_unf  
            Autodidax: JAX core from scratch-Part 2: Jaxprs-Building jaxprs with tracing
            Pythondot img3Lines of Code : 263dot img3License : Permissive (Apache-2.0)
            copy iconCopy
            def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
              assert 0 <= n <= len(lst)
              return lst[:n], lst[n:]
            def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
              assert len(bs) == len(l)
            jax - differentially private sgd
            Pythondot img4Lines of Code : 142dot img4License : Non-SPDX (Apache License 2.0)
            copy iconCopy
            # Copyright 2019 The JAX Authors.
            # Licensed under the Apache License, Version 2.0 (the "License");
            # you may not use this file except in compliance with the License.
            # You may obtain a copy of the License at
            #     https://www.apache.org/licenses  
            jax - mnist vae
            Pythondot img5Lines of Code : 87dot img5License : Non-SPDX (Apache License 2.0)
            copy iconCopy
            # Copyright 2018 The JAX Authors.
            # Licensed under the Apache License, Version 2.0 (the "License");
            # you may not use this file except in compliance with the License.
            # You may obtain a copy of the License at
            #     https://www.apache.org/licenses  
            jax - gaussian process regression
            Pythondot img6Lines of Code : 86dot img6License : Non-SPDX (Apache License 2.0)
            copy iconCopy
            # Copyright 2018 The JAX Authors.
            # Licensed under the Apache License, Version 2.0 (the "License");
            # you may not use this file except in compliance with the License.
            # You may obtain a copy of the License at
            #     https://www.apache.org/licenses  

            Community Discussions


            Parameters do not converge at a lower tolerance in nonlinear least square implementation in python
            Asked 2022-Apr-17 at 14:20

            I am translating some of my R codes to Python as a learning process, especially trying JAX for autodiff.

            In functions to implement non-linear least square, when I set tolerance at 1e-8, the estimated parameters are nearly identical after several iterations, but the algorithm never appear to converge.

            However, the R codes converge at the 12th inter at tol=1e-8 and 14th inter at tol=1e-9. The estimated parameters are almost the same as the ones resulted from Python implementation.

            I think this has something to do with floating point, but not sure which step I could improve to make the converge as quickly as seen in R.

            Here are my codes, and most steps are the same as in R



            Answered 2022-Apr-17 at 14:20

            One thing to be aware of is that by default, JAX performs computations in 32-bit, while tools like R and numpy perform computations in 64-bit. Since 1E-8 is at the edge of 32-bit floating point precision, I suspect this is why your program is failing to converge.

            You can enable 64-bit computation by putting this at the beginning of your script:

            Source https://stackoverflow.com/questions/71902257


            Bean Validation on Jax-RS Resource stops working while using CDI on Apache TomEE 8.0.10
            Asked 2022-Mar-16 at 22:46

            I'm having troubles getting bean validation to work with the following minimalised project consisting only of this three java files plus pom.xml. I'm using Apache TomEE 8.0.10.




            Answered 2022-Mar-15 at 15:29

            This appears to be a bug in OpenWebBeans or TomEE. So what's happening is the first the actual instance of the bean is managed by JAX-RS, and the second, the bean is managed by the CDI container. In the second case, there needs to be some sort of interceptor the invokes the Bean Validation framework.

            I would start a discussion on the mailing list and open a bug on in the JIRA. If you can create a sample project that reproduces the problem it helps the devs out tremendously.

            As a workaround, you can @Inject private Validator validator and if there are any constraint violations returned, throw new ConstraintViolationException(constraintViolations);.

            Source https://stackoverflow.com/questions/71453728


            Using Keycloak adapter with Wildfly 26 does not provide "KEYCLOAK" as mechanism
            Asked 2022-Mar-16 at 19:01

            I have a JAX-RS application deployed in WildFly. The application's endpoints shall be protected by Keycloak with Access Type: bearer-only. This works perfectly fine for WildFly versions up to 24.

            Starting from WildFly 25 the Keycloak adapter is deprecated and one should migrate to the new Elytron subsystem. According to this WildFly issue https://issues.redhat.com/browse/WFLY-15485 however the OIDC adapter is not ready yet to work with bearer-only. But it is mentioned that it should still be possible using the Keycloak Wildfly adapter.

            Also the latest Keycloak documentation and this thread in Google Groups states this.

            So I installed the adapter from this location and ran the installation script:


            ./bin/jboss-cli.sh --file=bin/adapter-elytron-install-offline.cli -Dserver.config=standalone-full.xml

            When deploying the application I get thte following error message:

            java.lang.IllegalStateException: The required mechanism 'KEYCLOAK' is not available in mechanisms [BASIC, CLIENT_CERT, DIGEST, FORM] from the HttpAuthenticationFactory


            • WildFly 26 (Jakarta EE 8)
            • Keycloak 16.1.1




            Answered 2022-Feb-01 at 07:31

            I finally got it working without the Keycloak adapter, i.e. using the new built-in Elytron subsystem.

            oidc.json (located in the WEB-INF directory)

            Source https://stackoverflow.com/questions/70922622


            Mask a numpy array after a given value
            Asked 2022-Mar-07 at 17:38

            I have two numpy arrays like :



            Answered 2022-Mar-07 at 16:57

            You can do a cumulated sum

            Source https://stackoverflow.com/questions/71384567


            Is there a module to convert a tensorflow NN to Jax?
            Asked 2022-Feb-08 at 11:05

            There is a libary to convert Jax functions to Tensorflow functions. Is there a similar library to convert TensorFlow functions to Jax functions?



            Answered 2021-Dec-14 at 22:16

            No, there is no library supported by the JAX team to convert tensorflow into JAX in a manner similar to how jax.experimental.jax2tf converts JAX code to tensorflow, and I have not seen any such library developed by others.

            Source https://stackoverflow.com/questions/70356126


            Tomcat context is empty when accessed via executor and runnable
            Asked 2022-Feb-07 at 22:45

            Hello I have a web application running on apache-tomee-plus-8.0.1. My problem is about getting an Environment variable from a runnable in a custom executor. The variable is defined in /conf/context.xml:



            Answered 2022-Feb-07 at 22:45

            JNDI lookups depend on some context information on the running thread, usually the context class loader.

            On a Java EE/Jakarta EE server you should not spawn new (unmanaged) threads yourself, but use the ManagedExecutorService provided by the container. This service automatically propagates some kinds of contexts from the calling thread:

            The types of contexts to be propagated from a contextualizing application component include JNDI naming context, classloader, and security information. Containers must support propagation of these context types.

            (Jakarta Concurrency Specification, emphasis mine)

            You can inject a ManagedExecutorService using a @Resource annotation:

            Source https://stackoverflow.com/questions/70970270


            Compute efficiently Hessian matrices in JAX
            Asked 2022-Jan-04 at 14:16

            In JAX's Quickstart tutorial I found that the Hessian matrix can be computed efficiently for a differentiable function fun using the following lines of code:



            Answered 2022-Jan-04 at 14:16

            The answer to your question is within the JAX documentation; see for example this section: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

            To quote its discussion of jacrev and jacfwd:

            These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

            and further down,

            To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function 𝑓:ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇𝑓:ℝⁿ→ℝⁿ), which is where forward-mode wins out.

            Since your function looks like 𝑓:ℝⁿ→ℝ, then jit(jacfwd(jacrev(fun))) is likely the most efficient approach.

            As for why you can't implement a hessian with grad, this is because grad is only designed for derivatives of functions with scalar outputs. A hessian by definition is a composition of vector-valued jacobians, not a composition of scalar gradients.

            Source https://stackoverflow.com/questions/70572362


            Combine scipy.root and Jax Jacobian
            Asked 2021-Dec-21 at 15:32

            I am having trouble using the Jacobian from JAX with scipy.root. In the below example, the root works without the Jacobian, while it fails with the Jacobian. Any ideas on what I need to rewrite in order to get the code below working with the Jacobian?



            Answered 2021-Dec-19 at 14:01

            There are two issues:

            1. to perform automatic differentiation, JAX relies on replacing values with tracers. This means your approach of printing and evaluating the string representation of the value will not work.
            2. additionally, you are attempting to assign traced values to a standard numpy array. You should use a JAX array instead, as it knows how to handle traced values.

            With this in mind, you can rewrite your function this way and it should work, so long as your equations only use Python arithmetic operations and jax functions (not things like np.exp):

            Source https://stackoverflow.com/questions/70409729


            Fastest way to multiply and sum 4D array with 2D array in python?
            Asked 2021-Dec-17 at 15:19

            Here's my problem. I have two matrices A and B, with complex entries, of dimensions (n,n,m,m) and (n,n) respectively.

            Below is the operation I perform to get a matrix C -



            Answered 2021-Dec-17 at 15:19


            Automatic Differentiation with respect to rank-based computations
            Asked 2021-Dec-03 at 16:44

            I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve.

            I have two input arrays - a vector A of size N and a matrix B of shape (N, M), as well a parameter vector theta of size M. I define a new array C(theta) = B * theta to get a new vector of size N. I then obtain the indices of elements that fall in the upper and lower quartile of C, and use them to create a new array A_low(theta) = A[lower quartile indices of C] and A_high(theta) = A[upper quartile indices of C]. Clearly these two do depend on theta, but is it possible to differentiate A_low and A_high w.r.t theta?

            My attempts so far seem to suggest no - I have using the python libraries of autograd, JAX and tensorflow, but they all return a gradient of zero. (The approaches I have tried so far involve using argsort or extracting the relevant sub-arrays using tf.top_k.)

            What I'm seeking help with is either a proof that the derivative is not defined (or cannot be analytically computed) or if it does exist, a suggestion on how to estimate it. My eventual goal is to minimize some function f(A_low, A_high) wrt theta.



            Answered 2021-Dec-03 at 16:44

            This is the JAX computation that I wrote based on your description:

            Source https://stackoverflow.com/questions/70214451

            Community Discussions, Code Snippets contain sources that include Stack Exchange Network


            No vulnerabilities reported

            Install jax

            JAX is written in pure Python, but it depends on XLA, which needs to be installed as the jaxlib package. Use the following instructions to install a binary package with pip, or to build JAX from source. We support installing or building jaxlib on Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. Windows users can use JAX on CPU and GPU via the Windows Subsystem for Linux. There is some initial native Windows support, but since it is still somewhat immature, there are no binary releases and it must be built from source.


            For details about the JAX API, see the reference documentation. For getting started as a JAX developer, see the developer documentation.
            Find more information at:

            Find, review, and download reusable Libraries, Code Snippets, Cloud APIs from over 650 million Knowledge Items

            Find more libraries
          • PyPI

            pip install jax

          • CLONE
          • HTTPS


          • CLI

            gh repo clone google/jax

          • sshUrl


          • Stay Updated

            Subscribe to our newsletter for trending solutions and developer bootcamps

            Agree to Sign up and Terms & Conditions

            Share this Page

            share link