Tutorial 1: Application on 10x Visium human dorsolateral prefrontal cortex (DLPFC) dataset - yihe-csu/SAGE GitHub Wiki

Here, we present a re-analysis of representative slice 151673 from the human dorsolateral prefrontal cortex (DLPFC) dataset. Maynard et al. manually annotated the cortical layers and white matter (WM) of the DLPFC based on morphological features and gene expression markers. The processed data can be accessed at the following link: 🔗 https://github.com/yihe-csu/SAGE/tree/main/Dataset/DLPFC_151673. This tutorial demonstrates how to identify spatial domains in 10X Visium spatial transcriptomics data using our unsupervised model, SAGE.

Identify HSGs

1. Load SAGE and its dependent packages

import os
import sys
import numpy as np
import pandas as pd
import torch
from sklearn import metrics
import matplotlib.pyplot as plt
import scanpy as sc
import importlib
sys.path.append(os.path.abspath("C://Users//heyi//Desktop/code_iteration/SAGE-main/"))
import SAGE
print(SAGE.__version__)
print(SAGE.__author__)
print(SAGE.__email__)
1.1.8 
Yi He
[email protected]

2. Set up the working environment and import data

# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = 'E:\\R-4.4.1'

base_path = "C:/Users/heyi/Desktop/code_iteration/SAGE-main"
Dataset = "DLPFC_151673"

file_fold = f'{base_path}/Dataset/{Dataset}/'
# Set directory (If you want to use additional data, please change the file path)
output_dir=f"{base_path}/Result/{Dataset}"
output_process_dir = output_dir + "/process_data"
output_result_dir = output_dir + "/result_data"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(output_process_dir):
    os.makedirs(output_process_dir)
if not os.path.exists(output_result_dir):
    os.makedirs(output_result_dir)


# Read data from input_dir
adata = sc.read_visium(file_fold, count_file='filtered_feature_bc_matrix.h5', load_images=True)
adata.var_names_make_unique()
adata.raw =adata.copy()

# add ground_truth
df_meta = pd.read_csv(file_fold + '/metadata.tsv', sep='\t')
df_meta_layer = df_meta['layer_guess']
adata.obs['ground_truth'] = df_meta_layer.values
# filter out NA nodes
adata = adata[~pd.isnull(adata.obs['ground_truth'])]
adata
View of AnnData object with n_obs × n_vars = 3611 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'ground_truth'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

3. Expression data preprocessing

adata = adata.copy()
sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_genes(adata, min_counts=5)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.pca(adata, n_comps=50, mask_var="highly_variable", svd_solver='arpack')

4. Multi-resolution Consensus Clustering

con_method = "mclust"  #["leiden","mclust"]
con_range = ( 2, 9 , 1)
con_use_rep = 'X_pca'
n_neig_coord=6
n_neig_feat =6
con_dim =25
con_radius=20
con_refine=True

SAGE.consensus_clustering(adata, 
                          method=con_method, 
                          resolution_range=con_range, 
                          n_neighbors=n_neig_feat, 
                          use_rep=con_use_rep, 
                          dims=con_dim, 
                          radius=con_radius, 
                          refinement=con_refine)
>>> Starting mclust clustering...
Running mclust clustering:   0%|                                                                 | 0/7 [00:00<?, ?it/s]R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.

fitting ...
  |======================================================================| 100%
