Tutorial 4: Label transfer for DLPFC dataset - cmzuo11/stClinic GitHub Wiki
Here, we present our re-label transfer knowledge from stClinic model trained on three slices (151673-151675) to slice 151676 in the human DLPFC dataset. The processed data can be accessed at https://github.com/cmzuo11/stClinic/tree/main/Datasets/DLPFC, where each slice, annotated with clusters, is included in its respective folder.
Preparation
import os
import anndata
import scanpy as sc
import random
import torch
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt
from pathlib import Path
import stClinic as stClinic
warnings.filterwarnings("ignore")
#Set seed
seed = 666
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#Set parameters
used_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
parser = stClinic.parameter_setting()
args = parser.parse_args()
args.out_dir = args.input_dir + 'stClinic/'
Path(args.out_dir).mkdir(parents=True, exist_ok=True)
Load data
section_ids_dict = {'Reference': ['151673','151674','151675'], 'Query': ['151676']}
data_list = []
for (k, v) in section_ids_dict.items():
print(f'Loading {k} Datasets...')
Batch_list, adj_list = [], []
for section_id in v:
print(f'Slice ID: {section_id}')
input_dir = os.path.join(args.input_dir, section_id)
adata = sc.read_visium(path=input_dir, count_file='filtered_feature_bc_matrix.h5', load_images=True)
adata.var_names_make_unique(join="++")
# read the annotation
Ann_df = pd.read_csv(os.path.join(input_dir, section_id + '_annotation.txt'), sep='\t', header=0, index_col=0)
Ann_df.loc[Ann_df['Layer'].isna(),'Layer'] = "unknown"
adata.obs['Ground Truth'] = Ann_df.loc[adata.obs_names, 'Layer'].astype('category')
adata.obs['Identity'] = k
adata.obs_names = [x+'_'+section_id for x in adata.obs_names]
# Constructing the spatial network
Cal_Spatial_Net(adata, rad_cutoff=150)
# Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata = adata[:, adata.var['highly_variable']]
sc.tl.pca(adata, n_comps=10, random_state=seed)
Batch_list.append(adata); adj_list.append(adata.uns['adj'])
# Concat scanpy objects
adata_concat = anndata.concat(Batch_list, label="slice_name", keys=v)
adata_concat.obs['Ground Truth', 'batch_name'](/cmzuo11/stClinic/wiki/'Ground-Truth',-'batch_name') = adata_concat.obs['Ground Truth','slice_name'](/cmzuo11/stClinic/wiki/'Ground-Truth','slice_name').astype('category')
adj_concat = inter_linked_graph(adj_list, v, mnn_dict=None)
adata_concat.uns['adj'], adata_concat.uns['edgeList'] = adj_concat, np.nonzero(adj_concat)
data_list.append(adata_concat)
Pretraining and fine-tuning
adata_ref_map = VGAEX_zero_shot(data_list, pretrained_model=False, params_dir=args.out_dir, device=used_device)
adata_ref_map = adata_ref_map[adata_ref_map.obs['Ground Truth']!='unknown']
adata_refer, adata_query = adata_ref_map[adata_ref_map.obs['Identity']=='Reference',:]
adata_ref_map[adata_ref_map.obs['Identity']=='Query',:]
# Label transfer
adata_query.obs['labels'] = spatial_match(adata_refer, adata_query, label='Ground Truth', use_rep='stClinic')
adata_refer.obs['labels'] = adata_refer.obs['Ground Truth']
Accuracy
import seaborn as sns
from matplotlib.pyplot import rc_context
cmap = sns.light_palette("purple", as_cmap=True)
with rc_context({'figure.figsize': (9, 6)}):
sns.heatmap(pd.crosstab(adata_query.obs['labels'], adata_query.obs['Ground Truth'], normalize=1),
annot=True,
fmt='.2f',
cmap=cmap)
Visualization
plt.rcParams["figure.figsize"] = (3, 3)
plt.rcParams['font.size'] = 10
adata_stClinic = anndata.concat([adata_refer, adata_query], join='outer')
sc.pp.neighbors(adata_stClinic, use_rep='stClinic', random_state=seed)
sc.tl.umap(adata_stClinic, random_state=seed)
sc.pl.umap(adata_stClinic, color=['batch_name'], wspace=0.5, show=False)