Cell type annotation

UniST performs cell type annotation by transferring labels from original cells to newly generated synthetic coordinates via KNN search in a hybrid space that combines transcriptomic embeddings (from GAE-INR) and spatial coordinates.

! pip install scanpy
Requirement already satisfied: scanpy in /usr/local/lib/python3.12/dist-packages (1.12)
Requirement already satisfied: anndata>=0.10.8 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.12.10)
Requirement already satisfied: fast-array-utils>=1.2.1 in /usr/local/lib/python3.12/dist-packages (from fast-array-utils[accel,sparse]>=1.2.1->scanpy) (1.3.1)
Requirement already satisfied: h5py>=3.11 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.15.1)
Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.5.3)
Requirement already satisfied: legacy-api-wrap>=1.5 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.5)
Requirement already satisfied: matplotlib>=3.9 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.10.0)
Requirement already satisfied: natsort in /usr/local/lib/python3.12/dist-packages (from scanpy) (8.4.0)
Requirement already satisfied: networkx>=2.8.8 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.6.1)
Requirement already satisfied: numba>=0.60 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.60.0)
Requirement already satisfied: numpy>=2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (2.0.2)
Requirement already satisfied: packaging>=25 in /usr/local/lib/python3.12/dist-packages (from scanpy) (26.0)
Requirement already satisfied: pandas>=2.2.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (2.2.2)
Requirement already satisfied: patsy in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.0.2)
Requirement already satisfied: pynndescent>=0.5.13 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.6.0)
Requirement already satisfied: scikit-learn>=1.4.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.6.1)
Requirement already satisfied: scipy>=1.13 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.16.3)
Requirement already satisfied: seaborn>=0.13.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.13.2)
Requirement already satisfied: session-info2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.4)
Requirement already satisfied: statsmodels>=0.14.5 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.14.6)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from scanpy) (4.67.3)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.12/dist-packages (from scanpy) (4.15.0)
Requirement already satisfied: umap-learn>=0.5.7 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.5.11)
Requirement already satisfied: array-api-compat>=1.7.1 in /usr/local/lib/python3.12/dist-packages (from anndata>=0.10.8->scanpy) (1.13.0)
Requirement already satisfied: zarr!=3.0.*,>=2.18.7 in /usr/local/lib/python3.12/dist-packages (from anndata>=0.10.8->scanpy) (3.1.5)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (1.4.9)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (11.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (2.9.0.post0)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.12/dist-packages (from numba>=0.60->scanpy) (0.43.0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.2->scanpy) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.2->scanpy) (2025.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.4.2->scanpy) (3.6.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib>=3.9->scanpy) (1.17.0)
Requirement already satisfied: donfig>=0.8 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (0.8.1.post1)
Requirement already satisfied: google-crc32c>=1.5 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (1.8.0)
Requirement already satisfied: numcodecs>=0.14 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (0.16.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from donfig>=0.8->zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (6.0.3)
import random
import seaborn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import scanpy as sc

We illustrate the cell annotation using Spateo 2D slice.

%cd external/SUICA_pro
adata_path = 'data/slice_440.h5ad'
slice = sc.read(adata_path)
slice
View of AnnData object with n_obs × n_vars = 8382 × 17649
    obs: 'area', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_counts', 'louvain', 'transf_anno_E95', 'mapped_celltype', 'head_or_body', 'heart_region'
    obsm: '3d_align_spatial', 'X_pca', 'bbox', 'global_align_spatial', 'spatial'
#### all the cell types

anno_colors = {
        "Sclerotome": "#2DA048", "Dermomyotome": "#C49C94", "Lateral plate mesoderm": "#1F77B4", "Limb progenitors and lateral plate mesoderm": "#A1C299",
        "Head and facial mesenchyme": "#F57E20", "Cranial mesoderm": "#9268AC", "Gut mesenchyme": "#8C574C", "Cardiopharyngeal mesoderm": "#DBDB8C",
        "Renal mesenchyme": "#1DBFCF", "Mesodermal progenitors": "#4C64AE", "Hepatic mesenchyme": "#FF00FF", "Intermediate mesoderm": "#E5E4E2",
        "Anterior intermediate mesoderm": "#FCB415", "Primitive erythroid cells": "#FCBA78", "Hematoendothelial progenitors": "#456D75",
        "Endothelium": "#D52928", "Arterial endothelial cells": "#F8A993",  'Liver sinusoidal endothelial cell precursors':"#CAE8E1",
        "Liver sinusoidal endothelial cells": "#CAE8E1", "Vessel endothelial cells": "#0E1033",
        "Epithelial precursors": "#AEC7E8", "Placode and neural crest-derived": "#F69696", "Gut": "#4B3B53", "Otic epithelial cells": "#003349",
        "Olfactory epithelial cells": "#CC1F47", "Amniotic ectoderm": "#352411", "Hepatocytes": "#0C2013", "Ectoderm-derived": "#C7DF92",
        "Glutamatergic neurons": "#D06327", "Spinal cord motor neurons": "#05A6D8", "Cranial motor neurons": "#A079B6", "Neural progenitor cells": "#A93493",
        "Spinal cord ventral progenitors": "#E0C4DE", "Spinal cord dorsal progenitors": "#798D66", "Telencephalon neuroectoderm": "#F6B5D1",
        "Spinal cord neuroectoderm": "#9ED089", "Midbrain neuroectoderm": "#7F7F7F", "Hindbrain neuroectoderm": "#C4B0D5", "Midbrain-hindbrain boundary": "#C7C7C6",
        "Diencephalon neuroectoderm": "#7B504B", "Hypothalamus neuroectoderm": "#DEEEFA", "Hypothalamus (Sim1+) neuroectoderm": "#038470",
        "Dorsal telencephalon neuroectoderm": "#124B99", "Neuroectoderm-derived": "#BCBD32", "Anterior floor plate": "#C0B9B2", "Anterior roof plate": "#2B0F1B",
        "Posterior roof plate": "#6C2062", "Eye field": "#62615A", "NMPs and spinal cord progenitors": "#EE3780", "Neural crest (PNS glia)": "#9FDAE5",
        "Neural crest (PNS neurons)": "#6B7A34", "LV/atrioventricular canal/common atrium cardiomyocytes": "#FF0000", "Second heart field-derived cardiomyocytes": "#FF0000",
        "Proepicardium": "#FF0000", "Endocardial cells": "#FF0000"
}
spatial_coords = slice.obsm['global_align_spatial']
x_coords = spatial_coords[:, 0]
y_coords = spatial_coords[:, 1]

cell_types = slice.obs['mapped_celltype']

# Get unique types in this slice
unique_types = cell_types.unique()
n_types = len(unique_types)

fig, (ax1) = plt.subplots(1, 1, figsize=(7.5, 6))

for cell_type in unique_types:
    mask = cell_types == cell_type
    ax1.scatter(x_coords[mask], y_coords[mask],
               c=[anno_colors[cell_type]],
               label=cell_type,
               s=1, alpha=0.7)

ax1.set_axis_off()
ax1.set_title(f'Cell Types, Slice 440, {slice.shape[0]} cells')
Text(0.5, 1.0, 'Cell Types, Slice 440, 8382 cells')
_images/67e3af428022767324674ee307d330cca516251b51f1a29e0823a8584618f8b8.png

Annotation legend:

from matplotlib.lines import Line2D

legend_elements = [
    Line2D(
        [0], [0],
        marker='o',
        color='none',
        label=cell_type,
        markerfacecolor=anno_colors[cell_type],
        markersize=8
    )
    for cell_type in unique_types
]

fig_leg, ax_leg = plt.subplots(figsize=(4, 0.25 * len(unique_types)))
ax_leg.legend(
    handles=legend_elements,
    loc='center left',
    frameon=False,
    fontsize=9
)
ax_leg.axis('off')
(np.float64(0.0), np.float64(1.0), np.float64(0.0), np.float64(1.0))
_images/e1e3043f2224b053a2b115e3f31879fffa86140476908d8720a40e769ebf6345.png
### Mapping cell types to 9 major cell types according to Spateo paper

celltype_to_9major = {
    "Spinal cord motor neurons": "CNS neurons",
    "Spinal cord dorsal progenitors": "CNS neurons",
    "Spinal cord ventral progenitors": "CNS neurons",
    "Cranial motor neurons": "CNS neurons",
    "Glutamatergic neurons": "CNS neurons",
    "Neural progenitor cells": "CNS neurons",
    "Neuroectoderm-derived": "CNS neurons",
    "Neural crest (PNS neurons)": "PNS neurons",
    "Neural crest (PNS glia)": "PNS glia",
    "Epithelial precursors": "Epithelium",
    "Placode and neural crest-derived": "Epithelium",
    "Ectoderm-derived": "Epithelium",
    "Amniotic ectoderm": "Epithelium",
    "Anterior intermediate mesoderm": "Epithelium",
    "Intermediate mesoderm": "Epithelium",
    "Olfactory epithelial cells": "Epithelium",
    "Otic epithelial cells": "Epithelium",
    "Gut": "Epithelium",
    "Endothelium": "Endothelium",
    "Vessel endothelial cells": "Endothelium",
    "Arterial endothelial cells": "Endothelium",
    "Hematoendothelial progenitors": "Endothelium",
    "Endocardial cells": "Endothelium",
    "Liver sinusoidal endothelial cell precursors": "Endothelium",
    "Primitive erythroid cells": "Blood lineages",
    "Second heart field-derived cardiomyocytes": "Cardiomyocytes",
    "LV/atrioventricular canal/common atrium cardiomyocytes": "Cardiomyocytes",
    "Cranial mesoderm": "Mesoderm",
    "Head and facial mesenchyme": "Mesoderm",
    "Cardiopharyngeal mesoderm": "Mesoderm",
    "Dermomyotome": "Mesoderm",
    "Lateral plate mesoderm": "Mesoderm",
    "Sclerotome": "Mesoderm",
    "Proepicardium": "Mesoderm",
    "Mesodermal progenitors": "Mesoderm",
    "Hepatic mesenchyme": "Mesoderm",
    "Gut mesenchyme": "Mesoderm",
    "Renal mesenchyme": "Mesoderm",
    "Limb progenitors and lateral plate mesoderm": "Mesoderm",
    "Spinal cord neuroectoderm": "Neuroectoderm",
    "Hindbrain neuroectoderm": "Neuroectoderm",
    "Midbrain-hindbrain boundary": "Neuroectoderm",
    "Midbrain neuroectoderm": "Neuroectoderm",
    "Telencephalon neuroectoderm": "Neuroectoderm",
    "Diencephalon neuroectoderm": "Neuroectoderm",
    "Hypothalamus (Sim1+) neuroectoderm": "Neuroectoderm",
    "Hypothalamus neuroectoderm": "Neuroectoderm",
    "Dorsal telencephalon neuroectoderm": "Neuroectoderm",
    "Anterior floor plate": "Neuroectoderm",
    "Anterior roof plate": "Neuroectoderm",
    "Posterior roof plate": "Neuroectoderm",
    "Eye field": "Neuroectoderm",
    "NMPs and spinal cord progenitors": "Neuroectoderm"
}

slice.obs['mapped_celltype_nine'] = slice.obs['mapped_celltype'].map(celltype_to_9major)
/tmp/ipython-input-1172833971.py:59: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
  slice.obs['mapped_celltype_nine'] = slice.obs['mapped_celltype'].map(celltype_to_9major)
major_class_colors = {
    "CNS neurons": "#A93493",
    "PNS neurons": "#6B7A34",
    "PNS glia": "#9FDAE5",
    "Epithelium": "#F69696",
    "Endothelium": "#FFDAB9",
    "Blood lineages": "#FCBA78",
    "Mesoderm": "#F57E20",
    "Neuroectoderm": "#124B99",
    "Cardiomyocytes": "#FF0000",
}
from matplotlib.lines import Line2D

coords = slice.obsm['global_align_spatial']
labels = slice.obs['mapped_celltype_nine']

unique_labels = sorted(labels.dropna().unique())

label_to_color = {label: major_class_colors.get(label, "#CCCCCC") for label in unique_labels}
colors = [label_to_color.get(l, "#CCCCCC") for l in labels]

plt.figure(figsize=(8, 6))
plt.scatter(coords[:, 0], coords[:, 1], c=colors, s=1, alpha=0.8)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f'Major cell types, slice440, {slice.shape[0]} cells')
plt.axis('equal')
plt.box(False)
plt.axis('off')

