Gene Imputation (2D)

The example data used for this tutorial can be downloaded from:

https://drive.google.com/file/d/16yGWgy4CUiEZG6yOz9XJzWCOdthx_Qy5/view?usp=sharing

!git clone https://github.com/lanshui98/UniST.git
%cd UniST
Cloning into 'UniST'...
remote: Enumerating objects: 324, done.
remote: Counting objects: 100% (324/324), done.
remote: Compressing objects: 100% (288/288), done.
remote: Total 324 (delta 114), reused 182 (delta 30), pack-reused 0 (from 0)
Receiving objects: 100% (324/324), 7.41 MiB | 16.28 MiB/s, done.
Resolving deltas: 100% (114/114), done.
/content/UniST
!pip install -r requirements.txt
Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 5)) (2.9.0+cu128)
Collecting einops==0.7.0 (from -r requirements.txt (line 6))
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: tensorflow>=2.16.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 7)) (2.19.0)
Collecting lightning>=2.0.0 (from -r requirements.txt (line 8))
  Downloading lightning-2.6.1-py3-none-any.whl.metadata (44 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/44.8 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.8/44.8 kB 5.5 MB/s eta 0:00:00
?25hRequirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 11)) (2.0.2)
Requirement already satisfied: scipy>=1.5.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 12)) (1.16.3)
Requirement already satisfied: scikit-learn>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 13)) (1.6.1)
Requirement already satisfied: pandas>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 16)) (2.2.2)
Collecting anndata>=0.8.0 (from -r requirements.txt (line 17))
  Downloading anndata-0.12.10-py3-none-any.whl.metadata (9.9 kB)
Requirement already satisfied: h5py in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 18)) (3.15.1)
Collecting open3d>=0.13.0 (from -r requirements.txt (line 21))
  Downloading open3d-0.19.0-cp312-cp312-manylinux_2_31_x86_64.whl.metadata (4.3 kB)
Collecting pyvista>=0.40.0 (from -r requirements.txt (line 22))
  Downloading pyvista-0.47.0-py3-none-any.whl.metadata (16 kB)
Requirement already satisfied: imageio in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 25)) (2.37.2)
Requirement already satisfied: pillow in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 26)) (11.3.0)
Requirement already satisfied: tifffile in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 27)) (2026.1.28)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 30)) (4.67.3)
Requirement already satisfied: natsort in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 31)) (8.4.0)
Collecting ninja (from -r requirements.txt (line 32))
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Requirement already satisfied: omegaconf>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 33)) (2.3.0)
Requirement already satisfied: rich in /usr/local/lib/python3.12/dist-packages (from -r requirements.txt (line 34)) (13.9.4)
Collecting scanpy (from -r requirements.txt (line 35))
  Downloading scanpy-1.12-py3-none-any.whl.metadata (8.4 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (3.20.3)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (4.15.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (3.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (2025.3.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (1.13.1.3)
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->-r requirements.txt (line 5)) (3.5.0)
Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (1.4.0)
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (1.6.3)
Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (25.12.19)
Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.7.0)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.2.0)
Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (18.1.1)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.4.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (26.0)
Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (5.29.6)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (2.32.4)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (1.17.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.3.0)
Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (2.1.1)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (1.76.0)
Requirement already satisfied: tensorboard~=2.19.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (2.19.0)
Requirement already satisfied: keras>=3.5.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.10.0)
Requirement already satisfied: ml-dtypes<1.0.0,>=0.5.1 in /usr/local/lib/python3.12/dist-packages (from tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.5.4)
Requirement already satisfied: PyYAML<8.0,>5.4 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.0.0->-r requirements.txt (line 8)) (6.0.3)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning>=2.0.0->-r requirements.txt (line 8))
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning>=2.0.0->-r requirements.txt (line 8))
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning>=2.0.0->-r requirements.txt (line 8))
  Downloading pytorch_lightning-2.6.1-py3-none-any.whl.metadata (21 kB)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.6.0->-r requirements.txt (line 13)) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.6.0->-r requirements.txt (line 13)) (3.6.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.0.0->-r requirements.txt (line 16)) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.0.0->-r requirements.txt (line 16)) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.0.0->-r requirements.txt (line 16)) (2025.3)
Collecting array-api-compat>=1.7.1 (from anndata>=0.8.0->-r requirements.txt (line 17))
  Downloading array_api_compat-1.13.0-py3-none-any.whl.metadata (2.5 kB)
Collecting legacy-api-wrap (from anndata>=0.8.0->-r requirements.txt (line 17))
  Downloading legacy_api_wrap-1.5-py3-none-any.whl.metadata (2.2 kB)
Collecting zarr!=3.0.*,>=2.18.7 (from anndata>=0.8.0->-r requirements.txt (line 17))
  Downloading zarr-3.1.5-py3-none-any.whl.metadata (10 kB)
Collecting dash>=2.6.0 (from open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading dash-4.0.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: werkzeug>=3.0.0 in /usr/local/lib/python3.12/dist-packages (from open3d>=0.13.0->-r requirements.txt (line 21)) (3.1.5)
Requirement already satisfied: flask>=3.0.0 in /usr/local/lib/python3.12/dist-packages (from open3d>=0.13.0->-r requirements.txt (line 21)) (3.1.2)
Requirement already satisfied: nbformat>=5.7.0 in /usr/local/lib/python3.12/dist-packages (from open3d>=0.13.0->-r requirements.txt (line 21)) (5.10.4)
Collecting configargparse (from open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading configargparse-1.7.1-py3-none-any.whl.metadata (24 kB)
Collecting ipywidgets>=8.0.4 (from open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting addict (from open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading addict-2.4.0-py3-none-any.whl.metadata (1.0 kB)
Requirement already satisfied: matplotlib>=3 in /usr/local/lib/python3.12/dist-packages (from open3d>=0.13.0->-r requirements.txt (line 21)) (3.10.0)
Collecting pyquaternion (from open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading pyquaternion-0.9.9-py3-none-any.whl.metadata (1.4 kB)
Collecting cyclopts>=4.0.0 (from pyvista>=0.40.0->-r requirements.txt (line 22))
  Downloading cyclopts-4.5.1-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: pooch in /usr/local/lib/python3.12/dist-packages (from pyvista>=0.40.0->-r requirements.txt (line 22)) (1.9.0)
Requirement already satisfied: scooby>=0.5.1 in /usr/local/lib/python3.12/dist-packages (from pyvista>=0.40.0->-r requirements.txt (line 22)) (0.11.0)
Collecting vtk!=9.4.0 (from pyvista>=0.40.0->-r requirements.txt (line 22))
  Downloading vtk-9.5.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.6 kB)
Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from omegaconf>=2.0.0->-r requirements.txt (line 33)) (4.9.3)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich->-r requirements.txt (line 34)) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich->-r requirements.txt (line 34)) (2.19.2)
Collecting fast-array-utils>=1.2.1 (from fast-array-utils[accel,sparse]>=1.2.1->scanpy->-r requirements.txt (line 35))
  Downloading fast_array_utils-1.3.1-py3-none-any.whl.metadata (3.9 kB)
Requirement already satisfied: numba>=0.60 in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (0.60.0)
Requirement already satisfied: patsy in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (1.0.2)
Requirement already satisfied: pynndescent>=0.5.13 in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (0.6.0)
Requirement already satisfied: seaborn>=0.13.2 in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (0.13.2)
Collecting session-info2 (from scanpy->-r requirements.txt (line 35))
  Downloading session_info2-0.4-py3-none-any.whl.metadata (3.5 kB)
Requirement already satisfied: statsmodels>=0.14.5 in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (0.14.6)
Requirement already satisfied: umap-learn>=0.5.7 in /usr/local/lib/python3.12/dist-packages (from scanpy->-r requirements.txt (line 35)) (0.5.11)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from astunparse>=1.6.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.46.3)
Requirement already satisfied: attrs>=23.1.0 in /usr/local/lib/python3.12/dist-packages (from cyclopts>=4.0.0->pyvista>=0.40.0->-r requirements.txt (line 22)) (25.4.0)
Requirement already satisfied: docstring-parser<4.0,>=0.15 in /usr/local/lib/python3.12/dist-packages (from cyclopts>=4.0.0->pyvista>=0.40.0->-r requirements.txt (line 22)) (0.17.0)
Collecting rich-rst<2.0.0,>=1.3.1 (from cyclopts>=4.0.0->pyvista>=0.40.0->-r requirements.txt (line 22))
  Downloading rich_rst-1.3.2-py3-none-any.whl.metadata (6.1 kB)
