package com.krickert.search.vectorizer;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar; // Import ProgressBar
import ai.djl.translate.TranslateException;
import io.micronaut.context.annotation.Value;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.io.ResourceLoader;
import io.micronaut.core.util.StringUtils;
import jakarta.annotation.PreDestroy;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.apache.commons.io.IOUtils; // Keep for JAR extraction
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.*; // Keep BlockingQueue
import java.util.stream.Collectors;
/**
* Simplified SentenceVectorizer class converts text inputs into vector embeddings using a pre-trained model.
* It manages a pool of predictors for efficient instance reuse and processes inputs sequentially in batches
* based on configurable max batch size.
* Assumes input data (texts) are valid (non-null, non-empty) as per upstream guarantees.
*/
@Singleton
public class SentenceVectorizer implements Vectorizer {
private static final Logger log = LoggerFactory.getLogger(SentenceVectorizer.class);
// Keep necessary fields for model loading, pooling, and configuration
private final String modelUrl;
private final ResourceLoader resourceLoader;
private final ZooModel<String, float[]> model;
private final BlockingQueue<Predictor<String, float[]>> predictorPool;
private final int poolSize; // Still used to size the pool
private final int maxBatchSize;
// private final ExecutorService batchExecutor; // Removed - not needed for sequential loop
/**
* Constructs the SentenceVectorizer, loading the model and initializing the predictor pool.
*
* @param modelUrl The URL or resource path to the model (.zip format).
* @param tempDir A temporary directory for extracting models from JARs.
* @param poolSize The number of predictors to create in the pool (for instance reuse).
* @param maxBatchSize The maximum number of texts to process in a single batch prediction call.
* @param resourceLoader Micronaut resource loader for accessing model files.
* @throws ModelNotFoundException If the model cannot be found at the specified URL.
* @throws MalformedModelException If the model file is corrupted or invalid.
* @throws IOException If an I/O error occurs during model loading or extraction.
*/
@Inject
public SentenceVectorizer(@Value("${vectorizer.model.url}") String modelUrl,
@Value("${vectorizer.temp-dir}") String tempDir,
@Value("${vectorizer.pool-size:10}") int poolSize,
@Value("${vectorizer.max-batch-size:64}") int maxBatchSize,
ResourceLoader resourceLoader)
throws ModelNotFoundException, MalformedModelException, IOException {
this.resourceLoader = resourceLoader;
this.poolSize = poolSize; // Pool size still relevant for predictor reuse
this.maxBatchSize = maxBatchSize;
log.info("Initializing SentenceVectorizer with poolSize={}, maxBatchSize={}", poolSize, maxBatchSize);
log.info("Attempting to load model from configured URL/path: {}", modelUrl);
// --- Model Loading Logic ---
this.modelUrl = initializeModelPath(modelUrl, tempDir);
log.info("Resolved model path to: {}", this.modelUrl);
this.model = loadModel(this.modelUrl);
if (this.model == null) {
throw new IllegalStateException("Model could not be loaded, cannot initialize SentenceVectorizer.");
}
log.info("Model loaded: {}", this.model.getName());
// --- Predictor Pool Creation ---
this.predictorPool = createPredictorPool(this.model, this.poolSize);
log.info("Predictor pool initialized.");
// --- ExecutorService Removed ---
// this.batchExecutor = Executors.newFixedThreadPool(this.poolSize);
log.info("SentenceVectorizer initialized successfully.");
}
/**
* Generates vector embeddings for a single text input.
* Uses the predictor pool. Returns null if interrupted.
* Throws RuntimeException on TranslateException.
* (Remains the same)
*/
@Override
@Nullable
public float[] embeddings(@NonNull String text) {
if (StringUtils.isEmpty(text)) {
log.warn("Received empty or null text in 'embeddings' method despite upstream guarantees. Returning empty array.");
return new float[0];
}
log.debug("Requesting predictor for single text embedding.");
Predictor<String, float[]> predictor = null;
try {
predictor = predictorPool.take();
log.debug("Predictor acquired. Vectorizing text (first 50 chars): '{}...'",
text.substring(0, Math.min(text.length(), 50)));
float[] response = predictor.predict(text);
log.debug("Single text embedding generated successfully.");
return response;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("Thread interrupted while waiting for a predictor from the pool.", e);
return null;
} catch (TranslateException e) {
log.error("Error during single text prediction for text: '{}...'",
text.substring(0, Math.min(text.length(), 50)), e);
throw new RuntimeException("Embedding generation failed for the provided text.", e);
} finally {
if (predictor != null) {
if (!predictorPool.offer(predictor)) {
log.warn("Failed to return predictor to the pool (pool might be full unexpectedly). Closing predictor.");
predictor.close();
} else {
log.debug("Predictor returned to pool.");
}
}
}
}
/**
* Generates vector embeddings for a batch of text inputs SEQUENTIALLY.
* Assumes input texts are valid (non-null, non-empty).
* Divides inputs into batches and processes them one by one using a predictor from the pool.
* If a batch fails prediction (TranslateException), results for that batch will be null.
*
* @param texts The list of input texts (assumed valid). Cannot be null.
* @return A list of float arrays representing the embeddings. The list will have the same size
* as the input list. Entries corresponding to batches that failed prediction will be null.
*/
@Override
@NonNull
public List<float[]> batchEmbeddings(@NonNull List<String> texts) {
if (texts.isEmpty()) {
log.info("Input text list is empty, returning empty list.");
return Collections.emptyList();
}
log.info("Starting SEQUENTIAL batch embedding generation for {} texts using model '{}'. Max batch size: {}",
texts.size(), model.getName(), maxBatchSize);
// Pre-allocate result array
final float[][] embeddings = new float[texts.size()][];
// Process the texts sequentially in chunks (batches)
for (int i = 0; i < texts.size(); i += maxBatchSize) {
final int batchStartIndex = i;
final int batchEndIndex = Math.min(batchStartIndex + maxBatchSize, texts.size());
// Sublist view is efficient
final List<String> currentBatchTexts = texts.subList(batchStartIndex, batchEndIndex);
if (currentBatchTexts.isEmpty()) {
continue;
}
log.debug("Processing batch sequentially for indices [{} - {}] (size: {})", batchStartIndex, batchEndIndex - 1, currentBatchTexts.size());
Predictor<String, float[]> predictor = null;
try {
// Take a predictor from the pool, waiting if necessary
// This ensures we don't create/destroy predictors constantly
predictor = predictorPool.take();
log.debug("Predictor acquired for sequential batch [{} - {}]", batchStartIndex, batchEndIndex - 1);
// Directly perform batch prediction - THIS CALL BLOCKS until the batch is done
List<float[]> batchResult = predictor.batchPredict(currentBatchTexts);
// Copy results back into the main embeddings array at the correct offset
for (int j = 0; j < batchResult.size(); j++) {
embeddings[batchStartIndex + j] = batchResult.get(j);
}
log.debug("Sequential batch [{} - {}] processed successfully.", batchStartIndex, batchEndIndex - 1);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("Thread interrupted while waiting for predictor or during prediction for sequential batch [{} - {}]. Marking batch as failed.", batchStartIndex, batchEndIndex - 1, e);
// Mark results in this batch as null
Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
// Optionally break the loop if interruption should stop all processing
// break;
} catch (TranslateException e) {
// Catch DJL translation/prediction errors for the batch
log.error("Batch prediction failed via TranslateException for sequential batch [{} - {}]. Marking batch as failed. Error: {}", batchStartIndex, batchEndIndex - 1, e.getMessage(), e);
// Mark results in this batch as null
Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
// Continue to the next batch
} catch (Exception e) {
// Catch any other unexpected exceptions during batch processing
log.error("Unexpected error during sequential batch prediction for indices [{} - {}]. Marking batch as failed.", batchStartIndex, batchEndIndex - 1, e);
Arrays.fill(embeddings, batchStartIndex, batchEndIndex, null);
// Continue to the next batch
} finally {
// Return the predictor to the pool for the next iteration or another request
if (predictor != null) {
if (!predictorPool.offer(predictor)) {
log.warn("Failed to return predictor to the pool after sequential batch [{} - {}]. Closing predictor.", batchStartIndex, batchEndIndex - 1);
predictor.close();
} else {
log.debug("Predictor returned to pool after sequential batch [{} - {}].", batchStartIndex, batchEndIndex - 1);
}
}
}
// End of processing for one batch, the loop continues to the next
}
log.info("Sequential batch embedding generation finished for {} texts.", texts.size());
// Convert the result array (containing embeddings or nulls for failed batches) to a list
return Arrays.asList(embeddings);
}
/**
* Convenience method to get embeddings as a Collection of Floats.
* (Remains the same)
*/
@Override
@NonNull
public Collection<Float> getEmbeddings(@NonNull String text) {
float[] embeddingArray = embeddings(text);
if (embeddingArray == null || embeddingArray.length == 0) {
return Collections.emptyList();
}
List<Float> embeddingList = new ArrayList<>(embeddingArray.length);
for (float value : embeddingArray) {
embeddingList.add(value);
}
return embeddingList;
}
// --- Helper Methods ---
// (initializeModelPath, loadModel, selectDevice, createPredictorPool, extractResourceFromJar)
// Remain the same as the previous full version
/**
* Resolves the model path, extracting it from JAR if necessary.
*/
private String initializeModelPath(String configuredModelUrl, String tempDir) throws IOException {
log.debug("Initializing model path for: {}", configuredModelUrl);
if (StringUtils.isEmpty(configuredModelUrl)) {
throw new IllegalArgumentException("Model URL/path ('vectorizer.model.url') cannot be empty.");
}
if (configuredModelUrl.endsWith(".zip")) {
Optional<URL> modelOpt = resourceLoader.getResource(configuredModelUrl);
if (modelOpt.isPresent()) {
URL modelUrlFull = modelOpt.get();
log.debug("Found model resource via loader: {}", modelUrlFull);
if ("jar".equals(modelUrlFull.getProtocol())) {
log.info("Model resource is inside a JAR. Extracting '{}' to temporary directory '{}'", configuredModelUrl, tempDir);
return extractResourceFromJar(configuredModelUrl, tempDir).toString();
} else {
return modelUrlFull.toString();
}
} else {
log.warn("Model resource '{}' not found via ResourceLoader. Assuming it's a direct file path or URL.", configuredModelUrl);
return configuredModelUrl;
}
} else {
log.info("Model path '{}' does not end with .zip. Assuming it's a directory or pre-loaded model URL.", configuredModelUrl);
return configuredModelUrl;
}
}
/**
* Loads the DJL ZooModel based on the resolved model URL.
*/
private ZooModel<String, float[]> loadModel(String resolvedModelUrl) throws ModelNotFoundException, MalformedModelException, IOException {
Device device = selectDevice();
log.info("Attempting to load model from {} onto device {}", resolvedModelUrl, device);
Criteria<String, float[]> criteria = Criteria.builder()
.setTypes(String.class, float[].class)
.optModelUrls(resolvedModelUrl)
.optEngine("PyTorch")
.optDevice(device)
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.optProgress(new ProgressBar())
.build();
log.info("Loading model via DJL criteria...");
return criteria.loadModel();
}
/**
* Selects the best available device (GPU if available, otherwise CPU).
*/
private Device selectDevice() {
if (Engine.getInstance().getGpuCount() > 0) {
log.info("GPU detected ({} available). Selecting GPU device.", Engine.getInstance().getGpuCount());
return Device.gpu();
}
log.info("No GPU detected. Selecting CPU device.");
return Device.cpu();
}
/**
* Creates a pool of predictors.
*/
private BlockingQueue<Predictor<String, float[]>> createPredictorPool(ZooModel<String, float[]> model, int poolSize) {
BlockingQueue<Predictor<String, float[]>> pool = new ArrayBlockingQueue<>(poolSize);
log.info("Creating predictor pool with size: {}", poolSize);
try {
for (int i = 0; i < poolSize; i++) {
Predictor<String, float[]> predictor = model.newPredictor();
if (!pool.offer(predictor)) {
log.error("Failed to add predictor {} to the pool (unexpected). Closing predictor.", i);
predictor.close();
}
}
log.info("Predictor pool populated successfully.");
} catch (Exception e) {
log.error("Failed to create predictors for the pool.", e);
pool.forEach(Predictor::close);
pool.clear();
throw new RuntimeException("Failed to initialize predictor pool", e);
}
return pool;
}
/**
* Extracts a resource from within a JAR file to a specified target directory.
*/
private URL extractResourceFromJar(String resourcePath, String targetDirectory) throws IOException {
Optional<InputStream> resourceStreamOpt = resourceLoader.getResourceAsStream(resourcePath);
if (resourceStreamOpt.isEmpty()) {
throw new IOException("Resource not found inside JAR: " + resourcePath);
}
Path targetDirPath = Paths.get(targetDirectory);
if (!Files.exists(targetDirPath)) {
log.debug("Creating temporary directory for JAR extraction: {}", targetDirPath);
Files.createDirectories(targetDirPath);
} else if (!Files.isDirectory(targetDirPath)) {
throw new IOException("Target path for extraction exists but is not a directory: " + targetDirectory);
}
String fileName = Paths.get(resourcePath).getFileName().toString();
if (StringUtils.isEmpty(fileName)) {
throw new IOException("Could not determine filename from resource path: " + resourcePath);
}
Path targetFilePath = targetDirPath.resolve(fileName);
log.info("Extracting resource '{}' from JAR to '{}'", resourcePath, targetFilePath);
try (InputStream resourceStream = resourceStreamOpt.get();
FileOutputStream outputStream = new FileOutputStream(targetFilePath.toFile())) {
IOUtils.copy(resourceStream, outputStream);
} catch (IOException e) {
log.error("Failed to copy resource '{}' to '{}'", resourcePath, targetFilePath, e);
try { Files.deleteIfExists(targetFilePath); } catch (IOException cleanupEx) { log.warn("Failed to delete partially extracted file: {}", targetFilePath, cleanupEx); }
throw e;
}
log.info("Resource extracted successfully to: {}", targetFilePath);
return targetFilePath.toUri().toURL();
}
/**
* Cleans up resources on bean destruction: closes predictors, closes model.
* ExecutorService shutdown is removed as it's no longer used.
*/
@PreDestroy
public void close() {
log.info("Closing SentenceVectorizer resources...");
// 1. Shutdown the batch executor service - REMOVED
// if (batchExecutor != null && !batchExecutor.isShutdown()) { ... }
// 2. Close predictors in the pool
if (predictorPool != null) {
int count = predictorPool.size();
log.debug("Closing {} predictors remaining in the pool...", count);
List<Predictor<String,float[]>> predictorsToClose = new ArrayList<>();
predictorPool.drainTo(predictorsToClose);
predictorsToClose.forEach(predictor -> {
try { predictor.close(); } catch (Exception e) { log.warn("Error closing a predictor instance.", e); }
});
predictorPool.clear();
log.debug("Predictor pool cleared and {} predictors closed.", predictorsToClose.size());
} else {
log.debug("Predictor pool is null.");
}
// 3. Close the model
if (model != null) {
try {
log.debug("Closing DJL model: {}", model.getName());
model.close();
log.info("Model closed successfully: {}", model.getName());
} catch (Exception e) {
log.error("Error closing the DJL model: {}", model.getName(), e);
}
} else {
log.debug("Model is null, nothing to close.");
}
log.info("SentenceVectorizer closed.");
}
// ExecutorServiceManager helper class is removed as it's no longer needed.
}