legend_handles = [
    Line2D(
        [0], [0],
        marker='o',
        color='w',
        markerfacecolor=label_to_color[label],
        markersize=8,
        label=label
    )
    for label in unique_labels
]

plt.legend(
    handles=legend_handles,
    bbox_to_anchor=(1.05, 1),
    loc='upper left',
    frameon=False
)

plt.tight_layout()
plt.show()
_images/71434cb42b6570ed07d9210a89fd1b34205f59960dc7db70187146514647b2c4.png

Read in embeded features from original slice and generate slice.

emb = sc.read('logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad')
emb
AnnData object with n_obs × n_vars = 8382 × 17649
    obsm: 'embeddings', 'spatial'
res = sc.read('logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-original.h5ad')
res
AnnData object with n_obs × n_vars = 16764 × 1
    obsm: 'fitted_embd', 'reconstructed_raw', 'spatial', 'spatial_normalized'

Transfer annotations based on KNN

emb.obs['mapped_celltype'] = slice.obs['mapped_celltype'].values
from scipy.spatial.distance import cdist

alpha = 0.05  # weight,0.0 = pure spatial,1.0 = pure latent gene expression,

# At the end of this tutorial, we illustrate how to choose the alpha value in order to
# balance the influence of spatial coord and latent gene expression

