Geometric Models in Python Machine Learning

Standard machine learning treats every input as a flat vector, ignoring the structure hiding in the data. Geometric machine learning flips that assumption on its head. By encoding symmetries, curvatures, and graph topologies directly into model architectures, geometric approaches are producing results that flat-space models simply cannot match -- from molecular simulations to social network analysis. Python sits at the center of this shift, with libraries like PyTorch Geometric, e3nn, and Geomstats making these powerful techniques accessible to practitioners.

Classical neural networks treat inputs as vectors in Euclidean space. An image becomes a grid of pixel values. A molecule becomes a list of atom features. A social network becomes a row in an adjacency matrix. In each case, the underlying structure -- spatial relationships between atoms, connection patterns in a graph, rotational symmetries in physical systems -- gets flattened away before the model ever sees it.

Geometric machine learning takes the opposite approach. Instead of asking the model to rediscover structure from raw data, geometric methods bake that structure into the architecture itself. The result is models that need fewer training examples, generalize better to unseen configurations, and respect the physical laws governing the data they process.

What Makes Machine Learning "Geometric"

At its core, geometric machine learning is about respecting the structure of your data. That structure can take several forms, each leading to different modeling strategies.

Symmetry is the first and arguably the most fundamental. When a property of interest does not change under certain transformations -- rotating a molecule does not change its energy, translating an image does not change what object it contains -- a well-designed model should reflect that invariance. Convolutional neural networks already encode translation symmetry for images. Geometric models generalize this idea to rotations, reflections, permutations, and other groups of transformations.

Graph topology captures relationships between entities. Atoms bonded in a molecule, users connected in a social network, and devices linked in a communication mesh are all naturally represented as graphs. Graph neural networks operate directly on these structures, passing messages along edges to update node representations.

Manifold geometry addresses the fact that data frequently lives on curved surfaces rather than in flat space. Rotation matrices form a curved manifold called SO(3). Brain connectivity patterns can be modeled as points on the manifold of symmetric positive definite matrices. Working directly on these manifolds, rather than forcing data into Euclidean coordinates, preserves geometric relationships that would otherwise be distorted.

Note

The unifying framework behind these ideas was formalized by Michael Bronstein, Joan Bruna, Taco Cohen, and Petar Velickovic in their "Geometric Deep Learning" blueprint. Their work shows how CNNs, GNNs, Transformers, and other architectures can all be derived from symmetry principles applied to different geometric domains.

Graph Neural Networks with PyTorch Geometric

PyTorch Geometric (PyG) is the leading Python library for building graph neural networks. Currently at version 2.7 (with version 2.8 in development on the main branch), PyG provides implementations of well over 100 GNN architectures, a rich collection of benchmark datasets, and efficient data loaders that handle both batches of small graphs and single massive graphs.

The core abstraction in PyG is the Data object, which bundles a graph's node features, edge indices, edge attributes, and any labels into a single structure. A basic GNN training pipeline looks like this:

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

# Define a simple two-layer Graph Convolutional Network
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Create a small example graph
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3],
    [1, 0, 2, 1, 3, 2]
], dtype=torch.long)

node_features = torch.randn(4, 16)  # 4 nodes, 16 features each
data = Data(x=node_features, edge_index=edge_index)

# Initialize and run the model
model = GCN(in_channels=16, hidden_channels=32, out_channels=7)
output = model(data.x, data.edge_index)
print(output.shape)  # torch.Size([4, 7])

Each GCNConv layer implements the message-passing paradigm: every node gathers information from its neighbors, aggregates those messages, and updates its own representation. The edge_index tensor defines which nodes are connected, and PyG handles the sparse indexing behind the scenes.

PyG goes well beyond basic graph convolutions. The library includes attention-based models like GAT (Graph Attention Network), spectral methods like ChebNet, and heterogeneous graph support for datasets where nodes and edges have different types. Recent versions have added integration with large language models, letting practitioners combine GNN-based knowledge graph reasoning with text generation.

Pro Tip

PyG supports torch.compile() for significant speedups. After defining your model, wrap it with torch_geometric.compile(model) to take advantage of PyTorch's compilation backend. The PyG team has reported runtime improvements of up to 300% on certain workloads.

For temporal data -- graphs whose structure evolves over time -- the companion library PyTorch Geometric Temporal extends PyG with recurrent GNN layers. It handles tasks like traffic flow prediction, where the road network is a graph and traffic volumes change at each timestep.

Equivariant Networks and the e3nn Library

Graph neural networks handle relational structure, but many scientific and engineering problems demand something stronger: the model's output must transform predictably when the input is rotated, reflected, or translated. This property is called equivariance, and it is central to applications in molecular dynamics, protein structure prediction, and materials science.

