Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Issue #11398 - allow frames to be demanded in WebSocket onOpen #11402

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private enum DemandState
private final Flusher flusher;
private final Random random;
private DemandState demand = DemandState.NOT_DEMANDING;
private boolean fillingAndParsing;
private boolean fillingAndParsing = true;
private final LongAdder messagesIn = new LongAdder();
private final LongAdder bytesIn = new LongAdder();
// Read / Parse variables
Expand Down Expand Up @@ -199,11 +199,6 @@ public void setUseOutputDirectByteBuffers(boolean useOutputDirectByteBuffers)
this.useOutputDirectByteBuffers = useOutputDirectByteBuffers;
}

/**
* Physical connection disconnect.
* <p>
* Not related to WebSocket close handshake.
*/
@Override
public void onClose(Throwable cause)
{
Expand Down Expand Up @@ -236,11 +231,6 @@ public boolean onIdleExpired(TimeoutException timeoutException)
return true;
}

/**
* Event for no activity on connection (read or write)
*
* @return true to signal that the endpoint must be closed, false to keep the endpoint open
*/
@Override
protected boolean onReadTimeout(TimeoutException timeout)
{
Expand Down Expand Up @@ -394,7 +384,7 @@ public boolean moreDemand()
case NOT_DEMANDING ->
{
fillingAndParsing = false;
if (!networkBuffer.hasRemaining())
if (networkBuffer != null && !networkBuffer.hasRemaining())
releaseNetworkBuffer();
return false;
}
Expand Down Expand Up @@ -530,9 +520,6 @@ protected void setInitialBuffer(ByteBuffer initialBuffer)
BufferUtil.flipToFlush(buffer, 0);
}

/**
* Physical connection Open.
*/
@Override
public void onOpen()
{
Expand All @@ -542,6 +529,8 @@ public void onOpen()
// Open Session
super.onOpen();
coreSession.onOpen();
if (moreDemand())
fillAndParse();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,31 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Frame;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.server.WebSocketUpgradeHandler;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -74,6 +81,37 @@ public void onWebSocketFrame(Frame frame, Callback callback)
}
}

@WebSocket(autoDemand = false)
public static class OnOpenSocket implements Session.Listener
{
CountDownLatch onOpen = new CountDownLatch(1);
BlockingQueue<String> textMessages = new BlockingArrayQueue<>();
Session session;

@Override
public void onWebSocketOpen(Session session)
{
try
{
this.session = session;
session.demand();
onOpen.await();
}
catch (InterruptedException e)
{
throw new RuntimeException(e);
}
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
if (frame.getOpCode() == OpCode.TEXT)
textMessages.add(BufferUtil.toString(frame.getPayload()));
callback.succeed();
}
}

@WebSocket(autoDemand = false)
public static class PingSocket extends ListenerSocket
{
Expand All @@ -99,6 +137,7 @@ public void onWebSocketFrame(Frame frame, Callback callback)
private final WebSocketClient client = new WebSocketClient();
private final SuspendSocket serverSocket = new SuspendSocket();
private final ListenerSocket listenerSocket = new ListenerSocket();
private final OnOpenSocket onOpenSocket = new OnOpenSocket();
private final PingSocket pingSocket = new PingSocket();
private ServerConnector connector;

Expand All @@ -113,6 +152,7 @@ public void start() throws Exception
container.addMapping("/suspend", (rq, rs, cb) -> serverSocket);
container.addMapping("/listenerSocket", (rq, rs, cb) -> listenerSocket);
container.addMapping("/ping", (rq, rs, cb) -> pingSocket);
container.addMapping("/onOpen", (rq, rs, cb) -> onOpenSocket);
});

server.setHandler(wsHandler);
Expand Down Expand Up @@ -213,4 +253,27 @@ public void testServerPing() throws Exception
frame = pingSocket.frames.get(2);
assertThat(frame.getType(), is(Frame.Type.CLOSE));
}

@Test
public void testDemandInOnOpen() throws Exception
{
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/onOpen");
EventSocket clientSocket = new EventSocket();

Future<Session> connect = client.connect(clientSocket, uri);
Session session = connect.get(5, TimeUnit.SECONDS);
session.sendText("test-text", Callback.NOOP);

// We cannot receive messages while in onOpen, even if we have demanded.
assertNull(onOpenSocket.textMessages.poll(1, TimeUnit.SECONDS));

// Once we leave onOpen we receive the message.
onOpenSocket.onOpen.countDown();
String received = onOpenSocket.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(received, equalTo("test-text"));

session.close();
assertTrue(clientSocket.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientSocket.closeCode, equalTo(CloseStatus.NORMAL));
}
}
Loading