Requirement already satisfied: plotly>=5.0.0 in /usr/local/lib/python3.12/dist-packages (from dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21)) (5.24.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.12/dist-packages (from dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21)) (8.7.1)
Collecting retrying (from dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading retrying-1.4.2-py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21)) (1.6.0)
Requirement already satisfied: blinker>=1.9.0 in /usr/local/lib/python3.12/dist-packages (from flask>=3.0.0->open3d>=0.13.0->-r requirements.txt (line 21)) (1.9.0)
Requirement already satisfied: click>=8.1.3 in /usr/local/lib/python3.12/dist-packages (from flask>=3.0.0->open3d>=0.13.0->-r requirements.txt (line 21)) (8.3.1)
Requirement already satisfied: itsdangerous>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from flask>=3.0.0->open3d>=0.13.0->-r requirements.txt (line 21)) (2.2.0)
Requirement already satisfied: markupsafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from flask>=3.0.0->open3d>=0.13.0->-r requirements.txt (line 21)) (3.0.3)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (3.13.3)
Collecting comm>=0.1.3 (from ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading comm-0.2.3-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (7.34.0)
Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (5.7.1)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: jupyterlab_widgets~=3.0.15 in /usr/local/lib/python3.12/dist-packages (from ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (3.0.16)
Requirement already satisfied: namex in /usr/local/lib/python3.12/dist-packages (from keras>=3.5.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.1.0)
Requirement already satisfied: optree in /usr/local/lib/python3.12/dist-packages (from keras>=3.5.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.18.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich->-r requirements.txt (line 34)) (0.1.2)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3->open3d>=0.13.0->-r requirements.txt (line 21)) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3->open3d>=0.13.0->-r requirements.txt (line 21)) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3->open3d>=0.13.0->-r requirements.txt (line 21)) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3->open3d>=0.13.0->-r requirements.txt (line 21)) (1.4.9)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3->open3d>=0.13.0->-r requirements.txt (line 21)) (3.3.2)
Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (2.21.2)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (4.26.0)
Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /usr/local/lib/python3.12/dist-packages (from nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (5.9.1)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.12/dist-packages (from numba>=0.60->scanpy->-r requirements.txt (line 35)) (0.43.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.21.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.21.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.21.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.21.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (2026.1.4)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8.0->-r requirements.txt (line 5)) (1.3.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard~=2.19.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (3.10.1)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard~=2.19.0->tensorflow>=2.16.0->-r requirements.txt (line 7)) (0.7.2)
Collecting donfig>=0.8 (from zarr!=3.0.*,>=2.18.7->anndata>=0.8.0->-r requirements.txt (line 17))
  Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB)
Requirement already satisfied: google-crc32c>=1.5 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.8.0->-r requirements.txt (line 17)) (1.8.0)
Collecting numcodecs>=0.14 (from zarr!=3.0.*,>=2.18.7->anndata>=0.8.0->-r requirements.txt (line 17))
  Downloading numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (3.4 kB)
Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch->pyvista>=0.40.0->-r requirements.txt (line 22)) (4.5.1)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (1.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (6.7.1)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (0.4.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2028.0,>=2022.5.0->lightning>=2.0.0->-r requirements.txt (line 8)) (1.22.0)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21))
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (3.0.52)
Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.2.0)
Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.2.1)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (4.9.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d>=0.13.0->-r requirements.txt (line 21)) (0.30.0)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly>=5.0.0->dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21)) (9.1.3)
Requirement already satisfied: docutils in /usr/local/lib/python3.12/dist-packages (from rich-rst<2.0.0,>=1.3.1->cyclopts>=4.0.0->pyvista>=0.40.0->-r requirements.txt (line 22)) (0.21.2)
Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib-metadata->dash>=2.6.0->open3d>=0.13.0->-r requirements.txt (line 21)) (3.23.0)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.8.5)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=6.1.0->ipywidgets>=8.0.4->open3d>=0.13.0->-r requirements.txt (line 21)) (0.5.3)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 5.6 MB/s eta 0:00:00
?25hDownloading lightning-2.6.1-py3-none-any.whl (853 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 853.6/853.6 kB 63.6 MB/s eta 0:00:00
?25hDownloading anndata-0.12.10-py3-none-any.whl (176 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.6/176.6 kB 24.5 MB/s eta 0:00:00
?25hDownloading open3d-0.19.0-cp312-cp312-manylinux_2_31_x86_64.whl (447.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 447.7/447.7 MB 4.1 MB/s eta 0:00:00
?25hDownloading pyvista-0.47.0-py3-none-any.whl (2.5 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 68.7 MB/s eta 0:00:00
?25hDownloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.7/180.7 kB 23.5 MB/s eta 0:00:00
?25hDownloading scanpy-1.12-py3-none-any.whl (2.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 31.7 MB/s eta 0:00:00
?25hDownloading array_api_compat-1.13.0-py3-none-any.whl (58 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.6/58.6 kB 7.8 MB/s eta 0:00:00
?25hDownloading cyclopts-4.5.1-py3-none-any.whl (199 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.8/199.8 kB 26.4 MB/s eta 0:00:00
?25hDownloading dash-4.0.0-py3-none-any.whl (7.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 29.0 MB/s eta 0:00:00
?25hDownloading fast_array_utils-1.3.1-py3-none-any.whl (36 kB)
Downloading ipywidgets-8.1.8-py3-none-any.whl (139 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 139.8/139.8 kB 19.3 MB/s eta 0:00:00
?25hDownloading legacy_api_wrap-1.5-py3-none-any.whl (10 kB)
Downloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 983.2/983.2 kB 73.4 MB/s eta 0:00:00
?25hDownloading vtk-9.5.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (112.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 112.3/112.3 MB 17.3 MB/s eta 0:00:00
?25hDownloading zarr-3.1.5-py3-none-any.whl (284 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 284.1/284.1 kB 3.4 MB/s eta 0:00:00
?25hDownloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Downloading configargparse-1.7.1-py3-none-any.whl (25 kB)
Downloading pyquaternion-0.9.9-py3-none-any.whl (14 kB)
Downloading pytorch_lightning-2.6.1-py3-none-any.whl (857 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 857.3/857.3 kB 32.9 MB/s eta 0:00:00
?25hDownloading session_info2-0.4-py3-none-any.whl (17 kB)
Downloading comm-0.2.3-py3-none-any.whl (7.3 kB)
Downloading donfig-0.8.1.post1-py3-none-any.whl (21 kB)
Downloading numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (9.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.2/9.2 MB 96.1 MB/s eta 0:00:00
?25hDownloading rich_rst-1.3.2-py3-none-any.whl (12 kB)
Downloading widgetsnbextension-4.0.15-py3-none-any.whl (2.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 114.5 MB/s eta 0:00:00
?25hDownloading retrying-1.4.2-py3-none-any.whl (10 kB)
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 97.8 MB/s eta 0:00:00
?25hInstalling collected packages: addict, widgetsnbextension, session-info2, retrying, pyquaternion, numcodecs, ninja, lightning-utilities, legacy-api-wrap, jedi, fast-array-utils, einops, donfig, configargparse, comm, array-api-compat, zarr, vtk, rich-rst, ipywidgets, dash, anndata, torchmetrics, cyclopts, scanpy, pyvista, pytorch-lightning, open3d, lightning
  Attempting uninstall: widgetsnbextension
    Found existing installation: widgetsnbextension 3.6.10
    Uninstalling widgetsnbextension-3.6.10:
      Successfully uninstalled widgetsnbextension-3.6.10
  Attempting uninstall: einops
    Found existing installation: einops 0.8.2
    Uninstalling einops-0.8.2:
      Successfully uninstalled einops-0.8.2
  Attempting uninstall: ipywidgets
    Found existing installation: ipywidgets 7.7.1
    Uninstalling ipywidgets-7.7.1:
      Successfully uninstalled ipywidgets-7.7.1
Successfully installed addict-2.4.0 anndata-0.12.10 array-api-compat-1.13.0 comm-0.2.3 configargparse-1.7.1 cyclopts-4.5.1 dash-4.0.0 donfig-0.8.1.post1 einops-0.7.0 fast-array-utils-1.3.1 ipywidgets-8.1.8 jedi-0.19.2 legacy-api-wrap-1.5 lightning-2.6.1 lightning-utilities-0.15.2 ninja-1.13.0 numcodecs-0.16.5 open3d-0.19.0 pyquaternion-0.9.9 pytorch-lightning-2.6.1 pyvista-0.47.0 retrying-1.4.2 rich-rst-1.3.2 scanpy-1.12 session-info2-0.4 torchmetrics-1.8.2 vtk-9.5.2 widgetsnbextension-4.0.15 zarr-3.1.5

Note the the spatial information should be at adata.obsm[‘spatial’];

Put the adata under external/SUICA_pro/data/.

%cd external/SUICA_pro
/content/UniST/external/SUICA_pro
! pip install scanpy
Requirement already satisfied: scanpy in /usr/local/lib/python3.12/dist-packages (1.12)
Requirement already satisfied: anndata>=0.10.8 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.12.10)
Requirement already satisfied: fast-array-utils>=1.2.1 in /usr/local/lib/python3.12/dist-packages (from fast-array-utils[accel,sparse]>=1.2.1->scanpy) (1.3.1)
Requirement already satisfied: h5py>=3.11 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.15.1)
Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.5.3)
Requirement already satisfied: legacy-api-wrap>=1.5 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.5)
Requirement already satisfied: matplotlib>=3.9 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.10.0)
Requirement already satisfied: natsort in /usr/local/lib/python3.12/dist-packages (from scanpy) (8.4.0)
Requirement already satisfied: networkx>=2.8.8 in /usr/local/lib/python3.12/dist-packages (from scanpy) (3.6.1)
Requirement already satisfied: numba>=0.60 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.60.0)
Requirement already satisfied: numpy>=2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (2.0.2)
Requirement already satisfied: packaging>=25 in /usr/local/lib/python3.12/dist-packages (from scanpy) (26.0)
Requirement already satisfied: pandas>=2.2.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (2.2.2)
Requirement already satisfied: patsy in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.0.2)
Requirement already satisfied: pynndescent>=0.5.13 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.6.0)
Requirement already satisfied: scikit-learn>=1.4.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.6.1)
Requirement already satisfied: scipy>=1.13 in /usr/local/lib/python3.12/dist-packages (from scanpy) (1.16.3)
Requirement already satisfied: seaborn>=0.13.2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.13.2)
Requirement already satisfied: session-info2 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.4)
Requirement already satisfied: statsmodels>=0.14.5 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.14.6)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from scanpy) (4.67.3)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.12/dist-packages (from scanpy) (4.15.0)
Requirement already satisfied: umap-learn>=0.5.7 in /usr/local/lib/python3.12/dist-packages (from scanpy) (0.5.11)
Requirement already satisfied: array-api-compat>=1.7.1 in /usr/local/lib/python3.12/dist-packages (from anndata>=0.10.8->scanpy) (1.13.0)
Requirement already satisfied: zarr!=3.0.*,>=2.18.7 in /usr/local/lib/python3.12/dist-packages (from anndata>=0.10.8->scanpy) (3.1.5)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (1.4.9)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (11.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.9->scanpy) (2.9.0.post0)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.12/dist-packages (from numba>=0.60->scanpy) (0.43.0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.2->scanpy) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.2->scanpy) (2025.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.4.2->scanpy) (3.6.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib>=3.9->scanpy) (1.17.0)
Requirement already satisfied: donfig>=0.8 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (0.8.1.post1)
Requirement already satisfied: google-crc32c>=1.5 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (1.8.0)
Requirement already satisfied: numcodecs>=0.14 in /usr/local/lib/python3.12/dist-packages (from zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (0.16.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from donfig>=0.8->zarr!=3.0.*,>=2.18.7->anndata>=0.10.8->scanpy) (6.0.3)

Read the Data

import scanpy as sc
adata_path = 'data/slice_440.h5ad'
a = sc.read(adata_path)
a
AnnData object with n_obs × n_vars = 8382 × 17649
    obs: 'area', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_counts'
    obsm: 'X_pca', 'bbox', 'spatial'

Visulize gene expression

import numpy as np
import matplotlib.pyplot as plt

# Extract coordinates and gene expression values
coordinates = a.obsm['spatial'][:, 0:2]

# Handle sparse matrix for gene expression values
if hasattr(a.X, "tocsc"):  # check if sparse
    values = a.X[:, 8293].toarray().flatten()
else:
    values = a.X[:, 8293]

n_points = coordinates.shape[0]

# Visualization
plt.figure(figsize=(6, 6))
plt.scatter(coordinates[:, 0], coordinates[:, 1], c=values, cmap='viridis_r', s=1, alpha=0.7)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f"Expression of gene Hba-x, slice 440")
plt.axis('equal')
plt.colorbar(label='Expression Level')
#plt.savefig('slice440_Hba-x.png', format='png', dpi=300)
plt.show()
../_images/9c4039c51f8039063000333e8e58e18ce0864633acc791a97107554aa92cadcd.png

