Tutorial 4: Label transfer for DLPFC dataset - cmzuo11/stClinic GitHub Wiki

Here, we present our re-label transfer knowledge from stClinic model trained on three slices (151673-151675) to slice 151676 in the human DLPFC dataset. The processed data can be accessed at https://github.com/cmzuo11/stClinic/tree/main/Datasets/DLPFC, where each slice, annotated with clusters, is included in its respective folder.

Preparation

import os
import anndata
import scanpy as sc
import random
import torch
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt

from pathlib import Path
import stClinic as stClinic

warnings.filterwarnings("ignore")
#Set seed
seed    = 666
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#Set parameters
used_device   =  torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
parser        =  stClinic.parameter_setting()
args          =  parser.parse_args()
args.out_dir  = args.input_dir + 'stClinic/'
Path(args.out_dir).mkdir(parents=True, exist_ok=True)

Load data

section_ids_dict = {'Reference': ['151673','151674','151675'], 'Query': ['151676']}
data_list        = []
for (k, v) in section_ids_dict.items():
    print(f'Loading {k} Datasets...')
    Batch_list, adj_list = [], []
    for section_id in v:
        print(f'Slice ID: {section_id}')
        input_dir = os.path.join(args.input_dir, section_id)
        adata = sc.read_visium(path=input_dir, count_file='filtered_feature_bc_matrix.h5', load_images=True)
        adata.var_names_make_unique(join="++")
        # read the annotation
        Ann_df = pd.read_csv(os.path.join(input_dir, section_id + '_annotation.txt'), sep='\t', header=0, index_col=0)
        Ann_df.loc[Ann_df['Layer'].isna(),'Layer'] = "unknown"
        adata.obs['Ground Truth'] = Ann_df.loc[adata.obs_names, 'Layer'].astype('category')
        adata.obs['Identity'] = k
        adata.obs_names = [x+'_'+section_id for x in adata.obs_names]
        # Constructing the spatial network
        Cal_Spatial_Net(adata, rad_cutoff=150)
        # Normalization
        sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000)
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
        adata = adata[:, adata.var['highly_variable']]
        sc.tl.pca(adata, n_comps=10, random_state=seed)
        Batch_list.append(adata); adj_list.append(adata.uns['adj'])
        # Concat scanpy objects
        adata_concat = anndata.concat(Batch_list, label="slice_name", keys=v)
        adata_concat.obs['Ground Truth', 'batch_name'](/cmzuo11/stClinic/wiki/'Ground-Truth',-'batch_name') = adata_concat.obs['Ground Truth','slice_name'](/cmzuo11/stClinic/wiki/'Ground-Truth','slice_name').astype('category')
        adj_concat = inter_linked_graph(adj_list, v, mnn_dict=None)
        adata_concat.uns['adj'], adata_concat.uns['edgeList'] = adj_concat, np.nonzero(adj_concat)
        data_list.append(adata_concat)

Pretraining and fine-tuning

adata_ref_map = VGAEX_zero_shot(data_list, pretrained_model=False, params_dir=args.out_dir, device=used_device)
adata_ref_map = adata_ref_map[adata_ref_map.obs['Ground Truth']!='unknown']
adata_refer, adata_query = adata_ref_map[adata_ref_map.obs['Identity']=='Reference',:]
adata_ref_map[adata_ref_map.obs['Identity']=='Query',:]
# Label transfer
adata_query.obs['labels'] = spatial_match(adata_refer, adata_query, label='Ground Truth', use_rep='stClinic')
adata_refer.obs['labels'] = adata_refer.obs['Ground Truth']

Accuracy

import seaborn as sns

from matplotlib.pyplot import rc_context

cmap = sns.light_palette("purple", as_cmap=True)

with rc_context({'figure.figsize': (9, 6)}):
    sns.heatmap(pd.crosstab(adata_query.obs['labels'], adata_query.obs['Ground Truth'], normalize=1),
                annot=True, 
                fmt='.2f', 
                cmap=cmap)

image

Visualization

plt.rcParams["figure.figsize"] = (3, 3)
plt.rcParams['font.size'] = 10
adata_stClinic = anndata.concat([adata_refer, adata_query], join='outer')
sc.pp.neighbors(adata_stClinic, use_rep='stClinic', random_state=seed)
sc.tl.umap(adata_stClinic, random_state=seed)
sc.pl.umap(adata_stClinic, color=['batch_name'], wspace=0.5, show=False)

image