Python Snippets
A collection of useful python snippets.
uv script template
Run a python file with dependencies using uv.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "fire",
# ]
# ///
# pyright: reportMissingModuleSource=false
# pyright: reportMissingImports=false
import fire
def main(name: str) -> None:
"""Greet the user.
Run this script with the command:
`uv run script.py --name="You"`
Or:
`chmod +x script.py`
`./script.py --name="You"`
"""
print(f"Hello {name}!")
if __name__ == "__main__":
fire.Fire(main)
overload functionality
Make type hints work when toggling a flag that changes the return type.
from typing import Literal, overload
@overload
def fit_poly(order: Literal[2]) -> tuple[float, float]: ...
@overload
def fit_poly(order: Literal[3]) -> tuple[float, float, float]: ...
def fit_poly(order: Literal[2, 3]) -> tuple[float, float] | tuple[float, float, float]:
if order == 2:
return (1.0, 2.0)
elif order == 3:
return (1.0, 2.0, 3.0)
else:
raise ValueError("order must be 2 or 3")
if __name__ == "__main__":
a, b = fit_poly(2) # works
a, b, c = fit_poly(3) # works
a, b, c = fit_poly(2) # won't work
Interactive matplotlib animation in notebooks
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
%matplotlib widget
plt.ion()
def f(params, x):
a, b, c, d, e = params[0], params[1], params[2], params[3], params[4]
return a / (1 + np.exp(b * (x + c)) + d + e * x)
params0 = np.array([5, -3.0, -4.0, 1.0, 1.0])
params1 = np.array([4, -2.0, -7.0, 2.0, 2.0])
x = np.linspace(0, 10, 100)
y0 = f(params0, x)
y1 = f(params1, x)
fig, ax = plt.subplots()
alpha_slider = Slider(
ax=plt.axes([0.2, 0.01, 0.65, 0.03]),
label="alpha",
valmin=0,
valmax=1,
valinit=0.5,
valstep=0.01,
)
ax.plot(x, y0, label="params0")
ax.plot(x, y1, label="params1")
def update(val):
alpha = alpha_slider.val
ax.clear()
ax.plot(x, y0, label="params0")
ax.plot(x, y1, label="params1")
ax.plot(x, alpha * y0 + (1 - alpha) * y1, label="interp function")
ax.plot(x, f(alpha * params0 + (1 - alpha) * params1, x), label="interp params")
ax.set_xlabel("x")
ax.set_ylabel("f(x)")
ax.grid()
ax.legend()
ax.set_xlim(0, 10)
ax.set_ylim(0, 1)
alpha_slider.on_changed(update)
A container for constants
from dataclasses import dataclass
from enum import StrEnum, auto
class DISTANCE(StrEnum):
KILOMETERS = auto()
MILES = auto()
class WEIGHT(StrEnum):
KILOGRAMS = auto()
POUNDS = auto()
@dataclass(frozen=True)
class CONSTANTS:
DISTANCE = DISTANCE
WEIGHT = WEIGHT
def dummy(x: CONSTANTS.DISTANCE) -> CONSTANTS.WEIGHT:
if x == CONSTANTS.DISTANCE.KILOMETERS:
return CONSTANTS.WEIGHT.KILOGRAMS
elif x == CONSTANTS.DISTANCE.MILES:
return CONSTANTS.WEIGHT.POUNDS
else:
raise ValueError("Invalid distance type")
Boilerplate minimization code
from typing import Callable
import numpy as np
import pandas as pd
import polars as pl
from scipy.optimize import minimize
def ansatz(
params: np.ndarray,
df: pl.DataFrame,
model_column: str,
) -> pl.DataFrame:
"""Define the model."""
return df.with_columns((params[0] * pl.col("test")).alias(model_column))
def loss(
params: np.ndarray,
data: pl.DataFrame,
model_column: str,
reference_column: str,
metric: Callable[[np.ndarray, np.ndarray], float],
) -> float:
"""Eval model and compute loss."""
model = ansatz(params, data, model_column)
error = metric(model[model_column], data[reference_column])
return error
def metric(x: np.ndarray, y: np.ndarray) -> float:
"""Compute L2 norm between two arrays."""
return np.linalg.norm(x - y)
data = pl.DataFrame({"test": [1.0, 2, 3], "reference": [2.0, 3, 4]})
x0 = np.array([1.0, 1.0])
res = minimize(loss, x0, args=(data, "model", "reference", metric))