生产者——拦截器 - 969251639/study GitHub Wiki

Kafka生产者的拦截器十分简单,在构造生产者时可以通过以下配置是则该生产者的拦截器

Properties props = new Properties();
...
props.put("interceptor.classes", "xxx");//多个用逗号隔开
...
org.apache.kafka.clients.producer.KafkaProducer<Integer, String> producer = new org.apache.kafka.clients.producer.KafkaProducer<Integer, String>(props);

然后再生产者构造方法中创建拦截器

List<ProducerInterceptor<K, V>> interceptorList = (List) configWithClientId.getConfiguredInstances(
    ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, ProducerInterceptor.class);

所以只要实现了ProducerInterceptor接口,重写ProducerInterceptor接口的方法,设置到生产者中即可
public interface ProducerInterceptor<K, V> extends Configurable { //发送前调用 public ProducerRecord<K, V> onSend(ProducerRecord<K, V> record); //服务器响应消息回调用户方法前或发送失败时 public void onAcknowledgement(RecordMetadata metadata, Exception exception); //拦截器关闭时调用 public void close(); }

接下来看ProducerInterceptor的三个调用时机

  1. 生产者发送时调用send方法,doSend会触发onSend方法
    @Override
    public Future<RecordMetadata> send(ProducerRecord<K, V> record, Callback callback) {
        // intercept the record, which can be potentially modified; this method does not throw exceptions
        ProducerRecord<K, V> interceptedRecord = this.interceptors.onSend(record);
        return doSend(interceptedRecord, callback);
    }
    public ProducerRecord<K, V> onSend(ProducerRecord<K, V> record) {
        ProducerRecord<K, V> interceptRecord = record;
        for (ProducerInterceptor<K, V> interceptor : this.interceptors) {
            try {
                interceptRecord = interceptor.onSend(interceptRecord);//调用拦截器的onSend方法
            } catch (Exception e) {
                // do not propagate interceptor exception, log and continue calling other interceptors
                // be careful not to throw exception from here
                if (record != null)
                    log.warn("Error executing interceptor onSend callback for topic: {}, partition: {}", record.topic(), record.partition(), e);
                else
                    log.warn("Error executing interceptor onSend callback", e);
            }
        }
        return interceptRecord;
    }
  1. 生产者发送的时候,会将拦截器和回调方法一起封装在消息累加器当中
private Future<RecordMetadata> doSend(ProducerRecord<K, V> record, Callback callback) {
    ...
    Callback interceptCallback = new InterceptorCallback<>(callback, this.interceptors, tp);
    ...
    RecordAccumulator.RecordAppendResult result = accumulator.append(tp, timestamp, serializedKey,
       serializedValue, headers, interceptCallback, remainingWaitMs);
    ...
}
    private static class InterceptorCallback<K, V> implements Callback {
        private final Callback userCallback;
        private final ProducerInterceptors<K, V> interceptors;
        private final TopicPartition tp;

        private InterceptorCallback(Callback userCallback, ProducerInterceptors<K, V> interceptors, TopicPartition tp) {
            this.userCallback = userCallback;
            this.interceptors = interceptors;
            this.tp = tp;
        }

        public void onCompletion(RecordMetadata metadata, Exception exception) {
            metadata = metadata != null ? metadata : new RecordMetadata(tp, -1, -1, RecordBatch.NO_TIMESTAMP, Long.valueOf(-1L), -1, -1);
            this.interceptors.onAcknowledgement(metadata, exception);//先调用拦截器方法
            if (this.userCallback != null)//如果用户的回调不为空
                this.userCallback.onCompletion(metadata, exception);//调用用户注册的回调方法
        }
    }

最后会将这些回调的拦截器保存在每个批次的thunks当中

public FutureRecordMetadata tryAppend(long timestamp, byte[] key, byte[] value, Header[] headers, Callback callback, long now) {
    ...
    thunks.add(new Thunk(callback, future));
    ...
}

在Sender发送生产者数据时(还没有IO操作)会调用sendProduceRequest

private void sendProduceRequest(long now, int destination, short acks, int timeout, List<ProducerBatch> batches) {
    ...
    RequestCompletionHandler callback = new RequestCompletionHandler() {
        public void onComplete(ClientResponse response) {
            handleProduceResponse(response, recordsByPartition, time.milliseconds());
        }
    };
    String nodeId = Integer.toString(destination);
    ClientRequest clientRequest = client.newClientRequest(nodeId, requestBuilder, now, acks != 0,
            requestTimeoutMs, callback);
    ...
}

然后会将创建的请求信息ClientRequest放到inFlightRequests中