The e3nn library (Euclidean Neural Networks) provides a modular framework for building neural networks that are equivariant with respect to E(3) -- the group of all rotations, translations, and reflections in three-dimensional space. The library's foundation rests on irreducible representations (irreps) of the rotation group O(3), which describe how different types of geometric quantities transform under rotation.

In e3nn's notation, a scalar is represented as 0e (the even irrep of order 0), a vector as 1o (the odd irrep of order 1), and higher-order tensors follow the same pattern. This compact notation lets you specify exactly how each feature channel should behave under symmetry operations:

import torch
from e3nn import o3

# Define irreducible representations
# "2x0e" means two scalar channels
# "3x1o" means three vector channels
irreps_input = o3.Irreps("2x0e + 3x1o")
irreps_output = o3.Irreps("4x0e + 2x1o")

# Create an equivariant linear layer
linear = o3.Linear(
    irreps_in=irreps_input,
    irreps_out=irreps_output
)

# Generate random input matching the irreps
x = irreps_input.randn(-1)
y = linear(x)

print(f"Input shape:  {x.shape}")   # 11 values (2 scalars + 3*3 vector components)
print(f"Output shape: {y.shape}")   # 10 values (4 scalars + 2*3 vector components)

The real power of e3nn emerges when combining inputs through tensor products. An equivariant tensor product takes two geometric objects -- say a set of node features and a set of edge direction vectors -- and produces a new set of features that still transforms correctly under rotation. This is the fundamental building block of equivariant message passing:

from e3nn import o3

# Tensor product of a scalar+vector with itself
irreps = o3.Irreps("0e + 1o")
tp = o3.FullTensorProduct(
    irreps_in1=irreps,
    irreps_in2=irreps
)

x = irreps.randn(-1)
result = tp(x, x)
print(f"Tensor product output irreps: {tp.irreps_out}")
# Produces scalars, vectors, and rank-2 tensors

This approach has proven transformative in computational chemistry. The NequIP architecture, built on e3nn, demonstrated that equivariant networks can achieve state-of-the-art accuracy on molecular dynamics benchmarks while requiring up to three orders of magnitude less training data than invariant approaches. The efficiency comes from preserving directional information through the network rather than discarding it at each layer.

Note

The e3nn library is available in both PyTorch (e3nn) and JAX (e3nn-jax) versions. The JAX version offers a different API built around the IrrepsArray class and can deliver faster training times on certain hardware configurations, particularly for models like MACE that are used in production molecular simulations.

Manifold Learning: From scikit-learn to Geomstats

Not all geometric structure involves graphs or 3D symmetries. In many practical settings, high-dimensional data concentrates near a lower-dimensional curved surface -- a manifold -- embedded within the ambient space. Manifold learning methods aim to uncover that hidden structure.

scikit-learn provides several classical manifold learning algorithms through its sklearn.manifold module. These include Isomap, which preserves geodesic distances between data points; Locally Linear Embedding (LLE), which maintains local neighborhood relationships; t-SNE, which excels at producing visually meaningful 2D projections of clustered data; and Spectral Embedding, which uses the graph Laplacian to find a low-dimensional representation.

from sklearn.manifold import Isomap, TSNE
from sklearn.datasets import make_swiss_roll
import numpy as np

# Generate a 3D Swiss Roll dataset
X, color = make_swiss_roll(n_samples=1500, noise=0.1)

# Isomap: preserves geodesic distances
isomap = Isomap(n_components=2, n_neighbors=12)
X_isomap = isomap.fit_transform(X)

# t-SNE: optimizes for visual cluster separation
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
X_tsne = tsne.fit_transform(X)

print(f"Isomap output shape: {X_isomap.shape}")
print(f"t-SNE output shape:  {X_tsne.shape}")

These tools work well for visualization and exploratory analysis, but they treat manifolds as something to flatten. When the data naturally lives on a known manifold -- rotation matrices on SO(3), covariance matrices on the SPD manifold, directional data on the hypersphere -- flattening introduces distortion. This is where Geomstats enters the picture.

Geomstats is a Python library for Riemannian geometry in machine learning. It provides implementations of manifolds (hyperspheres, hyperbolic spaces, Lie groups, SPD matrices, and more), equipped with Riemannian metrics, geodesics, exponential and logarithmic maps, and parallel transport. On top of this geometric foundation, Geomstats implements learning algorithms -- K-means, PCA, regression -- that operate natively on curved spaces.

import geomstats.backend as gs
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.frechet_mean import FrechetMean

# Work on the 2-sphere (surface of a ball in 3D)
sphere = Hypersphere(dim=2)

