Tutorial 3: Application on Mouse Brain Dataset - yihe-csu/SAGE GitHub Wiki
This tutorial demonstrates the analysis of spatial transcriptomics data from the mouse brain, a complex tissue with fine-grained and highly organized anatomical structures. To evaluate the generalization capability and structural robustness of our unsupervised model SAGE, we apply it to the 10x Visium mouse brain dataset, using a fixed number of clusters (n = 26).
The dataset includes a sagittal section composed of anterior and posterior brain slices. We integrate these two slices to assess the continuity of gene expression and structural consistency at the slice junctions. The inferred spatial domains are validated against the Allen Mouse Brain Atlas to ensure biological relevance and accuracy.
Due to the large size of the processed dataset, we have uploaded it to Figshare: π https://figshare.com/articles/dataset/Mouse_Brain_Merge/29492798. Please download the dataset and place it in the appropriate directory before running this tutorial.
Identify HSGs
1. Load SAGE and its dependent packages
import os
import sys
import numpy as np
import pandas as pd
import torch
from sklearn import metrics
import matplotlib.pyplot as plt
import scanpy as sc
import importlib
sys.path.append(os.path.abspath("C://Users//heyi//Desktop/code_iteration/SAGE-main/"))
import SAGE
print(SAGE.__version__)
print(SAGE.__author__)
print(SAGE.__email__)
1.1.8
Yi He
[email protected]
2. Set up the working environment and import data
# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = 'E:\\R-4.4.1'
base_path = "C:/Users/heyi/Desktop/code_iteration/SAGE-main/"
Dataset = "Mouse_Brain_Merge"
file_fold = f'{base_path}/Dataset/{Dataset}/'
# Set directory (If you want to use additional data, please change the file path)
output_dir=f"{base_path}/Result/{Dataset}"
output_process_dir = output_dir + "/process_data"
output_result_dir = output_dir + "/result_data"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(output_process_dir):
os.makedirs(output_process_dir)
if not os.path.exists(output_result_dir):
os.makedirs(output_result_dir)
# Read data from input_dir
adata = sc.read_h5ad(file_fold+'\\filtered_feature_bc_matrix.h5ad')
adata.var_names_make_unique()
adata.obs.index = adata.obs.index + "_" + adata.obs.groupby(adata.obs.index).cumcount().astype(str)
adata
AnnData object with n_obs Γ n_vars = 6050 Γ 32285
obs: 'in_tissue', 'array_row', 'array_col'
obsm: 'spatial'
3. Expression data preprocessing
adata = adata.copy()
sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_genes(adata, min_counts=5)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.pca(adata, n_comps=50, mask_var="highly_variable", svd_solver='arpack')
4. Multi-resolution Consensus Clustering
con_method = "leiden" #["leiden","mclust"]
con_range = ( 0.1, 3 , 0.1)
con_use_rep = 'X_pca'
n_neig_coord=6
n_neig_feat =6
con_dim =25
con_radius=20
con_refine=True
SAGE.consensus_clustering(adata,
method=con_method,
resolution_range=con_range,
n_neighbors=n_neig_feat,
use_rep=con_use_rep,
dims=con_dim,
radius=con_radius,
refinement=con_refine)
>>> Starting leiden clustering...
Running leiden clustering: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 29/29 [00:08<00:00, 3.31it/s]
>>> leiden clustering finished. Time elapsed: 15.43 seconds
>>> Computing consensus matrix...
>>> Consensus matrix computed. Time elapsed: 3.07 seconds
>>> Performing final Leiden clustering on consensus matrix...
>>> Consensus clustering completed.
>>> adata.obs['pre_domain'] generated!
>>> adata.obsm['consensus_freq'] generated!
>>> adata.obsm['clusters_results'] generated!
5.Topic Selection via Supervised Learning
SAGE.preprocess.topics_selection(adata, n_topics=40)
>>> Start the genes_mri calculation ...
>>> Finish genes_mri calculation! Time taken: 5.12 seconds
>>> adata.var['genes_mri'] generated!
>>> Step1: Starting NMF calculation...
>>> NMF completed. 40 topics identified. Time taken: 61.36 seconds
>>> Step2: Filtering based on Topics Moran's I values...
>>> Stpe3: Filtering based on Random Forest importances...
>>> 32 topics selected !
>>> adata.obsm['W_nmf'], adata.varm['H_nmf'] generate!
>>> adata.uns['topic_mri_list'] generate!
>>> adata.uns['sorted_indices_imp'] generate!
>>> adata.uns['final_topics'] generate!
6. Draw Topics detected by RF
SAGE.plot.plot_topics(adata,
uns_key="final_topics",
img_key=None,
spot_size=300,
ncols=8,
figsize=(4, 4),
fontsize=20,
frameon=False,
legend_loc=False,
colorbar_loc=None,
show=False)
7. Construction of high-specificity genes (HSGs)
SAGE.preprocess.genes_selection(adata, n_genes=3000, lower_p=10, upper_p=99)
percentile_genes
Topic gene_lower rank_lower gene_upper rank_upper
0 0 Caly 70 Fem1a 2935
1 1 Gabra1 80 Cep83 2933
2 2 Kcnc1 79 Rbm4b 2927
3 3 Syn1 74 Kmt2a 2924
4 4 Pkm 106 Abhd14b 2943
5 5 Rpl4 103 Asns 2941
6 6 Ywhag 92 F730043M19Rik 2943
7 7 Rtl8a 72 Rab15 2923
8 8 Efnb3 60 Rhpn2 2913
9 9 Atp1b1 82 Evi5 2931
10 10 Napb 108 Ndufs1 2942
11 11 Rpl38 53 Cdk16 2882
12 12 Rpl9 35 Actr3b 2856
13 13 Camk2n1 61 Dcaf11 2920
14 14 Hsp90ab1 101 Chmp1a 2946
15 15 Slc1a3 87 Wdr1 2938
16 16 Dst 106 Rusc2 2941
17 17 Dnm1 79 Sf3b3 2926
18 18 Cyfip2 77 Rtl6 2933
19 19 Abr 97 Timm29 2940
20 20 Trim9 75 Srxn1 2940
21 21 Tmsb4x 22 Gabrd 2705
22 22 Ywhab 73 Bckdha 2927
23 23 Serpine2 78 Pdgfrb 2935
24 24 Cox4i1 35 Tdrp 2919
25 25 Calb1 98 Ccl27a 2939
26 26 Nupr1 62 Klhdc8b 2927
27 27 Ppia 18 Atxn1l 2858
28 28 Mmd2 46 Stxbp1 2916
29 29 Rnf13 48 Wdr45b 2899
30 30 Sumo2 86 Ppp2r3a 2938
31 31 Mmd2 47 Cr1l 2922
number_svgs:932
number_no_svgs:766
>>> HSGs selected!
>>> adata.uns['HSG'] generate!
>>> adata.varm['gene_topic_corr'] generate! (method='pearson')
>>> adata.uns['marker_genes_all_dict'] generate! (corr_threshold=0.2)
Clustering
1. Consensus-driven Graph Construction
SAGE.preprocess.optimize_graph_topology(
adata,
n_neig_coord=6,
cut_side_thr=0.2,
min_neighbor=2,
n_neig_feat=15,
)
>>> Graph_coord constructed!
>>> adata.obsm['adj_opt'] generate!
>>> Graph_feat constructed!
>>> adata.obsm['adj_feat'] generate!
2. Draw SAG edge
SAGE.plot.plot_neighbors_cut(adata,
img_key=None,
spot_size=200,
figsize=(5, 5),
frameon=False,
legend_loc=None,
colorbar_loc=None,
show=False)
3. Run SAGE
# define model
model = SAGE.SAGE(adata, device=device, epochs=800)
adata = model.train()
>>> SAGE model construction completed, training begins.
>>> Optimization finished for ST data.
4. Domain clustering
n_clusters=26
SAGE.utils.clustering(adata,
data= adata.obsm["emb_latent_att"],
method='mclust',
n_clusters=n_clusters,
res = 0.3,
radius=30,
refinement=False)
R[write to console]: __ __
____ ___ _____/ /_ _______/ /_
/ __ `__ \/ ___/ / / / / ___/ __/
/ / / / / / /__/ / /_/ (__ ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/ version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.
fitting ...
|======================================================================| 100%
>>> Clustering completed using mclust with 26 clusters.
>>> adata.obsm['domain'] & ['mclust'] generate!
5. Save result
adata.write_h5ad(output_result_dir+"/result.h5")
clusters = adata.obs["domain"]
clusters.to_csv(output_result_dir+f"/clusters.csv",header = False)
embedding = adata.obsm["emb_latent_att"]
np.savetxt(output_result_dir+"/embedding.txt",embedding)
SAGE.utils.export_H_zscore_to_csv(adata, out_dir=output_process_dir)
SAGE.utils.export_Corr_to_csv(adata, out_dir=output_process_dir)
6. Plot clustering results
# Read data
adata = sc.read(output_result_dir+'/result.h5')
plt.rcParams["figure.figsize"] = (8,8)
plt.rcParams['font.size'] = 10
labels_pred = adata.obs["domain"]
X = adata.obsm["spatial"]
SC = metrics.silhouette_score(X, labels_pred)
DB = metrics.davies_bouldin_score(X, labels_pred)
custom_palette = [
'#F4E8B8', '#EEC30E', '#8F0100', '#058187', '#0C4DA7',
'#B44541', '#632621', '#92C7A3', '#D98882', '#6A93CB',
'#F0C94A', '#AD6448', '#4F6A9C', '#CCB9A1', '#0B3434',
'#3C4F76', '#C1D354', '#7D5BA6', '#F28522', '#4A9586',
'#FF6F61', '#D32F2F', '#1976D2', '#388E3C', '#FBC02D',
'#8E24AA', '#0288D1', '#7B1FA2', '#F57C00', '#C2185B',
'#1B4F72', '#117864', '#D4AC0D', '#922B21', '#6C3483',
'#1F618D', '#A04000', '#196F3D', '#2C3E50', '#F39C12',
'#7D6608', '#4A235A', '#D68910', '#B03A2E', '#7B241C',
'#2471A3', '#148F77', '#9C640C', '#6E2C00', '#512E5F',
'#154360', '#145A32'
]
cluster_palette = {int(cluster): custom_palette[i % len(custom_palette)]
for i, cluster in enumerate(adata.obs['domain'].unique())}
sc.pl.spatial(adata,
img_key = None,
spot_size = 160,
color = ["domain"],
title = [f'(SC={SC:.3f} DB={DB:.3f})'],
na_in_legend = False,
palette = cluster_palette,
frameon=False,
show= False)
plt.tight_layout()
plt.show()
7. Marker genes detection
SAGE.utils.run_domain_gene_mapping_pipeline(
adata,
useref="domain",
topic_key="W_nmf",
corr_threshold=0.2,
topic_gene_corr_key="gene_topic_corr",
domain_mapping_topic_key="domain_mapping_topic",
marker_genes_dict_key="marker_genes_multi_dict",
domain_to_genes_key="domain_to_genes"
)
>>> Step 1: Generating marker genes for topics...
>>> Step 2: Matching domains to best topics...
>>> Step 3: Generating domain β genes dictionary...
>>> generate adata.uns['domain_to_genes']οΌcontain 20 domainγ
>>> Pipeline completed.
SAGE.plot.plot_marker_genes(
adata,
userep="domain",
domain='23',
n_genes=8,
ncols=4,
spot_size=160,
out_dir=None,
figsize=(4, 4),
fontsize=10,
frameon=None,
legend_loc='on data',
colorbar_loc="right",
show_title=True,palette_dict=cluster_palette
)
8. Plot UMAP and PAGA graph
adata = sc.read(output_result_dir+'/result.h5')
adata.var_names_make_unique()
SAGE_embed = pd.read_csv(output_result_dir+"/embedding.txt",header = None, delim_whitespace=True)
labels_true = adata.obs["domain"].copy()
SAGE_embed.index = labels_true.index
adata.obsm["SAGE"] = SAGE_embed
adata.obsm["SAGE"].columns = adata.obsm["SAGE"].columns.astype(str)
# Plot
plt.rcParams["figure.figsize"] = (6,5)
sc.pp.neighbors(adata, use_rep="SAGE")
sc.tl.umap(adata)
sc.tl.paga(adata,groups='domain')
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
# UMAP
sc.pl.umap(adata,
color='domain',
# palette=combine_palette,
legend_loc='on data',
legend_fontoutline=3,
add_outline=True, s=30,
outline_width=(0.3, 0.05),
legend_fontsize=10,
frameon=False,
ax=axes[0],
show=False)
axes[0].set_title('UMAP', fontsize=20)
umap_coords = pd.DataFrame(adata.obsm['X_umap'],
index=adata.obs.index,
columns=['UMAP1', 'UMAP2'])
clusters = adata.obs['domain'].astype(str)
cluster_centers = umap_coords.groupby(clusters).mean()
# PAGA
sc.pl.paga(adata,
color='domain',
pos=cluster_centers.values,
node_size_scale=10,
fontoutline=3,
frameon=False,
edge_width_scale=1,
fontsize=10,
# fontweight='bold',
ax=axes[1],
show=False)
axes[1].set_title('PAGA', fontsize=20)