22
22
import com .mongodb .connection .SslSettings ;
23
23
import com .mongodb .internal .TimeoutContext ;
24
24
import com .mongodb .internal .TimeoutSettings ;
25
- import org .junit .jupiter .api .AfterEach ;
26
- import org .junit .jupiter .api .BeforeEach ;
27
25
import org .junit .jupiter .params .ParameterizedTest ;
28
26
import org .junit .jupiter .params .provider .ValueSource ;
29
27
import org .mockito .MockedStatic ;
37
35
import java .nio .channels .SocketChannel ;
38
36
import java .util .concurrent .TimeUnit ;
39
37
38
+ import static com .mongodb .internal .connection .OperationContext .simpleOperationContext ;
40
39
import static java .lang .String .format ;
41
- import static java .util .concurrent .TimeUnit .SECONDS ;
42
- import static org .junit .Assert .assertThrows ;
40
+ import static java .util .concurrent .TimeUnit .MILLISECONDS ;
43
41
import static org .junit .jupiter .api .Assertions .assertFalse ;
44
42
import static org .junit .jupiter .api .Assertions .assertInstanceOf ;
45
43
import static org .junit .jupiter .api .Assertions .assertNotNull ;
44
+ import static org .junit .jupiter .api .Assertions .assertThrows ;
46
45
import static org .junit .jupiter .api .Assertions .assertTrue ;
47
46
import static org .junit .jupiter .api .Assertions .fail ;
48
47
import static org .mockito .Mockito .atLeast ;
51
50
class TlsChannelStreamFunctionalTest {
52
51
private static final SslSettings SSL_SETTINGS = SslSettings .builder ().enabled (true ).build ();
53
52
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1" ;
54
- private ServerSocket serverSocket ;
55
- private int port ;
56
-
57
- @ BeforeEach
58
- void setUp () throws IOException {
59
- serverSocket = new ServerSocket (0 , 1 );
60
- port = serverSocket .getLocalPort ();
61
- }
62
-
63
- @ AfterEach
64
- @ SuppressWarnings ("try" )
65
- void cleanUp () throws IOException {
66
- try (ServerSocket ignored = serverSocket ) {
67
- //ignored
68
- }
69
- }
53
+ private static final int UNREACHABLE_PORT = 65333 ;
70
54
71
55
@ ParameterizedTest
72
56
@ ValueSource (ints = {500 , 1000 , 2000 })
73
- void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires (final int connectTimeout ) throws IOException {
57
+ void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires (final int connectTimeoutMs ) throws IOException {
74
58
//given
75
- try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
59
+ try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
76
60
MockedStatic <SocketChannel > socketChannelMockedStatic = Mockito .mockStatic (SocketChannel .class )) {
77
61
SingleResultSpyCaptor <SocketChannel > singleResultSpyCaptor = new SingleResultSpyCaptor <>();
78
62
socketChannelMockedStatic .when (SocketChannel ::open ).thenAnswer (singleResultSpyCaptor );
79
63
80
- StreamFactory streamFactory = factory .create (SocketSettings .builder ()
81
- .connectTimeout (connectTimeout , TimeUnit .MILLISECONDS )
64
+ StreamFactory streamFactory = streamFactoryFactory .create (SocketSettings .builder ()
65
+ .connectTimeout (connectTimeoutMs , TimeUnit .MILLISECONDS )
82
66
.build (), SSL_SETTINGS );
83
67
84
- Stream stream = streamFactory .create (new ServerAddress (UNREACHABLE_PRIVATE_IP_ADDRESS , port ));
68
+ Stream stream = streamFactory .create (new ServerAddress (UNREACHABLE_PRIVATE_IP_ADDRESS , UNREACHABLE_PORT ));
85
69
long connectOpenStart = System .nanoTime ();
86
70
87
71
//when
72
+ OperationContext operationContext = createOperationContext (connectTimeoutMs );
88
73
MongoSocketOpenException mongoSocketOpenException = assertThrows (MongoSocketOpenException .class , () ->
89
- stream .open (OperationContext
90
- .simpleOperationContext (new TimeoutContext (TimeoutSettings .DEFAULT
91
- .withConnectTimeoutMS (connectTimeout )))));
74
+ stream .open (operationContext ));
92
75
93
76
//then
94
77
long elapsedMs = TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - connectOpenStart );
95
78
// Allow for some timing imprecision due to test overhead.
96
79
int maximumAcceptableTimeoutOvershoot = 300 ;
97
80
98
- assertInstanceOf (InterruptedByTimeoutException .class , mongoSocketOpenException .getCause (),
99
- "Actual cause: " + mongoSocketOpenException .getCause ());
100
- assertFalse (connectTimeout > elapsedMs ,
101
- format ("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d" , connectTimeout , elapsedMs ));
102
- assertTrue (elapsedMs - connectTimeout < maximumAcceptableTimeoutOvershoot ,
103
- format ("Connection timeout overshoot time %d ms should be within %d ms" , elapsedMs - connectTimeout ,
81
+ assertInstanceOf (InterruptedByTimeoutException .class , mongoSocketOpenException .getCause ());
82
+ assertFalse (connectTimeoutMs > elapsedMs ,
83
+ format ("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d" , connectTimeoutMs , elapsedMs ));
84
+ assertTrue (elapsedMs - connectTimeoutMs < maximumAcceptableTimeoutOvershoot ,
85
+ format ("Connection timeout overshoot time %d ms should be within %d ms" , elapsedMs - connectTimeoutMs ,
104
86
maximumAcceptableTimeoutOvershoot ));
105
87
106
88
SocketChannel actualSpySocketChannel = singleResultSpyCaptor .getResult ();
@@ -111,30 +93,30 @@ void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final in
111
93
112
94
@ ParameterizedTest
113
95
@ ValueSource (ints = {0 , 500 , 1000 , 2000 })
114
- void shouldEstablishConnection (final int connectTimeout ) throws IOException , InterruptedException {
96
+ void shouldEstablishConnection (final int connectTimeoutMs ) throws IOException , InterruptedException {
115
97
//given
116
- try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
117
- MockedStatic <SocketChannel > socketChannelMockedStatic = Mockito .mockStatic (SocketChannel .class )) {
98
+ try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
99
+ MockedStatic <SocketChannel > socketChannelMockedStatic = Mockito .mockStatic (SocketChannel .class );
100
+ ServerSocket serverSocket = new ServerSocket (0 , 1 )) {
118
101
SingleResultSpyCaptor <SocketChannel > singleResultSpyCaptor = new SingleResultSpyCaptor <>();
119
102
socketChannelMockedStatic .when (SocketChannel ::open ).thenAnswer (singleResultSpyCaptor );
120
103
121
- StreamFactory streamFactory = factory .create (SocketSettings .builder ()
122
- .connectTimeout (connectTimeout , TimeUnit .MILLISECONDS )
104
+ StreamFactory streamFactory = streamFactoryFactory .create (SocketSettings .builder ()
105
+ .connectTimeout (connectTimeoutMs , TimeUnit .MILLISECONDS )
123
106
.build (), SSL_SETTINGS );
124
107
125
- Stream stream = streamFactory .create (new ServerAddress (serverSocket .getInetAddress (), port ));
108
+ Stream stream = streamFactory .create (new ServerAddress (serverSocket .getInetAddress (), serverSocket . getLocalPort () ));
126
109
try {
127
110
//when
128
- stream .open (OperationContext .simpleOperationContext (
129
- new TimeoutContext (TimeoutSettings .DEFAULT .withConnectTimeoutMS (connectTimeout ))));
111
+ stream .open (createOperationContext (connectTimeoutMs ));
130
112
131
113
//then
132
114
SocketChannel actualSpySocketChannel = singleResultSpyCaptor .getResult ();
133
115
assertNotNull (actualSpySocketChannel , "SocketChannel was not opened" );
134
116
assertTrue (actualSpySocketChannel .isConnected ());
135
117
136
118
// Wait to verify that socket was not closed by timeout.
137
- SECONDS .sleep (3 );
119
+ MILLISECONDS .sleep (connectTimeoutMs * 2L );
138
120
assertTrue (actualSpySocketChannel .isConnected ());
139
121
assertFalse (stream .isClosed ());
140
122
} finally {
@@ -151,7 +133,7 @@ public T getResult() {
151
133
}
152
134
153
135
@ Override
154
- public T answer (InvocationOnMock invocationOnMock ) throws Throwable {
136
+ public T answer (final InvocationOnMock invocationOnMock ) throws Throwable {
155
137
if (result != null ) {
156
138
fail (invocationOnMock .getMethod ().getName () + " was called more then once" );
157
139
}
@@ -160,4 +142,8 @@ public T answer(InvocationOnMock invocationOnMock) throws Throwable {
160
142
return result ;
161
143
}
162
144
}
145
+
146
+ private static OperationContext createOperationContext (final int connectTimeoutMs ) {
147
+ return simpleOperationContext (new TimeoutContext (TimeoutSettings .DEFAULT .withConnectTimeoutMS (connectTimeoutMs )));
148
+ }
163
149
}
0 commit comments