fitted_embd_dense = res.obsm['fitted_embd']
if hasattr(fitted_embd_dense, 'toarray'):
    fitted_embd_dense = fitted_embd_dense.toarray()

# for each point, find the nearest ref point
annotations = []
for i in range(res.shape[0]):
    d_expr = cdist(emb.obsm['embeddings'], fitted_embd_dense[i, np.newaxis, :], metric='euclidean').flatten()
    ### note that we used normalized spatial coord ###
    d_spatial = cdist(emb.obsm['spatial'], res.obsm['spatial_normalized'][i, np.newaxis, :], metric='euclidean').flatten()
    d_combined = alpha * d_expr + (1 - alpha) * d_spatial
    best_idx = d_combined.argmin()
    annotations.append(emb.obs.iloc[best_idx]['mapped_celltype'])
res.obs['predicted_celltype'] = annotations
spatial_coords = res.obsm['spatial']
x_coords = spatial_coords[:, 0]
y_coords = spatial_coords[:, 1]

cell_types = res.obs['predicted_celltype']

# Get unique types in this slice
unique_types = cell_types.unique()
n_types = len(unique_types)

fig, (ax1) = plt.subplots(1, 1, figsize=(7.5, 6))

for cell_type in unique_types:
    mask = cell_types == cell_type
    ax1.scatter(x_coords[mask], y_coords[mask],
               c=[anno_colors[cell_type]],
               label=cell_type,
               s=1, alpha=0.7)

