{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2418600b",
   "metadata": {},
   "source": [
    "# An Introduction to JAX\n",
    "\n",
    "In addition to what’s in Anaconda, this lecture will need the following libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b372d947",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "!pip install jax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3d4fd37",
   "metadata": {},
   "source": [
    "This lecture provides a short introduction to [Google JAX](https://github.com/google/jax).\n",
    "\n",
    "Here we are focused on using JAX on the CPU, rather than on accelerators such as\n",
    "GPUs or TPUs.\n",
    "\n",
    "This means we will only see a small amount of the possible benefits from using\n",
    "JAX.\n",
    "\n",
    "At the same time, JAX computing on the CPU is a good place to start, since the\n",
    "JAX just-in-time compiler seamlessly handles transitions across different\n",
    "hardware platforms.\n",
    "\n",
    "(In other words, if you do want to shift to using GPUs, you will almost never\n",
    "need to modify your code.)\n",
    "\n",
    "For a discusson of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "481e919b",
   "metadata": {},
   "source": [
    "## JAX as a NumPy Replacement\n",
    "\n",
    "One way to use JAX is as a plug-in NumPy replacement. Let’s look at the\n",
    "similarities and differences."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a55e476",
   "metadata": {},
   "source": [
    "### Similarities\n",
    "\n",
    "The following import is standard, replacing `import numpy as np`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3608b7cf",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3fcc217",
   "metadata": {},
   "source": [
    "Now we can use `jnp` in place of `np` for the usual array operations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2682ca0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = jnp.asarray((1.0, 3.2, -1.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd18a532",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6ba6424",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "print(jnp.sum(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6db633b1",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "print(jnp.mean(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b00cda",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "print(jnp.dot(a, a))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ce91974",
   "metadata": {},
   "source": [
    "However, the array object `a` is not a NumPy array:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "605ac9d7",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e3b2db",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "type(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39d99a34",
   "metadata": {},
   "source": [
    "Even scalar-valued maps on arrays return JAX arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1562cc5b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jnp.sum(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "316d4248",
   "metadata": {},
   "source": [
    "Operations on higher dimensional arrays are also similar to NumPy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61ee56d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "A = jnp.ones((2, 2))\n",
    "B = jnp.identity(2)\n",
    "A @ B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f466bb14",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "from jax.numpy import linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7849e072",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "linalg.inv(B)   # Inverse of identity is identity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd18dea4",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "linalg.eigh(B)  # Computes eigenvalues and eigenvectors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97af9b1b",
   "metadata": {},
   "source": [
    "### Differences\n",
    "\n",
    "One difference between NumPy and JAX is that JAX uses 32 bit floats by default.\n",
    "\n",
    "This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats.\n",
    "\n",
    "Using 32 bit floats can lead to significant speed gains with small loss of precision.\n",
    "\n",
    "However, for some calculations precision matters.\n",
    "\n",
    "In these cases 64 bit floats can be enforced via the command"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfc0e9f9",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "232b5894",
   "metadata": {},
   "source": [
    "Let’s check this works:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abc6c491",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jnp.ones(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3f99759",
   "metadata": {},
   "source": [
    "As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**.\n",
    "\n",
    "For example, with NumPy we can write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ac3cd0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "a = np.linspace(0, 1, 3)\n",
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7c9b509",
   "metadata": {},
   "source": [
    "and then mutate the data in memory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "157e0ef0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a[0] = 1\n",
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b6dd49d",
   "metadata": {},
   "source": [
    "In JAX this fails:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7861a352",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = jnp.linspace(0, 1, 3)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce737f8",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a[0] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a3f8405",
   "metadata": {},
   "source": [
    "In line with immutability, JAX does not support inplace operations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95268527",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = np.array((2, 1))\n",
    "a.sort()\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d7a1cc",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = jnp.array((2, 1))\n",
    "a_new = a.sort()\n",
    "a, a_new"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "078269d0",
   "metadata": {},
   "source": [
    "The designers of JAX chose to make arrays immutable because JAX uses a\n",
    "functional programming style.  More on this below.\n",
    "\n",
    "However, JAX provides a functionally pure equivalent of in-place array modification\n",
    "using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4daf4194",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = jnp.linspace(0, 1, 3)\n",
    "id(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10e50f01",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3aa4c73",
   "metadata": {},
   "source": [
    "Applying `at[0].set(1)` returns a new copy of `a` with the first element set to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8a1ec8",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = a.at[0].set(1)\n",
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d9d3828",
   "metadata": {},
   "source": [
    "Inspecting the identifier of `a` shows that it has been reassigned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2857b55b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "id(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f775c426",
   "metadata": {},
   "source": [
    "## Random Numbers\n",
    "\n",
    "Random numbers are also a bit different in JAX, relative to NumPy.  Typically, in JAX, the state of the random number generator needs to be controlled explicitly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc14621c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import jax.random as random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd3c611d",
   "metadata": {},
   "source": [
    "First we produce a key, which seeds the random number generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b68b5e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78eb9b93",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "type(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889b9937",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "print(key)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "194316d9",
   "metadata": {},
   "source": [
    "Now we can use the key to generate some random numbers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b5df53d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x = random.normal(key, (3, 3))\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d74667be",
   "metadata": {},
   "source": [
    "If we use the same key again, we initialize at the same seed, so the random numbers are the same:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af3715eb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "random.normal(key, (3, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efb7e63",
   "metadata": {},
   "source": [
    "To produce a (quasi-) independent draw, best practice is to “split” the existing key:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71bf5990",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "key, subkey = random.split(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf58438b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "random.normal(key, (3, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb14dbd",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "random.normal(subkey, (3, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b8af9ff",
   "metadata": {},
   "source": [
    "The function below produces `k` (quasi-) independent random `n x n` matrices using this procedure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f306c0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def gen_random_matrices(key, n, k):\n",
    "    matrices = []\n",
    "    for _ in range(k):\n",
    "        key, subkey = random.split(key)\n",
    "        matrices.append(random.uniform(subkey, (n, n)))\n",
    "    return matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f42e0959",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "matrices = gen_random_matrices(key, 2, 2)\n",
    "for A in matrices:\n",
    "    print(A)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ce1b773",
   "metadata": {},
   "source": [
    "One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays.  Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d8b7bd5",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "random.normal(key, (5, ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f21260fe",
   "metadata": {},
   "source": [
    "## JIT compilation\n",
    "\n",
    "The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear\n",
    "algebra operations into a single optimized kernel."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c179a5",
   "metadata": {},
   "source": [
    "### A first example\n",
    "\n",
    "To see the JIT compiler in action, consider the following function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "332821c1",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5\n",
    "    return jnp.sum(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d8e7d3b",
   "metadata": {},
   "source": [
    "Let’s build an array to call the function on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44cf9b30",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "n = 50_000_000\n",
    "x = jnp.ones(n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c39f4293",
   "metadata": {},
   "source": [
    "How long does the function take to execute?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634c954c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7eeb3a7",
   "metadata": {},
   "source": [
    ">**Note**\n",
    ">\n",
    ">Here, in order to measure actual speed, we use the `block_until_ready()` method\n",
    "to hold the interpreter until the results of the computation are returned.\n",
    "This is necessary because JAX uses asynchronous dispatch, which\n",
    "allows the Python interpreter to run ahead of numerical computations.\n",
    "\n",
    "If we run it a second time it becomes faster again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4343d184",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a02c5617",
   "metadata": {},
   "source": [
    "This is because the built in functions like `jnp.cos` are JIT compiled and the\n",
    "first run includes compile time.\n",
    "\n",
    "Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of\n",
    "just providing pre-compiled versions, like NumPy?\n",
    "\n",
    "The reason is that the JIT compiler can specialize on the *size* of the array\n",
    "being used, which is helpful for parallelization.\n",
    "\n",
    "For example, in running the code above, the JIT compiler produced a version of `jnp.cos` that is\n",
    "specialized to floating point arrays of size `n = 50_000_000`.\n",
    "\n",
    "We can check this by calling `f` with a new array of different size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d308b3",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "m = 50_000_001\n",
    "y = jnp.ones(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "279b6406",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f(y).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21282c71",
   "metadata": {},
   "source": [
    "Notice that the execution time increases, because now new versions of\n",
    "the built-ins like `jnp.cos` are being compiled, specialized to the new array\n",
    "size.\n",
    "\n",
    "If we run again, the code is dispatched to the correct compiled version and we\n",
    "get faster execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c8b996",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f(y).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b34f1ee",
   "metadata": {},
   "source": [
    "The compiled versions for the previous array size are still available in memory\n",
    "too, and the following call is dispatched to the correct compiled code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b95ce4bb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59acbd8a",
   "metadata": {},
   "source": [
    "### Compiling the outer function\n",
    "\n",
    "We can do even better if we manually JIT-compile the outer function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a648c09",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_jit = jax.jit(f)   # target for JIT compilation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1703e7da",
   "metadata": {},
   "source": [
    "Let’s run once to compile it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b836e214",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_jit(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2972aa7",
   "metadata": {},
   "source": [
    "And now let’s time it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5890823",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%time f_jit(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01a4768f",
   "metadata": {},
   "source": [
    "Note the speed gain.\n",
    "\n",
    "This is because the array operations are fused and no intermediate arrays are created.\n",
    "\n",
    "Incidentally, a more common syntax when targetting a function for the JIT\n",
    "compiler is"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba22d83a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def f(x):\n",
    "    a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5\n",
    "    return jnp.sum(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c34cc19",
   "metadata": {},
   "source": [
    "## Functional Programming\n",
    "\n",
    "From JAX’s documentation:\n",
    "\n",
    "*When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.*\n",
    "\n",
    "In other words, JAX assumes a functional programming style.\n",
    "\n",
    "The major implication is that JAX functions should be pure.\n",
    "\n",
    "A pure function will always return the same result if invoked with the same inputs.\n",
    "\n",
    "In particular, a pure function has\n",
    "\n",
    "- no dependence on global variables and  \n",
    "- no side effects  \n",
    "\n",
    "\n",
    "JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable.\n",
    "\n",
    "Here’s an illustration of this fact, using global variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d55a477",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = 1  # global\n",
    "\n",
    "@jax.jit\n",
    "def f(x):\n",
    "    return a + x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96aa1e42",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x = jnp.ones(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6c874a1",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "341c0cbe",
   "metadata": {},
   "source": [
    "In the code above, the global value `a=1` is fused into the jitted function.\n",
    "\n",
    "Even if we change `a`, the output of `f` will not be affected — as long as the same compiled version is called."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "350f0786",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "a = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c48b955",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e4ecce0",
   "metadata": {},
   "source": [
    "Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of `a` takes effect:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78780d6e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x = jnp.ones(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54b7888b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a1a80d",
   "metadata": {},
   "source": [
    "Moral of the story: write pure functions when using JAX!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c897925",
   "metadata": {},
   "source": [
    "## Gradients\n",
    "\n",
    "JAX can use automatic differentiation to compute gradients.\n",
    "\n",
    "This can be extremely useful for optimization and solving nonlinear systems.\n",
    "\n",
    "We will see significant applications later in this lecture series.\n",
    "\n",
    "For now, here’s a very simple illustration involving the function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96aa8273",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    return (x**2) / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e4cab5e",
   "metadata": {},
   "source": [
    "Let’s take the derivative:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8373434",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_prime = jax.grad(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce9fb77",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_prime(10.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7410a227",
   "metadata": {},
   "source": [
    "Let’s plot the function and derivative, noting that $ f'(x) = x $."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55dbdf86",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "x_grid = jnp.linspace(-4, 4, 200)\n",
    "ax.plot(x_grid, f(x_grid), label=\"$f$\")\n",
    "ax.plot(x_grid, [f_prime(x) for x in x_grid], label=\"$f'$\")\n",
    "ax.legend(loc='upper center')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94e66107",
   "metadata": {},
   "source": [
    "We defer further exploration of automatic differentiation with JAX until [Adventures with Autodiff](https://jax.quantecon.org/autodiff.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bec59ad8",
   "metadata": {},
   "source": [
    "## Writing vectorized code\n",
    "\n",
    "Writing fast JAX code requires shifting repetitive tasks from loops to array processing operations, so that the JAX compiler can easily understand the whole operation and generate more efficient machine code.\n",
    "\n",
    "This procedure is called **vectorization** or **array programming**, and will be\n",
    "familiar to anyone who has used NumPy or MATLAB.\n",
    "\n",
    "In most ways, vectorization is the same in JAX as it is in NumPy.\n",
    "\n",
    "But there are also some differences, which we highlight here.\n",
    "\n",
    "As a running example, consider the function\n",
    "\n",
    "$$\n",
    "f(x,y) = \\frac{\\cos(x^2 + y^2)}{1 + x^2 + y^2}\n",
    "$$\n",
    "\n",
    "Suppose that we want to evaluate this function on a square grid of $ x $ and $ y $ points and then plot it.\n",
    "\n",
    "To clarify, here is the slow `for` loop version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1641bff",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def f(x, y):\n",
    "    return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)\n",
    "\n",
    "n = 80\n",
    "x = jnp.linspace(-2, 2, n)\n",
    "y = x\n",
    "\n",
    "z_loops = np.empty((n, n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef3f23e3",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "for i in range(n):\n",
    "    for j in range(n):\n",
    "        z_loops[i, j] = f(x[i], y[j])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5073ceb",
   "metadata": {},
   "source": [
    "Even for this very small grid, the run time is extremely slow.\n",
    "\n",
    "(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.)\n",
    "\n",
    "OK, so how can we do the same operation in vectorized form?\n",
    "\n",
    "If you are new to vectorization, you might guess that we can simply write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a141ab7",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "z_bad = f(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e957bfb7",
   "metadata": {},
   "source": [
    "But this gives us the wrong result because JAX doesn’t understand the nested for loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6d5c7e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "z_bad.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c617252",
   "metadata": {},
   "source": [
    "Here is what we actually wanted:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443c41d1",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "z_loops.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "573b29e5",
   "metadata": {},
   "source": [
    "To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation designed for this purpose:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f259f563",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x_mesh, y_mesh = jnp.meshgrid(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce48c482",
   "metadata": {},
   "source": [
    "Now we get what we want and the execution time is very fast."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "042881ca",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "z_mesh = f(x_mesh, y_mesh).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e921d4c9",
   "metadata": {},
   "source": [
    "Let’s run again to eliminate compile time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f6a3524",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "z_mesh = f(x_mesh, y_mesh).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d62265c1",
   "metadata": {},
   "source": [
    "Let’s confirm that we got the right answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30e0fe26",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jnp.allclose(z_mesh, z_loops)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43fc91bd",
   "metadata": {},
   "source": [
    "Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f46f0e8",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "n = 6000\n",
    "x = jnp.linspace(-2, 2, n)\n",
    "y = x\n",
    "x_mesh, y_mesh = jnp.meshgrid(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1775b7c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "z_mesh = f(x_mesh, y_mesh).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f016549d",
   "metadata": {},
   "source": [
    "But there is one problem here: the mesh grids use a lot of memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1157ff53",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x_mesh.nbytes + y_mesh.nbytes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5febe51",
   "metadata": {},
   "source": [
    "By comparison, the flat array `x` is just"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9294f362",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x.nbytes  # and y is just a pointer to x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7227e32f",
   "metadata": {},
   "source": [
    "This extra memory usage can be a big problem in actual research calculations.\n",
    "\n",
    "So let’s try a different approach using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)\n",
    "\n",
    "First we vectorize `f` in `y`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a526271",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_vec_y = jax.vmap(f, in_axes=(None, 0))  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f946dc99",
   "metadata": {},
   "source": [
    "In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.\n",
    "\n",
    "Next, we vectorize in the first argument, which is `x`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a3706c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "f_vec = jax.vmap(f_vec_y, in_axes=(0, None))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3863a76e",
   "metadata": {},
   "source": [
    "With this construction, we can now call the function $ f $ on flat (low memory) arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80eb1d52",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "z_vmap = f_vec(x, y).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bac3635a",
   "metadata": {},
   "source": [
    "The execution time is essentially the same as the mesh operation but we are using much less memory.\n",
    "\n",
    "And we produce the correct answer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "857f6194",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jnp.allclose(z_vmap, z_mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15ee8b2c",
   "metadata": {},
   "source": [
    "## Exercises"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2b7d11b",
   "metadata": {},
   "source": [
    "## Exercise 19.1\n",
    "\n",
    "In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option.\n",
    "\n",
    "The code was accelerated by Numba-based multithreading.\n",
    "\n",
    "Try writing a version of this operation for JAX, using all the same\n",
    "parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e755581",
   "metadata": {},
   "source": [
    "## Solution to[ Exercise 19.1](https://python-programming.quantecon.org/#jax_intro_ex2)\n",
    "\n",
    "Here is one solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f7d9f9",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "M = 10_000_000\n",
    "\n",
    "n, β, K = 20, 0.99, 100\n",
    "μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0\n",
    "\n",
    "@jax.jit\n",
    "def compute_call_price_jax(β=β,\n",
    "                           μ=μ,\n",
    "                           S0=S0,\n",
    "                           h0=h0,\n",
    "                           K=K,\n",
    "                           n=n,\n",
    "                           ρ=ρ,\n",
    "                           ν=ν,\n",
    "                           M=M,\n",
    "                           key=jax.random.PRNGKey(1)):\n",
    "\n",
    "    s = jnp.full(M, np.log(S0))\n",
    "    h = jnp.full(M, h0)\n",
    "    for t in range(n):\n",
    "        key, subkey = jax.random.split(key)\n",
    "        Z = jax.random.normal(subkey, (2, M))\n",
    "        s = s + μ + jnp.exp(h) * Z[0, :]\n",
    "        h = ρ * h + ν * Z[1, :]\n",
    "    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))\n",
    "        \n",
    "    return β**n * expectation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b62adb10",
   "metadata": {},
   "source": [
    "Let’s run it once to compile it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10bfe352",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time \n",
    "compute_call_price_jax().block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7250121b",
   "metadata": {},
   "source": [
    "And now let’s time it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8d1a9d0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time \n",
    "compute_call_price_jax().block_until_ready()"
   ]
  }
 ],
 "metadata": {
  "date": 1754011682.3286355,
  "filename": "jax_intro.md",
  "kernelspec": {
   "display_name": "Python",
   "language": "python3",
   "name": "python3"
  },
  "title": "An Introduction to JAX"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}