Tutorial 3: Application on Mouse Brain Dataset - yihe-csu/SAGE GitHub Wiki

This tutorial demonstrates the analysis of spatial transcriptomics data from the mouse brain, a complex tissue with fine-grained and highly organized anatomical structures. To evaluate the generalization capability and structural robustness of our unsupervised model SAGE, we apply it to the 10x Visium mouse brain dataset, using a fixed number of clusters (n = 26).

The dataset includes a sagittal section composed of anterior and posterior brain slices. We integrate these two slices to assess the continuity of gene expression and structural consistency at the slice junctions. The inferred spatial domains are validated against the Allen Mouse Brain Atlas to ensure biological relevance and accuracy.

Due to the large size of the processed dataset, we have uploaded it to Figshare: πŸ”— https://figshare.com/articles/dataset/Mouse_Brain_Merge/29492798. Please download the dataset and place it in the appropriate directory before running this tutorial.

Identify HSGs

1. Load SAGE and its dependent packages

import os
import sys
import numpy as np
import pandas as pd
import torch
from sklearn import metrics
import matplotlib.pyplot as plt
import scanpy as sc
import importlib
sys.path.append(os.path.abspath("C://Users//heyi//Desktop/code_iteration/SAGE-main/"))
import SAGE
print(SAGE.__version__)
print(SAGE.__author__)
print(SAGE.__email__)
1.1.8 
Yi He
[email protected]

2. Set up the working environment and import data

# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = 'E:\\R-4.4.1'

base_path = "C:/Users/heyi/Desktop/code_iteration/SAGE-main/"
Dataset = "Mouse_Brain_Merge"

file_fold = f'{base_path}/Dataset/{Dataset}/'
# Set directory (If you want to use additional data, please change the file path)
output_dir=f"{base_path}/Result/{Dataset}"
output_process_dir = output_dir + "/process_data"
output_result_dir = output_dir + "/result_data"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(output_process_dir):
    os.makedirs(output_process_dir)
if not os.path.exists(output_result_dir):
    os.makedirs(output_result_dir)


# Read data from input_dir
adata = sc.read_h5ad(file_fold+'\\filtered_feature_bc_matrix.h5ad')
adata.var_names_make_unique()
adata.obs.index = adata.obs.index + "_" + adata.obs.groupby(adata.obs.index).cumcount().astype(str)
adata
AnnData object with n_obs Γ— n_vars = 6050 Γ— 32285
    obs: 'in_tissue', 'array_row', 'array_col'
    obsm: 'spatial'

3. Expression data preprocessing

adata = adata.copy()
sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_genes(adata, min_counts=5)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.pca(adata, n_comps=50, mask_var="highly_variable", svd_solver='arpack')

4. Multi-resolution Consensus Clustering

con_method = "leiden"  #["leiden","mclust"]
con_range = ( 0.1, 3 , 0.1)
con_use_rep = 'X_pca'
n_neig_coord=6
n_neig_feat =6
con_dim =25
con_radius=20
con_refine=True

SAGE.consensus_clustering(adata, 
                          method=con_method, 
                          resolution_range=con_range, 
                          n_neighbors=n_neig_feat, 
                          use_rep=con_use_rep, 
                          dims=con_dim, 
                          radius=con_radius, 
                          refinement=con_refine)
>>> Starting leiden clustering...
Running leiden clustering: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 29/29 [00:08<00:00,  3.31it/s]
>>> leiden clustering finished. Time elapsed: 15.43 seconds
>>> Computing consensus matrix...
>>> Consensus matrix computed. Time elapsed: 3.07 seconds
>>> Performing final Leiden clustering on consensus matrix...
>>> Consensus clustering completed.
>>> adata.obs['pre_domain'] generated!
>>> adata.obsm['consensus_freq'] generated!
>>> adata.obsm['clusters_results'] generated!

5.Topic Selection via Supervised Learning

SAGE.preprocess.topics_selection(adata, n_topics=40)
>>> Start the genes_mri calculation ...
>>> Finish genes_mri calculation! Time taken: 5.12 seconds
>>> adata.var['genes_mri'] generated!
>>> Step1: Starting NMF calculation...
>>> NMF completed. 40 topics identified. Time taken: 61.36 seconds
>>> Step2: Filtering based on Topics Moran's I values...
>>> Stpe3: Filtering based on Random Forest importances...
>>> 32 topics selected !
>>> adata.obsm['W_nmf'], adata.varm['H_nmf'] generate!
>>> adata.uns['topic_mri_list'] generate!
>>> adata.uns['sorted_indices_imp'] generate!
>>> adata.uns['final_topics'] generate!

