notes on djl embeddings ‐ simplified batch predictor with predictor pool - krickert/search-api GitHub Wiki

package com.krickert.search.vectorizer;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar; // Import ProgressBar
import ai.djl.translate.TranslateException;
import io.micronaut.context.annotation.Value;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.io.ResourceLoader;
import io.micronaut.core.util.StringUtils;
import jakarta.annotation.PreDestroy;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.apache.commons.io.IOUtils; // Keep for JAR extraction
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.*; // Keep BlockingQueue
import java.util.stream.Collectors;


/**
 * Simplified SentenceVectorizer class converts text inputs into vector embeddings using a pre-trained model.
 * It manages a pool of predictors for efficient instance reuse and processes inputs sequentially in batches
 * based on configurable max batch size.
 * Assumes input data (texts) are valid (non-null, non-empty) as per upstream guarantees.
 */
@Singleton
public class SentenceVectorizer implements Vectorizer {

    private static final Logger log = LoggerFactory.getLogger(SentenceVectorizer.class);

    // Keep necessary fields for model loading, pooling, and configuration
    private final String modelUrl;
    private final ResourceLoader resourceLoader;
    private final ZooModel<String, float[]> model;
    private final BlockingQueue<Predictor<String, float[]>> predictorPool;
    private final int poolSize; // Still used to size the pool
    private final int maxBatchSize;
    // private final ExecutorService batchExecutor; // Removed - not needed for sequential loop

    /**
     * Constructs the SentenceVectorizer, loading the model and initializing the predictor pool.
     *
     * @param modelUrl       The URL or resource path to the model (.zip format).
     * @param tempDir        A temporary directory for extracting models from JARs.
     * @param poolSize       The number of predictors to create in the pool (for instance reuse).
     * @param maxBatchSize   The maximum number of texts to process in a single batch prediction call.
     * @param resourceLoader Micronaut resource loader for accessing model files.
     * @throws ModelNotFoundException If the model cannot be found at the specified URL.
     * @throws MalformedModelException If the model file is corrupted or invalid.
     * @throws IOException            If an I/O error occurs during model loading or extraction.
     */
    @Inject
    public SentenceVectorizer(@Value("${vectorizer.model.url}") String modelUrl,
                              @Value("${vectorizer.temp-dir}") String tempDir,
                              @Value("${vectorizer.pool-size:10}") int poolSize,
                              @Value("${vectorizer.max-batch-size:64}") int maxBatchSize,
                              ResourceLoader resourceLoader)
            throws ModelNotFoundException, MalformedModelException, IOException {

        this.resourceLoader = resourceLoader;
        this.poolSize = poolSize; // Pool size still relevant for predictor reuse
        this.maxBatchSize = maxBatchSize;

        log.info("Initializing SentenceVectorizer with poolSize={}, maxBatchSize={}", poolSize, maxBatchSize);
        log.info("Attempting to load model from configured URL/path: {}", modelUrl);

        // --- Model Loading Logic ---
        this.modelUrl = initializeModelPath(modelUrl, tempDir);
        log.info("Resolved model path to: {}", this.modelUrl);
        this.model = loadModel(this.modelUrl);

        if (this.model == null) {
            throw new IllegalStateException("Model could not be loaded, cannot initialize SentenceVectorizer.");
        }
        log.info("Model loaded: {}", this.model.getName());

        // --- Predictor Pool Creation ---
        this.predictorPool = createPredictorPool(this.model, this.poolSize);
        log.info("Predictor pool initialized.");

        // --- ExecutorService Removed ---
        // this.batchExecutor = Executors.newFixedThreadPool(this.poolSize);

        log.info("SentenceVectorizer initialized successfully.");
    }

