Skip to content

Re-use HttpClient per host #61

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 86 additions & 75 deletions fusionauth-netcore-client/src/io/fusionauth/DefaultRESTClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
Expand All @@ -27,19 +28,15 @@
using Newtonsoft.Json.Serialization;

namespace io.fusionauth {
class DefaultRESTClient : IRESTClient {
public HttpClient httpClient;

public HttpContent content;

public string method = "GET";

public String uri = "";

public List<KeyValuePair<string, string>> parameters = new List<KeyValuePair<string, string>>();

public Dictionary<string, string> headers = new Dictionary<string, string>();

internal class DefaultRESTClient : IRESTClient {
private readonly HttpClient _httpClient;
private readonly List<KeyValuePair<string, string>> _parameters = new List<KeyValuePair<string, string>>();
private readonly Dictionary<string, string> _headers = new Dictionary<string, string>();

private HttpContent _content;
private string _method = "GET";
private string _uri = "";

private static readonly JsonSerializerSettings SerializerSettings = new JsonSerializerSettings
{
NullValueHandling = NullValueHandling.Ignore,
Expand All @@ -51,9 +48,11 @@ class DefaultRESTClient : IRESTClient {
},
ContractResolver = new DefaultContractResolver()
};

private static readonly ConcurrentDictionary<string, HttpClient> HttpClients = new ConcurrentDictionary<string, HttpClient>();

public DefaultRESTClient(string host) {
httpClient = new HttpClient {BaseAddress = new Uri(host)};
_httpClient = GetOrCreateHttpClient(host);
}

/**
Expand All @@ -75,11 +74,11 @@ public override IRESTClient withUriSegment(string segment) {
return this;
}

if (uri[uri.Length - 1] != '/') {
uri += '/';
if (_uri[_uri.Length - 1] != '/') {
_uri += '/';
}

uri = uri + segment;
_uri += segment;
return this;
}

Expand All @@ -90,7 +89,7 @@ public override IRESTClient withUriSegment(string segment) {
* @param value The value of the header.
*/
public override IRESTClient withHeader(string key, string value) {
headers[key] = value;
_headers[key] = value;
return this;
}

Expand All @@ -101,7 +100,7 @@ public override IRESTClient withHeader(string key, string value) {
*/
public override IRESTClient withFormData(FormUrlEncodedContent body)
{
content = body;
_content = body;
return this;
}

Expand All @@ -111,7 +110,7 @@ public override IRESTClient withFormData(FormUrlEncodedContent body)
* @param body The object to be written to the request body as JSON.
*/
public override IRESTClient withJSONBody(object body) {
content = new StringContent(JsonConvert.SerializeObject(body, SerializerSettings), Encoding.UTF8,
_content = new StringContent(JsonConvert.SerializeObject(body, SerializerSettings), Encoding.UTF8,
"application/json");
return this;
}
Expand All @@ -121,7 +120,7 @@ public override IRESTClient withJSONBody(object body) {
*/
public override IRESTClient withMethod(string method) {
if (method != null) {
this.method = method;
this._method = method;
}

return this;
Expand All @@ -132,7 +131,7 @@ public override IRESTClient withMethod(string method) {
*/
public override IRESTClient withUri(string uri) {
if (uri != null) {
this.uri = uri;
this._uri = uri;
}

return this;
Expand All @@ -145,86 +144,98 @@ public override IRESTClient withUri(string uri) {
* @param value The value of the parameter, may be a string, object or number.
*/
public override IRESTClient withParameter(string name, string value) {
parameters.Add(new KeyValuePair<string, string>(name, value));
_parameters.Add(new KeyValuePair<string, string>(name, value));
return this;
}

private string getFullUri() {
private string GetFullUri() {
var paramString = "?";
foreach (var (key, value) in parameters.Select(x => (x.Key, x.Value))) {
foreach (var (key, value) in _parameters.Select(x => (x.Key, x.Value))) {
if (!paramString.EndsWith("?")) {
paramString += "&";
}

paramString += key + "=" + value;
}

return uri + paramString;
return _uri + paramString;
}

private Task<HttpResponseMessage> baseRequest() {
foreach (var (key, value) in headers.Select(x => (x.Key, x.Value))) {
// .Add performs additional validation on the 'value' that may fail if an API key contains a '=' character.
// - Bypass this additional validation for the Authorization header. If we find other edge cases, perhaps
// we should just always use TryAddWithoutValidation unless there is a security risk.
if (key == "Authorization") {
httpClient.DefaultRequestHeaders.TryAddWithoutValidation(key, value);
} else {
httpClient.DefaultRequestHeaders.Add(key, value);
}
private HttpRequestMessage BuildRequest() {
var requestUri = GetFullUri();

var request = new HttpRequestMessage();

request.RequestUri = new Uri(requestUri, UriKind.RelativeOrAbsolute);

foreach (var (key, value) in _headers.Select(x => (x.Key, x.Value))) {
// .Add performs additional validation on the 'value' that may fail if an API key contains a '=' character.
// - Bypass this additional validation for the Authorization header. If we find other edge cases, perhaps
// we should just always use TryAddWithoutValidation unless there is a security risk.
if (key == "Authorization") {
request.Headers.TryAddWithoutValidation(key, value);
} else {
request.Headers.Add(key, value);
}
}

if (_content != null)
{
request.Content = _content;
}

var requestUri = getFullUri();
switch (method.ToUpper()) {
switch (_method.ToUpper()) {
case "GET":
return httpClient.GetAsync(requestUri);
case "DELETE":
if (content != null) {
var request = new HttpRequestMessage(HttpMethod.Delete, requestUri);
request.Content = content;
return httpClient.SendAsync(request);
} else {
return httpClient.DeleteAsync(requestUri);
}
request.Method = HttpMethod.Get;
break;
case "DELETE":
request.Method = HttpMethod.Delete;
break;
case "PUT":
return httpClient.PutAsync(requestUri, content);
request.Method = HttpMethod.Put;
break;
case "POST":
return httpClient.PostAsync(requestUri, content);
request.Method = HttpMethod.Post;
break;
case "PATCH":
var patchRequest = new HttpRequestMessage();
patchRequest.Method = new HttpMethod("PATCH");
patchRequest.Content = content;
patchRequest.RequestUri = new Uri(requestUri, UriKind.RelativeOrAbsolute);
return httpClient.SendAsync(patchRequest);
request.Method = new HttpMethod("PATCH");
break;
default:
throw new MissingMethodException("This REST client does not support that method. (yet?)");
}

return request;
}

public override Task<ClientResponse<T>> goAsync<T>() {
return baseRequest()
.ContinueWith(task => {
var clientResponse = new ClientResponse<T>();
try
{
var result = task.Result;
clientResponse.statusCode = (int)result.StatusCode;
if (clientResponse.statusCode >= 300) {
clientResponse.errorResponse =
JsonConvert.DeserializeObject<Errors>(result.Content.ReadAsStringAsync().Result, SerializerSettings);
public override async Task<ClientResponse<T>> goAsync<T>() {
var clientResponse = new ClientResponse<T>();

try
{
var request = BuildRequest();
var result = await _httpClient.SendAsync(request).ConfigureAwait(false);

clientResponse.statusCode = (int)result.StatusCode;

var responseContent = await result.Content.ReadAsStringAsync().ConfigureAwait(false);

if (clientResponse.statusCode >= 300)
{
clientResponse.errorResponse = JsonConvert.DeserializeObject<Errors>(responseContent, SerializerSettings);
}
else
{
clientResponse.successResponse = JsonConvert.DeserializeObject<T>(responseContent, SerializerSettings);
}
}
else {
clientResponse.successResponse =
JsonConvert.DeserializeObject<T>(result.Content.ReadAsStringAsync().Result, SerializerSettings);
catch (Exception e)
{
clientResponse.exception = e;
}
}
catch (Exception e)
{
clientResponse.exception = e;
}

return clientResponse;
});
return clientResponse;
}

private static HttpClient GetOrCreateHttpClient(string host) => HttpClients.GetOrAdd(host, (_) => new HttpClient { BaseAddress = new Uri(host) });
}
}