Python Snippets

A collection of useful python snippets.

uv script template

Run a python file with dependencies using uv.

#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.13"
# dependencies = [
#   "fire",
# ]
# ///
# pyright: reportMissingModuleSource=false
# pyright: reportMissingImports=false


import fire


def main(name: str) -> None:
    """Greet the user.

    Run this script with the command:
        `uv run script.py --name="You"`
    Or:
        `chmod +x script.py`
        `./script.py --name="You"`
    """
    print(f"Hello {name}!")


if __name__ == "__main__":
    fire.Fire(main)

overload functionality

Make type hints work when toggling a flag that changes the return type.

from typing import Literal, overload


@overload
def fit_poly(order: Literal[2]) -> tuple[float, float]: ...


@overload
def fit_poly(order: Literal[3]) -> tuple[float, float, float]: ...


def fit_poly(order: Literal[2, 3]) -> tuple[float, float] | tuple[float, float, float]:
    if order == 2:
        return (1.0, 2.0)
    elif order == 3:
        return (1.0, 2.0, 3.0)
    else:
        raise ValueError("order must be 2 or 3")


if __name__ == "__main__":
    a, b = fit_poly(2) # works
    a, b, c = fit_poly(3) # works
    a, b, c = fit_poly(2)  # won't work

Interactive matplotlib animation in notebooks

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

%matplotlib widget
plt.ion()


def f(params, x):
    a, b, c, d, e = params[0], params[1], params[2], params[3], params[4]
    return a / (1 + np.exp(b * (x + c)) + d + e * x)


params0 = np.array([5, -3.0, -4.0, 1.0, 1.0])
params1 = np.array([4, -2.0, -7.0, 2.0, 2.0])
x = np.linspace(0, 10, 100)
y0 = f(params0, x)
y1 = f(params1, x)

fig, ax = plt.subplots()
alpha_slider = Slider(
    ax=plt.axes([0.2, 0.01, 0.65, 0.03]),
    label="alpha",
    valmin=0,
    valmax=1,
    valinit=0.5,
    valstep=0.01,
)
ax.plot(x, y0, label="params0")
ax.plot(x, y1, label="params1")


def update(val):
    alpha = alpha_slider.val
    ax.clear()
    ax.plot(x, y0, label="params0")
    ax.plot(x, y1, label="params1")
    ax.plot(x, alpha * y0 + (1 - alpha) * y1, label="interp function")
    ax.plot(x, f(alpha * params0 + (1 - alpha) * params1, x), label="interp params")
    ax.set_xlabel("x")
    ax.set_ylabel("f(x)")
    ax.grid()
    ax.legend()
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 1)


alpha_slider.on_changed(update)

A container for constants

from dataclasses import dataclass
from enum import StrEnum, auto


class DISTANCE(StrEnum):
    KILOMETERS = auto()
    MILES = auto()


class WEIGHT(StrEnum):
    KILOGRAMS = auto()
    POUNDS = auto()


@dataclass(frozen=True)
class CONSTANTS:
    DISTANCE = DISTANCE
    WEIGHT = WEIGHT


def dummy(x: CONSTANTS.DISTANCE) -> CONSTANTS.WEIGHT:
    if x == CONSTANTS.DISTANCE.KILOMETERS:
        return CONSTANTS.WEIGHT.KILOGRAMS
    elif x == CONSTANTS.DISTANCE.MILES:
        return CONSTANTS.WEIGHT.POUNDS
    else:
        raise ValueError("Invalid distance type")

Boilerplate minimization code

from typing import Callable

import numpy as np
import pandas as pd
import polars as pl
from scipy.optimize import minimize


def ansatz(
    params: np.ndarray,
    df: pl.DataFrame,
    model_column: str,
) -> pl.DataFrame:
    """Define the model."""
    return df.with_columns((params[0] * pl.col("test")).alias(model_column))


def loss(
    params: np.ndarray,
    data: pl.DataFrame,
    model_column: str,
    reference_column: str,
    metric: Callable[[np.ndarray, np.ndarray], float],
) -> float:
    """Eval model and compute loss."""
    model = ansatz(params, data, model_column)
    error = metric(model[model_column], data[reference_column])
    return error


def metric(x: np.ndarray, y: np.ndarray) -> float:
    """Compute L2 norm between two arrays."""
    return np.linalg.norm(x - y)


data = pl.DataFrame({"test": [1.0, 2, 3], "reference": [2.0, 3, 4]})
x0 = np.array([1.0, 1.0])
res = minimize(loss, x0, args=(data, "model", "reference", metric))