    /**
     * Generates vector embeddings for a single text input.
     * Uses the predictor pool. Returns null if interrupted.
     * Throws RuntimeException on TranslateException.
     * (Remains the same)
     */
    @Override
    @Nullable
    public float[] embeddings(@NonNull String text) {
        if (StringUtils.isEmpty(text)) {
             log.warn("Received empty or null text in 'embeddings' method despite upstream guarantees. Returning empty array.");
             return new float[0];
        }
        log.debug("Requesting predictor for single text embedding.");
        Predictor<String, float[]> predictor = null;
        try {
            predictor = predictorPool.take();
            log.debug("Predictor acquired. Vectorizing text (first 50 chars): '{}...'",
                    text.substring(0, Math.min(text.length(), 50)));
            float[] response = predictor.predict(text);
            log.debug("Single text embedding generated successfully.");
            return response;
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("Thread interrupted while waiting for a predictor from the pool.", e);
            return null;
        } catch (TranslateException e) {
            log.error("Error during single text prediction for text: '{}...'",
                    text.substring(0, Math.min(text.length(), 50)), e);
            throw new RuntimeException("Embedding generation failed for the provided text.", e);
        } finally {
            if (predictor != null) {
                if (!predictorPool.offer(predictor)) {
                    log.warn("Failed to return predictor to the pool (pool might be full unexpectedly). Closing predictor.");
                    predictor.close();
                } else {
                    log.debug("Predictor returned to pool.");
                }
            }
        }
    }

    /**
     * Generates vector embeddings for a batch of text inputs SEQUENTIALLY.
     * Assumes input texts are valid (non-null, non-empty).
     * Divides inputs into batches and processes them one by one using a predictor from the pool.
     * If a batch fails prediction (TranslateException), results for that batch will be null.
     *
     * @param texts The list of input texts (assumed valid). Cannot be null.
     * @return A list of float arrays representing the embeddings. The list will have the same size
     * as the input list. Entries corresponding to batches that failed prediction will be null.
     */
    @Override
    @NonNull
    public List<float[]> batchEmbeddings(@NonNull List<String> texts) {
        if (texts.isEmpty()) {
            log.info("Input text list is empty, returning empty list.");
            return Collections.emptyList();
        }

        log.info("Starting SEQUENTIAL batch embedding generation for {} texts using model '{}'. Max batch size: {}",
                texts.size(), model.getName(), maxBatchSize);

        // Pre-allocate result array
        final float[][] embeddings = new float[texts.size()][];

        // Process the texts sequentially in chunks (batches)
        for (int i = 0; i < texts.size(); i += maxBatchSize) {
            final int batchStartIndex = i;
            final int batchEndIndex = Math.min(batchStartIndex + maxBatchSize, texts.size());
            // Sublist view is efficient
            final List<String> currentBatchTexts = texts.subList(batchStartIndex, batchEndIndex);

            if (currentBatchTexts.isEmpty()) {
                continue;
            }

            log.debug("Processing batch sequentially for indices [{} - {}] (size: {})", batchStartIndex, batchEndIndex - 1, currentBatchTexts.size());

            Predictor<String, float[]> predictor = null;
            try {
                // Take a predictor from the pool, waiting if necessary
                // This ensures we don't create/destroy predictors constantly
                predictor = predictorPool.take();
                log.debug("Predictor acquired for sequential batch [{} - {}]", batchStartIndex, batchEndIndex - 1);

                // Directly perform batch prediction - THIS CALL BLOCKS until the batch is done
                List<float[]> batchResult = predictor.batchPredict(currentBatchTexts);

                // Copy results back into the main embeddings array at the correct offset
                for (int j = 0; j < batchResult.size(); j++) {
                    embeddings[batchStartIndex + j] = batchResult.get(j);
                }
                log.debug("Sequential batch [{} - {}] processed successfully.", batchStartIndex, batchEndIndex - 1);

            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                log.error("Thread interrupted while waiting for predictor or during prediction for sequential batch [{} - {}]. Marking batch as failed.", batchStartIndex, batchEndIndex - 1, e);
                // Mark results in this batch as null
                Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
                // Optionally break the loop if interruption should stop all processing
                // break;
            } catch (TranslateException e) {
                // Catch DJL translation/prediction errors for the batch
                log.error("Batch prediction failed via TranslateException for sequential batch [{} - {}]. Marking batch as failed. Error: {}", batchStartIndex, batchEndIndex - 1, e.getMessage(), e);
                // Mark results in this batch as null
                Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
                // Continue to the next batch
            } catch (Exception e) {
                // Catch any other unexpected exceptions during batch processing
                 log.error("Unexpected error during sequential batch prediction for indices [{} - {}]. Marking batch as failed.", batchStartIndex, batchEndIndex - 1, e);
                 Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
                 // Continue to the next batch
            } finally {
                // Return the predictor to the pool for the next iteration or another request
                if (predictor != null) {
                    if (!predictorPool.offer(predictor)) {
                         log.warn("Failed to return predictor to the pool after sequential batch [{} - {}]. Closing predictor.", batchStartIndex, batchEndIndex - 1);
                         predictor.close();
                    } else {
                        log.debug("Predictor returned to pool after sequential batch [{} - {}].", batchStartIndex, batchEndIndex - 1);
                    }
                }
            }
            // End of processing for one batch, the loop continues to the next
        }

        log.info("Sequential batch embedding generation finished for {} texts.", texts.size());
        // Convert the result array (containing embeddings or nulls for failed batches) to a list
        return Arrays.asList(embeddings);
    }

