models Cell2Sentence Embedding - Azure/azureml-assets GitHub Wiki

Cell2Sentence-Embedding

Overview

Cell2Sentence (C2S) is a framework designed to apply Large Language Models (LLMs) to single-cell transcriptomics. It transforms single-cell RNA sequencing (scRNA-seq) data into a format that LLMs can natively understand. The core idea is to convert the gene expression vector of each cell into a "cell sentence"—a string of gene names ordered by their expression levels. This allows powerful LLMs to process, analyze, and generate insights from vast amounts of transcriptomic data.

The primary function of this model is to generate compact and informative "cell embeddings." These embeddings are vector representations that capture the essential biological features of each cell. They can be used for a wide range of downstream tasks, including cell clustering, data visualization, classification, perturbation response prediction, and complex biological reasoning. By bridging the gap between single-cell genomics and natural language processing, Cell2Sentence enables deeper analysis and the development of "virtual cells" for advanced biological research.

Model Architecture

The Cell2Sentence model is built upon a Large Language Model with a Pythia-style architecture. It has been trained on a massive corpus of over one billion tokens, comprising transcriptomic data, biological text, and metadata.

The workflow is as follows:

  1. Input: The model takes single-cell expression data (a matrix of cells vs. genes), a list of gene names, and associated cell metadata.
  2. Transformation: The expression data for each cell is converted into a "cell sentence" by rank-ordering gene names based on their expression values.
  3. Embedding: The LLM processes these cell sentences to generate high-dimensional cell embeddings that encapsulate the biological state of each cell.

This model is specifically packaged to perform the cell embedding task.

Sample inputs and outputs (for real time inference)

Due to the large size of single-cell expression data, inference is typically performed using a client script that reads data from a file (like .h5ad), prepares the JSON payload, and sends it to the model endpoint.

Sample Inference Script Due to the large size of single-cell expression data, inference is typically performed using a client script that reads data from a file (like .h5ad), prepares the JSON payload, and sends it to the model endpoint in chunks.

Important Notes:

  • Payload Size Limit: Each request is limited to approximately 100MB. The script automatically validates payload sizes and recommends chunk sizes of 100-200 cells.
  • Input Format: The model expects a specific JSON structure with expression data, gene names, and metadata packaged as a JSON string in a 'payload' field.
  • Chunking: Large datasets should be split into smaller chunks (via chunk size parameter in the following sample inference script) to stay within size limits and improve reliability.
import numpy as np
import pandas as pd
import requests
import json
import logging
import anndata
import scanpy as sc
import os
import random
from typing import Optional, Union, Dict
import argparse
from tqdm import tqdm
import sys
import psutil
import time

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Replace with your actual endpoint configuration ---
ENDPOINT_URL = "YOUR_ENDPOINT_URL"
API_KEY = "YOUR_API_KEY" 
DEPLOYMENT_NAME = "YOUR_DEPLOYMENT_NAME"

def preprocess_data(adata: anndata.AnnData, min_genes: int = 200, min_cells: int = 3, normalize: bool = True) -> anndata.AnnData:
    """Preprocess raw data following Cell2Sentence requirements."""
    logger.info("Starting preprocessing...")
    
    # Set random seed for reproducibility
    random.seed(1234)
    np.random.seed(1234)
    
    # Basic filtering
    logger.info("Filtering cells and genes...")
    sc.pp.filter_cells(adata, min_genes=min_genes)
    sc.pp.filter_genes(adata, min_cells=min_cells)
    
    # Calculate QC metrics
    logger.info("Calculating QC metrics...")
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True)
    
    if normalize:
        logger.info("Normalizing data...")
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata, base=10)
    
    logger.info(f"Preprocessing complete. Final shape: {adata.shape}")
    return adata

