Source code for spamosaic.MNN

"""Nearest-neighbor utilities and MNN matching for SpaMosaic.

Implements exact/approximate kNN (HNSW, Annoy) and mutual nearest neighbors (MNN) pairing.
"""

import numpy as np
import pandas as pd

from sklearn.neighbors import NearestNeighbors
from annoy import AnnoyIndex
import itertools
import networkx as nx
import hnswlib
from sklearn.preprocessing import normalize


[docs]def nn_approx(ds1, ds2, names1, names2, knn=50): """Approximate nearest-neighbor search using HNSW (hnswlib). Parameters ---------- ds1 : np.ndarray Query dataset of shape ``(N1, D)``. ds2 : np.ndarray Reference dataset of shape ``(N2, D)``. names1 : list of str Identifiers for rows in ``ds1``. names2 : list of str Identifiers for rows in ``ds2``. knn : int, default=50 Number of nearest neighbors to find for each query. Returns ------- set[tuple[str, str]] Set of matched ``(query_name, reference_name)`` pairs. """ dim = ds2.shape[1] num_elements = ds2.shape[0] p = hnswlib.Index(space='l2', dim=dim) p.init_index(max_elements=num_elements, ef_construction=100, M=16) p.set_ef(10) p.add_items(ds2) ind, distances = p.knn_query(ds1, k=knn) match = set() for a, b in zip(range(ds1.shape[0]), ind): for b_i in b: match.add((names1[a], names2[b_i])) return match
[docs]def nn(ds1, ds2, names1, names2, knn=50, metric_p=2): """Exact nearest-neighbor search using scikit-learn. Parameters ---------- ds1 : np.ndarray Query dataset of shape ``(N1, D)``. ds2 : np.ndarray Reference dataset of shape ``(N2, D)``. names1 : list of str Identifiers for rows in ``ds1``. names2 : list of str Identifiers for rows in ``ds2``. knn : int Number of nearest neighbors to retrieve for each query. metric_p : int Minkowski distance parameter (e.g., ``2`` for Euclidean). Returns ------- set[tuple[str, str]] Set of matched nearest-neighbor pairs. """ # Find nearest neighbors of first dataset. nn_ = NearestNeighbors(knn, p=metric_p) nn_.fit(ds2) ind = nn_.kneighbors(ds1, return_distance=False) match = set() for a, b in zip(range(ds1.shape[0]), ind): for b_i in b: match.add((names1[a], names2[b_i])) return match
[docs]def nn_annoy(ds1, ds2, names1, names2, norm=True, knn=20, metric='euclidean', n_trees=10, save_on_disk=False): """Approximate nearest-neighbor search using Annoy index. Parameters ---------- ds1 : np.ndarray Query dataset of shape ``(N1, D)``. ds2 : np.ndarray Reference dataset of shape ``(N2, D)``. names1 : list of str Identifiers for rows in ``ds1``. names2 : list of str Identifiers for rows in ``ds2``. norm : bool, default=True Whether to L2-normalize datasets before indexing/search. knn : int Number of nearest neighbors to retrieve. metric : str, default='euclidean' Distance metric (e.g., ``'euclidean'``, ``'manhattan'``). n_trees : int, default=10 Number of trees to build in the Annoy index. save_on_disk : bool, default=False If ``True``, write the index to disk. Returns ------- set[tuple[str, str]] Set of nearest-neighbor pairs. """ if norm: ds1 = normalize(ds1) ds2 = normalize(ds2) """ Assumes that Y is zero-indexed. """ # Build index. a = AnnoyIndex(ds2.shape[1], metric=metric) if (save_on_disk): a.on_disk_build('annoy.index') for i in range(ds2.shape[0]): a.add_item(i, ds2[i, :]) a.build(n_trees) # Search index. ind = [] for i in range(ds1.shape[0]): ind.append(a.get_nns_by_vector(ds1[i, :], knn, search_k=-1)) ind = np.array(ind) # Match. match = set() for a, b in zip(range(ds1.shape[0]), ind): for b_i in b: match.add((names1[a], names2[b_i])) return match
[docs]def mnn(ds1, ds2, names1, names2, knn1=20, knn2=20, approx=True, metric='euclidean', way='hnsw', norm=False): """Compute mutual nearest neighbors (MNN) between two datasets. Parameters ---------- ds1 : np.ndarray First dataset (queries), shape ``(N1, D)``. ds2 : np.ndarray Second dataset (references), shape ``(N2, D)``. names1 : list of str Identifiers for rows in ``ds1``. names2 : list of str Identifiers for rows in ``ds2``. knn1 : int Number of neighbors for ``ds1 → ds2``. knn2 : int Number of neighbors for ``ds2 → ds1``. approx : bool, default=True If ``True``, use approximate search (HNSW/Annoy); otherwise exact kNN. metric : str, default='euclidean' Distance metric used when ``way='annoy'``. way : str, default='hnsw' Approximation backend: ``'hnsw'`` or ``'annoy'``. norm : bool, default=False Whether to normalize inputs before Annoy search (ignored for HNSW/exact). Returns ------- set[tuple[str, str]] Set of mutual nearest-neighbor pairs. """ if approx: if way == 'hnsw': # Find nearest neighbors in first direction. # output KNN points for each point in ds1; size ~ N1 * knn1 match1 = nn_approx(ds1, ds2, names1, names2, knn=knn1) # Find nearest neighbors in second direction. match2 = nn_approx(ds2, ds1, names2, names1, knn=knn2) else: match1 = nn_annoy(ds1, ds2, names1, names2, norm=norm, knn=knn1, metric=metric) match2 = nn_annoy(ds2, ds1, names2, names1, norm=norm, knn=knn2, metric=metric) else: match1 = nn(ds1, ds2, names1, names2, knn=knn1) match2 = nn(ds2, ds1, names2, names1, knn=knn2) # Compute mutual nearest neighbors. mutual = match1 & set([(b, a) for a, b in match2]) return mutual