    private void doSend(ClientRequest clientRequest, boolean isInternalRequest, long now, AbstractRequest request) {
        ...
        InFlightRequest inFlightRequest = new InFlightRequest(
                clientRequest,
                header,
                isInternalRequest,
                request,
                send,
                now);
        this.inFlightRequests.add(inFlightRequest);
        ...
    }

响应后NetworkClient会处理一系列的handler方法

@Override
    public List<ClientResponse> poll(long timeout, long now) {
        ensureActive();

        if (!abortedSends.isEmpty()) {
            // If there are aborted sends because of unsupported version exceptions or disconnects,
            // handle them immediately without waiting for Selector#poll.
            List<ClientResponse> responses = new ArrayList<>();
            handleAbortedSends(responses);
            completeResponses(responses);
            return responses;
        }

        long metadataTimeout = metadataUpdater.maybeUpdate(now);
        try {
            this.selector.poll(Utils.min(timeout, metadataTimeout, defaultRequestTimeoutMs));
        } catch (IOException e) {
            log.error("Unexpected error during I/O", e);
        }

        // process completed actions
        long updatedNow = this.time.milliseconds();
        List<ClientResponse> responses = new ArrayList<>();
        handleCompletedSends(responses, updatedNow);
        handleCompletedReceives(responses, updatedNow);
        handleDisconnections(responses, updatedNow);
        handleConnections();
        handleInitiateApiVersionRequests(updatedNow);
        handleTimedOutRequests(responses, updatedNow);
        completeResponses(responses);

        return responses;
    }

其中handleCompletedReceives方法中的req.completed(body, now)会将InFlightRequest回调信息封装在ClientResponse中

    private void handleCompletedReceives(List<ClientResponse> responses, long now) {
        for (NetworkReceive receive : this.selector.completedReceives()) {
            String source = receive.source();
            InFlightRequest req = inFlightRequests.completeNext(source);
            Struct responseStruct = parseStructMaybeUpdateThrottleTimeMetrics(receive.payload(), req.header,
                throttleTimeSensor, now);
            if (log.isTraceEnabled()) {
                log.trace("Completed receive from node {} for {} with correlation id {}, received {}", req.destination,
                    req.header.apiKey(), req.header.correlationId(), responseStruct);
            }
            // If the received response includes a throttle delay, throttle the connection.
            AbstractResponse body = AbstractResponse.parseResponse(req.header.apiKey(), responseStruct);
            maybeThrottle(body, req.header.apiVersion(), req.destination, now);
            if (req.isInternalRequest && body instanceof MetadataResponse)
                metadataUpdater.handleCompletedMetadataResponse(req.header, now, (MetadataResponse) body);
            else if (req.isInternalRequest && body instanceof ApiVersionsResponse)
                handleApiVersionsResponse(responses, req, now, (ApiVersionsResponse) body);
            else
                responses.add(req.completed(body, now));
        }
    }

        public ClientResponse completed(AbstractResponse response, long timeMs) {
            return new ClientResponse(header, callback, destination, createdTimeMs, timeMs,
                    false, null, null, response);
        }

然后再处理最后的回调时会调用completeResponses

    private void completeResponses(List<ClientResponse> responses) {
        for (ClientResponse response : responses) {
            try {
                response.onComplete();
            } catch (Exception e) {
                log.error("Uncaught error in request completion:", e);
            }
        }
    }
    public void onComplete() {
        if (callback != null)
            callback.onComplete(this);
    }

callback会回调sendProduceRequest方法中的onComplete