def validate_payload_size(payload_dict, max_size_mb=100):
    """Validate payload size limits."""
    json_str = json.dumps(payload_dict)
    total_size_mb = len(json_str.encode('utf-8')) / 1024 / 1024
    
    logger.info(f"Payload size: {total_size_mb:.2f} MB")
    
    if total_size_mb > max_size_mb:
        logger.warning(f"Payload size ({total_size_mb:.2f} MB) exceeds limit ({max_size_mb} MB)")
        return False
    
    return True

def run_cell2sentence_inference(data_path: str, output_path: str, chunk_size: int = 100, timeout: int = 300):
    """
    Run Cell2Sentence inference on single-cell data.
    
    Args:
        data_path: Path to input .h5ad file
        output_path: Path to save output .h5ad file with embeddings
        chunk_size: Number of cells per request (recommended: 50-100 for Azure ML)
        timeout: Request timeout in seconds
    """
    try:
        # Load raw data
        logger.info(f"Loading raw data from {data_path}")
        adata = anndata.read_h5ad(data_path)
        logger.info(f"Initial data shape: {adata.shape}")
        
        # Preprocess data
        adata = preprocess_data(adata)
        
        # Convert data to array
        logger.info("Converting data to array format...")
        if isinstance(adata.X, np.ndarray):
            expression_data = adata.X
        else:
            expression_data = adata.X.toarray()
        
        # Check data characteristics
        logger.info(f"Expression data shape: {expression_data.shape}")
        logger.info(f"Expression data type: {expression_data.dtype}")
        logger.info(f"Expression data range: [{expression_data.min():.3f}, {expression_data.max():.3f}]")
        
        # Prepare inference data with chunking for Azure ML
        logger.info(f"Preparing inference data with chunk size: {chunk_size}")
        total_cells = expression_data.shape[0]
        
        all_embeddings = []
        
        # Define required metadata columns
        required_obs_cols = ["cell_type", "tissue", "batch_condition", "organism", "sex"]
        
        # Check if all required columns are present in the data
        missing_cols = [col for col in required_obs_cols if col not in adata.obs.columns]
        if missing_cols:
            logger.warning(f"Missing required metadata columns: {missing_cols}")
            logger.info("Adding default values for missing columns...")
            for col in missing_cols:
                adata.obs[col] = "unknown"

        # Process in chunks
        num_chunks = (total_cells + chunk_size - 1) // chunk_size
        logger.info(f"Processing {total_cells} cells in {num_chunks} chunks")

        for i in range(0, total_cells, chunk_size):
            chunk_end = min(i + chunk_size, total_cells)
            chunk_num = i // chunk_size + 1
            
            logger.info(f"Processing chunk {chunk_num}/{num_chunks}, cells {i} to {chunk_end}")
            
            metadata_chunk = adata.obs[required_obs_cols][i:chunk_end].to_dict(orient='records')

            # Create payload for this chunk - IMPORTANT: This is the new JSON format
            payload_dict = {
                "expression_data": expression_data[i:chunk_end].astype(np.float32).tolist(),
                "gene_names": adata.var_names.tolist(),
                "metadata": metadata_chunk
            }

            # Validate payload size (Azure ML has ~1.5MB limit per request)
            if not validate_payload_size(payload_dict):
                logger.error(f"Chunk {chunk_num} payload too large, skipping...")
                continue

            # Create DataFrame with single 'payload' column (MLflow format)
            df = pd.DataFrame([{
                "payload": json.dumps(payload_dict)
            }])

            # Format for Azure ML endpoint
            chunk_data = json.loads(df.to_json(orient="split"))
            request_payload = {
                "input_data": chunk_data
            }
            
            # Prepare headers
            headers = {
                'Content-Type': 'application/json',
                'Authorization': f'Bearer {API_KEY}'
            }
            if DEPLOYMENT_NAME:
                headers['azureml-model-deployment'] = DEPLOYMENT_NAME
            
            # Send request to Azure endpoint with retries
            max_retries = 3
            for retry in range(max_retries):
                try:
                    logger.info(f"Sending request to Azure endpoint (attempt {retry + 1}/{max_retries})...")
                    start_time = time.time()
                    
                    response = requests.post(
                        ENDPOINT_URL,
                        headers=headers,
                        data=json.dumps(request_payload),
                        timeout=timeout
                    )
                    
                    end_time = time.time()
                    logger.info(f"Request completed in {end_time - start_time:.2f} seconds")
                    
                    if response.status_code == 200:
                        break
                    else:
                        logger.warning(f"Request failed with status {response.status_code}: {response.text[:200]}...")
                        if retry == max_retries - 1:
                            raise Exception(f"Request failed after {max_retries} attempts")
                        time.sleep(5)
                        
                except requests.RequestException as e:
                    if retry == max_retries - 1:
                        logger.error(f"Request to Azure endpoint failed: {e}")
                        if hasattr(e, 'response') and e.response is not None:
                            logger.error(f"Response: {e.response.text}")
                        raise
                    logger.warning(f"Retry {retry + 1}/{max_retries} after error: {e}")
                    time.sleep(5)
            
            # Process response
            try:
                results = response.json()
                logger.info(f"Response status: {response.status_code}")
                
                # Handle the response format from Azure ML/MLflow
                if isinstance(results, list) and len(results) > 0 and isinstance(results[0], dict) and "predictions" in results[0]:
                    # Azure ML format: [{"predictions": "json_string"}]
                    prediction_string = results[0]["predictions"]
                    output_data = json.loads(prediction_string)
                    chunk_embeddings = np.array(output_data["cell_embeddings"])
                elif isinstance(results, dict) and "predictions" in results:
                    # Standard MLflow format: {"predictions": ["json_string"]}
                    prediction_string = results["predictions"][0]
                    output_data = json.loads(prediction_string)
                    chunk_embeddings = np.array(output_data["cell_embeddings"])
                else:
                    logger.error(f"Unexpected response format. Type: {type(results)}")
                    raise Exception("Unexpected response format")
                
                logger.info(f"Received embeddings with shape: {chunk_embeddings.shape}")
                all_embeddings.append(chunk_embeddings)
                    
            except json.JSONDecodeError as e:
                logger.error(f"Failed to parse JSON response: {str(e)}")
                logger.error(f"Response text: {response.text}")
                raise
            except Exception as e:
                logger.error(f"Error processing response: {str(e)}")
                logger.error(f"Response text (first 500 chars): {response.text[:500]}...")
                raise

            logger.info(f"Processed {chunk_end}/{total_cells} cells")
            
            # Add small delay between requests to be nice to the endpoint
            time.sleep(1)
        
        # Combine embeddings
        if all_embeddings:
            embeddings = np.concatenate(all_embeddings, axis=0)
            logger.info(f"Final embeddings shape: {embeddings.shape}")
            
            # Save results
            logger.info(f"Saving results to {output_path}")
            adata.obsm["c2s_cell_embeddings"] = embeddings
            adata.write_h5ad(output_path)
            logger.info("Save complete")
            
            # Log some statistics
            logger.info(f"Embedding statistics:")
            logger.info(f"  Shape: {embeddings.shape}")
            logger.info(f"  Mean: {embeddings.mean():.6f}")
            logger.info(f"  Std: {embeddings.std():.6f}")
            logger.info(f"  Min: {embeddings.min():.6f}")
            logger.info(f"  Max: {embeddings.max():.6f}")
            
            return embeddings
        else:
            logger.error("No embeddings were generated!")
            return None
        
    except Exception as e:
        logger.error(f"Error in run_cell2sentence_inference: {str(e)}")
        raise