Step1: Train GAE

!python train.py --mode embedder --conf ./configs/ST/embedder_gae.yaml
./configs/ST/embedder_gae.yaml
Current Configs:
{
'case': '2d',
'dataset': {
│   │   'type': 'GraphST2D',
│   │   'data_file': 'data/slice_440.h5ad',
│   │   'val_proportion': 0.2,
│   │   'keep_ratio': True,
│   │   'n_neighbors': 4,
│   │   'require_coordnorm': True
},
'pipeline': {
│   │   'embedder': {
│   │   │   'model': 'GAE',
│   │   │   'dim_hidden': [
│   │   │   │   2048,
│   │   │   │   512,
│   │   │   │   128
│   │   │   ],
│   │   │   'dim_latent': 64
│   │   },
│   │   'optimization': {
│   │   │   'seed': 8848,
│   │   │   'epochs': 1000,
│   │   │   'lr': 1e-05,
│   │   │   'val_freq': 50,
│   │   │   'logs': 'logs/GAE-2D/2d',
│   │   │   'batch_size': 512
│   │   },
│   │   'predict_mode': 'all',
│   │   'embedded_data': 'embedded-all.h5ad'
}
}
[CONF] conf     = /content/UniST/external/SUICA_pro/configs/ST/embedder_gae.yaml
[CONF] datafile = /content/UniST/external/SUICA_pro/data/slice_440.h5ad
[CONF] logs_dir = /content/UniST/external/SUICA_pro/logs/GAE-2D/2d
Seed set to 8848
/usr/local/lib/python3.12/dist-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  _C._set_float32_matmul_precision(precision)
6705 1677
val_idx[:10]=[3637, 1964, 1408, 7304, 7282, 1721, 4720, 7473, 964, 4825]
/content/UniST/external/SUICA_pro/datasets.py:516: ImplicitModificationWarning: Trying to modify attribute `.obsm` of view, initializing view as actual.
  self.coordinates[:,0] = (self.coordinates[:,0] - x_min) / x_range
/content/UniST/external/SUICA_pro/datasets.py:518: ImplicitModificationWarning: Trying to modify attribute `.obsm` of view, initializing view as actual.
  self.coordinates[:,1] = (self.coordinates[:,1] - y_min) / y_range
/content/UniST/external/SUICA_pro/datasets.py:516: ImplicitModificationWarning: Trying to modify attribute `.obsm` of view, initializing view as actual.
  self.coordinates[:,0] = (self.coordinates[:,0] - x_min) / x_range
/content/UniST/external/SUICA_pro/datasets.py:518: ImplicitModificationWarning: Trying to modify attribute `.obsm` of view, initializing view as actual.
  self.coordinates[:,1] = (self.coordinates[:,1] - y_min) / y_range
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 18 worker processes in total. Our suggested max number of worker in current system is 12, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
2026-02-10 22:29:15.118362: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 22:29:15.135094: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770762555.154251   16319 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770762555.161560   16319 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770762555.180699   16319 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770762555.180725   16319 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770762555.180729   16319 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770762555.180731   16319 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2026-02-10 22:29:15.185638: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━┳━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name           Type  Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ fitting_model │ GAE  │ 74.6 M │ train │     0 │
└───┴───────────────┴──────┴────────┴───────┴───────┘
Trainable params: 74.6 M                                                        
Non-trainable params: 0                                                         
Total params: 74.6 M                                                            
Total estimated model params size (MB): 298                                     
Modules in train mode: 26                                                       
Modules in eval mode: 0                                                         
Total FLOPs: 0                                                                  
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.028083937242627144,
    'val/mean_absolute_error_mask': 1.7222709655761719,
    'val/mean_squared_error': 0.07008925825357437,
    'val/mean_squared_error_mask': 5.899291038513184,
    'val/root_mean_squared_error': 0.14427080750465393,
    'val/cosine_similarity': np.float32(0.04632988),
    'val/cosine_similarity_mask': np.float32(0.4095784),
    'val/sam': np.float32(87.34438),
    'val/iou': np.float64(0.497095889794876),
    'val/pearsonr': np.float32(0.00021684644),
    'val/spearmanr': np.float64(-0.0012409286483568336),
    'val/pearsonr_mask': np.float32(0.013973423),
    'val/spearmanr_mask': np.float64(0.013366765028195914)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for training.
Extra variable passed to LightningModule for validation.
Epoch 49/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 49/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 49/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.029639828950166702,
    'val/mean_absolute_error_mask': 1.522835612297058,
    'val/mean_squared_error': 0.060804177075624466,
    'val/mean_squared_error_mask': 4.905801773071289,
    'val/root_mean_squared_error': 0.1431884765625,
    'val/cosine_similarity': np.float32(0.31824356),
    'val/cosine_similarity_mask': np.float32(0.44534045),
    'val/sam': np.float32(71.23616),
    'val/iou': np.float64(0.8071045249447657),
    'val/pearsonr': np.float32(0.30976525),
    'val/spearmanr': np.float64(0.14675876537935323),
    'val/pearsonr_mask': np.float32(0.19899252),
    'val/spearmanr_mask': np.float64(0.13719981057302652)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 99/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 99/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 99/999 ━━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.030393056571483612,
    'val/mean_absolute_error_mask': 1.5173572301864624,
    'val/mean_squared_error': 0.060910664498806,
    'val/mean_squared_error_mask': 4.886770248413086,
    'val/root_mean_squared_error': 0.14323948323726654,
    'val/cosine_similarity': np.float32(0.3218206),
    'val/cosine_similarity_mask': np.float32(0.44557717),
    'val/sam': np.float32(71.01856),
    'val/iou': np.float64(0.8030794786604388),
    'val/pearsonr': np.float32(0.31330216),
    'val/spearmanr': np.float64(0.1494861487532204),
    'val/pearsonr_mask': np.float32(0.20012583),
    'val/spearmanr_mask': np.float64(0.14252946805711664)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 149/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 149/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 149/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.029320217669010162,
    'val/mean_absolute_error_mask': 1.5355247259140015,
    'val/mean_squared_error': 0.061091333627700806,
    'val/mean_squared_error_mask': 4.95554780960083,
    'val/root_mean_squared_error': 0.14330101013183594,
    'val/cosine_similarity': np.float32(0.32103342),
    'val/cosine_similarity_mask': np.float32(0.43635625),
    'val/sam': np.float32(71.067764),
    'val/iou': np.float64(0.8072152885398143),
    'val/pearsonr': np.float32(0.31247544),
    'val/spearmanr': np.float64(0.15087799625514137),
    'val/pearsonr_mask': np.float32(0.19976543),
    'val/spearmanr_mask': np.float64(0.14402902884770022)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 199/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 199/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 199/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.03034846857190132,
    'val/mean_absolute_error_mask': 1.5174301862716675,
    'val/mean_squared_error': 0.06109563633799553,
    'val/mean_squared_error_mask': 4.891817092895508,
    'val/root_mean_squared_error': 0.14330480992794037,
    'val/cosine_similarity': np.float32(0.3231247),
    'val/cosine_similarity_mask': np.float32(0.44093636),
    'val/sam': np.float32(70.93985),
    'val/iou': np.float64(0.8142766937236496),
    'val/pearsonr': np.float32(0.31467605),
    'val/spearmanr': np.float64(0.15399071242432696),
    'val/pearsonr_mask': np.float32(0.20048651),
    'val/spearmanr_mask': np.float64(0.1438991809525124)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 249/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
