Integration, label transfer and multi-scale analysis with scPoli

In this notebook we demonstrate an example workflow of data integration, reference mapping, label transfer and multi-scale analysis of sample and cell embeddings using scPoli. We integrate pancreas data obtained from the scArches reproducibility repository. The data can be downloaded from figshare.

[1]:
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
WARNING:root:In order to use the mouse gastrulation seqFISH datsets, please install squidpy (see https://github.com/scverse/squidpy).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
WARNING:root:In order to use sagenet models, please install pytorch geometric (see https://pytorch-geometric.readthedocs.io) and
 captum (see https://github.com/pytorch/captum).
WARNING:root:mvTCR is not installed. To use mvTCR models, please install it first using "pip install mvtcr"
WARNING:root:multigrate is not installed. To use multigrate models, please install it first using "pip install multigrate".
[2]:
sc.settings.set_figure_params(dpi=100, frameon=False)
sc.set_figure_params(dpi=100)
sc.set_figure_params(figsize=(3, 3))
plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.figsize'] = (3, 3)

Data download

[3]:
!mkdir tmp
!wget -O tmp/pancreas.h5ad https://figshare.com/ndownloader/files/41581626
mkdir: tmp: File exists
--2023-07-17 12:47:23--  https://figshare.com/ndownloader/files/41581626
Resolving figshare.com (figshare.com)... 34.250.148.102, 34.242.105.80
Connecting to figshare.com (figshare.com)|34.250.148.102|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/41581626/pancreas_sparse.h5ad?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20230717/eu-west-1/s3/aws4_request&X-Amz-Date=20230717T104723Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=88c2fa94548ab7326e567ed762e1b9275fb98b77b8c56b3bbf838213c24f1db7 [following]
--2023-07-17 12:47:23--  https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/41581626/pancreas_sparse.h5ad?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20230717/eu-west-1/s3/aws4_request&X-Amz-Date=20230717T104723Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=88c2fa94548ab7326e567ed762e1b9275fb98b77b8c56b3bbf838213c24f1db7
Resolving s3-eu-west-1.amazonaws.com (s3-eu-west-1.amazonaws.com)... 52.218.122.88, 52.218.116.136, 52.92.37.0, ...
Connecting to s3-eu-west-1.amazonaws.com (s3-eu-west-1.amazonaws.com)|52.218.122.88|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 83620692 (80M) [application/octet-stream]
Saving to: ‘tmp/pancreas.h5ad’

tmp/pancreas.h5ad   100%[===================>]  79.75M  4.80MB/s    in 16s

2023-07-17 12:47:40 (4.86 MB/s) - ‘tmp/pancreas.h5ad’ saved [83620692/83620692]

[4]:
adata = sc.read('tmp/pancreas.h5ad')
adata
[4]:
AnnData object with n_obs × n_vars = 16382 × 4000
    obs: 'study', 'cell_type'
[5]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
WARNING: You’re trying to run this on 4000 dimensions of `.X`, if you really want this, set `use_rep='X'`.
         Falling back to preprocessing with `sc.pp.pca` and default params.
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
[6]:
sc.pl.umap(adata, color=['study', 'cell_type'], wspace=0.5, frameon=False)
_images/scpoli_surgery_pipeline_7_0.png
[7]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

condition_key = 'study'
cell_type_key = 'cell_type'
reference = [
    'inDrop1',
    'inDrop2',
    'inDrop3',
    'inDrop4',
    'fluidigmc1',
    'smartseq2',
    'smarter'
]
query = ['celseq', 'celseq2']

Reference - query split

We split our data in a group of reference datasets to be used for reference building, and a group of query datasets that we will map.

In order to simulate an unknown cell type scenario, we manually remove beta cells from the reference.

[8]:
adata.obs['query'] = adata.obs[condition_key].isin(query)
adata.obs['query'] = adata.obs['query'].astype('category')
source_adata = adata[adata.obs.study.isin(reference)].copy()
source_adata = source_adata[~source_adata.obs.cell_type.str.contains('alpha')].copy()
target_adata = adata[adata.obs.study.isin(query)].copy()
[9]:
source_adata, target_adata
[9]:
(AnnData object with n_obs × n_vars = 8634 × 4000
     obs: 'study', 'cell_type', 'query'
     uns: 'neighbors', 'umap', 'study_colors', 'cell_type_colors'
     obsm: 'X_pca', 'X_umap'
     obsp: 'distances', 'connectivities',
 AnnData object with n_obs × n_vars = 3289 × 4000
     obs: 'study', 'cell_type', 'query'
     uns: 'neighbors', 'umap', 'study_colors', 'cell_type_colors'
     obsm: 'X_pca', 'X_umap'
     obsp: 'distances', 'connectivities')

Train reference scPoli model on fully labeled reference data

Explanation of scPoli parameters:

Model: - condition_keys: obs column names of the covariate(s) you want to use for integration, if a list of names is passed, the model will use independent embeddings for each covariate - cell_type_keys: obs column names of the cell type annotation(s) to use for prototype learning, if a list is passed, the model will compute the prototype loss in parallel for each set of annotations passed - embedding_dims: embedding dimensionality, if an integer is passed, the model will use embeddings of the same dimensionality for each covariate, if the user wishes to define different dimensionalities for each covariate, a list needs to be provided

Training: - pretraining_epochs: number of epochs for which the model is trained in an unsupervised fashion - n_epochs: total number of training epochs, finetuning epochs therefore will be (n_epochs - pretraining_epochs) - eta: weight of the prototype loss - prototype_training: flag that can be used to turn off prototype training - unlabeled_prototype_training: flag that can be used to skip unlabeled prototype computation. This step involves Louvain clustering and can be time-consuming on big datasets. Unlabeled prototypes can be useful for downstream analyses but are not used at training time.

[10]:
scpoli_model = scPoli(
    adata=source_adata,
    condition_keys=condition_key,
    cell_type_keys=cell_type_key,
    embedding_dims=5,
    recon_loss='nb',
)
scpoli_model.train(
    n_epochs=50,
    pretraining_epochs=40,
    early_stopping_kwargs=early_stopping_kwargs,
    eta=5,
)
Embedding dictionary:
        Num conditions: [7]
        Embedding dim: [5]
Encoder Architecture:
        Input Layer in, out and cond: 4000 64 5
        Mean/Var Layer in/out: 64 10
Decoder Architecture:
        First Layer in, out and cond:  10 64 5
        Output Layer in/out:  64 4000

Initializing dataloaders
Starting training
 |████████████████████| 100.0%  - val_loss: 1088.52 - val_cvae_loss: 1076.07 - val_prototype_loss:   12.45 - val_labeled_loss:    2.49

We recommend using a pretraining/training epoch ratio of approximately 80 or 90%. If you train for more total epochs you should use a higher ratio, whereas if you’re training for only a few epochs, this ratio can be smaller. If the model is trained withthe prototype loss for too many epochs it can lead to very concentrated clusters in latent space.

Reference mapping of unlabeled query datasets

[11]:
scpoli_query = scPoli.load_query_data(
    adata=target_adata,
    reference_model=scpoli_model,
    labeled_indices=[],
)
Embedding dictionary:
        Num conditions: [9]
        Embedding dim: [5]
Encoder Architecture:
        Input Layer in, out and cond: 4000 64 5
        Mean/Var Layer in/out: 64 10
Decoder Architecture:
        First Layer in, out and cond:  10 64 5
        Output Layer in/out:  64 4000

[12]:
scpoli_query.train(
    n_epochs=50,
    pretraining_epochs=40,
    eta=10
)
Warning: Labels in adata.obs[cell_type] is not a subset of label-encoder!
The missing labels are: {'alpha'}
Therefore integer value of those labels is set to -1
Warning: Labels in adata.obs[cell_type] is not a subset of label-encoder!
The missing labels are: {'alpha'}
Therefore integer value of those labels is set to -1
Initializing dataloaders
Starting training
 |████████████████----| 80.0%  - val_loss: 1768.56 - val_cvae_loss: 1768.56
Initializing unlabeled prototypes with Leiden with an unknown number of  clusters.
Clustering succesful. Found 20 clusters.
 |████████████████████| 100.0%  - val_loss: 1759.32 - val_cvae_loss: 1759.32 - val_prototype_loss:    0.00 - val_unlabeled_loss:    0.76

Label transfer from reference to query

The uncertainties returned by the model consist of the distance between the cell in the latent space and the labeled prototypes closest to it. This distance does not have an upper bound, and the scale of the distance can provide information on the heterogeneity of the dataset. We also offer the option to scale the uncertainties between 0 and 1.

[13]:
results_dict = scpoli_query.classify(target_adata, scale_uncertainties=True)

Let’s check the label transfer performance we achieved.

[14]:
for i in range(len(cell_type_key)):
    preds = results_dict[cell_type_key]["preds"]
    results_dict[cell_type_key]["uncert"]
    classification_df = pd.DataFrame(
        classification_report(
            y_true=target_adata.obs[cell_type_key],
            y_pred=preds,
            output_dict=True,
        )
    ).transpose()
print(classification_df)
                    precision    recall  f1-score      support
acinar               0.958416  0.964143  0.961271   502.000000
activated_stellate   0.908333  1.000000  0.951965   109.000000
alpha                0.000000  0.000000  0.000000  1034.000000
beta                 0.946288  0.988449  0.966909   606.000000
delta                0.759259  0.972332  0.852686   253.000000
ductal               0.959322  0.967521  0.963404   585.000000
endothelial          1.000000  1.000000  1.000000    26.000000
epsilon              0.018519  1.000000  0.036364     5.000000
gamma                0.163265  1.000000  0.280702   128.000000
macrophage           0.882353  0.937500  0.909091    16.000000
mast                 1.000000  0.714286  0.833333     7.000000
quiescent_stellate   1.000000  0.692308  0.818182    13.000000
schwann              1.000000  1.000000  1.000000     5.000000
t_cell               0.000000  0.000000  0.000000     0.000000
accuracy             0.667984  0.667984  0.667984     0.667984
macro avg            0.685411  0.802610  0.683850  3289.000000
weighted avg         0.605955  0.667984  0.623204  3289.000000
[15]:
#get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
    source_adata,
    mean=True
)

adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()

#get latent representation of query data
data_latent= scpoli_query.get_latent(
    target_adata,
    mean=True
)

adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()

#get label annotations
adata_latent.obs['cell_type_pred'] = results_dict['cell_type']['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict['cell_type']['uncert'].tolist()
adata_latent.obs['classifier_outcome'] = (
    adata_latent.obs['cell_type_pred'] == adata_latent.obs['cell_type']
)

#get prototypes
labeled_prototypes = scpoli_query.get_prototypes_info()
labeled_prototypes.obs['study'] = 'labeled prototype'
unlabeled_prototypes = scpoli_query.get_prototypes_info(prototype_set='unlabeled')
unlabeled_prototypes.obs['study'] = 'unlabeled prototype'

#join adatas
adata_latent_full = adata_latent_source.concatenate(
    [adata_latent, labeled_prototypes, unlabeled_prototypes],
    batch_key='query'
)
adata_latent_full.obs['cell_type_pred'][adata_latent_full.obs['query'].isin(['0'])] = np.nan
sc.pp.neighbors(adata_latent_full, n_neighbors=15)
sc.tl.umap(adata_latent_full)
[16]:
#get adata without prototypes
adata_no_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['0', '1'])]
[17]:
sc.pl.umap(
    adata_no_prototypes,
    color='cell_type_pred',
    show=False,
    frameon=False,
)
[17]:
<Axes: title={'center': 'cell_type_pred'}, xlabel='UMAP1', ylabel='UMAP2'>
_images/scpoli_surgery_pipeline_26_1.png
[18]:
sc.pl.umap(
    adata_no_prototypes,
    color='study',
    show=False,
    frameon=False,
)
[18]:
<Axes: title={'center': 'study'}, xlabel='UMAP1', ylabel='UMAP2'>
_images/scpoli_surgery_pipeline_27_1.png

Inspect uncertainty

We can look at the uncertainty of each prediction and either select a threshold after visual inspection or by looking at the percentiles of the uncertainties distribution.

[19]:
sc.pl.umap(
    adata_no_prototypes,
    color='cell_type_uncert',
    show=False,
    frameon=False,
    cmap='magma',
    vmax=1
)
[19]:
<Axes: title={'center': 'cell_type_uncert'}, xlabel='UMAP1', ylabel='UMAP2'>
_images/scpoli_surgery_pipeline_29_1.png

Inspect prototypes

[20]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
adata_labeled_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['2'])]
adata_unlabeled_prototypes = adata_latent_full[adata_latent_full.obs['query'].isin(['3'])]
adata_labeled_prototypes.obs['cell_type_pred'] = adata_labeled_prototypes.obs['cell_type_pred'].astype('category')
adata_unlabeled_prototypes.obs['cell_type_pred'] = adata_unlabeled_prototypes.obs['cell_type_pred'].astype('category')
adata_unlabeled_prototypes.obs['cell_type'] = adata_unlabeled_prototypes.obs['cell_type'].astype('category')