def test_simple_endpoint():
    """Test endpoint with simple synthetic data."""
    logger.info("=== TESTING ENDPOINT WITH SIMPLE DATA ===")
    
    # Create simple test data
    n_cells, n_genes = 5, 100
    expression_data = np.random.lognormal(mean=0, sigma=1, size=(n_cells, n_genes)).astype(np.float32)
    gene_names = [f"GENE_{i:04d}" for i in range(n_genes)]
    
    metadata_list = []
    for i in range(n_cells):
        metadata_list.append({
            "cell_type": f"cell_type_{i}",
            "tissue": "test_tissue",
            "batch_condition": "test_batch",
            "organism": "human",
            "sex": "unknown"
        })
    
    # Create payload
    payload_dict = {
        "expression_data": expression_data.tolist(),
        "gene_names": gene_names,
        "metadata": metadata_list
    }
    
    # Create DataFrame with single 'payload' column
    df = pd.DataFrame([{"payload": json.dumps(payload_dict)}])
    
    # Format for Azure ML endpoint
    chunk_data = json.loads(df.to_json(orient="split"))
    request_payload = {"input_data": chunk_data}
    
    # Prepare headers
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {API_KEY}'
    }
    if DEPLOYMENT_NAME:
        headers['azureml-model-deployment'] = DEPLOYMENT_NAME
    
    logger.info("Sending request to endpoint...")
    try:
        response = requests.post(
            ENDPOINT_URL,
            headers=headers,
            data=json.dumps(request_payload),
            timeout=180
        )
        
        logger.info(f"Response status code: {response.status_code}")
        
        if response.status_code == 200:
            results = response.json()
            logger.info("SUCCESS! Endpoint responded successfully.")
            
            # Parse the prediction
            prediction_string = results["predictions"][0]
            output_data = json.loads(prediction_string)
            embeddings = np.array(output_data["cell_embeddings"])
            logger.info(f"Received embeddings with shape: {embeddings.shape}")
            logger.info(f"Embedding sample: {embeddings[0][:5].tolist()}")
            return True
        else:
            logger.error(f"FAILED! Status: {response.status_code}")
            logger.error(f"Response: {response.text}")
            return False
            
    except Exception as e:
        logger.error(f"Request failed: {e}")
        return False

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Cell2Sentence Inference Client')
    parser.add_argument('--input', '-i', type=str, help='Path to input h5ad file')
    parser.add_argument('--output', '-o', type=str, default="cell2sentence_embeddings.h5ad", help='Path to output h5ad file')
    parser.add_argument('--chunk-size', type=int, default=100, help='Number of cells per chunk (recommended: 50-100 for Azure ML)')
    parser.add_argument('--timeout', type=int, default=300, help='Request timeout in seconds')
    parser.add_argument('--test', action='store_true', help='Run simple test with synthetic data')
    
    args = parser.parse_args()
    
    if args.test:
        logger.info("Running simple test with synthetic data...")
        success = test_simple_endpoint()
        if success:
            logger.info("Simple test passed! Endpoint is working.")
        else:
            logger.error("Simple test failed! Check endpoint configuration.")
    
    elif args.input:
        if os.path.exists(args.input):
            logger.info(f"Running Cell2Sentence inference on: {args.input}")
            logger.info(f"Output will be saved to: {args.output}")
            logger.info(f"Using chunk size: {args.chunk_size}")
            
            try:
                embeddings = run_cell2sentence_inference(
                    data_path=args.input,
                    output_path=args.output,
                    chunk_size=args.chunk_size,
                    timeout=args.timeout
                )
                
                if embeddings is not None:
                    logger.info("Inference completed successfully!")
                    logger.info(f"Generated embeddings for {embeddings.shape[0]} cells")
                    logger.info(f"Results saved to: {args.output}")
                else:
                    logger.error("Inference failed!")
                    
            except Exception as e:
                logger.error(f"Error during inference: {str(e)}")
                sys.exit(1)
        else:
            logger.error(f"Data file not found: {args.input}")
            sys.exit(1)
    
    else:
        logger.info("Usage:")
        logger.info("  # Test with simple synthetic data:")
        logger.info("  python cell2sentence_inference.py --test")
        logger.info("")
        logger.info("  # Run inference on real data:")
        logger.info("  python cell2sentence_inference.py --input /path/to/your/data.h5ad --output results.h5ad")
        logger.info("")
        logger.info("  # Adjust chunk size for payload limits:")
        logger.info("  python cell2sentence_inference.py --input data.h5ad --chunk-size 50")