Epoch 249/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
Epoch 249/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.03007148765027523,
    'val/mean_absolute_error_mask': 1.5170608758926392,
    'val/mean_squared_error': 0.06119992956519127,
    'val/mean_squared_error_mask': 4.895423889160156,
    'val/root_mean_squared_error': 0.14334069192409515,
    'val/cosine_similarity': np.float32(0.3218703),
    'val/cosine_similarity_mask': np.float32(0.43794894),
    'val/sam': np.float32(71.0178),
    'val/iou': np.float64(0.8313530111128634),
    'val/pearsonr': np.float32(0.31353167),
    'val/spearmanr': np.float64(0.15950158128502934),
    'val/pearsonr_mask': np.float32(0.19990644),
    'val/spearmanr_mask': np.float64(0.14589835398649542)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 299/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 299/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 299/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02915032021701336,
    'val/mean_absolute_error_mask': 1.5230721235275269,
    'val/mean_squared_error': 0.061537064611911774,
    'val/mean_squared_error_mask': 4.921195983886719,
    'val/root_mean_squared_error': 0.14349360764026642,
    'val/cosine_similarity': np.float32(0.31731674),
    'val/cosine_similarity_mask': np.float32(0.4291934),
    'val/sam': np.float32(71.29776),
    'val/iou': np.float64(0.8600320503743273),
    'val/pearsonr': np.float32(0.30927128),
    'val/spearmanr': np.float64(0.16666608762419458),
    'val/pearsonr_mask': np.float32(0.19841315),
    'val/spearmanr_mask': np.float64(0.14727969568448404)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 349/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 349/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 349/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02959413081407547,
    'val/mean_absolute_error_mask': 1.5166385173797607,
    'val/mean_squared_error': 0.061724793165922165,
    'val/mean_squared_error_mask': 4.901226043701172,
    'val/root_mean_squared_error': 0.14362160861492157,
    'val/cosine_similarity': np.float32(0.31543517),
    'val/cosine_similarity_mask': np.float32(0.42918235),
    'val/sam': np.float32(71.41383),
    'val/iou': np.float64(0.8709318204821506),
    'val/pearsonr': np.float32(0.30740613),
    'val/spearmanr': np.float64(0.16945093142722298),
    'val/pearsonr_mask': np.float32(0.19596536),
    'val/spearmanr_mask': np.float64(0.14584184640389625)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 399/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
Epoch 399/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
Epoch 399/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02896631695330143,
    'val/mean_absolute_error_mask': 1.5210232734680176,
    'val/mean_squared_error': 0.06185933202505112,
    'val/mean_squared_error_mask': 4.913853645324707,
    'val/root_mean_squared_error': 0.14370954036712646,
    'val/cosine_similarity': np.float32(0.3123556),
    'val/cosine_similarity_mask': np.float32(0.42315567),
    'val/sam': np.float32(71.60048),
    'val/iou': np.float64(0.8866406196547326),
    'val/pearsonr': np.float32(0.3046314),
    'val/spearmanr': np.float64(0.1737976001047576),
    'val/pearsonr_mask': np.float32(0.19617787),
    'val/spearmanr_mask': np.float64(0.1471115017211118)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 449/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 449/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 449/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.028319593518972397,
    'val/mean_absolute_error_mask': 1.5265610218048096,
    'val/mean_squared_error': 0.062094103544950485,
    'val/mean_squared_error_mask': 4.934778213500977,
    'val/root_mean_squared_error': 0.14382480084896088,
    'val/cosine_similarity': np.float32(0.30782494),
    'val/cosine_similarity_mask': np.float32(0.41729882),
    'val/sam': np.float32(71.8792),
    'val/iou': np.float64(0.901905085351709),
    'val/pearsonr': np.float32(0.30036226),
    'val/spearmanr': np.float64(0.17734303424187287),
    'val/pearsonr_mask': np.float32(0.19404253),
    'val/spearmanr_mask': np.float64(0.14693384021629097)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 499/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 499/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 499/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.027380086481571198,
    'val/mean_absolute_error_mask': 1.5379695892333984,
    'val/mean_squared_error': 0.06242884323000908,
    'val/mean_squared_error_mask': 4.977259635925293,
    'val/root_mean_squared_error': 0.1439480185508728,
    'val/cosine_similarity': np.float32(0.30068845),
    'val/cosine_similarity_mask': np.float32(0.40648776),
    'val/sam': np.float32(72.30999),
    'val/iou': np.float64(0.9166496347000109),
    'val/pearsonr': np.float32(0.29356638),
    'val/spearmanr': np.float64(0.1819286545124142),
    'val/pearsonr_mask': np.float32(0.1913012),
    'val/spearmanr_mask': np.float64(0.1466541880243045)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 549/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.16it/s v_num: 0.000
Epoch 549/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.16it/s v_num: 0.000
Epoch 549/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.16it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.027036204934120178,
    'val/mean_absolute_error_mask': 1.5434268712997437,
    'val/mean_squared_error': 0.06277874112129211,
    'val/mean_squared_error_mask': 5.00020694732666,
    'val/root_mean_squared_error': 0.1441282331943512,
    'val/cosine_similarity': np.float32(0.2925464),
    'val/cosine_similarity_mask': np.float32(0.39768386),
    'val/sam': np.float32(72.79444),
    'val/iou': np.float64(0.9267387278999227),
    'val/pearsonr': np.float32(0.28576452),
    'val/spearmanr': np.float64(0.18216221180758593),
    'val/pearsonr_mask': np.float32(0.18781704),
    'val/spearmanr_mask': np.float64(0.1417430201185983)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 599/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
Epoch 599/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
Epoch 599/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.027154333889484406,
    'val/mean_absolute_error_mask': 1.5423911809921265,
    'val/mean_squared_error': 0.0629739835858345,
    'val/mean_squared_error_mask': 5.005171298980713,
    'val/root_mean_squared_error': 0.144180566072464,
    'val/cosine_similarity': np.float32(0.29276124),
    'val/cosine_similarity_mask': np.float32(0.39701024),
    'val/sam': np.float32(72.79195),
    'val/iou': np.float64(0.9300722482925716),
    'val/pearsonr': np.float32(0.28582567),
    'val/spearmanr': np.float64(0.184738354490406),
    'val/pearsonr_mask': np.float32(0.18552475),
    'val/spearmanr_mask': np.float64(0.14219993533833666)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 649/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 649/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 649/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.027094826102256775,
    'val/mean_absolute_error_mask': 1.5422779321670532,
    'val/mean_squared_error': 0.06326215714216232,
    'val/mean_squared_error_mask': 5.008144378662109,
    'val/root_mean_squared_error': 0.14428547024726868,
    'val/cosine_similarity': np.float32(0.29089102),
    'val/cosine_similarity_mask': np.float32(0.3932744),
    'val/sam': np.float32(72.90521),
    'val/iou': np.float64(0.9347183953696802),
    'val/pearsonr': np.float32(0.2841145),
    'val/spearmanr': np.float64(0.18684038168519324),
    'val/pearsonr_mask': np.float32(0.18491563),
    'val/spearmanr_mask': np.float64(0.14270913433917304)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 699/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
Epoch 699/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
Epoch 699/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.17it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02682727761566639,
    'val/mean_absolute_error_mask': 1.5453301668167114,
    'val/mean_squared_error': 0.06334993243217468,
    'val/mean_squared_error_mask': 5.0213303565979,
    'val/root_mean_squared_error': 0.14430907368659973,
    'val/cosine_similarity': np.float32(0.2884502),
    'val/cosine_similarity_mask': np.float32(0.39061114),
    'val/sam': np.float32(73.05527),
    'val/iou': np.float64(0.9397320961509772),
    'val/pearsonr': np.float32(0.28174844),
    'val/spearmanr': np.float64(0.18878563888875197),
    'val/pearsonr_mask': np.float32(0.18372907),
    'val/spearmanr_mask': np.float64(0.14175068963837717)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 749/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 749/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
Epoch 749/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.22it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.026957865804433823,
    'val/mean_absolute_error_mask': 1.542768120765686,
    'val/mean_squared_error': 0.06357229501008987,
    'val/mean_squared_error_mask': 5.011989593505859,
    'val/root_mean_squared_error': 0.14440415799617767,
    'val/cosine_similarity': np.float32(0.288576),
    'val/cosine_similarity_mask': np.float32(0.38965374),
    'val/sam': np.float32(73.04668),
    'val/iou': np.float64(0.9418707716453445),
    'val/pearsonr': np.float32(0.2820041),
    'val/spearmanr': np.float64(0.18942662889573197),
    'val/pearsonr_mask': np.float32(0.18454649),
    'val/spearmanr_mask': np.float64(0.1420135173670543)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 799/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