sc.pl.umap(
    adata_no_prototypes,
    alpha=0.2,
    show=False,
    ax=ax
)
ax.legend([])
# plot labeled prototypes
sc.pl.umap(
    adata_labeled_prototypes,
    size=200,
    color=f'{cell_type_key}_pred',
    ax=ax,
    show=False,
    frameon=False,
)
cell_types = adata_labeled_prototypes.obs[f'{cell_type_key}_pred'].cat.categories
color_ct = adata_labeled_prototypes.uns[f'{cell_type_key}_pred_colors']
color_dict = dict(zip(cell_types, color_ct))
# plot labeled prototypes
sc.pl.umap(
    adata_unlabeled_prototypes,
    size=100,
    color=f'{cell_type_key}_pred',
    palette=color_dict,
    ax=ax,
    show=False,
    frameon=False,
    alpha=0.5,
)
sc.pl.umap(
    adata_unlabeled_prototypes,
    size=0,
    color=cell_type_key,
    #palette=color_dict,
    frameon=False,
    show=False,
    ax=ax,
    legend_loc='on data',
    legend_fontsize=5,
)
ax.set_title('Landmarks')
h, l = ax.get_legend_handles_labels()
ax.legend().remove()
ax.legend(handles=h[:13], labels= l[:13], frameon=False, bbox_to_anchor=(1, 1))
fig.tight_layout()
_images/scpoli_surgery_pipeline_31_0.png