Usage Examples:

Test the endpoint with synthetic data

python cell2sentence_inference.py --test

Run inference on real data

python cell2sentence_inference.py -i /path/to/your/data.h5ad -o results_with_embeddings.h5ad

Adjust chunk size for very large datasets or network constraints

python cell2sentence_inference.py -i data.h5ad --chunk-size 100

Output Sample The model returns a JSON object containing the cell embeddings for each input cell.

{
  "predictions": [
    "{\"cell_embeddings\": [[-0.013, 0.045, 0.021, ...], [-0.025, 0.033, 0.019, ...], [0.012, -0.008, 0.035, ...]]}"
  ]
}

Data and Resource Specification for Deployment

  • Supported Data Input Format
  1. Input Format: The model accepts single-cell RNA sequencing data, typically from .h5ad files. The user is responsible for extracting the required data arrays and lists before sending them to the endpoint.

  2. Input Schema Requirements: The inputs JSON object must contain three keys:

    • expression_data: (2D array of floats) A matrix where rows are cells and columns are genes.
    • gene_names: (1D array of strings) A list of gene names corresponding to the columns of the expression matrix.
    • metadata: (object) A dictionary where keys are metadata field names (e.g., "cell_type", "tissue") and values are lists of strings, with each list having the same length as the number of cells. The following metadata columns are required: cell_type, tissue, batch_condition, organism, and sex.
  3. Example Dataset:

    • Source: The model was tested with a subset of the Immune System tissue dataset from Domínguez Conde et al. (2022).
    • Link: The raw data can be downloaded from Google Drive.
    • Citation: Domínguez Conde, C., et al. "Cross-tissue immune cell analysis reveals tissue-specific features in humans." Science 376.6594 (2022): eabl5197.