ax1.set_axis_off()
ax1.set_title(f'Predicted Cell Types, Slice 440, {res.shape[0]} cells')
Text(0.5, 1.0, 'Predicted Cell Types, Slice 440, 16764 cells')
_images/53d87408bf210cee50a5c47057fd336533344b9bc1653d0d13c07c2d145b2b06.png

We can also directly predict major cell types

emb.obs['mapped_celltype_nine'] = slice.obs['mapped_celltype_nine'].values
from scipy.spatial.distance import cdist

alpha = 0.05  # weight,0.0 = pure spatial,1.0 = pure gene expression

fitted_embd_dense = res.obsm['fitted_embd']
if hasattr(fitted_embd_dense, 'toarray'):
    fitted_embd_dense = fitted_embd_dense.toarray()

# for each point, find the nearest ref point
annotations = []
for i in range(res.shape[0]):
    d_expr = cdist(emb.obsm['embeddings'], fitted_embd_dense[i, np.newaxis, :], metric='euclidean').flatten()
    ### note that we used normalized spatial coord ###
    d_spatial = cdist(emb.obsm['spatial'], res.obsm['spatial_normalized'][i, np.newaxis, :], metric='euclidean').flatten()
    d_combined = alpha * d_expr + (1 - alpha) * d_spatial
    best_idx = d_combined.argmin()
    annotations.append(emb.obs.iloc[best_idx]['mapped_celltype_nine'])
res.obs['predicted_celltype_nine'] = annotations
from matplotlib.lines import Line2D