After inspecting the prototypes we can observe that unlabeled prototype 4, 5, 7, 8, 11 and 13 fall into the region of high uncertainty. With this knowledge, we can add a new labeled prototype.

[21]:
scpoli_query.add_new_cell_type(
    "alpha",
    cell_type_key,
    [3, 4, 5, 12, 13, 15, 19]
)
[22]:
results_dict = scpoli_query.classify(target_adata)
[23]:
#get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
    source_adata,
    mean=True
)

adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()

#get latent representation of query data
data_latent= scpoli_query.get_latent(
    target_adata,
    mean=True
)

adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()

#get label annotations
adata_latent.obs['cell_type_pred'] = results_dict['cell_type']['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict['cell_type']['uncert'].tolist()
adata_latent.obs['classifier_outcome'] = (
    adata_latent.obs['cell_type_pred'] == adata_latent.obs['cell_type']
)

#join adatas
adata_latent_full = adata_latent_source.concatenate(
    [adata_latent, labeled_prototypes, unlabeled_prototypes],
    batch_key='query'
)
adata_latent_full.obs['cell_type_pred'][adata_latent_full.obs['query'].isin(['0'])] = np.nan
sc.pp.neighbors(adata_latent_full, n_neighbors=15)
sc.tl.umap(adata_latent_full)
[24]:
sc.pl.umap(
    adata_latent_full,
    color='cell_type_pred',
    show=False,
    frameon=False,
)
[24]:
<Axes: title={'center': 'cell_type_pred'}, xlabel='UMAP1', ylabel='UMAP2'>
_images/scpoli_surgery_pipeline_36_1.png

We can now see that the alpha cell cluster is correctly classified.

Sample embeddings

We can extract the conditional embeddings learnt by scPoli and analyse them.

[25]:
adata_emb = scpoli_query.get_conditional_embeddings()
[26]:
from sklearn.decomposition import KernelPCA
pca = KernelPCA(n_components=2, kernel='linear')
emb_pca = pca.fit_transform(adata_emb.X)
conditions = scpoli_query.conditions_['study']
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sns.scatterplot(x=emb_pca[:, 0], y=emb_pca[:, 1], hue=conditions, ax=ax)
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
for i, c in enumerate(conditions):
    ax.plot([0, emb_pca[i, 0]], [0, emb_pca[i, 1]])
    ax.text(emb_pca[i, 0], emb_pca[i, 1], c)
sns.despine()
_images/scpoli_surgery_pipeline_41_0.png