Running mclust clustering:  14%|████████▏                                                | 1/7 [00:08<00:52,  8.81s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering:  29%|████████████████▎                                        | 2/7 [00:17<00:44,  8.87s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering:  43%|████████████████████████▍                                | 3/7 [00:26<00:35,  8.85s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering:  57%|████████████████████████████████▌                        | 4/7 [00:37<00:29,  9.72s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering:  71%|████████████████████████████████████████▋                | 5/7 [00:48<00:20, 10.10s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering:  86%|████████████████████████████████████████████████▊        | 6/7 [00:59<00:10, 10.33s/it]
fitting ...
  |======================================================================| 100%
Running mclust clustering: 100%|█████████████████████████████████████████████████████████| 7/7 [01:10<00:00, 10.10s/it]
>>> mclust clustering finished. Time elapsed: 70.71 seconds
>>> Computing consensus matrix...
>>> Consensus matrix computed. Time elapsed: 0.36 seconds
>>> Performing final Leiden clustering on consensus matrix...
>>> Consensus clustering completed.
>>> adata.obs['pre_domain'] generated!
>>> adata.obsm['consensus_freq'] generated!
>>> adata.obsm['clusters_results'] generated!

5.Topic Selection via Supervised Learning

SAGE.preprocess.topics_selection(adata, n_topics=30)
>>> Start the genes_mri calculation ...
>>> Finish genes_mri calculation! Time taken: 8.41 seconds
>>> adata.var['genes_mri'] generated!
>>> Step1: Starting NMF calculation...
>>> NMF completed. 30 topics identified. Time taken: 32.20 seconds
>>> Step2: Filtering based on Topics Moran's I values...
>>> Stpe3: Filtering based on Random Forest importances...
>>> 17 topics selected !
>>> adata.obsm['W_nmf'], adata.varm['H_nmf'] generate!
>>> adata.uns['topic_mri_list'] generate!
>>> adata.uns['sorted_indices_imp'] generate!
>>> adata.uns['final_topics'] generate!

6. Draw Topics detected by RF

SAGE.plot.plot_topics(adata, 
                    uns_key="final_topics", 
                    img_key=None, 
                    spot_size=180, 
                    ncols=5,
                    figsize=(4, 4),
                    fontsize=10,
                    frameon=False, 
                    legend_loc=False, 
                    colorbar_loc=None,
                    show=True)

7. Construction of high-specificity genes (HSGs)

SAGE.preprocess.genes_selection(adata, n_genes=3000)
percentile_genes
    Topic gene_lower  rank_lower gene_upper  rank_upper
0       0      WASF1         113      DCTN6        2927
1       1      RPL13          45   TMEM179B        2852
2       2     DNAJC6         123      SNTA1        2924
3       3  MTRNR2L12          17    C3orf18        2657
4       4      RPS3A         116      CHPF2        2921
5       5     COX6A1          33       NRAS        2731
6       6       NEFL          78      TRPC1        2898
7       7      ACOT7         152       VAT1        2928
8       8      RPLP2          74       PBLD        2897
9       9       EEF2          67       H1F0        2908
10     10      COX5B          47      MEF2A        2896
11     11     PODXL2         138       KLF3        2925
12     12   SLC9A3R1          65     CC2D1A        2893
13     13     TMSB4X          95     ZNF593        2908
number_svgs:707
number_no_svgs:1289
>>> HSGs selected!
>>> adata.uns['HSG'] generate!
>>> adata.varm['gene_topic_corr'] generate! (method='pearson')
>>> adata.uns['marker_genes_all_dict'] generate! (corr_threshold=0.2)

Clustering

1. Consensus-driven Graph Construction

SAGE.preprocess.optimize_graph_topology(
    adata,
    n_neig_coord=6,
    cut_side_thr=0.3,
    min_neighbor=2,
    n_neig_feat=15,
)
>>> Graph_coord constructed!
>>> adata.obsm['adj_opt'] generate!
>>> Graph_feat constructed!
>>> adata.obsm['adj_feat'] generate!

2. Draw SAG edge

SAGE.plot.plot_neighbors_cut(adata,
                   img_key=None,
                   spot_size=180,
                   figsize=(5, 5),
                   frameon=False,
                   legend_loc=None,
                   colorbar_loc=None,
                   show=True)

3. Run SAGE

# define model
model = SAGE.SAGE(adata, device=device, epochs=800)
adata = model.train()
Optimization finished for ST data

4. Domain clustering

n_clusters = len(adata.obs["ground_truth"].unique())
SAGE.utils.clustering(adata, 
                     data= adata.obsm["emb_latent_att"], 
                     method='mclust', 
                     n_clusters=n_clusters, 
                     res = 0.3, 
                     radius=30,  
                     refinement=True)
fitting ...
  |======================================================================| 100%
>>> Clustering completed using mclust with 7 clusters.
>>> adata.obsm['domain'] & ['mclust'] generate!

5. Save result

adata.write_h5ad(output_result_dir+"/result.h5")
clusters = adata.obs["domain"] 
clusters.to_csv(output_result_dir+f"/clusters.csv",header = False)
embedding = adata.obsm["emb_latent_att"]
np.savetxt(output_result_dir+"/embedding.txt",embedding)

SAGE.utils.export_H_zscore_to_csv(adata, out_dir=output_process_dir)
SAGE.utils.export_Corr_to_csv(adata, out_dir=output_process_dir)

6. Plot clustering results

# Read data
adata = sc.read(output_result_dir+'/result.h5')

# Plot clustering results and Calculate NMI
combine_palette = {
    '0': '#F4E8B8',   '1': '#058187',   '2': '#632621',   '3': '#F4E8B8',
    '4': '#B44541',   '5': '#0C4DA7',   '6': '#EEC30E',   '7': '#8F0100',
    'Layer1': '#EEC30E', 'Layer2': '#0C4DA7', 'Layer3': '#F4E8B8', 'Layer4': '#B44541',
    'Layer5': '#632621', 'Layer6': '#058187', 'WM': '#8F0100'
                  }
plt.rcParams["figure.figsize"] = (4,4)
labels_true = adata.obs["ground_truth"]
labels_pred = adata.obs["domain"]
NMI = metrics.normalized_mutual_info_score(labels_true, labels_pred)
sc.pl.spatial(adata, 
            img_key = None, 
            spot_size = 180,
            color = ["ground_truth","domain"],
            title = ["Manual annotation",f'151673 (NMI={NMI:.3f})'], 
            palette=combine_palette,  
            na_in_legend = False,
            frameon=False,
            ncols = 2,
            size = 1,
            show= False)

7. Plot UMAP and PAGA graph

SAGE_embed = pd.read_csv(output_result_dir+"/embedding.txt",header = None, delim_whitespace=True)
labels_true = adata.obs["ground_truth"].copy()
SAGE_embed.index = labels_true.index
adata.obsm["SAGE"] = SAGE_embed

# Plot
plt.rcParams["figure.figsize"] = (6,5)
sc.pp.neighbors(adata, use_rep="SAGE")
sc.tl.umap(adata)
sc.tl.paga(adata,groups='ground_truth')

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

sc.pl.umap(adata,
           color='ground_truth',
           palette=combine_palette,
           legend_loc='on data',
           legend_fontoutline=5,
           add_outline=True, s=150,
           outline_width=(0.8, 0.05),
           legend_fontsize=25,
           frameon=False,
           ax=axes[0],   
           show=False)  

axes[0].set_title('UMAP', fontsize=25)

umap_coords = pd.DataFrame(adata.obsm['X_umap'],
                           index=adata.obs.index,
                           columns=['UMAP1', 'UMAP2'])

clusters = adata.obs['ground_truth'].astype(str)
cluster_centers = umap_coords.groupby(clusters).mean()

sc.pl.paga(adata,
           color='ground_truth',
           pos=cluster_centers.values,
           node_size_scale=30,
           fontoutline=5,
           frameon=False,
           edge_width_scale=3,
           fontsize=25,
           fontweight='bold',
           ax=axes[1], 
           show=False)

axes[1].set_title('PAGA', fontsize=25)