6. Draw Topics detected by RF

SAGE.plot.plot_topics(adata, 
                    uns_key="final_topics", 
                    img_key=None, 
                    spot_size=300, 
                    ncols=8,
                    figsize=(4, 4),
                    fontsize=20,
                    frameon=False, 
                    legend_loc=False, 
                    colorbar_loc=None,
                    show=False)

7. Construction of high-specificity genes (HSGs)

SAGE.preprocess.genes_selection(adata, n_genes=3000, lower_p=10, upper_p=99)
 percentile_genes
    Topic gene_lower  rank_lower     gene_upper  rank_upper
0       0       Caly          70          Fem1a        2935
1       1     Gabra1          80          Cep83        2933
2       2      Kcnc1          79          Rbm4b        2927
3       3       Syn1          74          Kmt2a        2924
4       4        Pkm         106        Abhd14b        2943
5       5       Rpl4         103           Asns        2941
6       6      Ywhag          92  F730043M19Rik        2943
7       7      Rtl8a          72          Rab15        2923
8       8      Efnb3          60          Rhpn2        2913
9       9     Atp1b1          82           Evi5        2931
10     10       Napb         108         Ndufs1        2942
11     11      Rpl38          53          Cdk16        2882
12     12       Rpl9          35         Actr3b        2856
13     13    Camk2n1          61         Dcaf11        2920
14     14   Hsp90ab1         101         Chmp1a        2946
15     15     Slc1a3          87           Wdr1        2938
16     16        Dst         106          Rusc2        2941
17     17       Dnm1          79          Sf3b3        2926
18     18     Cyfip2          77           Rtl6        2933
19     19        Abr          97         Timm29        2940
20     20      Trim9          75          Srxn1        2940
21     21     Tmsb4x          22          Gabrd        2705
22     22      Ywhab          73         Bckdha        2927
23     23   Serpine2          78         Pdgfrb        2935
24     24     Cox4i1          35           Tdrp        2919
25     25      Calb1          98         Ccl27a        2939
26     26      Nupr1          62        Klhdc8b        2927
27     27       Ppia          18         Atxn1l        2858
28     28       Mmd2          46         Stxbp1        2916
29     29      Rnf13          48         Wdr45b        2899
30     30      Sumo2          86        Ppp2r3a        2938
31     31       Mmd2          47           Cr1l        2922
number_svgs:932
number_no_svgs:766
>>> HSGs selected!
>>> adata.uns['HSG'] generate!
>>> adata.varm['gene_topic_corr'] generate! (method='pearson')
>>> adata.uns['marker_genes_all_dict'] generate! (corr_threshold=0.2)

Clustering

1. Consensus-driven Graph Construction

SAGE.preprocess.optimize_graph_topology(
    adata,
    n_neig_coord=6,
    cut_side_thr=0.2,
    min_neighbor=2,
    n_neig_feat=15,
)
>>> Graph_coord constructed!
>>> adata.obsm['adj_opt'] generate!
>>> Graph_feat constructed!
>>> adata.obsm['adj_feat'] generate!

2. Draw SAG edge

SAGE.plot.plot_neighbors_cut(adata,
                   img_key=None,
                   spot_size=200,
                   figsize=(5, 5),
                   frameon=False,
                   legend_loc=None,
                   colorbar_loc=None,
                   show=False)

3. Run SAGE

# define model
model = SAGE.SAGE(adata, device=device, epochs=800)
adata = model.train()
>>> SAGE model construction completed, training begins.
>>> Optimization finished for ST data.

4. Domain clustering

n_clusters=26
SAGE.utils.clustering(adata, 
                     data= adata.obsm["emb_latent_att"], 
                     method='mclust', 
                     n_clusters=n_clusters, 
                     res = 0.3, 
                     radius=30,  
                     refinement=False)
R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.

fitting ...
  |======================================================================| 100%
>>> Clustering completed using mclust with 26 clusters.
>>> adata.obsm['domain'] & ['mclust'] generate!

5. Save result

adata.write_h5ad(output_result_dir+"/result.h5")
clusters = adata.obs["domain"] 
clusters.to_csv(output_result_dir+f"/clusters.csv",header = False)
embedding = adata.obsm["emb_latent_att"]
np.savetxt(output_result_dir+"/embedding.txt",embedding)

SAGE.utils.export_H_zscore_to_csv(adata, out_dir=output_process_dir)
SAGE.utils.export_Corr_to_csv(adata, out_dir=output_process_dir)

6. Plot clustering results

# Read data
adata = sc.read(output_result_dir+'/result.h5')

