13. Multi-Language Programming
13.1. Cython
13.1.1. What is Cython?
Cython (/ˈsaɪθɒn/) is a superset of the programming language Python, which allows developers to write Python code (with optional, C-inspired syntax extensions) that yields performance comparable to that of C. Cython works by producing a standard Python module. However, the behavior differs from standard Python in that the module code, originally written in Python, is translated into C. (Wikipedia)
Most research computing packages include Cython code, for instance scipy, numpy, pandas, scikit-learn, etc.
Typical Cython workflow includes:
Convert Python code into Cython code
Import C code into Cython code together with Python code
Write Cython code inside Python code
When converting code into cython, we say we are cythonizing the code. In some cases, this can be done automatically in one line of code.
To use Cython, you first need to make sure it is installed. To install it, do:
[2]:
!pip install Cython
Collecting Cython
Using cached cython-3.2.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.6 kB)
Using cached cython-3.2.0-cp312-cp312-macosx_11_0_arm64.whl (3.0 MB)
Installing collected packages: Cython
Successfully installed Cython-3.2.0
[notice] A new release of pip is available: 25.0 -> 25.3
[notice] To update, run: pip install --upgrade pip
13.1.2. Cython magic
In a jupyter notebook, we can use the %%cython magic to write Cython code.
Beforehand, we load the extension.
[3]:
%load_ext Cython
Let us look at a simple example with a nested loop.
First, in Python:
[4]:
def python_nested_loop(matrix):
total = 0
for i in range(len(matrix)):
for j in range(len(matrix[0])):
total += matrix[i][j]
return total
[5]:
import numpy as np
# Create a 2D matrix with size 1000x1000
matrix = [[i * j for j in range(1000)] for i in range(1000)]
matrix_np = np.array(matrix, dtype=np.int32)
[6]:
%timeit -r 10 -n 10 python_nested_loop(matrix)
34.4 ms ± 1.17 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
Now in Cython:
[7]:
%%cython
import numpy as np
cimport numpy as np
# Cython function to sum elements of a 2D matrix
def cython_nested_loop(np.ndarray[np.int32_t, ndim=2] matrix):
cdef int total = 0
cdef int i, j
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
total += matrix[i, j]
return total
[8]:
# Use timeit with -r and -n for the Cython implementation
%timeit -r 10 -n 10 cython_nested_loop(matrix_np)
951 μs ± 296 μs per loop (mean ± std. dev. of 10 runs, 10 loops each)
The Cython implementation is many times faster than the Python implementation (on a MacBook Pro M1 we get a 30x speedup).
The Cython nested loop is faster because of several reasons related to how Python and Cython handle variable types and execution:
Static Typing in Cython:
Python: Variables are dynamically typed, so every operation (e.g.,
total += matrix[i][j]) requires:Checking the types of
total,matrix[i], andmatrix[i][j].Dynamically dispatching the appropriate operation based on those types.
Cython: Variables are statically typed (e.g.,
cdef int total, i, j), meaning:The types are known at compile time.
There’s no need for type checking or dynamic dispatch during execution.
The addition operation (
total += matrix[i, j]) is compiled into efficient, low-level machine code.
Avoiding Python Overhead:
Python: Every iteration of the loop involves Python’s internal overhead:
Reference counting and memory management for objects.
Function calls for accessing and operating on the matrix elements.
Cython: The loops and operations are converted into C-level loops with no Python overhead. Accessing
matrix[i, j]directly translates into efficient pointer arithmetic.
Efficient Array Handling with NumPy in Cython:
Python: Even with NumPy, the indexing
matrix[i][j]ormatrix[i, j]involves:A Python function call to NumPy’s internal methods.
Boundary checks and type checks at runtime.
Cython: Using
np.ndarraywith Cython avoids these Python-level calls:Indexing operations (
matrix[i, j]) are performed directly in C, bypassing Python entirely.The
cimportednumpymodule and static type declaration (e.g.,np.ndarray[np.int32_t, ndim=2]) eliminate unnecessary checks.
Compiler Optimizations:
Cython compiles the code into optimized C, enabling:
Loop unrolling: The compiler might optimize repetitive loop structures.
CPU-specific instructions: For example, SIMD operations when possible.
Python’s interpreter cannot perform these low-level optimizations.
(Note: SIMD stands for Single Instruction, Multiple Data, a parallel computing paradigm used in processors to perform the same operation on multiple pieces of data simultaneously.)
Reduced Function Call Overhead:
In Python, every loop iteration may implicitly call functions to:
Fetch the next index (
range(len(matrix))involves dynamic object iteration).Access elements (
matrix[i][j]involves Python’s__getitem__method).
Cython eliminates these function calls by working directly with compiled C loops and array pointers.
For a 1000 x 1000 matrix:
Python: Each loop iteration involves thousands of type checks, boundary checks, and method calls. Even with a fast interpreter like CPython, this takes significant time.
Cython: Executes the same nested loop in plain C, with just raw addition operations and pointer arithmetic.
This results in an order-of-magnitude improvement in speed with Cython.
When Cython Does Not Help:
If the task involves very high-level operations (e.g., NumPy’s built-in vectorized functions like np.sum()), the advantage of using Cython is reduced because NumPy already performs these operations in optimized C.
However, for nested loops and tasks requiring manual iteration or array manipulation (e.g., matrices, vectors), Cython is vastly superior.
13.1.3. Cythonizing pure Python
It is easy to convert a pure Python package into Cython code. Let us give a working example with our pygbm package seen in the example class.
The core code consists of the following files:
pyproject.toml
pygbm
├── __init__.py
├── base_pygbm.py
└── gbm_simulation.py
To cythonize the package, follow the four following steps:
Move the python files into a
src/pygbm_xfolder.Rename the
.pyfiles into.pyxfiles.Modify the
pyproject.tomlfile:
[build-system]
requires = ["setuptools", "Cython", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "pygbm_x"
version = "0.0.1b2" # or whatever you want
description = "A package"
[tool.setuptools]
package-dir = {"" = "src"}
Add a
setup.pyfile next to thepyproject.tomlfile, containing the following:
from setuptools import setup, Extension
from Cython.Build import cythonize
# Define the extensions (Cython modules)
extensions = [
Extension("pygbm_x.base_pygbm", ["src/pygbm_x/base_pygbm.pyx"]),
Extension("pygbm_x.gbm_simulation", ["src/pygbm_x/gbm_simulation.pyx"]),
]
# Call setup with cythonized extensions
setup(
ext_modules=cythonize(extensions,
compiler_directives={'language_level': "3"}),
package_dir={"": "src"},
packages=["pygbm_x"],
# # Include only .so/.pyd files (compiled extensions), exclude source files
package_data={"pygbm_x": ["*.so", "*.pyd"]},
exclude_package_data={"pygbm_x": ["*.pyx", "*.py"]},
# Ensure that wheels can be built
zip_safe=False,
)
To build the package we then do:
python setup.py build_ext --inplace
This builds the Cython code and create the .so and .c files, in the same folder as the .pyx files. The .so files are the compiled extensions (machine code) that can be imported in Python.
We then create the wheel with:
python setup.py bdist_wheel
which creates a dist folder with the wheel inside. Unlike for pure Python, the wheel now has the platform and specific Python version in the name. It looks like:
pygbm_x-0.0.1b2-cp39-cp39-macosx_11_0_arm64.whl
The wheel can then be installed with:
pip install dist/pygbm_x-0.0.1b2-cp39-cp39-macosx_11_0_arm64.whl
And we can test the package.
[1]:
import pygbm_x as pg
[2]:
simulator = pg.GBMSimulator(y0=1.0, mu=0.05, sigma=0.2)
t_values, y_values = simulator.simulate_path(T=1.0, N=100)
simulator.plot_path(t_values, y_values)
13.1.4. Creating wheels for multiple platforms
In this case, this is an example where it becomes relevant to create wheels for multiple platforms. We use cibuildwheel to do this, it runs on Mac, Linux and Windows.
On Mac and Windows, it can run docker containers to build Linux wheels.
You can install it with:
pip install cibuildwheel
To build the wheel you can then do, from the root of the package:
CIBW_BUILD="cp311-manylinux_x86_64" CIBW_ARCHS="x86_64" cibuildwheel --platform linux
where you speficy the Python version and the architecture you want to build for. Note that for this to run on Mac/Windows, docker needs to run in the background (i.e., to be switched on).
This is an important part of continuous integration and this step would typically be run in a CI/CD pipeline (e.g. on Github Actions), see here for more details.
On Mac, although your platform is ARM64, you can still build x86_64 wheels by switching your terminal to x86_64. To do so, navigate to the Applications/Utilities folder and click on Get Info on the Terminal application. A new window will open showing the application’s properties. In the “Open using Rosetta” section, make sure “Open in Rosetta” is checked. You can then run the cibuildwheel command in the terminal.
13.1.5. Turning C/C++ into Python
Cython can also be used to turn C/C++ libraries into Python packages. We won’t cover an example of this here, but you can find more information here.
In the jargon, we say we are wrapping the C/C++ code into Python code. The C/C++ functions that we want to wrap are declared in a Cython header file (with .pxd extension), and defined in the Cython source file (with .pyx extension).
This is extremely useful because it means that you can use Python syntax to call C/C++ functions (generally much more optimised) without the overhead of calling a Python function.
Notable examples where this is useful is for codes that rely on OpenMP parallelisation, or that use BLAS/LAPACK libraries.
However, this procedure does not allow you to perform automatic differentiation easily.
13.2. Differentiable Programming
Writing differentiable code is crucial for machine learning pipelines as it allows for efficient, machine-precision gradient computations. This is referred to as automatic differentiation and is based on the chain rule.
PyTorch, TensorFlow and Jax provide differentiable programming frameworks available in Python, i.e., they can be imported in Python code. Julia is a separate language that is also differentiable.
We will look at a simple example in the four cases.
Consider the function:
For \(a=0.5\) and \(x \in [-50, 50]\), this function looks like:
[9]:
import numpy as np
import matplotlib.pyplot as plt
# Create x values
x = np.linspace(-50, 50, 1000)
# Set a
a = 0.5
# Calculate f(x)
f = np.sin(a * x) / x
plt.figure()
plt.plot(x, f)
plt.grid(True)
plt.title(f'f(x) = sin({a}x)/x')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.show()
13.2.1. PyTorch
[10]:
import torch
import matplotlib.pyplot as plt
import numpy as np
# Define the function
def f(a, x):
return torch.sin(a * x) / x
# Create grid for a and x
a = torch.linspace(-1, 1, 100, requires_grad=True)
x = torch.linspace(-50, 50, 200, requires_grad=True)
A, X = torch.meshgrid(a, x, indexing='ij')
# Compute the function
F = f(A, X)
# Compute derivatives
grad_a, grad_x = torch.autograd.grad(
F.sum(), [A, X], create_graph=True # see https://pytorch.org/blog/computational-graphs-constructed-in-pytorch/
)
# Plot the derivatives
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("df/da (PyTorch)")
plt.imshow(grad_a.detach().numpy(), extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("df/dx (PyTorch)")
plt.imshow(grad_x.detach().numpy(), extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.show()
13.2.2. TensorFlow
[12]:
!pip install tensorflow
Collecting tensorflow
Using cached tensorflow-2.20.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (4.5 kB)
Collecting absl-py>=1.0.0 (from tensorflow)
Using cached absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
Using cached astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
Using cached flatbuffers-25.9.23-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow)
Using cached gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google_pasta>=0.1.1 (from tensorflow)
Using cached google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
Using cached libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl.metadata (5.2 kB)
Collecting opt_einsum>=2.3.2 (from tensorflow)
Using cached opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Requirement already satisfied: packaging in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (25.0)
Requirement already satisfied: protobuf>=5.28.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (6.33.0)
Requirement already satisfied: requests<3,>=2.21.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (2.32.5)
Requirement already satisfied: setuptools in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (80.9.0)
Requirement already satisfied: six>=1.12.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (1.17.0)
Collecting termcolor>=1.1.0 (from tensorflow)
Using cached termcolor-3.2.0-py3-none-any.whl.metadata (6.4 kB)
Requirement already satisfied: typing_extensions>=3.6.6 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (4.15.0)
Collecting wrapt>=1.11.0 (from tensorflow)
Using cached wrapt-2.0.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (9.0 kB)
Collecting grpcio<2.0,>=1.24.3 (from tensorflow)
Using cached grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl.metadata (3.7 kB)
Collecting tensorboard~=2.20.0 (from tensorflow)
Using cached tensorboard-2.20.0-py3-none-any.whl.metadata (1.8 kB)
Collecting keras>=3.10.0 (from tensorflow)
Using cached keras-3.12.0-py3-none-any.whl.metadata (5.9 kB)
Requirement already satisfied: numpy>=1.26.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorflow) (2.3.4)
Collecting h5py>=3.11.0 (from tensorflow)
Using cached h5py-3.15.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.0 kB)
Collecting ml_dtypes<1.0.0,>=0.5.1 (from tensorflow)
Using cached ml_dtypes-0.5.3-cp312-cp312-macosx_10_13_universal2.whl.metadata (8.9 kB)
Collecting wheel<1.0,>=0.23.0 (from astunparse>=1.6.0->tensorflow)
Using cached wheel-0.45.1-py3-none-any.whl.metadata (2.3 kB)
Collecting rich (from keras>=3.10.0->tensorflow)
Using cached rich-14.2.0-py3-none-any.whl.metadata (18 kB)
Collecting namex (from keras>=3.10.0->tensorflow)
Using cached namex-0.1.0-py3-none-any.whl.metadata (322 bytes)
Collecting optree (from keras>=3.10.0->tensorflow)
Using cached optree-0.17.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (33 kB)
Requirement already satisfied: charset_normalizer<4,>=2 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2025.10.5)
Collecting markdown>=2.6.8 (from tensorboard~=2.20.0->tensorflow)
Using cached markdown-3.10-py3-none-any.whl.metadata (5.1 kB)
Requirement already satisfied: pillow in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (11.3.0)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard~=2.20.0->tensorflow)
Using cached tensorboard_data_server-0.7.2-py3-none-any.whl.metadata (1.1 kB)
Collecting werkzeug>=1.0.1 (from tensorboard~=2.20.0->tensorflow)
Using cached werkzeug-3.1.3-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard~=2.20.0->tensorflow) (3.0.3)
Collecting markdown-it-py>=2.2.0 (from rich->keras>=3.10.0->tensorflow)
Using cached markdown_it_py-4.0.0-py3-none-any.whl.metadata (7.3 kB)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from rich->keras>=3.10.0->tensorflow) (2.15.0)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich->keras>=3.10.0->tensorflow)
Using cached mdurl-0.1.2-py3-none-any.whl.metadata (1.6 kB)
Using cached tensorflow-2.20.0-cp312-cp312-macosx_12_0_arm64.whl (200.5 MB)
Using cached absl_py-2.3.1-py3-none-any.whl (135 kB)
Using cached astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Using cached flatbuffers-25.9.23-py2.py3-none-any.whl (30 kB)
Using cached gast-0.6.0-py3-none-any.whl (21 kB)
Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Using cached grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl (11.8 MB)
Using cached h5py-3.15.1-cp312-cp312-macosx_11_0_arm64.whl (2.8 MB)
Using cached keras-3.12.0-py3-none-any.whl (1.5 MB)
Using cached libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl (25.8 MB)
Using cached ml_dtypes-0.5.3-cp312-cp312-macosx_10_13_universal2.whl (663 kB)
Using cached opt_einsum-3.4.0-py3-none-any.whl (71 kB)
Using cached tensorboard-2.20.0-py3-none-any.whl (5.5 MB)
Using cached termcolor-3.2.0-py3-none-any.whl (7.7 kB)
Using cached wrapt-2.0.1-cp312-cp312-macosx_11_0_arm64.whl (61 kB)
Using cached markdown-3.10-py3-none-any.whl (107 kB)
Using cached tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)
Using cached werkzeug-3.1.3-py3-none-any.whl (224 kB)
Using cached wheel-0.45.1-py3-none-any.whl (72 kB)
Using cached namex-0.1.0-py3-none-any.whl (5.9 kB)
Using cached optree-0.17.0-cp312-cp312-macosx_11_0_arm64.whl (351 kB)
Using cached rich-14.2.0-py3-none-any.whl (243 kB)
Using cached markdown_it_py-4.0.0-py3-none-any.whl (87 kB)
Using cached mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Installing collected packages: namex, libclang, flatbuffers, wrapt, wheel, werkzeug, termcolor, tensorboard-data-server, optree, opt_einsum, ml_dtypes, mdurl, markdown, h5py, grpcio, google_pasta, gast, absl-py, tensorboard, markdown-it-py, astunparse, rich, keras, tensorflow
Successfully installed absl-py-2.3.1 astunparse-1.6.3 flatbuffers-25.9.23 gast-0.6.0 google_pasta-0.2.0 grpcio-1.76.0 h5py-3.15.1 keras-3.12.0 libclang-18.1.1 markdown-3.10 markdown-it-py-4.0.0 mdurl-0.1.2 ml_dtypes-0.5.3 namex-0.1.0 opt_einsum-3.4.0 optree-0.17.0 rich-14.2.0 tensorboard-2.20.0 tensorboard-data-server-0.7.2 tensorflow-2.20.0 termcolor-3.2.0 werkzeug-3.1.3 wheel-0.45.1 wrapt-2.0.1
[notice] A new release of pip is available: 25.0 -> 25.3
[notice] To update, run: pip install --upgrade pip
[13]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Define the function
@tf.function
def f(a, x):
return tf.sin(a * x) / x
# Create grid for a and x
a = tf.linspace(-1.0, 1.0, 100)
x = tf.linspace(-50.0, 50.0, 200)
A, X = tf.meshgrid(a, x, indexing='ij')
with tf.GradientTape() as tape_a, tf.GradientTape() as tape_x:
tape_a.watch(A)
tape_x.watch(X)
F = f(A, X)
grad_a = tape_a.gradient(F, A)
grad_x = tape_x.gradient(F, X)
# Plot the derivatives
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("df/da (TensorFlow)")
plt.imshow(grad_a.numpy(), extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("df/dx (TensorFlow)")
plt.imshow(grad_x.numpy(), extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.show()
13.2.3. Jax
[15]:
!pip install jax
Collecting jax
Using cached jax-0.8.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.8.0,>=0.8.0 (from jax)
Using cached jaxlib-0.8.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (1.3 kB)
Requirement already satisfied: ml_dtypes>=0.5.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from jax) (0.5.3)
Requirement already satisfied: numpy>=2.0 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from jax) (2.3.4)
Requirement already satisfied: opt_einsum in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.13 in /Users/boris/MPhil/ResearchComputing/venvs/c1_base_env/lib/python3.12/site-packages (from jax) (1.16.3)
Using cached jax-0.8.0-py3-none-any.whl (2.9 MB)
Using cached jaxlib-0.8.0-cp312-cp312-macosx_11_0_arm64.whl (55.0 MB)
Installing collected packages: jaxlib, jax
Successfully installed jax-0.8.0 jaxlib-0.8.0
[notice] A new release of pip is available: 25.0 -> 25.3
[notice] To update, run: pip install --upgrade pip
[16]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Define the function
def f(a, x):
return jnp.sin(a * x) / x
# Create grid for a and x
a = jnp.linspace(-1.0, 1.0, 100)
x = jnp.linspace(-50.0, 50.0, 200)
A, X = jnp.meshgrid(a, x, indexing='ij')
# Compute derivatives
df_da = jax.grad(lambda a, x: f(a, x).sum(), argnums=0)(A, X)
df_dx = jax.grad(lambda a, x: f(a, x).sum(), argnums=1)(A, X)
# Plot the derivatives
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("df/da (JAX)")
plt.imshow(df_da, extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("df/dx (JAX)")
plt.imshow(df_dx, extent=[-50, 50, -1, 1], aspect='auto')
plt.colorbar()
plt.show()
13.2.4. Julia
Change the kernel to a Julia kernel. Then execute the following.
[1]:
println("Hello")
Hello
[2]:
using Pkg
Pkg.add("IJulia")
Pkg.add("Zygote")
Pkg.add("Plots")
Updating registry at `~/.julia/registries/General.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.11/Project.toml`
No Changes to `~/.julia/environments/v1.11/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.11/Project.toml`
No Changes to `~/.julia/environments/v1.11/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.11/Project.toml`
No Changes to `~/.julia/environments/v1.11/Manifest.toml`
[3]:
using Zygote, Plots
# Define the function
f(a, x) = x == 0 ? a : sin(a * x) / x # Handle division by zero
# Create grid for a and x
a = LinRange(-1, 1, 100) # 100 points between -1 and 1
x = LinRange(-50, 50, 200) # 200 points between -50 and 50
# Compute derivatives
df_da = [gradient(a -> f(a, x), a)[1] for a in a, x in x] # Gradient w.r.t. `a`
df_dx = [gradient(x -> f(a, x), x)[1] for a in a, x in x] # Gradient w.r.t. `x`
# Plot the derivatives
heatmap(x, a, df_da, title="df/da (Julia)", xlabel="x", ylabel="a", colorbar=true)
[3]:
[4]:
heatmap(x, a, df_dx, title="df/dx (Julia)", xlabel="x", ylabel="a", colorbar=true)
[4]: