Integration, label transfer and multi-scale analysis with scPoli¶
In this notebook we demonstrate an example workflow of data integration, reference mapping, label transfer and multi-scale analysis of sample and cell embeddings using scPoli. We integrate pancreas data obtained from the scArches reproducibility repository. The data can be downloaded from figshare.
[1]:
import os
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from sklearn.metrics.pairwise import cosine_similarity
from scarches.dataset.trvae.data_handling import remove_sparsity
from scarches.models.scpoli import scPoli
import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
WARNING:root:In order to use the mouse gastrulation seqFISH datsets, please install squidpy (see https://github.com/scverse/squidpy).
WARNING:root:In order to use sagenet models, please install pytorch geometric (see https://pytorch-geometric.readthedocs.io) and
captum (see https://github.com/pytorch/captum).
INFO:pytorch_lightning.utilities.seed:Global seed set to 0
/home/icb/carlo.dedonno/anaconda3/envs/scarches/lib/python3.10/site-packages/pytorch_lightning/utilities/warnings.py:53: LightningDeprecationWarning: pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6 and will be removed in v1.8. Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.
new_rank_zero_deprecation(
/home/icb/carlo.dedonno/anaconda3/envs/scarches/lib/python3.10/site-packages/pytorch_lightning/utilities/warnings.py:58: LightningDeprecationWarning: The `pytorch_lightning.loggers.base.rank_zero_experiment` is deprecated in v1.7 and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.rank_zero_experiment` instead.
return new_rank_zero_deprecation(*args, **kwargs)
WARNING:root:mvTCR is not installed. To use mvTCR models, please install it first using "pip install mvtcr"
WARNING:root:multigrate is not installed. To use multigrate models, please install it first using "pip install multigrate".
[2]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = (4, 4)
[3]:
adata = sc.read('../../../lataq_reproduce/data/pancreas.h5ad')
adata
[3]:
AnnData object with n_obs × n_vars = 16382 × 4000
obs: 'study', 'cell_type', 'pred_label', 'pred_score'
obsm: 'X_seurat', 'X_symphony'
[4]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
WARNING: You’re trying to run this on 4000 dimensions of `.X`, if you really want this, set `use_rep='X'`.
Falling back to preprocessing with `sc.pp.pca` and default params.
[5]:
sc.pl.umap(adata, color=['study', 'cell_type'], wspace=0.5)

[6]:
early_stopping_kwargs = {
"early_stopping_metric": "val_prototype_loss",
"mode": "min",
"threshold": 0,
"patience": 20,
"reduce_lr": True,
"lr_patience": 13,
"lr_factor": 0.1,
}
condition_key = 'study'
cell_type_key = ['cell_type']
reference = [
'inDrop1',
'inDrop2',
'inDrop3',
'inDrop4',
'fluidigmc1',
'smartseq2',
'smarter'
]
query = ['celseq', 'celseq2']
Reference - query split¶
We split our data in a group of reference datasets to be used for reference building, and a group of query datasets that we will map.
In order to simulate an unknown cell type scenario, we manually remove beta cells from the reference.
[7]:
adata.obs['query'] = adata.obs[condition_key].isin(query)
adata.obs['query'] = adata.obs['query'].astype('category')
source_adata = adata[adata.obs.study.isin(reference)].copy()
source_adata = source_adata[~source_adata.obs.cell_type.str.contains('alpha')].copy()
target_adata = adata[adata.obs.study.isin(query)].copy()
[8]:
source_adata, target_adata
[8]:
(AnnData object with n_obs × n_vars = 8634 × 4000
obs: 'study', 'cell_type', 'pred_label', 'pred_score', 'query'
uns: 'neighbors', 'umap', 'study_colors', 'cell_type_colors'
obsm: 'X_seurat', 'X_symphony', 'X_pca', 'X_umap'
obsp: 'distances', 'connectivities',
AnnData object with n_obs × n_vars = 3289 × 4000
obs: 'study', 'cell_type', 'pred_label', 'pred_score', 'query'
uns: 'neighbors', 'umap', 'study_colors', 'cell_type_colors'
obsm: 'X_seurat', 'X_symphony', 'X_pca', 'X_umap'
obsp: 'distances', 'connectivities')
Train reference scPoli model on fully labeled reference data¶
[9]:
scpoli_model = scPoli(
adata=source_adata,
condition_key=condition_key,
cell_type_keys=cell_type_key,
embedding_dim=3,
)
Embedding dictionary:
Num conditions: 7
Embedding dim: 3
Encoder Architecture:
Input Layer in, out and cond: 4000 256 3
Hidden Layer 1 in/out: 256 64
Mean/Var Layer in/out: 64 10
Decoder Architecture:
First Layer in, out and cond: 10 64 3
Hidden Layer 1 in/out: 64 256
Output Layer in/out: 256 4000
We recommend using a pretraining/training epoch ratio of approximately 80 or 90%. If you train for more total epochs you should use a higher ratio, whereas if you’re training for only a few epochs, this ratio can be smaller. If the model is trained withthe prototype loss for too many epochs it can lead to very concentrated clusters in latent space.
[10]:
scpoli_model.train(
n_epochs=50,
pretraining_epochs=40,
early_stopping_kwargs=early_stopping_kwargs,
eta=5,
)
|████████████████████| 100.0% - val_loss: 1084.3780866350 - val_trvae_loss: 1076.3039550781 - val_prototype_loss: 8.0741387095 - val_labeled_loss: 1.6148277521
Saving best state of network...
Best State was in Epoch 48
Reference mapping of unlabeled query datasets¶
[11]:
scpoli_query = scPoli.load_query_data(
adata=target_adata,
reference_model=scpoli_model,
labeled_indices=[],
)
Embedding dictionary:
Num conditions: 9
Embedding dim: 3
Encoder Architecture:
Input Layer in, out and cond: 4000 256 3
Hidden Layer 1 in/out: 256 64
Mean/Var Layer in/out: 64 10
Decoder Architecture:
First Layer in, out and cond: 10 64 3
Hidden Layer 1 in/out: 64 256
Output Layer in/out: 256 4000
[12]:
scpoli_query.train(
n_epochs=50,
pretraining_epochs=40,
eta=5
)
Warning: Labels in adata.obs[cell_type] is not a subset of label-encoder!
Therefore integer value of those labels is set to -1
Warning: Labels in adata.obs[cell_type] is not a subset of label-encoder!
Therefore integer value of those labels is set to -1
|████████████████----| 80.0% - val_loss: 1849.8289388021 - val_trvae_loss: 1849.8289388021
Initializing unlabeled prototypes with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 18 clusters.
|████████████████████| 100.0% - val_loss: 1831.1854248047 - val_trvae_loss: 1831.1854248047 - val_prototype_loss: 0.0000000000 - val_unlabeled_loss: 0.0000000000
Saving best state of network...
Best State was in Epoch 49
Label transfer from reference to query¶
[13]:
results_dict = scpoli_query.classify(
target_adata.X,
target_adata.obs[condition_key],
)
Let’s check the label transfer performance we achieved.
[14]:
for i in range(len(cell_type_key)):
preds = results_dict[cell_type_key[i]]["preds"]
results_dict[cell_type_key[i]]["uncert"]
classification_df = pd.DataFrame(
classification_report(
y_true=target_adata.obs[cell_type_key[i]],
y_pred=preds,
output_dict=True,
)
).transpose()
print(classification_df)
precision recall f1-score support
acinar 0.941399 0.992032 0.966052 502.000000
activated_stellate 0.908333 1.000000 0.951965 109.000000
alpha 0.000000 0.000000 0.000000 1034.000000
beta 0.903177 0.985149 0.942384 606.000000
delta 0.772586 0.980237 0.864111 253.000000
ductal 0.985989 0.962393 0.974048 585.000000
endothelial 1.000000 1.000000 1.000000 26.000000
epsilon 0.007278 1.000000 0.014451 5.000000
gamma 0.375740 0.992188 0.545064 128.000000
macrophage 1.000000 0.937500 0.967742 16.000000
mast 0.833333 0.714286 0.769231 7.000000
quiescent_stellate 1.000000 0.615385 0.761905 13.000000
schwann 0.833333 1.000000 0.909091 5.000000
t_cell 0.000000 0.000000 0.000000 0.000000
accuracy 0.670721 0.670721 0.670721 0.670721
macro avg 0.682941 0.798512 0.690432 3289.000000
weighted avg 0.609399 0.670721 0.632230 3289.000000
[15]:
#get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
source_adata.X,
source_adata.obs[condition_key].values,
mean=True
)
adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()
#get latent representation of query data
data_latent= scpoli_query.get_latent(
target_adata.X,
target_adata.obs[condition_key].values,
mean=True
)
adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()
#get label annotations
adata_latent.obs['cell_type_pred'] = results_dict['cell_type']['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict['cell_type']['uncert'].tolist()
adata_latent.obs['classifier_outcome'] = (
adata_latent.obs['cell_type_pred'] == adata_latent.obs['cell_type']
)
#get prototypes
labeled_prototypes = scpoli_query.get_prototypes_info()
labeled_prototypes.obs['study'] = 'labeled prototype'
unlabeled_prototypes = scpoli_query.get_prototypes_info(prototype_set='unlabeled')
unlabeled_prototypes.obs['study'] = 'unlabeled prototype'
#join adatas
adata_latent_full = adata_latent_source.concatenate(
[adata_latent, labeled_prototypes, unlabeled_prototypes],
batch_key='query'
)
adata_latent_full.obs['cell_type_pred'][adata_latent_full.obs['query'].isin(['0'])] = np.nan
sc.pp.neighbors(adata_latent_full, n_neighbors=15)
sc.tl.umap(adata_latent_full)
[16]:
#get adata without prototypes
adata_no_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['0', '1'])]
[17]:
sc.pl.umap(
adata_no_prototypes,
color='cell_type_pred',
show=False,
frameon=False,
)
[17]:
<AxesSubplot: title={'center': 'cell_type_pred'}, xlabel='UMAP1', ylabel='UMAP2'>

[18]:
sc.pl.umap(
adata_no_prototypes,
color='study',
show=False,
frameon=False,
)
[18]:
<AxesSubplot: title={'center': 'study'}, xlabel='UMAP1', ylabel='UMAP2'>

Inspect uncertainty¶
We can look at the uncertainty of each prediction and either select a threshold after visual inspection or by looking at the percentiles of the uncertainties distribution.
[19]:
sc.pl.umap(
adata_no_prototypes,
color='cell_type_uncert',
show=False,
frameon=False,
cmap='magma',
vmax=1
)
[19]:
<AxesSubplot: title={'center': 'cell_type_uncert'}, xlabel='UMAP1', ylabel='UMAP2'>

Inspect prototypes¶
[20]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
adata_labeled_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['2'])]
adata_unlabeled_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['3'])]
adata_labeled_prototypes.obs['cell_type_pred'] = adata_labeled_prototypes.obs['cell_type_pred'].astype('category')
adata_unlabeled_prototypes.obs['cell_type_pred'] = adata_unlabeled_prototypes.obs['cell_type_pred'].astype('category')
adata_unlabeled_prototypes.obs['cell_type'] = adata_unlabeled_prototypes.obs['cell_type'].astype('category')
sc.pl.umap(
adata_no_prototypes,
alpha=0.2,
show=False,
ax=ax
)
ax.legend([])
# plot labeled prototypes
sc.pl.umap(
adata_labeled_prototypes,
size=200,
color=f'{cell_type_key[0]}_pred',
ax=ax,
show=False,
frameon=False,
)
# plot labeled prototypes
sc.pl.umap(
adata_unlabeled_prototypes,
size=100,
color=[cell_type_key[0] + '_pred'],
palette=adata_labeled_prototypes.uns['cell_type_pred_colors'],
ax=ax,
show=False,
frameon=False,
alpha=0.5,
)
sc.pl.umap(
adata_unlabeled_prototypes,
size=0,
color=cell_type_key[0],
frameon=False,
ax=ax,
legend_loc='on data',
legend_fontsize=5,
)
ax.set_title('Landmarks')
fig.tight_layout()

After inspecting the prototypes we can observe that unlabeled prototype 4, 5, 7, 8, 11 and 13 fall into the region of high uncertainty. With this knowledge, we can add a new labeled prototype.
[21]:
scpoli_query.add_new_cell_type(
"alpha",
cell_type_key[0],
[4, 5, 7, 8, 11, 13]
)
[22]:
results_dict = scpoli_query.classify(
target_adata.X,
target_adata.obs[condition_key],
)
[23]:
#get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
source_adata.X,
source_adata.obs[condition_key].values,
mean=True
)
adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()
#get latent representation of query data
data_latent= scpoli_query.get_latent(
target_adata.X,
target_adata.obs[condition_key].values,
mean=True
)
adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()
#get label annotations
adata_latent.obs['cell_type_pred'] = results_dict['cell_type']['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict['cell_type']['uncert'].tolist()
adata_latent.obs['classifier_outcome'] = (
adata_latent.obs['cell_type_pred'] == adata_latent.obs['cell_type']
)
#join adatas
adata_latent_full = adata_latent_source.concatenate(
[adata_latent, labeled_prototypes, unlabeled_prototypes],
batch_key='query'
)
adata_latent_full.obs['cell_type_pred'][adata_latent_full.obs['query'].isin(['0'])] = np.nan
sc.pp.neighbors(adata_latent_full, n_neighbors=15)
sc.tl.umap(adata_latent_full)
[24]:
sc.pl.umap(
adata_latent_full,
color='cell_type_pred',
show=False,
frameon=False,
)
[24]:
<AxesSubplot: title={'center': 'cell_type_pred'}, xlabel='UMAP1', ylabel='UMAP2'>

We can now see that the alpha cell cluster is correctly classified.
Sample embeddings¶
We can extract the conditional embeddings learnt by scPoli and analyse them.
[25]:
adata_emb = scpoli_query.get_conditional_embeddings()
[28]:
from sklearn.decomposition import KernelPCA
pca = KernelPCA(n_components=2, kernel='linear')
emb_pca = pca.fit_transform(adata_emb.X)
conditions = scpoli_query.conditions_
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sns.scatterplot(x=emb_pca[:, 0], y=emb_pca[:, 1], hue=conditions, ax=ax)
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
for i, c in enumerate(conditions):
ax.plot([0, emb_pca[i, 0]], [0, emb_pca[i, 1]])
ax.text(emb_pca[i, 0], emb_pca[i, 1], c)
sns.despine()