Epoch 799/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
Epoch 799/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.21it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.026840120553970337,
    'val/mean_absolute_error_mask': 1.5448095798492432,
    'val/mean_squared_error': 0.06366516649723053,
    'val/mean_squared_error_mask': 5.022266864776611,
    'val/root_mean_squared_error': 0.14446642994880676,
    'val/cosine_similarity': np.float32(0.28718692),
    'val/cosine_similarity_mask': np.float32(0.38682854),
    'val/sam': np.float32(73.13253),
    'val/iou': np.float64(0.945434512677269),
    'val/pearsonr': np.float32(0.2806806),
    'val/spearmanr': np.float64(0.19066986782705772),
    'val/pearsonr_mask': np.float32(0.18386592),
    'val/spearmanr_mask': np.float64(0.14018593419688083)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 849/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 849/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
Epoch 849/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.20it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.026184504851698875,
    'val/mean_absolute_error_mask': 1.5530517101287842,
    'val/mean_squared_error': 0.0637301653623581,
    'val/mean_squared_error_mask': 5.053819179534912,
    'val/root_mean_squared_error': 0.14446799457073212,
    'val/cosine_similarity': np.float32(0.28274015),
    'val/cosine_similarity_mask': np.float32(0.37981784),
    'val/sam': np.float32(73.40073),
    'val/iou': np.float64(0.9498157420654428),
    'val/pearsonr': np.float32(0.27645215),
    'val/spearmanr': np.float64(0.19197728158128463),
    'val/pearsonr_mask': np.float32(0.1816219),
    'val/spearmanr_mask': np.float64(0.13969355386051974)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 899/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 899/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 899/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02686235122382641,
    'val/mean_absolute_error_mask': 1.5426524877548218,
    'val/mean_squared_error': 0.06416705995798111,
    'val/mean_squared_error_mask': 5.022475242614746,
    'val/root_mean_squared_error': 0.14464764297008514,
    'val/cosine_similarity': np.float32(0.28372008),
    'val/cosine_similarity_mask': np.float32(0.38507763),
    'val/sam': np.float32(73.34439),
    'val/iou': np.float64(0.9501635304569747),
    'val/pearsonr': np.float32(0.27737164),
    'val/spearmanr': np.float64(0.19326709239351259),
    'val/pearsonr_mask': np.float32(0.1820783),
    'val/spearmanr_mask': np.float64(0.13895742027556515)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 949/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 949/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
Epoch 949/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.19it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02628360688686371,
    'val/mean_absolute_error_mask': 1.5502204895019531,
    'val/mean_squared_error': 0.0640205591917038,
    'val/mean_squared_error_mask': 5.046116828918457,
    'val/root_mean_squared_error': 0.1445959359407425,
    'val/cosine_similarity': np.float32(0.2825891),
    'val/cosine_similarity_mask': np.float32(0.37997043),
    'val/sam': np.float32(73.41318),
    'val/iou': np.float64(0.9532520158913278),
    'val/pearsonr': np.float32(0.27639192),
    'val/spearmanr': np.float64(0.1942360538665618),
    'val/pearsonr_mask': np.float32(0.18236578),
    'val/spearmanr_mask': np.float64(0.14020226226504764)
}
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: 
UserWarning: This DataLoader will create 18 worker processes in total. Our 
suggested max number of worker in current system is 12, which is smaller than 
what this DataLoader is going to create. Please be aware that excessive worker 
creation might get DataLoader running slow or even freeze, lower the worker 
number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Extra variable passed to LightningModule for validation.
Epoch 999/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
Epoch 999/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
Epoch 999/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
/content/UniST/external/SUICA_pro/utils.py:111: UserWarning: No data for 
colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, c=z_norm, s=spot_size, cmap=cmap)
{
    'val/mean_absolute_error': 0.02662677876651287,
    'val/mean_absolute_error_mask': 1.5443346500396729,
    'val/mean_squared_error': 0.06423915922641754,
    'val/mean_squared_error_mask': 5.029016017913818,
    'val/root_mean_squared_error': 0.14470060169696808,
    'val/cosine_similarity': np.float32(0.28242445),
    'val/cosine_similarity_mask': np.float32(0.3837467),
    'val/sam': np.float32(73.428185),
    'val/iou': np.float64(0.9540528620483256),
    'val/pearsonr': np.float32(0.27614704),
    'val/spearmanr': np.float64(0.19516697390749269),
    'val/pearsonr_mask': np.float32(0.18025337),
    'val/spearmanr_mask': np.float64(0.13865657675801757)
}
Epoch 999/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000`Trainer.fit` stopped: `max_epochs=1000` reached.
Epoch 999/999 ━━━━━━━━━━━━━━━━━━━━ 14/14 0:00:07 • 0:00:00 2.18it/s v_num: 0.000
?25hLOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
Extra variable passed to LightningModule for prediction.
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.
py:257: predict returned None if it was on purpose, ignore this warning...
Writing to 
/content/UniST/external/SUICA_pro/logs/GAE-2D/2d/lightning_logs/version_0/embedd
ed-all.h5ad ... It may take some time ...
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17/17 0:00:01 • 0:00:00 10.59it/s 
?25hCurrent Configs:
{
'case': '2d',
'dataset': {
│   │   'type': 'GraphST2D',
│   │   'data_file': '/content/UniST/external/SUICA_pro/data/slice_440.h5ad',
│   │   'val_proportion': 0.2,
│   │   'keep_ratio': True,
│   │   'n_neighbors': 4,
│   │   'require_coordnorm': True
},
'pipeline': {
│   │   'embedder': {
│   │   │   'model': 'GAE',
│   │   │   'dim_hidden': [
│   │   │   │   2048,
│   │   │   │   512,
│   │   │   │   128
│   │   │   ],
│   │   │   'dim_latent': 64,
│   │   │   'dim_in': 17649
│   │   },
│   │   'optimization': {
│   │   │   'seed': 8848,
│   │   │   'epochs': 1000,
│   │   │   'lr': 1e-05,
│   │   │   'val_freq': 50,
│   │   │   'logs': '/content/UniST/external/SUICA_pro/logs/GAE-2D/2d',
│   │   │   'batch_size': 512
│   │   },
│   │   'predict_mode': 'all',
│   │   'embedded_data': 'embedded-all.h5ad'
}
}

Visulize latent space

emb = sc.read('/content/UniST/external/SUICA_pro/logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad')
emb
AnnData object with n_obs × n_vars = 8382 × 17649
    obsm: 'embeddings', 'spatial'
import matplotlib.pyplot as plt

embeddings = emb.obsm["embeddings"]
spatial = emb.obsm["spatial"]

fig, axes = plt.subplots(1, 3, figsize=(12, 3))

for idx, dim in enumerate(range(3)): # first three
    sc = axes[idx].scatter(
        spatial[:, 0],
        spatial[:, 1],
        c=embeddings[:, dim],
        cmap="viridis",
        s=0.8
    )
    axes[idx].set_title(f"Embedding dim {dim+1}")
    plt.colorbar(sc, ax=axes[idx], fraction=0.046)

plt.tight_layout()
plt.show()
../_images/6688a9927659118a4de9de79940f31a49245d536fd27b6e7faf7b8074868fa1b.png

Step2: Train INR + fine-tune GAE