coords = res.obsm['spatial']
labels = res.obs['predicted_celltype_nine']

unique_labels = sorted(labels.dropna().unique())

label_to_color = {label: major_class_colors.get(label, "#CCCCCC") for label in unique_labels}
colors = [label_to_color.get(l, "#CCCCCC") for l in labels]

plt.figure(figsize=(8, 6))
plt.scatter(coords[:, 0], coords[:, 1], c=colors, s=1, alpha=0.8)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f'Predicted Major cell types, slice440, {res.shape[0]} cells')
plt.axis('equal')
plt.box(False)
plt.axis('off')

legend_handles = [
    Line2D(
        [0], [0],
        marker='o',
        color='w',
        markerfacecolor=label_to_color[label],
        markersize=8,
        label=label
    )
    for label in unique_labels
]

plt.legend(
    handles=legend_handles,
    bbox_to_anchor=(1.05, 1),
    loc='upper left',
    frameon=False
)

plt.tight_layout()
plt.show()
_images/935d7af2cc1b365f3f5a7ac40bc30a7c93654d879e35ae67ae7243ef5e3e4a00.png

How to choose the alpha value?

np.median(d_expr)
np.float64(115.56246011613429)
np.median(d_spatial)
np.float64(0.6717270550719289)
mE = np.median(d_expr)
mS = np.median(d_spatial)
alpha_balanced = mS / (mE + mS + 1e-12)
alpha_balanced
np.float64(0.005779083343891829)

Alpha would need to be around 0.006 for the latent expression and spatial terms to contribute at comparable magnitudes (i.e., to have roughly equal influence in the combined distance).

alpha_balanced * d_expr
array([0.57150895, 0.56836565, 0.79764584, ..., 0.75834499, 0.76417003,
       0.65952493])
(1 - alpha_balanced) * d_spatial
array([1.06020115, 1.05720464, 1.18767899, ..., 0.32496874, 0.34960844,
       0.37166664])
alpha = alpha_balanced

fitted_embd_dense = res.obsm['fitted_embd']
if hasattr(fitted_embd_dense, 'toarray'):
    fitted_embd_dense = fitted_embd_dense.toarray()

# for each point, find the nearest ref point
annotations = []
for i in range(res.shape[0]):
    d_expr = cdist(emb.obsm['embeddings'], fitted_embd_dense[i, np.newaxis, :], metric='euclidean').flatten()
    ### note that we used normalized spatial coord ###
    d_spatial = cdist(emb.obsm['spatial'], res.obsm['spatial_normalized'][i, np.newaxis, :], metric='euclidean').flatten()
    d_combined = alpha * d_expr + (1 - alpha) * d_spatial
    best_idx = d_combined.argmin()
    annotations.append(emb.obs.iloc[best_idx]['mapped_celltype_nine'])
res.obs['predicted_celltype_nine'] = annotations
from matplotlib.lines import Line2D

coords = res.obsm['spatial']
labels = res.obs['predicted_celltype_nine']

unique_labels = sorted(labels.dropna().unique())

label_to_color = {label: major_class_colors.get(label, "#CCCCCC") for label in unique_labels}
colors = [label_to_color.get(l, "#CCCCCC") for l in labels]

plt.figure(figsize=(8, 6))
plt.scatter(coords[:, 0], coords[:, 1], c=colors, s=1, alpha=0.8)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f'Predicted Major cell types, slice440, {res.shape[0]} cells')
plt.axis('equal')
plt.box(False)
plt.axis('off')

legend_handles = [
    Line2D(
        [0], [0],
        marker='o',
        color='w',
        markerfacecolor=label_to_color[label],
        markersize=8,
        label=label
    )
    for label in unique_labels
]

plt.legend(
    handles=legend_handles,
    bbox_to_anchor=(1.05, 1),
    loc='upper left',
    frameon=False
)

plt.tight_layout()
plt.show()
_images/16efdb83895db3b96ae49a2ed68abbbd81d9ac7b721c97f3ef02d55472f4206c.png

Previously, we choose alpha= 0.05, which assigns substantially greater weight to the expression term.

alpha = 0.05
np.median(alpha * d_expr)
np.float64(5.778123005806715)
np.median((1 - alpha) * d_spatial)
np.float64(0.6381407023183324)