Skip to content

Commit

Permalink
Refactor actuator middleware: move parsing of request (body) into sep…
Browse files Browse the repository at this point in the history
…arate virtual method, so that handler invocation is agnostic of HttpContext
  • Loading branch information
bart-vmware committed Feb 14, 2025
1 parent 06ba5be commit 586a3f8
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,22 @@ internal sealed class CloudFoundryEndpointMiddleware(
ICloudFoundryEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<string, Links>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
private readonly ILogger _logger = loggerFactory.CreateLogger<CloudFoundryEndpointMiddleware>();

protected override async Task<Links> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
protected override Task<string?> ParseRequestAsync(HttpContext httpContext, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(context);
ArgumentNullException.ThrowIfNull(httpContext);

_logger.LogDebug("InvokeAsync({Method}, {Path})", context.Request.Method, context.Request.Path.Value);
string uri = GetRequestUri(context);
return await EndpointHandler.InvokeAsync(uri, cancellationToken);
string scheme = httpContext.Request.Headers.TryGetValue("X-Forwarded-Proto", out StringValues headerScheme)
? headerScheme.ToString()
: httpContext.Request.Scheme;

string uri = $"{scheme}://{httpContext.Request.Host}{httpContext.Request.PathBase}{httpContext.Request.Path}";
return Task.FromResult<string?>(uri);
}

private string GetRequestUri(HttpContext context)
protected override async Task<Links> InvokeEndpointHandlerAsync(string? uri, CancellationToken cancellationToken)
{
HttpRequest request = context.Request;
string scheme = request.Scheme;
ArgumentNullException.ThrowIfNull(uri);

if (request.Headers.TryGetValue("X-Forwarded-Proto", out StringValues headerScheme))
{
scheme = headerScheme.ToString();
}

return $"{scheme}://{request.Host}{request.PathBase}{request.Path}";
return await EndpointHandler.InvokeAsync(uri, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Steeltoe.Management.Endpoint.Configuration;
Expand All @@ -14,9 +13,8 @@ internal sealed class DbMigrationsEndpointMiddleware(
IDbMigrationsEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<object?, Dictionary<string, DbMigrationsDescriptor>>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
protected override async Task<Dictionary<string, DbMigrationsDescriptor>> InvokeEndpointHandlerAsync(HttpContext context,
CancellationToken cancellationToken)
protected override async Task<Dictionary<string, DbMigrationsDescriptor>> InvokeEndpointHandlerAsync(object? request, CancellationToken cancellationToken)
{
return await EndpointHandler.InvokeAsync(null, cancellationToken);
return await EndpointHandler.InvokeAsync(request, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Steeltoe.Management.Endpoint.Configuration;
Expand All @@ -14,8 +13,8 @@ internal sealed class EnvironmentEndpointMiddleware(
IEnvironmentEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<object?, EnvironmentResponse>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
protected override async Task<EnvironmentResponse> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
protected override async Task<EnvironmentResponse> InvokeEndpointHandlerAsync(object? request, CancellationToken cancellationToken)
{
return await EndpointHandler.InvokeAsync(null, context.RequestAborted);
return await EndpointHandler.InvokeAsync(request, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,37 @@ public override ActuatorMetadataProvider GetMetadataProvider()
return new HealthActuatorMetadataProvider(ContentType);
}

protected override async Task<HealthEndpointResponse> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
protected override Task<HealthEndpointRequest?> ParseRequestAsync(HttpContext httpContext, CancellationToken cancellationToken)
{
HealthEndpointOptions currentEndpointOptions = _endpointOptionsMonitor.CurrentValue;
string groupName = GetRequestedHealthGroup(context.Request.Path, currentEndpointOptions, _logger);
ArgumentNullException.ThrowIfNull(httpContext);

if (!IsValidGroup(groupName, currentEndpointOptions))
HealthEndpointRequest? request = null;
HealthEndpointOptions options = _endpointOptionsMonitor.CurrentValue;
string groupName = GetRequestedHealthGroup(httpContext.Request.Path, options);

if (IsValidGroup(groupName, options))
{
return new HealthEndpointResponse
{
Exists = false
};
bool hasClaim = GetHasClaim(httpContext, options);
request = new HealthEndpointRequest(groupName, hasClaim);
}

bool hasClaim = GetHasClaim(context, currentEndpointOptions);

var request = new HealthEndpointRequest(groupName, hasClaim);
return await EndpointHandler.InvokeAsync(request, context.RequestAborted);
return Task.FromResult(request);
}

/// <summary>
/// Returns the last segment of the HTTP request path, which is expected to be the name of a configured health group.
/// </summary>
private static string GetRequestedHealthGroup(PathString requestPath, HealthEndpointOptions endpointOptions, ILogger<HealthEndpointMiddleware> logger)
private string GetRequestedHealthGroup(PathString requestPath, HealthEndpointOptions endpointOptions)
{
string[] requestComponents = requestPath.Value?.Split('/') ?? [];

if (requestComponents.Length > 0 && requestComponents[^1] != endpointOptions.Id)
{
logger.LogTrace("Found group '{HealthGroup}' in the request path.", requestComponents[^1]);
_logger.LogTrace("Found group '{HealthGroup}' in the request path.", requestComponents[^1]);
return requestComponents[^1];
}

logger.LogTrace("Did not find a health group in the request path.");

_logger.LogTrace("Did not find a health group in the request path.");
return string.Empty;
}

Expand All @@ -80,20 +77,33 @@ private static bool GetHasClaim(HttpContext context, HealthEndpointOptions endpo
return claim is { Type: not null, Value: not null } && context.User.HasClaim(claim.Type, claim.Value);
}

protected override async Task WriteResponseAsync(HealthEndpointResponse result, HttpContext context, CancellationToken cancellationToken)
protected override async Task<HealthEndpointResponse> InvokeEndpointHandlerAsync(HealthEndpointRequest? request, CancellationToken cancellationToken)
{
if (request == null)
{
return new HealthEndpointResponse
{
Exists = false
};
}

return await EndpointHandler.InvokeAsync(request, cancellationToken);
}

protected override async Task WriteResponseAsync(HealthEndpointResponse response, HttpContext httpContext, CancellationToken cancellationToken)
{
if (!result.Exists)
if (!response.Exists)
{
context.Response.StatusCode = (int)HttpStatusCode.NotFound;
httpContext.Response.StatusCode = (int)HttpStatusCode.NotFound;
return;
}

if (ManagementOptionsMonitor.CurrentValue.UseStatusCodeFromResponse || UseStatusCodeFromResponseInHeader(context.Request.Headers))
if (ManagementOptionsMonitor.CurrentValue.UseStatusCodeFromResponse || UseStatusCodeFromResponseInHeader(httpContext.Request.Headers))
{
context.Response.StatusCode = ((HealthEndpointHandler)EndpointHandler).GetStatusCode(result);
httpContext.Response.StatusCode = ((HealthEndpointHandler)EndpointHandler).GetStatusCode(response);
}

await base.WriteResponseAsync(result, context, cancellationToken);
await base.WriteResponseAsync(response, httpContext, cancellationToken);
}

private static bool UseStatusCodeFromResponseInHeader(IHeaderDictionary requestHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,34 @@ internal sealed class HeapDumpEndpointMiddleware(
: EndpointMiddleware<object?, string?>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
private readonly ILogger<HeapDumpEndpointMiddleware> _logger = loggerFactory.CreateLogger<HeapDumpEndpointMiddleware>();
private protected override string ContentType { get; } = "application/octet-stream";

protected override async Task<string?> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
private protected override string ContentType => "application/octet-stream";

protected override async Task<string?> InvokeEndpointHandlerAsync(object? request, CancellationToken cancellationToken)
{
return await EndpointHandler.InvokeAsync(null, context.RequestAborted);
return await EndpointHandler.InvokeAsync(request, cancellationToken);
}

protected override async Task WriteResponseAsync(string? fileName, HttpContext context, CancellationToken cancellationToken)
protected override async Task WriteResponseAsync(string? fileName, HttpContext httpContext, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(httpContext);

_logger.LogDebug("Returning: {FileName}", fileName);

if (!File.Exists(fileName))
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
httpContext.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}

context.Response.ContentType = ContentType;
context.Response.Headers.Append("Content-Disposition", $"attachment; filename=\"{Path.GetFileName(fileName)}.gz\"");
context.Response.StatusCode = StatusCodes.Status200OK;
httpContext.Response.ContentType = ContentType;
httpContext.Response.Headers.Append("Content-Disposition", $"attachment; filename=\"{Path.GetFileName(fileName)}.gz\"");
httpContext.Response.StatusCode = StatusCodes.Status200OK;

try
{
await using var inputStream = new FileStream(fileName, FileMode.Open);
await using var outputStream = new GZipStream(context.Response.Body, CompressionLevel.Fastest, true);
await using var outputStream = new GZipStream(httpContext.Response.Body, CompressionLevel.Fastest, true);
await inputStream.CopyToAsync(outputStream, cancellationToken);
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Steeltoe.Management.Endpoint.Configuration;
Expand All @@ -14,8 +13,8 @@ internal sealed class HttpExchangesEndpointMiddleware(
IHttpExchangesEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<object?, HttpExchangesResult>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
protected override async Task<HttpExchangesResult> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
protected override async Task<HttpExchangesResult> InvokeEndpointHandlerAsync(object? request, CancellationToken cancellationToken)
{
return await EndpointHandler.InvokeAsync(null, cancellationToken);
return await EndpointHandler.InvokeAsync(request, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,27 @@ internal sealed class HypermediaEndpointMiddleware(
IActuatorEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<string, Links>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
private readonly ILogger _logger = loggerFactory.CreateLogger<HypermediaEndpointMiddleware>();

protected override async Task<Links> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(context);

_logger.LogDebug("InvokeAsync({Method}, {Path})", context.Request.Method, context.Request.Path.Value);
string requestUri = GetRequestUri(context.Request);
return await EndpointHandler.InvokeAsync(requestUri, cancellationToken);
}

private static string GetRequestUri(HttpRequest request)
protected override Task<string?> ParseRequestAsync(HttpContext httpContext, CancellationToken cancellationToken)
{
string scheme = request.Scheme;
ArgumentNullException.ThrowIfNull(httpContext);

if (request.Headers.TryGetValue("X-Forwarded-Proto", out StringValues headerScheme))
{
scheme = headerScheme.ToString();
}
string scheme = httpContext.Request.Headers.TryGetValue("X-Forwarded-Proto", out StringValues headerScheme)
? headerScheme.ToString()
: httpContext.Request.Scheme;

// request.Host automatically includes or excludes the port based on whether it is standard for the scheme
// ... except when we manually change the scheme to match the X-Forwarded-Proto
if (scheme == "https" && request.Host.Port == 443)
{
return $"{scheme}://{request.Host.Host}{request.PathBase}{request.Path}";
}
string requestUri = scheme == "https" && httpContext.Request.Host.Port == 443
? $"{scheme}://{httpContext.Request.Host.Host}{httpContext.Request.PathBase}{httpContext.Request.Path}"
: $"{scheme}://{httpContext.Request.Host}{httpContext.Request.PathBase}{httpContext.Request.Path}";

return $"{scheme}://{request.Host}{request.PathBase}{request.Path}";
return Task.FromResult<string?>(requestUri);
}

protected override async Task<Links> InvokeEndpointHandlerAsync(string? requestUri, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(requestUri);

return await EndpointHandler.InvokeAsync(requestUri, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Steeltoe.Management.Endpoint.Configuration;
Expand All @@ -14,8 +13,8 @@ internal sealed class InfoEndpointMiddleware(
IInfoEndpointHandler endpointHandler, IOptionsMonitor<ManagementOptions> managementOptionsMonitor, ILoggerFactory loggerFactory)
: EndpointMiddleware<object?, IDictionary<string, object>>(endpointHandler, managementOptionsMonitor, loggerFactory)
{
protected override async Task<IDictionary<string, object>> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken)
protected override async Task<IDictionary<string, object>> InvokeEndpointHandlerAsync(object? request, CancellationToken cancellationToken)
{
return await EndpointHandler.InvokeAsync(null, cancellationToken);
return await EndpointHandler.InvokeAsync(request, cancellationToken);
}
}
Loading

0 comments on commit 586a3f8

Please # to comment.