plt.rcParams["figure.figsize"] = (8,8)
plt.rcParams['font.size'] = 10
labels_pred = adata.obs["domain"]
X = adata.obsm["spatial"]
SC = metrics.silhouette_score(X, labels_pred)
DB = metrics.davies_bouldin_score(X, labels_pred)

custom_palette = [
    '#F4E8B8',  '#EEC30E',  '#8F0100',  '#058187',  '#0C4DA7',  
    '#B44541',  '#632621',  '#92C7A3',  '#D98882',  '#6A93CB',  
    '#F0C94A',  '#AD6448',  '#4F6A9C',  '#CCB9A1',  '#0B3434',  
    '#3C4F76',  '#C1D354',  '#7D5BA6',  '#F28522',  '#4A9586',
    '#FF6F61',  '#D32F2F',  '#1976D2',  '#388E3C',  '#FBC02D', 
    '#8E24AA',  '#0288D1',  '#7B1FA2',  '#F57C00',  '#C2185B',
    '#1B4F72',  '#117864',  '#D4AC0D',  '#922B21',  '#6C3483',
    '#1F618D',  '#A04000',  '#196F3D',  '#2C3E50',  '#F39C12',
    '#7D6608',  '#4A235A',  '#D68910',  '#B03A2E',  '#7B241C',
    '#2471A3',  '#148F77',  '#9C640C',  '#6E2C00',  '#512E5F',
    '#154360',  '#145A32'
]

cluster_palette = {int(cluster): custom_palette[i % len(custom_palette)] 
                   for i, cluster in enumerate(adata.obs['domain'].unique())}
sc.pl.spatial(adata, 
            img_key = None, 
            spot_size = 160,
            color = ["domain"],
            title = [f'(SC={SC:.3f}  DB={DB:.3f})'], 
            na_in_legend = False,
            palette = cluster_palette,
            frameon=False,
            show= False)
plt.tight_layout()
plt.show()

7. Marker genes detection

SAGE.utils.run_domain_gene_mapping_pipeline(
    adata,
    useref="domain",
    topic_key="W_nmf",
    corr_threshold=0.2,
    topic_gene_corr_key="gene_topic_corr",
    domain_mapping_topic_key="domain_mapping_topic",
    marker_genes_dict_key="marker_genes_multi_dict",
    domain_to_genes_key="domain_to_genes"
)
>>> Step 1: Generating marker genes for topics...
>>> Step 2: Matching domains to best topics...
>>> Step 3: Generating domain β†’ genes dictionary...
>>> generate adata.uns['domain_to_genes'],contain 20 domain。 
>>> Pipeline completed.
SAGE.plot.plot_marker_genes(
    adata,
    userep="domain",
    domain='23',
    n_genes=8, 
    ncols=4,
    spot_size=160, 
    out_dir=None,
    figsize=(4, 4), 
    fontsize=10,
    frameon=None, 
    legend_loc='on data', 
    colorbar_loc="right",
    show_title=True,palette_dict=cluster_palette
)

8. Plot UMAP and PAGA graph

adata = sc.read(output_result_dir+'/result.h5')
adata.var_names_make_unique()

SAGE_embed = pd.read_csv(output_result_dir+"/embedding.txt",header = None, delim_whitespace=True)
labels_true = adata.obs["domain"].copy()
SAGE_embed.index = labels_true.index
adata.obsm["SAGE"] = SAGE_embed
adata.obsm["SAGE"].columns = adata.obsm["SAGE"].columns.astype(str)

# Plot
plt.rcParams["figure.figsize"] = (6,5)
sc.pp.neighbors(adata, use_rep="SAGE")
sc.tl.umap(adata)
sc.tl.paga(adata,groups='domain')

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# UMAP
sc.pl.umap(adata,
           color='domain',
           # palette=combine_palette,
           legend_loc='on data',
           legend_fontoutline=3,
           add_outline=True, s=30,
           outline_width=(0.3, 0.05),
           legend_fontsize=10,
           frameon=False,
           ax=axes[0],   
           show=False)  

axes[0].set_title('UMAP', fontsize=20)

umap_coords = pd.DataFrame(adata.obsm['X_umap'],
                           index=adata.obs.index,
                           columns=['UMAP1', 'UMAP2'])

clusters = adata.obs['domain'].astype(str)
cluster_centers = umap_coords.groupby(clusters).mean()

# PAGA
sc.pl.paga(adata,
           color='domain',
           pos=cluster_centers.values,
           node_size_scale=10,
           fontoutline=3,
           frameon=False,
           edge_width_scale=1,
           fontsize=10,
           # fontweight='bold',
           ax=axes[1], 
           show=False)

axes[1].set_title('PAGA', fontsize=20)