!python train.py --mode inr --conf ./configs/ST/inr_embd.yaml
./configs/ST/inr_embd.yaml
Current Configs:
{
'case': '2d',
'dataset': {
│   │   'type': 'ST2D',
│   │   'data_file': 'logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad',
│   │   'val_proportion': 0.2,
│   │   'keep_ratio': True,
│   │   'require_coordnorm': False
},
'pipeline': {
│   │   'target': 'embeddings',
│   │   'inr': {
│   │   │   'model': 'FFN',
│   │   │   'num_hidden_layers': 3,
│   │   │   'num_hidden_features': 2048,
│   │   │   'phase': 1000,
│   │   │   'decoder': {
│   │   │   │   'ckpt': 'logs/GAE-2D/2d/lightning_logs/version_0/checkpoints/last.ckpt',
│   │   │   │   'recon_loss': True,
│   │   │   │   'finetune': True
│   │   │   }
│   │   },
│   │   'optimization': {
│   │   │   'seed': 8848,
│   │   │   'epochs': 2000,
│   │   │   'lr': 0.0001,
│   │   │   'val_freq': 200,
│   │   │   'logs': 'logs/GAE+FFN-2D/2d',
│   │   │   'batch_size': 5000
│   │   },
│   │   'predict_mode': 'val',
│   │   'reconstructed_data': 'reconstructed-val.h5ad'
}
}
[CONF] conf     = /content/UniST/external/SUICA_pro/configs/ST/inr_embd.yaml
[CONF] datafile = /content/UniST/external/SUICA_pro/logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad
[CONF] logs_dir = /content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d
Seed set to 8848
/usr/local/lib/python3.12/dist-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  _C._set_float32_matmul_precision(precision)
Fitting ST embeddings with INR ...
6705 1677
val_idx[:10]=[3637, 1964, 1408, 7304, 7282, 1721, 4720, 7473, 964, 4825]
with pretrained decoder
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2026-02-11 03:58:16.666655: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 03:58:16.683301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770782296.702933   11867 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770782296.709197   11867 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770782296.725718   11867 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770782296.725745   11867 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770782296.725748   11867 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770782296.725750   11867 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2026-02-11 03:58:16.730705: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name           Type               Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ fitting_model │ FourierFeatureNet │ 13.8 M │ train │     0 │
│ 1 │ decoder       │ Decoder           │ 37.3 M │ train │     0 │
└───┴───────────────┴───────────────────┴────────┴───────┴───────┘
Trainable params: 51.1 M                                                        
Non-trainable params: 0                                                         
Total params: 51.1 M                                                            
Total estimated model params size (MB): 204                                     
Modules in train mode: 24                                                       
Modules in eval mode: 0                                                         
Total FLOPs: 0                                                                  
{
    'val_fitting/mean_absolute_error': 9.803089141845703,
    'val_fitting/mean_absolute_error_mask': 10.628742218017578,
    'val_fitting/mean_squared_error': 171.34371948242188,
    'val_fitting/mean_squared_error_mask': 202.01722717285156,
    'val_fitting/root_mean_squared_error': 11.19957160949707,
    'val_fitting/cosine_similarity': np.float32(0.12160339),
    'val_fitting/cosine_similarity_mask': np.float32(0.1653341),
    'val_fitting/sam': np.float32(82.988594),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.12167605),
    'val_fitting/spearmanr': np.float64(0.12038659760037218),
    'val_fitting/pearsonr_mask': np.float32(0.08505417),
    'val_fitting/spearmanr_mask': np.float64(0.08909855553886589)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.02165960893034935,
    'val_recon/mean_absolute_error_mask': 1.7172333002090454,
    'val_recon/mean_squared_error': 0.06840622425079346,
    'val_recon/mean_squared_error_mask': 5.838356018066406,
    'val_recon/root_mean_squared_error': 0.14547452330589294,
    'val_recon/cosine_similarity': np.float32(0.13697076),
    'val_recon/cosine_similarity_mask': np.float32(0.36834526),
    'val_recon/sam': np.float32(82.12168),
    'val_recon/iou': np.float64(0.8911825762653438),
    'val_recon/pearsonr': np.float32(0.1223479),
    'val_recon/spearmanr': np.float64(0.1042421679282781),
    'val_recon/pearsonr_mask': np.float32(0.06927782),
    'val_recon/spearmanr_mask': np.float64(0.06782675598513148)
}
{
    'val_fitting/mean_absolute_error': 3.2481529712677,
    'val_fitting/mean_absolute_error_mask': 3.3302149772644043,
    'val_fitting/mean_squared_error': 24.950271606445312,
    'val_fitting/mean_squared_error_mask': 26.7098388671875,
    'val_fitting/root_mean_squared_error': 4.792659759521484,
    'val_fitting/cosine_similarity': np.float32(0.96840703),
    'val_fitting/cosine_similarity_mask': np.float32(0.93181896),
    'val_fitting/sam': np.float32(12.749548),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9684848),
    'val_fitting/spearmanr': np.float64(0.9594725067220593),
    'val_fitting/pearsonr_mask': np.float32(0.9488256),
    'val_fitting/spearmanr_mask': np.float64(0.9343449446663876)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.02803116850554943,
    'val_recon/mean_absolute_error_mask': 1.5134501457214355,
    'val_recon/mean_squared_error': 0.06449563801288605,
    'val_recon/mean_squared_error_mask': 5.05856466293335,
    'val_recon/root_mean_squared_error': 0.1442723423242569,
    'val_recon/cosine_similarity': np.float32(0.29309714),
    'val_recon/cosine_similarity_mask': np.float32(0.38395637),
    'val_recon/sam': np.float32(72.903206),
    'val_recon/iou': np.float64(0.9386127525542001),
    'val_recon/pearsonr': np.float32(0.286002),
    'val_recon/spearmanr': np.float64(0.19761426273974408),
    'val_recon/pearsonr_mask': np.float32(0.16350925),
    'val_recon/spearmanr_mask': np.float64(0.14219608537234726)
}
{
    'val_fitting/mean_absolute_error': 3.025294303894043,
    'val_fitting/mean_absolute_error_mask': 3.116687059402466,
    'val_fitting/mean_squared_error': 20.233470916748047,
    'val_fitting/mean_squared_error_mask': 21.940380096435547,
    'val_fitting/root_mean_squared_error': 4.298580169677734,
    'val_fitting/cosine_similarity': np.float32(0.97599155),
    'val_fitting/cosine_similarity_mask': np.float32(0.9447646),
    'val_fitting/sam': np.float32(11.491507),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.976072),
    'val_fitting/spearmanr': np.float64(0.967004216713519),
    'val_fitting/pearsonr_mask': np.float32(0.9587564),
    'val_fitting/spearmanr_mask': np.float64(0.9421800007643844)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.026943445205688477,
    'val_recon/mean_absolute_error_mask': 1.5258015394210815,
    'val_recon/mean_squared_error': 0.06269751489162445,
    'val_recon/mean_squared_error_mask': 4.935014247894287,
    'val_recon/root_mean_squared_error': 0.1441444456577301,
    'val_recon/cosine_similarity': np.float32(0.29973173),
    'val_recon/cosine_similarity_mask': np.float32(0.40790847),
    'val_recon/sam': np.float32(72.37264),
    'val_recon/iou': np.float64(0.9466365720964929),
    'val_recon/pearsonr': np.float32(0.29300696),
    'val_recon/spearmanr': np.float64(0.19903610981677888),
    'val_recon/pearsonr_mask': np.float32(0.18533199),
    'val_recon/spearmanr_mask': np.float64(0.14180284395187312)
}
{
    'val_fitting/mean_absolute_error': 3.096369981765747,
    'val_fitting/mean_absolute_error_mask': 3.183185577392578,
    'val_fitting/mean_squared_error': 21.11383819580078,
    'val_fitting/mean_squared_error_mask': 22.750463485717773,
    'val_fitting/root_mean_squared_error': 4.362051963806152,
    'val_fitting/cosine_similarity': np.float32(0.9738573),
    'val_fitting/cosine_similarity_mask': np.float32(0.94278264),
    'val_fitting/sam': np.float32(11.821716),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.97394055),
    'val_fitting/spearmanr': np.float64(0.9646839594514013),
    'val_fitting/pearsonr_mask': np.float32(0.9551693),
    'val_fitting/spearmanr_mask': np.float64(0.9381571761969161)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.026225129142403603,
    'val_recon/mean_absolute_error_mask': 1.5400773286819458,
    'val_recon/mean_squared_error': 0.06302483379840851,
    'val_recon/mean_squared_error_mask': 4.989807605743408,
    'val_recon/root_mean_squared_error': 0.1443091481924057,
    'val_recon/cosine_similarity': np.float32(0.28978956),
    'val_recon/cosine_similarity_mask': np.float32(0.39597008),
    'val_recon/sam': np.float32(72.96944),
    'val_recon/iou': np.float64(0.9510904044113003),
    'val_recon/pearsonr': np.float32(0.2832987),
    'val_recon/spearmanr': np.float64(0.19789411294481812),
    'val_recon/pearsonr_mask': np.float32(0.18069202),
    'val_recon/spearmanr_mask': np.float64(0.14041022670312742)
}
{
    'val_fitting/mean_absolute_error': 3.186626434326172,
    'val_fitting/mean_absolute_error_mask': 3.26865816116333,
    'val_fitting/mean_squared_error': 21.915382385253906,
    'val_fitting/mean_squared_error_mask': 23.43694305419922,
    'val_fitting/root_mean_squared_error': 4.420367240905762,
    'val_fitting/cosine_similarity': np.float32(0.9725111),
    'val_fitting/cosine_similarity_mask': np.float32(0.9421381),
    'val_fitting/sam': np.float32(12.039261),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9725973),
    'val_fitting/spearmanr': np.float64(0.9633232475137661),
    'val_fitting/pearsonr_mask': np.float32(0.95341235),
    'val_fitting/spearmanr_mask': np.float64(0.9361080736791874)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.026105046272277832,
    'val_recon/mean_absolute_error_mask': 1.5455834865570068,
    'val_recon/mean_squared_error': 0.06362394243478775,
    'val_recon/mean_squared_error_mask': 5.014831066131592,
    'val_recon/root_mean_squared_error': 0.14459550380706787,
    'val_recon/cosine_similarity': np.float32(0.28314647),
    'val_recon/cosine_similarity_mask': np.float32(0.38843572),
    'val_recon/sam': np.float32(73.36646),
    'val_recon/iou': np.float64(0.9532606117376937),
    'val_recon/pearsonr': np.float32(0.2768542),
    'val_recon/spearmanr': np.float64(0.19572024679404396),
    'val_recon/pearsonr_mask': np.float32(0.17982924),
    'val_recon/spearmanr_mask': np.float64(0.13791711558277078)
}
{
    'val_fitting/mean_absolute_error': 3.2325429916381836,
    'val_fitting/mean_absolute_error_mask': 3.2905004024505615,
    'val_fitting/mean_squared_error': 22.396663665771484,
    'val_fitting/mean_squared_error_mask': 23.66312599182129,
    'val_fitting/root_mean_squared_error': 4.460636138916016,
    'val_fitting/cosine_similarity': np.float32(0.97105485),
    'val_fitting/cosine_similarity_mask': np.float32(0.9415784),
    'val_fitting/sam': np.float32(12.25332),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9711237),
    'val_fitting/spearmanr': np.float64(0.9614763466507652),
    'val_fitting/pearsonr_mask': np.float32(0.9522627),
    'val_fitting/spearmanr_mask': np.float64(0.9345506486621038)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.02602044865489006,
    'val_recon/mean_absolute_error_mask': 1.5499576330184937,
    'val_recon/mean_squared_error': 0.06409505009651184,
    'val_recon/mean_squared_error_mask': 5.0271382331848145,
    'val_recon/root_mean_squared_error': 0.14491058886051178,
    'val_recon/cosine_similarity': np.float32(0.27831817),
    'val_recon/cosine_similarity_mask': np.float32(0.3853728),
    'val_recon/sam': np.float32(73.649185),
    'val_recon/iou': np.float64(0.9548744368550195),
    'val_recon/pearsonr': np.float32(0.2721035),
    'val_recon/spearmanr': np.float64(0.19372869695017914),
    'val_recon/pearsonr_mask': np.float32(0.17886032),
    'val_recon/spearmanr_mask': np.float64(0.1357424999731391)
}
{
    'val_fitting/mean_absolute_error': 3.2233128547668457,
    'val_fitting/mean_absolute_error_mask': 3.2782351970672607,
    'val_fitting/mean_squared_error': 22.361495971679688,
    'val_fitting/mean_squared_error_mask': 23.596450805664062,
    'val_fitting/root_mean_squared_error': 4.457827568054199,
    'val_fitting/cosine_similarity': np.float32(0.969664),
    'val_fitting/cosine_similarity_mask': np.float32(0.94065034),
    'val_fitting/sam': np.float32(12.485819),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9697332),
    'val_fitting/spearmanr': np.float64(0.959894478409684),
    'val_fitting/pearsonr_mask': np.float32(0.9501587),
    'val_fitting/spearmanr_mask': np.float64(0.9320453913105698)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.1991494596004486,
    'val_recon/mean_absolute_error_mask': 1.2084848880767822,
    'val_recon/mean_squared_error': 0.36856260895729065,
    'val_recon/mean_squared_error_mask': 3.7554783821105957,
    'val_recon/root_mean_squared_error': 0.30926695466041565,
    'val_recon/cosine_similarity': np.float32(0.17161262),
    'val_recon/cosine_similarity_mask': np.float32(0.6019463),
    'val_recon/sam': np.float32(80.09502),
    'val_recon/iou': np.float64(0.8566545458283872),
    'val_recon/pearsonr': np.float32(0.15381473),
    'val_recon/spearmanr': np.float64(0.1579520705668068),
    'val_recon/pearsonr_mask': np.float32(0.16576657),
    'val_recon/spearmanr_mask': np.float64(0.16806622520657605)
}
{
    'val_fitting/mean_absolute_error': 3.2233128547668457,
    'val_fitting/mean_absolute_error_mask': 3.2782351970672607,
    'val_fitting/mean_squared_error': 22.361495971679688,
    'val_fitting/mean_squared_error_mask': 23.596450805664062,
    'val_fitting/root_mean_squared_error': 4.457827568054199,
    'val_fitting/cosine_similarity': np.float32(0.969664),
    'val_fitting/cosine_similarity_mask': np.float32(0.94065034),
    'val_fitting/sam': np.float32(12.485819),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9697332),
    'val_fitting/spearmanr': np.float64(0.959894478409684),
    'val_fitting/pearsonr_mask': np.float32(0.9501587),
    'val_fitting/spearmanr_mask': np.float64(0.9320453913105698)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.1963220238685608,
    'val_recon/mean_absolute_error_mask': 1.2238929271697998,
    'val_recon/mean_squared_error': 0.3862129747867584,
    'val_recon/mean_squared_error_mask': 3.795976161956787,
    'val_recon/root_mean_squared_error': 0.3143541216850281,
    'val_recon/cosine_similarity': np.float32(0.16858377),
    'val_recon/cosine_similarity_mask': np.float32(0.59696686),
    'val_recon/sam': np.float32(80.27106),
    'val_recon/iou': np.float64(0.862667912937159),
    'val_recon/pearsonr': np.float32(0.15103157),
    'val_recon/spearmanr': np.float64(0.15977832954235457),
    'val_recon/pearsonr_mask': np.float32(0.16347322),
    'val_recon/spearmanr_mask': np.float64(0.1636289730125434)
}
{
    'val_fitting/mean_absolute_error': 3.2233128547668457,
    'val_fitting/mean_absolute_error_mask': 3.2782351970672607,
    'val_fitting/mean_squared_error': 22.361495971679688,
    'val_fitting/mean_squared_error_mask': 23.596450805664062,
    'val_fitting/root_mean_squared_error': 4.457827568054199,
    'val_fitting/cosine_similarity': np.float32(0.969664),
    'val_fitting/cosine_similarity_mask': np.float32(0.94065034),
    'val_fitting/sam': np.float32(12.485819),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9697332),
    'val_fitting/spearmanr': np.float64(0.959894478409684),
    'val_fitting/pearsonr_mask': np.float32(0.9501587),
    'val_fitting/spearmanr_mask': np.float64(0.9320453913105698)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.19346606731414795,
    'val_recon/mean_absolute_error_mask': 1.241257667541504,
    'val_recon/mean_squared_error': 0.40165582299232483,
    'val_recon/mean_squared_error_mask': 3.850717306137085,
    'val_recon/root_mean_squared_error': 0.3188914656639099,
    'val_recon/cosine_similarity': np.float32(0.16558865),
    'val_recon/cosine_similarity_mask': np.float32(0.5903364),
    'val_recon/sam': np.float32(80.44671),
    'val_recon/iou': np.float64(0.8680786158642505),
    'val_recon/pearsonr': np.float32(0.14828137),
    'val_recon/spearmanr': np.float64(0.16151258023185983),
    'val_recon/pearsonr_mask': np.float32(0.16016483),
    'val_recon/spearmanr_mask': np.float64(0.15970734219641153)
}
{
    'val_fitting/mean_absolute_error': 3.2233128547668457,
    'val_fitting/mean_absolute_error_mask': 3.2782351970672607,
    'val_fitting/mean_squared_error': 22.361495971679688,
    'val_fitting/mean_squared_error_mask': 23.596450805664062,
    'val_fitting/root_mean_squared_error': 4.457827568054199,
    'val_fitting/cosine_similarity': np.float32(0.969664),
    'val_fitting/cosine_similarity_mask': np.float32(0.94065034),
    'val_fitting/sam': np.float32(12.485819),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9697332),
    'val_fitting/spearmanr': np.float64(0.959894478409684),
    'val_fitting/pearsonr_mask': np.float32(0.9501587),
    'val_fitting/spearmanr_mask': np.float64(0.9320453913105698)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.19486239552497864,
    'val_recon/mean_absolute_error_mask': 1.2670198678970337,
    'val_recon/mean_squared_error': 0.42833319306373596,
    'val_recon/mean_squared_error_mask': 3.91340970993042,
    'val_recon/root_mean_squared_error': 0.3265944719314575,
    'val_recon/cosine_similarity': np.float32(0.16456817),
    'val_recon/cosine_similarity_mask': np.float32(0.5851358),
    'val_recon/sam': np.float32(80.507416),
    'val_recon/iou': np.float64(0.8726037987049506),
    'val_recon/pearsonr': np.float32(0.14753704),
    'val_recon/spearmanr': np.float64(0.1632070860577605),
    'val_recon/pearsonr_mask': np.float32(0.15671371),
    'val_recon/spearmanr_mask': np.float64(0.1568066690774747)
}
{
    'val_fitting/mean_absolute_error': 3.2233128547668457,
    'val_fitting/mean_absolute_error_mask': 3.2782351970672607,
    'val_fitting/mean_squared_error': 22.361495971679688,
    'val_fitting/mean_squared_error_mask': 23.596450805664062,
    'val_fitting/root_mean_squared_error': 4.457827568054199,
    'val_fitting/cosine_similarity': np.float32(0.969664),
    'val_fitting/cosine_similarity_mask': np.float32(0.94065034),
    'val_fitting/sam': np.float32(12.485819),
    'val_fitting/iou': 0,
    'val_fitting/pearsonr': np.float32(0.9697332),
    'val_fitting/spearmanr': np.float64(0.959894478409684),
    'val_fitting/pearsonr_mask': np.float32(0.9501587),
    'val_fitting/spearmanr_mask': np.float64(0.9320453913105698)
}
Decoder saved to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/de
coder.pth
{
    'val_recon/mean_absolute_error': 0.18629270792007446,
    'val_recon/mean_absolute_error_mask': 1.2741972208023071,
    'val_recon/mean_squared_error': 0.4215638041496277,
    'val_recon/mean_squared_error_mask': 3.9775025844573975,
    'val_recon/root_mean_squared_error': 0.3236192464828491,
    'val_recon/cosine_similarity': np.float32(0.16184235),
    'val_recon/cosine_similarity_mask': np.float32(0.57593405),
    'val_recon/sam': np.float32(80.66488),
    'val_recon/iou': np.float64(0.8776891750781495),
    'val_recon/pearsonr': np.float32(0.14508367),
    'val_recon/spearmanr': np.float64(0.1643765121734796),
    'val_recon/pearsonr_mask': np.float32(0.15638234),
    'val_recon/spearmanr_mask': np.float64(0.15403514199020105)
}
Epoch 1999/1999 ━━━━━━━━━━━━━━━━━━━━ 2/2 0:00:01 • 0:00:00 9.22it/s v_num: 0.000`Trainer.fit` stopped: `max_epochs=2000` reached.
Epoch 1999/1999 ━━━━━━━━━━━━━━━━━━━━ 2/2 0:00:01 • 0:00:00 9.22it/s v_num: 0.000
?25hLOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.
py:257: predict returned None if it was on purpose, ignore this warning...
Writing to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_0/re
constructed-val.h5ad ... It may take some time ...
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s 
?25hCurrent Configs:
{
'case': '2d',
'dataset': {
│   │   'type': 'ST2D',
│   │   'data_file': '/content/UniST/external/SUICA_pro/logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad',
│   │   'val_proportion': 0.2,
│   │   'keep_ratio': True,
│   │   'require_coordnorm': False
},
'pipeline': {
│   │   'target': 'embeddings',
│   │   'inr': {
│   │   │   'model': 'FFN',
│   │   │   'num_hidden_layers': 3,
│   │   │   'num_hidden_features': 2048,
│   │   │   'phase': 1000,
│   │   │   'decoder': {
│   │   │   │   'ckpt': 'logs/GAE-2D/2d/lightning_logs/version_0/checkpoints/last.ckpt',
│   │   │   │   'recon_loss': True,
│   │   │   │   'finetune': True
│   │   │   },
│   │   │   'dim_in': 2,
│   │   │   'dim_out': 64
│   │   },
│   │   'optimization': {
│   │   │   'seed': 8848,
│   │   │   'epochs': 2000,
│   │   │   'lr': 0.0001,
│   │   │   'val_freq': 200,
│   │   │   'logs': '/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d',
│   │   │   'batch_size': 5000
│   │   },
│   │   'predict_mode': 'val',
│   │   'reconstructed_data': 'reconstructed-val.h5ad'
}
}