    /**
     * Convenience method to get embeddings as a Collection of Floats.
     * (Remains the same)
     */
    @Override
    @NonNull
    public Collection<Float> getEmbeddings(@NonNull String text) {
        float[] embeddingArray = embeddings(text);
        if (embeddingArray == null || embeddingArray.length == 0) {
            return Collections.emptyList();
        }
        List<Float> embeddingList = new ArrayList<>(embeddingArray.length);
        for (float value : embeddingArray) {
            embeddingList.add(value);
        }
        return embeddingList;
    }


    // --- Helper Methods ---
    // (initializeModelPath, loadModel, selectDevice, createPredictorPool, extractResourceFromJar)
    // Remain the same as the previous full version

    /**
     * Resolves the model path, extracting it from JAR if necessary.
     */
    private String initializeModelPath(String configuredModelUrl, String tempDir) throws IOException {
        log.debug("Initializing model path for: {}", configuredModelUrl);
        if (StringUtils.isEmpty(configuredModelUrl)) {
            throw new IllegalArgumentException("Model URL/path ('vectorizer.model.url') cannot be empty.");
        }
        if (configuredModelUrl.endsWith(".zip")) {
            Optional<URL> modelOpt = resourceLoader.getResource(configuredModelUrl);
            if (modelOpt.isPresent()) {
                URL modelUrlFull = modelOpt.get();
                log.debug("Found model resource via loader: {}", modelUrlFull);
                if ("jar".equals(modelUrlFull.getProtocol())) {
                    log.info("Model resource is inside a JAR. Extracting '{}' to temporary directory '{}'", configuredModelUrl, tempDir);
                    return extractResourceFromJar(configuredModelUrl, tempDir).toString();
                } else {
                    return modelUrlFull.toString();
                }
            } else {
                log.warn("Model resource '{}' not found via ResourceLoader. Assuming it's a direct file path or URL.", configuredModelUrl);
                return configuredModelUrl;
            }
        } else {
             log.info("Model path '{}' does not end with .zip. Assuming it's a directory or pre-loaded model URL.", configuredModelUrl);
             return configuredModelUrl;
        }
    }

    /**
     * Loads the DJL ZooModel based on the resolved model URL.
     */
     private ZooModel<String, float[]> loadModel(String resolvedModelUrl) throws ModelNotFoundException, MalformedModelException, IOException {
        Device device = selectDevice();
        log.info("Attempting to load model from {} onto device {}", resolvedModelUrl, device);
        Criteria<String, float[]> criteria = Criteria.builder()
                .setTypes(String.class, float[].class)
                .optModelUrls(resolvedModelUrl)
                .optEngine("PyTorch")
                .optDevice(device)
                .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
                .optProgress(new ProgressBar())
                .build();
        log.info("Loading model via DJL criteria...");
        return criteria.loadModel();
    }

    /**
     * Selects the best available device (GPU if available, otherwise CPU).
     */
     private Device selectDevice() {
         if (Engine.getInstance().getGpuCount() > 0) {
            log.info("GPU detected ({} available). Selecting GPU device.", Engine.getInstance().getGpuCount());
            return Device.gpu();
         }
         log.info("No GPU detected. Selecting CPU device.");
         return Device.cpu();
     }