Version: 1

Tags

task : embedding industry : health-and-life-sciences Preview licenseDescription : This model is provided under the License Terms available at <https://github.com/vandijklab/cell2sentence/blob/master/LICENSE>. inference_supported_envs : ['hf'] license : cc-by-nc-nd-4.0 author : vandijklab hiddenlayerscanned SharedComputeCapacityEnabled inference_compute_allow_list : ['Standard_NC4as_T4_v3', 'Standard_NC8as_T4_v3', 'Standard_NC16as_T4_v3', 'Standard_NC64as_T4_v3', 'Standard_NC6s_v3', 'Standard_NC12s_v3', 'Standard_NC24s_v3', 'Standard_NC24ads_A100_v4', 'Standard_NC48ads_A100_v4', 'Standard_NC96ads_A100_v4', 'Standard_ND96asr_v4', 'Standard_ND96amsr_A100_v4', 'Standard_ND40rs_v2', 'Standard_NC40ads_H100_v5', 'Standard_NC80adis_H100_v5', 'Standard_ND96isr_H100_v5']

View in Studio: https://ml.azure.com/registries/azureml/models/Cell2Sentence-Embedding/version/1

License: cc-by-nc-nd-4.0

Properties

inference-min-sku-spec: 4|1|28|64

inference-recommended-sku: Standard_NC4as_T4_v3, Standard_NC8as_T4_v3, Standard_NC16as_T4_v3, Standard_NC64as_T4_v3, Standard_NC6s_v3, Standard_NC12s_v3, Standard_NC24s_v3, Standard_NC24ads_A100_v4, Standard_NC48ads_A100_v4, Standard_NC96ads_A100_v4, Standard_ND96asr_v4, Standard_ND96amsr_A100_v4, Standard_ND40rs_v2, Standard_NC40ads_H100_v5, Standard_NC80adis_H100_v5, Standard_ND96isr_H100_v5

languages: en

SharedComputeCapacityEnabled: True

⚠️ **GitHub.com Fallback** ⚠️