diff --git a/fusionauth-netcore-client/src/io/fusionauth/DefaultRESTClient.cs b/fusionauth-netcore-client/src/io/fusionauth/DefaultRESTClient.cs index aaa5e34a..3b414d73 100644 --- a/fusionauth-netcore-client/src/io/fusionauth/DefaultRESTClient.cs +++ b/fusionauth-netcore-client/src/io/fusionauth/DefaultRESTClient.cs @@ -15,6 +15,7 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net.Http; @@ -27,19 +28,15 @@ using Newtonsoft.Json.Serialization; namespace io.fusionauth { - class DefaultRESTClient : IRESTClient { - public HttpClient httpClient; - - public HttpContent content; - - public string method = "GET"; - - public String uri = ""; - - public List> parameters = new List>(); - - public Dictionary headers = new Dictionary(); - + internal class DefaultRESTClient : IRESTClient { + private readonly HttpClient _httpClient; + private readonly List> _parameters = new List>(); + private readonly Dictionary _headers = new Dictionary(); + + private HttpContent _content; + private string _method = "GET"; + private string _uri = ""; + private static readonly JsonSerializerSettings SerializerSettings = new JsonSerializerSettings { NullValueHandling = NullValueHandling.Ignore, @@ -51,9 +48,11 @@ class DefaultRESTClient : IRESTClient { }, ContractResolver = new DefaultContractResolver() }; + + private static readonly ConcurrentDictionary HttpClients = new ConcurrentDictionary(); public DefaultRESTClient(string host) { - httpClient = new HttpClient {BaseAddress = new Uri(host)}; + _httpClient = GetOrCreateHttpClient(host); } /** @@ -75,11 +74,11 @@ public override IRESTClient withUriSegment(string segment) { return this; } - if (uri[uri.Length - 1] != '/') { - uri += '/'; + if (_uri[_uri.Length - 1] != '/') { + _uri += '/'; } - uri = uri + segment; + _uri += segment; return this; } @@ -90,7 +89,7 @@ public override IRESTClient withUriSegment(string segment) { * @param value The value of the header. */ public override IRESTClient withHeader(string key, string value) { - headers[key] = value; + _headers[key] = value; return this; } @@ -101,7 +100,7 @@ public override IRESTClient withHeader(string key, string value) { */ public override IRESTClient withFormData(FormUrlEncodedContent body) { - content = body; + _content = body; return this; } @@ -111,7 +110,7 @@ public override IRESTClient withFormData(FormUrlEncodedContent body) * @param body The object to be written to the request body as JSON. */ public override IRESTClient withJSONBody(object body) { - content = new StringContent(JsonConvert.SerializeObject(body, SerializerSettings), Encoding.UTF8, + _content = new StringContent(JsonConvert.SerializeObject(body, SerializerSettings), Encoding.UTF8, "application/json"); return this; } @@ -121,7 +120,7 @@ public override IRESTClient withJSONBody(object body) { */ public override IRESTClient withMethod(string method) { if (method != null) { - this.method = method; + this._method = method; } return this; @@ -132,7 +131,7 @@ public override IRESTClient withMethod(string method) { */ public override IRESTClient withUri(string uri) { if (uri != null) { - this.uri = uri; + this._uri = uri; } return this; @@ -145,13 +144,13 @@ public override IRESTClient withUri(string uri) { * @param value The value of the parameter, may be a string, object or number. */ public override IRESTClient withParameter(string name, string value) { - parameters.Add(new KeyValuePair(name, value)); + _parameters.Add(new KeyValuePair(name, value)); return this; } - private string getFullUri() { + private string GetFullUri() { var paramString = "?"; - foreach (var (key, value) in parameters.Select(x => (x.Key, x.Value))) { + foreach (var (key, value) in _parameters.Select(x => (x.Key, x.Value))) { if (!paramString.EndsWith("?")) { paramString += "&"; } @@ -159,72 +158,84 @@ private string getFullUri() { paramString += key + "=" + value; } - return uri + paramString; + return _uri + paramString; } - private Task baseRequest() { - foreach (var (key, value) in headers.Select(x => (x.Key, x.Value))) { - // .Add performs additional validation on the 'value' that may fail if an API key contains a '=' character. - // - Bypass this additional validation for the Authorization header. If we find other edge cases, perhaps - // we should just always use TryAddWithoutValidation unless there is a security risk. - if (key == "Authorization") { - httpClient.DefaultRequestHeaders.TryAddWithoutValidation(key, value); - } else { - httpClient.DefaultRequestHeaders.Add(key, value); - } + private HttpRequestMessage BuildRequest() { + var requestUri = GetFullUri(); + + var request = new HttpRequestMessage(); + + request.RequestUri = new Uri(requestUri, UriKind.RelativeOrAbsolute); + + foreach (var (key, value) in _headers.Select(x => (x.Key, x.Value))) { + // .Add performs additional validation on the 'value' that may fail if an API key contains a '=' character. + // - Bypass this additional validation for the Authorization header. If we find other edge cases, perhaps + // we should just always use TryAddWithoutValidation unless there is a security risk. + if (key == "Authorization") { + request.Headers.TryAddWithoutValidation(key, value); + } else { + request.Headers.Add(key, value); + } + } + + if (_content != null) + { + request.Content = _content; } - var requestUri = getFullUri(); - switch (method.ToUpper()) { + switch (_method.ToUpper()) { case "GET": - return httpClient.GetAsync(requestUri); - case "DELETE": - if (content != null) { - var request = new HttpRequestMessage(HttpMethod.Delete, requestUri); - request.Content = content; - return httpClient.SendAsync(request); - } else { - return httpClient.DeleteAsync(requestUri); - } + request.Method = HttpMethod.Get; + break; + case "DELETE": + request.Method = HttpMethod.Delete; + break; case "PUT": - return httpClient.PutAsync(requestUri, content); + request.Method = HttpMethod.Put; + break; case "POST": - return httpClient.PostAsync(requestUri, content); + request.Method = HttpMethod.Post; + break; case "PATCH": - var patchRequest = new HttpRequestMessage(); - patchRequest.Method = new HttpMethod("PATCH"); - patchRequest.Content = content; - patchRequest.RequestUri = new Uri(requestUri, UriKind.RelativeOrAbsolute); - return httpClient.SendAsync(patchRequest); + request.Method = new HttpMethod("PATCH"); + break; default: throw new MissingMethodException("This REST client does not support that method. (yet?)"); } + + return request; } - public override Task> goAsync() { - return baseRequest() - .ContinueWith(task => { - var clientResponse = new ClientResponse(); - try - { - var result = task.Result; - clientResponse.statusCode = (int)result.StatusCode; - if (clientResponse.statusCode >= 300) { - clientResponse.errorResponse = - JsonConvert.DeserializeObject(result.Content.ReadAsStringAsync().Result, SerializerSettings); + public override async Task> goAsync() { + var clientResponse = new ClientResponse(); + + try + { + var request = BuildRequest(); + var result = await _httpClient.SendAsync(request).ConfigureAwait(false); + + clientResponse.statusCode = (int)result.StatusCode; + + var responseContent = await result.Content.ReadAsStringAsync().ConfigureAwait(false); + + if (clientResponse.statusCode >= 300) + { + clientResponse.errorResponse = JsonConvert.DeserializeObject(responseContent, SerializerSettings); + } + else + { + clientResponse.successResponse = JsonConvert.DeserializeObject(responseContent, SerializerSettings); + } } - else { - clientResponse.successResponse = - JsonConvert.DeserializeObject(result.Content.ReadAsStringAsync().Result, SerializerSettings); + catch (Exception e) + { + clientResponse.exception = e; } - } - catch (Exception e) - { - clientResponse.exception = e; - } - return clientResponse; - }); + return clientResponse; } + + private static HttpClient GetOrCreateHttpClient(string host) => HttpClients.GetOrAdd(host, (_) => new HttpClient { BaseAddress = new Uri(host) }); } } \ No newline at end of file