    /**
     * Creates a pool of predictors.
     */
     private BlockingQueue<Predictor<String, float[]>> createPredictorPool(ZooModel<String, float[]> model, int poolSize) {
         BlockingQueue<Predictor<String, float[]>> pool = new ArrayBlockingQueue<>(poolSize);
         log.info("Creating predictor pool with size: {}", poolSize);
         try {
             for (int i = 0; i < poolSize; i++) {
                 Predictor<String, float[]> predictor = model.newPredictor();
                 if (!pool.offer(predictor)) {
                     log.error("Failed to add predictor {} to the pool (unexpected). Closing predictor.", i);
                     predictor.close();
                 }
             }
             log.info("Predictor pool populated successfully.");
         } catch (Exception e) {
             log.error("Failed to create predictors for the pool.", e);
             pool.forEach(Predictor::close);
             pool.clear();
             throw new RuntimeException("Failed to initialize predictor pool", e);
         }
         return pool;
     }

    /**
     * Extracts a resource from within a JAR file to a specified target directory.
     */
    private URL extractResourceFromJar(String resourcePath, String targetDirectory) throws IOException {
        Optional<InputStream> resourceStreamOpt = resourceLoader.getResourceAsStream(resourcePath);
        if (resourceStreamOpt.isEmpty()) {
            throw new IOException("Resource not found inside JAR: " + resourcePath);
        }
        Path targetDirPath = Paths.get(targetDirectory);
        if (!Files.exists(targetDirPath)) {
            log.debug("Creating temporary directory for JAR extraction: {}", targetDirPath);
            Files.createDirectories(targetDirPath);
        } else if (!Files.isDirectory(targetDirPath)) {
             throw new IOException("Target path for extraction exists but is not a directory: " + targetDirectory);
        }
        String fileName = Paths.get(resourcePath).getFileName().toString();
        if (StringUtils.isEmpty(fileName)) {
             throw new IOException("Could not determine filename from resource path: " + resourcePath);
        }
        Path targetFilePath = targetDirPath.resolve(fileName);
        log.info("Extracting resource '{}' from JAR to '{}'", resourcePath, targetFilePath);
        try (InputStream resourceStream = resourceStreamOpt.get();
             FileOutputStream outputStream = new FileOutputStream(targetFilePath.toFile())) {
            IOUtils.copy(resourceStream, outputStream);
        } catch (IOException e) {
            log.error("Failed to copy resource '{}' to '{}'", resourcePath, targetFilePath, e);
            try { Files.deleteIfExists(targetFilePath); } catch (IOException cleanupEx) { log.warn("Failed to delete partially extracted file: {}", targetFilePath, cleanupEx); }
            throw e;
        }
        log.info("Resource extracted successfully to: {}", targetFilePath);
        return targetFilePath.toUri().toURL();
    }


    /**
     * Cleans up resources on bean destruction: closes predictors, closes model.
     * ExecutorService shutdown is removed as it's no longer used.
     */
    @PreDestroy
    public void close() {
        log.info("Closing SentenceVectorizer resources...");

        // 1. Shutdown the batch executor service - REMOVED
        // if (batchExecutor != null && !batchExecutor.isShutdown()) { ... }

        // 2. Close predictors in the pool
        if (predictorPool != null) {
             int count = predictorPool.size();
             log.debug("Closing {} predictors remaining in the pool...", count);
             List<Predictor<String,float[]>> predictorsToClose = new ArrayList<>();
             predictorPool.drainTo(predictorsToClose);
             predictorsToClose.forEach(predictor -> {
                 try { predictor.close(); } catch (Exception e) { log.warn("Error closing a predictor instance.", e); }
             });
             predictorPool.clear();
             log.debug("Predictor pool cleared and {} predictors closed.", predictorsToClose.size());
        } else {
            log.debug("Predictor pool is null.");
        }

        // 3. Close the model
        if (model != null) {
            try {
                log.debug("Closing DJL model: {}", model.getName());
                model.close();
                log.info("Model closed successfully: {}", model.getName());
            } catch (Exception e) {
                 log.error("Error closing the DJL model: {}", model.getName(), e);
            }
        } else {
             log.debug("Model is null, nothing to close.");
        }
        log.info("SentenceVectorizer closed.");
    }

     // ExecutorServiceManager helper class is removed as it's no longer needed.
}
⚠️ **GitHub.com Fallback** ⚠️