# Generate random points on the sphere
points = sphere.random_point(n_samples=100)

# Compute the Frechet mean (the "average" on a manifold)
mean_estimator = FrechetMean(sphere)
mean_estimator.fit(points)
frechet_mean = mean_estimator.estimate_

print(f"Points shape:  {points.shape}")    # (100, 3)
print(f"Mean on sphere: {frechet_mean}")   # A point on S^2
print(f"Lies on sphere: {sphere.belongs(frechet_mean)}")  # True

Notice that computing a simple average on the sphere requires special treatment. Taking the arithmetic mean of points on a sphere does not produce a point that lies on the sphere. The Frechet mean solves this by minimizing the sum of squared geodesic distances, ensuring the result respects the manifold's curvature.

Geomstats also supports hyperbolic embeddings, which have become popular for representing hierarchical data. Unlike Euclidean space, hyperbolic space expands exponentially with radius -- a property that mirrors the exponential growth of nodes in tree-like structures. The library provides tools for embedding graphs into the Poincare ball model and running downstream clustering tasks using Riemannian K-means:

from geomstats.geometry.poincare_ball import PoincareBall
from geomstats.learning.kmeans import RiemannianKMeans

# Create a 2D Poincare ball (hyperbolic space)
hyperbolic = PoincareBall(dim=2)

# Assume embeddings are already computed (e.g., from graph embedding)
# embeddings = ... (array of points inside the unit disk)

# Cluster using Riemannian K-means
kmeans = RiemannianKMeans(
    space=hyperbolic,
    n_clusters=3
)
# kmeans.fit(embeddings)
# labels = kmeans.predict(embeddings)
Pro Tip

Geomstats follows the scikit-learn API pattern -- classes have fit(), predict(), and transform() methods. If you are already familiar with scikit-learn, picking up Geomstats requires learning the geometric concepts rather than a new programming interface.

Putting It All Together: Choosing the Right Tool

The geometric ML landscape in Python can feel fragmented, but the choice of tool depends on a few clear questions about the data and the task.

If the data is naturally represented as a graph -- nodes with features connected by edges -- start with PyTorch Geometric. It covers node classification, link prediction, and graph-level classification with a mature, well-documented API. For graphs that change over time, extend the pipeline with PyTorch Geometric Temporal.

If the problem involves 3D spatial data and the model needs to respect physical symmetries like rotation and reflection, e3nn is the right foundation. This is particularly relevant in molecular property prediction, protein structure analysis, materials discovery, and any domain where E(3) equivariance translates to better data efficiency and physical consistency.

If the data lives on a known manifold -- rotation matrices, covariance matrices, directional statistics, or hierarchical structures suited to hyperbolic space -- use Geomstats. Its implementations of Riemannian metrics, geodesics, and manifold-native learning algorithms ensure computations respect the underlying geometry.

For visualization and exploration of high-dimensional data where the manifold is unknown, scikit-learn's manifold learning module (Isomap, t-SNE, UMAP via the separate umap-learn package) remains the practical starting point.

These tools are not mutually exclusive. A molecular dynamics pipeline might use e3nn for the equivariant network architecture, PyTorch Geometric for the graph data handling, and Geomstats for analyzing the resulting conformational space. The Python ecosystem makes this kind of composition straightforward.

Key Takeaways

  1. Structure is signal: Encoding geometric structure -- symmetries, graph topology, manifold curvature -- into model architectures produces models that learn faster, generalize better, and make physically consistent predictions.
  2. PyTorch Geometric handles graphs: With version 2.7 stable and 2.8 in development, PyG provides over 100 GNN architectures, efficient data loading for graphs of all sizes, and torch.compile() integration for performance.
  3. e3nn enforces 3D symmetry: By building on irreducible representations and equivariant tensor products, e3nn ensures that models respect rotation, translation, and reflection symmetry -- a requirement in fields from drug discovery to materials science.
  4. Geomstats brings Riemannian geometry to ML: When data naturally lives on curved manifolds, Geomstats provides the geometric primitives and scikit-learn-compatible algorithms needed to work with it correctly.
  5. Match the tool to the geometry: Graph structure calls for PyG, 3D physical symmetry calls for e3nn, known manifold structure calls for Geomstats, and unknown manifold exploration can start with scikit-learn's dimensionality reduction tools.

Geometric machine learning is not a niche specialization -- it is a shift in how models relate to the world they describe. As the Python libraries supporting these ideas continue to mature, the barrier to entry keeps dropping. The question is no longer whether geometric structure matters in machine learning, but how much performance and insight gets left on the table by ignoring it.

back to articles