Step3: Prediction/Imputation

Prepare normalized custom coords first

### The upsampled point cloud

import numpy as np
import matplotlib.pyplot as plt

file_path = "data/slice440_r2.xyz"

points = []
with open(file_path, 'r') as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) >= 2:
            x, y = float(parts[0]), float(parts[1])
            points.append((x, y))

points = np.array(points)
n_points = points.shape[0]

plt.figure(figsize=(6, 6))
plt.scatter(points[:, 0], points[:, 1], s=1, alpha=0.5)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f"Upsampling, r=2 (slice 440, {n_points} cells)")
plt.axis('equal')
#plt.savefig('slice440_r2.png', format='png', dpi=300)
plt.show()
../_images/00ec2ce295811b5f10e182e5358293b0f8f80a5f60f8109d78e36ef64c933b66.png
!python prepare_custom_coords.py --mode 2d --reference data/slice_440.h5ad --coords data/slice440_r2.xyz --output data/preprocessed_data/custom_coords_2d_norm.npy
Reading reference data: data/slice_440.h5ad
Reading coordinate file: data/slice440_r2.xyz
Input coordinate shape: (16764, 3)

Normalized coordinate ranges:
  X: [-0.9912, 0.9935]
  Y: [-0.7844, 0.7798]
Shape: (16764, 2)

