18
18
19
19
import static com .google .common .base .Preconditions .checkNotNull ;
20
20
import static com .google .common .base .Throwables .throwIfUnchecked ;
21
+ import static java .lang .String .format ;
22
+ import static java .nio .charset .StandardCharsets .UTF_8 ;
21
23
import static java .util .Optional .ofNullable ;
24
+ import static java .util .logging .Logger .getLogger ;
25
+ import static org .openqa .selenium .remote .DriverCommand .NEW_SESSION ;
22
26
23
27
import com .google .common .base .Supplier ;
24
28
import com .google .common .base .Throwables ;
25
29
30
+ import com .google .common .io .CountingOutputStream ;
31
+ import com .google .common .io .FileBackedOutputStream ;
32
+
33
+ import org .openqa .selenium .Capabilities ;
34
+ import org .openqa .selenium .ImmutableCapabilities ;
35
+ import org .openqa .selenium .SessionNotCreatedException ;
26
36
import org .openqa .selenium .WebDriverException ;
27
37
import org .openqa .selenium .remote .Command ;
28
38
import org .openqa .selenium .remote .CommandCodec ;
29
39
import org .openqa .selenium .remote .CommandInfo ;
40
+ import org .openqa .selenium .remote .Dialect ;
30
41
import org .openqa .selenium .remote .DriverCommand ;
31
42
import org .openqa .selenium .remote .HttpCommandExecutor ;
43
+ import org .openqa .selenium .remote .ProtocolHandshake ;
32
44
import org .openqa .selenium .remote .Response ;
45
+ import org .openqa .selenium .remote .ResponseCodec ;
33
46
import org .openqa .selenium .remote .http .HttpClient ;
34
47
import org .openqa .selenium .remote .http .HttpRequest ;
48
+ import org .openqa .selenium .remote .http .HttpResponse ;
35
49
import org .openqa .selenium .remote .http .W3CHttpCommandCodec ;
36
50
import org .openqa .selenium .remote .service .DriverService ;
37
51
52
+ import java .io .BufferedInputStream ;
38
53
import java .io .IOException ;
54
+ import java .io .InputStream ;
55
+ import java .io .OutputStreamWriter ;
56
+ import java .io .Writer ;
39
57
import java .lang .reflect .Field ;
58
+ import java .lang .reflect .InvocationTargetException ;
59
+ import java .lang .reflect .Method ;
40
60
import java .net .ConnectException ;
41
61
import java .net .URL ;
42
62
import java .util .Map ;
@@ -111,6 +131,76 @@ private void setCommandCodec(CommandCodec<HttpRequest> newCodec) {
111
131
setPrivateFieldValue ("commandCodec" , newCodec );
112
132
}
113
133
134
+ private void setResponseCodec (ResponseCodec <HttpResponse > codec ) {
135
+ setPrivateFieldValue ("responseCodec" , codec );
136
+ }
137
+
138
+ private HttpClient getClient () {
139
+ //noinspection unchecked
140
+ return getPrivateFieldValue ("client" , HttpClient .class );
141
+ }
142
+
143
+ private Response createSession (Command command ) throws IOException {
144
+ if (getCommandCodec () != null ) {
145
+ throw new SessionNotCreatedException ("Session already exists" );
146
+ }
147
+ ProtocolHandshake handshake = new ProtocolHandshake () {
148
+ @ SuppressWarnings ("unchecked" )
149
+ public Result createSession (HttpClient client , Command command )
150
+ throws IOException {
151
+ Capabilities desiredCapabilities = (Capabilities ) command .getParameters ().get ("desiredCapabilities" );
152
+ Capabilities desired = desiredCapabilities == null ? new ImmutableCapabilities () : desiredCapabilities ;
153
+
154
+ //the number of bytes before the stream should switch to buffering to a file
155
+ int threshold = (int ) Math .min (Runtime .getRuntime ().freeMemory () / 10 , Integer .MAX_VALUE );
156
+ FileBackedOutputStream os = new FileBackedOutputStream (threshold );
157
+ try {
158
+
159
+ CountingOutputStream counter = new CountingOutputStream (os );
160
+ Writer writer = new OutputStreamWriter (counter , UTF_8 );
161
+ NewAppiumSessionPayload payload = NewAppiumSessionPayload .create (desired );
162
+ payload .writeTo (writer );
163
+
164
+ try (InputStream rawIn = os .asByteSource ().openBufferedStream ();
165
+ BufferedInputStream contentStream = new BufferedInputStream (rawIn )) {
166
+
167
+ Method createSessionMethod = this .getClass ().getSuperclass ()
168
+ .getDeclaredMethod ("createSession" , HttpClient .class , InputStream .class , long .class );
169
+ createSessionMethod .setAccessible (true );
170
+
171
+ Optional <Result > result = (Optional <Result >) createSessionMethod
172
+ .invoke (this , client , contentStream , counter .getCount ());
173
+
174
+ return result .map (result1 -> {
175
+ Result toReturn = result .get ();
176
+ getLogger (ProtocolHandshake .class .getName ())
177
+ .info (format ("Detected dialect: %s" , toReturn .getDialect ()));
178
+ return toReturn ;
179
+ }).orElseThrow (() -> new SessionNotCreatedException (
180
+ format ("Unable to create new remote session. desired capabilities = %s" , desired )));
181
+ } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e ) {
182
+ throw new WebDriverException (format ("It is impossible to create a new session "
183
+ + "because 'createSession' which takes %s, %s and %s was not found "
184
+ + "or it is not accessible" ,
185
+ HttpClient .class .getSimpleName (),
186
+ InputStream .class .getSimpleName (),
187
+ long .class .getSimpleName ()), e );
188
+ }
189
+ } finally {
190
+ os .reset ();
191
+ }
192
+ }
193
+ };
194
+
195
+ ProtocolHandshake .Result result = handshake
196
+ .createSession (getClient (), command );
197
+ Dialect dialect = result .getDialect ();
198
+ setCommandCodec (dialect .getCommandCodec ());
199
+ getAdditionalCommands ().forEach (this ::defineCommand );
200
+ setResponseCodec (dialect .getResponseCodec ());
201
+ return result .createResponse ();
202
+ }
203
+
114
204
@ Override
115
205
public Response execute (Command command ) throws WebDriverException {
116
206
if (DriverCommand .NEW_SESSION .equals (command .getName ())) {
@@ -125,7 +215,7 @@ public Response execute(Command command) throws WebDriverException {
125
215
126
216
Response response ;
127
217
try {
128
- response = super .execute (command );
218
+ response = NEW_SESSION . equals ( command . getName ()) ? createSession ( command ) : super .execute (command );
129
219
} catch (Throwable t ) {
130
220
Throwable rootCause = Throwables .getRootCause (t );
131
221
if (rootCause instanceof ConnectException
0 commit comments