        RequestCompletionHandler callback = new RequestCompletionHandler() {
            public void onComplete(ClientResponse response) {
                handleProduceResponse(response, recordsByPartition, time.milliseconds());
            }
        };
private void handleProduceResponse(ClientResponse response, Map<TopicPartition, ProducerBatch> batches, long now) {
    ...
    completeBatch(batch, partResp, correlationId, now, receivedTimeMs + produceResponse.throttleTimeMs());
    ...
}
private void completeBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response, long correlationId,
                               long now, long throttleUntilTimeMs) {
    ...
    completeBatch(batch, response);
    ...
}
private void completeBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) {
    ...

    if (batch.done(response.baseOffset, response.logAppendTime, null)) {
        maybeRemoveFromInflightBatches(batch);
        this.accumulator.deallocate(batch);
    }
    ...
}
public boolean done(long baseOffset, long logAppendTime, RuntimeException exception) {
    ...
    if (this.finalState.compareAndSet(null, tryFinalState)) {
        completeFutureAndFireCallbacks(baseOffset, logAppendTime, exception);
        return true;
    }
    ...
}
private void completeFutureAndFireCallbacks(long baseOffset, long logAppendTime, RuntimeException exception) {
    ...
    // execute callbacks
    for (Thunk thunk : thunks) {
        try {
            if (exception == null) {
                RecordMetadata metadata = thunk.future.value();
                if (thunk.callback != null)
                    thunk.callback.onCompletion(metadata, null);
            } else {
                if (thunk.callback != null)
                    thunk.callback.onCompletion(null, exception);
            }
        } catch (Exception e) {
            log.error("Error executing user-provided callback on message for topic-partition '{}'", topicPartition, e);
        }
    }

    ...
}

最后会执行thunk中的onCompletion方法,也就是InterceptorCallback中的onCompletion方法

另外,发送失败时也会调用onAcknowledgement方法

private Future<RecordMetadata> doSend(ProducerRecord<K, V> record, Callback callback) {
    try {
        ...
    }catch (ApiException e) {
        log.debug("Exception occurred during message send:", e);
        if (callback != null)
            callback.onCompletion(null, e);
        this.errors.record();
        this.interceptors.onSendError(record, tp, e);
        return new FutureFailure(e);
    } catch (InterruptedException e) {
        this.errors.record();
        this.interceptors.onSendError(record, tp, e);
        throw new InterruptException(e);
    } catch (BufferExhaustedException e) {
        this.errors.record();
        this.metrics.sensor("buffer-exhausted-records").record();
        this.interceptors.onSendError(record, tp, e);
        throw e;
    } catch (KafkaException e) {
        this.errors.record();
        this.interceptors.onSendError(record, tp, e);
        throw e;
    } catch (Exception e) {
        // we notify interceptor about all exceptions, since onSend is called before anything else in this method
        this.interceptors.onSendError(record, tp, e);
        throw e;
    }
}

    public void onSendError(ProducerRecord<K, V> record, TopicPartition interceptTopicPartition, Exception exception) {
        for (ProducerInterceptor<K, V> interceptor : this.interceptors) {
            try {
                if (record == null && interceptTopicPartition == null) {
                    interceptor.onAcknowledgement(null, exception);
                } else {
                    if (interceptTopicPartition == null) {
                        interceptTopicPartition = new TopicPartition(record.topic(),
                                record.partition() == null ? RecordMetadata.UNKNOWN_PARTITION : record.partition());
                    }
                    interceptor.onAcknowledgement(new RecordMetadata(interceptTopicPartition, -1, -1,
                                    RecordBatch.NO_TIMESTAMP, Long.valueOf(-1L), -1, -1), exception);
                }
            } catch (Exception e) {
                // do not propagate interceptor exceptions, just log
                log.warn("Error executing interceptor onAcknowledgement callback", e);
            }
        }
    }
⚠️ **GitHub.com Fallback** ⚠️