✅ Saved 16764 2D coordinates to data/preprocessed_data/custom_coords_2d_norm.npy
file_path = "data/preprocessed_data/custom_coords_2d_norm.npy"

# Load points from the .npy file
points = np.load(file_path)
n_points = points.shape[0]

plt.figure(figsize=(6, 6))
plt.scatter(points[:, 0], points[:, 1], s=1, alpha=0.5)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f"Normalized Custom 2D Coordinates ({n_points} cells)")
plt.axis('equal')
plt.show()
../_images/fab88b1e9d96e29e7b21616260fff4b20da9f3e9acf0c2b7c7439a57c5c2a436.png
# Run prediction
!python predict.py --mode inr --conf ./configs/ST/inr_pred.yaml
./configs/ST/inr_pred.yaml
Current Configs:
{
'case': '2d',
'dataset': {
│   │   'type': 'ST2D',
│   │   'data_file': 'logs/GAE-2D/2d/lightning_logs/version_0/embedded-all.h5ad',
│   │   'val_proportion': 0.2,
│   │   'keep_ratio': True,
│   │   'require_coordnorm': False
},
'pipeline': {
│   │   'target': 'embeddings',
│   │   'prediction': {
│   │   │   'ckpt': 'logs/GAE+FFN-2D/2d/lightning_logs/version_0/checkpoints/last.ckpt'
│   │   },
│   │   'inr': {
│   │   │   'model': 'FFN',
│   │   │   'num_hidden_layers': 3,
│   │   │   'num_hidden_features': 2048,
│   │   │   'phase': 1000,
│   │   │   'decoder': {
│   │   │   │   'ckpt': 'logs/GAE-2D/2d/lightning_logs/version_0/checkpoints/last.ckpt',
│   │   │   │   'recon_loss': True,
│   │   │   │   'finetune': True
│   │   │   }
│   │   },
│   │   'optimization': {
│   │   │   'seed': 8848,
│   │   │   'epochs': 2000,
│   │   │   'lr': 0.0001,
│   │   │   'val_freq': 200,
│   │   │   'logs': 'logs/GAE+FFN-2D/2d',
│   │   │   'batch_size': 8192
│   │   },
│   │   'predict_mode': 'custom',
│   │   'custom_coords_file': 'data/preprocessed_data/custom_coords_2d_norm.npy',
│   │   'reconstructed_data': 'reconstructed-custom_2d.h5ad'
}
}
Seed set to 8848
/usr/local/lib/python3.12/dist-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  _C._set_float32_matmul_precision(precision)
 loaded from: 
/content/UniST/external/SUICA_pro/logs/GAE-2D/2d/lightning_logs/version_0/checkp
oints/last.ckpt
Detected custom coord dim: 2D
Predicting ST embeddings with INR ...
dim_in=2, dim_out=64
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
Start custom coordinate prediction...
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2026-02-11 05:20:19.905866: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 05:20:19.922430: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770787219.941514   73698 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770787219.947679   73698 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770787219.963756   73698 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770787219.963780   73698 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770787219.963783   73698 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770787219.963786   73698 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2026-02-11 05:20:19.967909: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.
py:257: predict returned None if it was on purpose, ignore this warning...
Writing to 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_1/re
constructed-custom_2d.h5ad ... It may take some time ...
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:01 • 0:00:00 4.40it/s 
?25hCustom prediction saved under: 
/content/UniST/external/SUICA_pro/logs/GAE+FFN-2D/2d/lightning_logs/version_1
# Map reconstructed coords back to original space
!python map_coords_back.py --reconstructed logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-custom_2d.h5ad --reference data/slice_440.h5ad --output logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-original.h5ad --mode 2d
============================================================
Reading prediction results: logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-custom_2d.h5ad
============================================================
Normalized coordinate shape: (16764, 2)
Normalized coordinate ranges:
  min: [-0.9911625 -0.784444 ]
  max: [0.99346364 0.77975   ]

============================================================
Reading reference data: data/slice_440.h5ad
============================================================

============================================================
Denormalizing coordinates...
============================================================
Parameters: keep_ratio=True
Reference data coordinate ranges:
  X: [4691.54, 8823.75], range=4132.22
  Y: [7357.52, 10613.24], range=3255.72
Scaling factors (keep_ratio=True): scale_x=1.0000, scale_y=0.7879

Denormalized coordinate ranges:
  min: [4709.79491889 7364.63086083]
  max: [ 8810.24711702 10596.42476476]

============================================================
Saving results to: logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-original.h5ad
============================================================
✓ Done!

Output file contains:
  - obsm['spatial']: Original coordinates (mapped results)
  - obsm['spatial_normalized']: Normalized coordinates (backup)
  - Other prediction results remain unchanged
res = sc.read('logs/GAE+FFN-2D/2d/lightning_logs/version_1/reconstructed-original.h5ad')
res
AnnData object with n_obs × n_vars = 16764 × 1
    obsm: 'fitted_embd', 'reconstructed_raw', 'spatial', 'spatial_normalized'
### Visualize the imputed embeddings

embeddings = res.obsm["fitted_embd"]
spatial = res.obsm["spatial"]

fig, axes = plt.subplots(1, 3, figsize=(12, 3))

for idx, dim in enumerate(range(3)): # first three
    sc = axes[idx].scatter(
        spatial[:, 0],
        spatial[:, 1],
        c=embeddings[:, dim].toarray().flatten(),
        cmap="viridis",
        s=0.5
    )
    axes[idx].set_title(f"Imputed embedding dim {dim+1}")
    plt.colorbar(sc, ax=axes[idx], fraction=0.046)

plt.tight_layout()
plt.show()
../_images/675827d2fdc6cfbcbe02d067fa6e236d8364cb4e8169d9fa2b4e2d2425648739.png
### Visualize the imputed gene expression

coordinates = res.obsm['spatial'][:, 0:2]
values = res.obsm['reconstructed_raw'][:, 8293]
n_points = coordinates.shape[0]

# Visualization
plt.figure(figsize=(6, 6))
plt.scatter(coordinates[:, 0], coordinates[:, 1], c=values, cmap='viridis_r', s=0.5, alpha=0.7)
plt.xlabel('X')
plt.ylabel('Y')
plt.title(f"Imputed expression of gene Hba-x, slice 440")
plt.axis('equal')
plt.colorbar(label='Expression Level')
#plt.savefig('slice440_Hba-x.png', format='png', dpi=300)
plt.show()
../_images/dc249c48282a5d772091b314d6800496426a24b2ebb685723ae357a60634ff50.png