diff --git a/src/main/java/com/databricks/zerobus/ZerobusSdk.java b/src/main/java/com/databricks/zerobus/ZerobusSdk.java index f086137..a8b5429 100644 --- a/src/main/java/com/databricks/zerobus/ZerobusSdk.java +++ b/src/main/java/com/databricks/zerobus/ZerobusSdk.java @@ -10,6 +10,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -160,7 +161,7 @@ public CompletableFuture> logger.debug("Creating stream for table: " + tableProperties.getTableName()); // Create a token supplier that generates a fresh token for each gRPC request - java.util.function.Supplier tokenSupplier = + Supplier tokenSupplier = () -> { try { return TokenFactory.getZerobusToken( @@ -174,14 +175,15 @@ public CompletableFuture> } }; - // Create gRPC stub once with token supplier - it will fetch fresh tokens as needed - ZerobusGrpc.ZerobusStub stub = - stubFactory.createStubWithTokenSupplier( - serverEndpoint, tableProperties.getTableName(), tokenSupplier); + // Create a stub supplier that generates a fresh stub with token supplier each time + Supplier stubSupplier = + () -> + stubFactory.createStubWithTokenSupplier( + serverEndpoint, tableProperties.getTableName(), tokenSupplier); ZerobusStream stream = new ZerobusStream<>( - stub, + stubSupplier, tableProperties, stubFactory, serverEndpoint, diff --git a/src/main/java/com/databricks/zerobus/ZerobusStream.java b/src/main/java/com/databricks/zerobus/ZerobusStream.java index 8abdd39..1ba47f5 100644 --- a/src/main/java/com/databricks/zerobus/ZerobusStream.java +++ b/src/main/java/com/databricks/zerobus/ZerobusStream.java @@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -126,6 +127,7 @@ public class ZerobusStream { private static final int CREATE_STREAM_TIMEOUT_MS = 15000; private ZerobusStub stub; + private final Supplier stubSupplier; final TableProperties tableProperties; private final ZerobusSdkStubFactory stubFactory; private final String serverEndpoint; @@ -352,10 +354,10 @@ private CompletableFuture createStream() { () -> { CompletableFuture createStreamTry = new CompletableFuture<>(); - // The stub was created once with a token supplier, so we don't recreate it here - // The token supplier will provide a fresh token for each gRPC request + // Get a fresh stub from the supplier + stub = stubSupplier.get(); - // Create the gRPC stream with the existing stub + // Create the gRPC stream with the fresh stub streamCreatedEvent = Optional.of(new CompletableFuture<>()); stream = Optional.of( @@ -500,6 +502,9 @@ private void closeStream(boolean hardFailure, Optional excepti try { if (stream.isPresent()) { stream.get().onCompleted(); + if (hardFailure) { + stream.get().cancel("Stream closed", null); + } } } catch (Exception e) { // Ignore errors during stream cleanup - stream may already be closed @@ -528,6 +533,7 @@ private void closeStream(boolean hardFailure, Optional excepti stream = Optional.empty(); streamCreatedEvent = Optional.empty(); streamId = Optional.empty(); + stub = null; this.notifyAll(); } @@ -1073,6 +1079,7 @@ public void onNext(EphemeralStreamResponse response) { String.format( "Server will close the stream in %.3fms. Triggering stream recovery.", durationMs)); + streamFailureInfo.resetFailure(StreamFailureType.SERVER_CLOSED_STREAM); handleStreamFailed(StreamFailureType.SERVER_CLOSED_STREAM, Optional.empty()); } break; @@ -1085,6 +1092,13 @@ public void onNext(EphemeralStreamResponse response) { @Override public void onError(Throwable t) { + synchronized (ZerobusStream.this) { + if (state == StreamState.CLOSED && !stream.isPresent()) { + logger.debug("Ignoring error on already closed stream: " + t.getMessage()); + return; + } + } + Optional error = Optional.of(t); if (t instanceof StatusRuntimeException) { @@ -1336,7 +1350,7 @@ public void close() throws ZerobusException { } public ZerobusStream( - ZerobusStub stub, + Supplier stubSupplier, TableProperties tableProperties, ZerobusSdkStubFactory stubFactory, String serverEndpoint, @@ -1347,7 +1361,8 @@ public ZerobusStream( StreamConfigurationOptions options, ExecutorService zerobusStreamExecutor, ExecutorService ec) { - this.stub = stub; + this.stub = null; + this.stubSupplier = stubSupplier; this.tableProperties = tableProperties; this.stubFactory = stubFactory; this.serverEndpoint = serverEndpoint; diff --git a/src/test/java/com/databricks/zerobus/ZerobusSdkTest.java b/src/test/java/com/databricks/zerobus/ZerobusSdkTest.java index d184405..6a0ec76 100644 --- a/src/test/java/com/databricks/zerobus/ZerobusSdkTest.java +++ b/src/test/java/com/databricks/zerobus/ZerobusSdkTest.java @@ -38,6 +38,7 @@ public class ZerobusSdkTest { private ZerobusSdk zerobusSdk; private ZerobusSdkStubFactory zerobusSdkStubFactory; private org.mockito.MockedStatic tokenFactoryMock; + private io.grpc.stub.ClientCallStreamObserver spiedStream; @BeforeEach public void setUp() { @@ -76,7 +77,10 @@ public void setUp() { (StreamObserver) invocation.getArgument(0); mockedGrpcServer.initialize(ackSender); - return mockedGrpcServer.getMessageReceiver(); + + // Spy on the message receiver to verify cancel() is called + spiedStream = spy(mockedGrpcServer.getMessageReceiver()); + return spiedStream; }) .when(zerobusStub) .ephemeralStream(any()); @@ -378,4 +382,40 @@ public void testCallbackExceptionHandling() throws Exception { stream.close(); assertEquals(StreamState.CLOSED, stream.getState()); } + + @Test + public void testGrpcStreamIsCancelledOnClose() throws Exception { + // Test that the underlying gRPC stream is properly cancelled when stream.close() is called + mockedGrpcServer.injectAckRecord(0); + + TableProperties tableProperties = + new TableProperties<>("test-table", CityPopulationTableRow.getDefaultInstance()); + StreamConfigurationOptions options = + StreamConfigurationOptions.builder().setRecovery(false).build(); + + ZerobusStream stream = + zerobusSdk.createStream(tableProperties, "client-id", "client-secret", options).get(); + + assertEquals(StreamState.OPENED, stream.getState()); + + // Ingest one record + CompletableFuture writeCompleted = + stream.ingestRecord( + CityPopulationTableRow.newBuilder() + .setCityName("test-city") + .setPopulation(1000) + .build()); + + writeCompleted.get(5, TimeUnit.SECONDS); + + // Close the stream + stream.close(); + assertEquals(StreamState.CLOSED, stream.getState()); + + // Verify that cancel() was called on the gRPC stream + verify(spiedStream, times(1)).cancel(anyString(), any()); + + // Also verify onCompleted() was called + verify(spiedStream, times(1)).onCompleted(); + } }