From c7847968b309eb71edbe1da6f6cfa4d1c0b21341 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Fri, 18 Aug 2023 12:11:49 -0500 Subject: [PATCH] HTTP API support (#4543) --- common/config/config.go | 8 +- common/headers/version_checker.go | 1 + common/metrics/metric_defs.go | 6 +- common/metrics/metricstest/capture_handler.go | 146 ++++++ .../rpc/encryption/fixedTLSConfigProvider.go | 79 +++ common/rpc/grpc.go | 4 + config/development-cass-archival.yaml | 1 + config/development-cass-es.yaml | 1 + config/development-cass-s3.yaml | 1 + config/development-cass.yaml | 1 + config/development-cluster-a.yaml | 1 + config/development-mysql-es.yaml | 1 + config/development-mysql.yaml | 1 + config/development-mysql8.yaml | 1 + config/development-postgres-es.yaml | 1 + config/development-postgres.yaml | 1 + config/development-postgres12.yaml | 1 + config/development-sqlite-file.yaml | 1 + config/development-sqlite.yaml | 1 + docker/config_template.yaml | 1 + go.mod | 11 +- go.sum | 25 +- service/frontend/fx.go | 54 +- service/frontend/http_api_server.go | 494 ++++++++++++++++++ service/frontend/service.go | 47 +- service/frontend/workflow_handler.go | 4 +- service/frontend/workflow_handler_test.go | 27 - tests/flag.go | 2 + tests/http_api_test.go | 261 +++++++++ tests/integrationbase.go | 3 + tests/onebox.go | 238 ++++++--- tests/test_cluster.go | 74 ++- .../clientintegrationtestcluster.yaml | 1 + .../tls_integration_test_cluster.yaml | 8 + tests/testutils/certificate.go | 6 +- tests/tls_test.go | 152 ++++++ 36 files changed, 1520 insertions(+), 145 deletions(-) create mode 100644 common/metrics/metricstest/capture_handler.go create mode 100644 common/rpc/encryption/fixedTLSConfigProvider.go create mode 100644 service/frontend/http_api_server.go create mode 100644 tests/http_api_test.go create mode 100644 tests/testdata/tls_integration_test_cluster.yaml create mode 100644 tests/tls_test.go diff --git a/common/config/config.go b/common/config/config.go index 0c0d0b1bb84..207bb35d0dc 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -85,7 +85,7 @@ type ( // RPC contains the rpc config items RPC struct { - // GRPCPort is the port on which gRPC will listen + // GRPCPort is the port on which gRPC will listen GRPCPort int `yaml:"grpcPort"` // Port used for membership listener MembershipPort int `yaml:"membershipPort"` @@ -95,6 +95,12 @@ type ( // check net.ParseIP for supported syntax, only IPv4 is supported, // mutually exclusive with `BindOnLocalHost` option BindOnIP string `yaml:"bindOnIP"` + // HTTPPort is the port on which HTTP will listen. If unset/0, HTTP will be + // disabled. This setting only applies to the frontend service. + HTTPPort int `yaml:"httpPort"` + // HTTPAdditionalForwardedHeaders adds additional headers to the default set + // forwarded from HTTP to gRPC. + HTTPAdditionalForwardedHeaders []string `yaml:"httpAdditionalForwardedHeaders"` } // Global contains config items that apply process-wide to all services diff --git a/common/headers/version_checker.go b/common/headers/version_checker.go index 7c09ad48519..240c960251c 100644 --- a/common/headers/version_checker.go +++ b/common/headers/version_checker.go @@ -38,6 +38,7 @@ import ( const ( ClientNameServer = "temporal-server" + ClientNameServerHTTP = "temporal-server-http" ClientNameGoSDK = "temporal-go" ClientNameJavaSDK = "temporal-java" ClientNamePHPSDK = "temporal-php" diff --git a/common/metrics/metric_defs.go b/common/metrics/metric_defs.go index 51ee1a42f73..41798a06354 100644 --- a/common/metrics/metric_defs.go +++ b/common/metrics/metric_defs.go @@ -1163,7 +1163,7 @@ const ( var ( ServiceRequests = NewCounterDef( "service_requests", - WithDescription("The number of gRPC requests received by the service."), + WithDescription("The number of RPC requests received by the service."), ) ServicePendingRequests = NewGaugeDef("service_pending_requests") ServiceFailures = NewCounterDef("service_errors") @@ -1231,6 +1231,10 @@ var ( VersionCheckFailedCount = NewCounterDef("version_check_failed") VersionCheckRequestFailedCount = NewCounterDef("version_check_request_failed") VersionCheckLatency = NewTimerDef("version_check_latency") + HTTPServiceRequests = NewCounterDef( + "http_service_requests", + WithDescription("The number of HTTP requests received by the service."), + ) // History CacheRequests = NewCounterDef("cache_requests") diff --git a/common/metrics/metricstest/capture_handler.go b/common/metrics/metricstest/capture_handler.go new file mode 100644 index 00000000000..107241062d8 --- /dev/null +++ b/common/metrics/metricstest/capture_handler.go @@ -0,0 +1,146 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package metricstest + +import ( + "sync" + "time" + + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/metrics" +) + +// CapturedRecording is a single recording. Fields here should not be mutated. +type CapturedRecording struct { + Value any + Tags map[string]string + Unit metrics.MetricUnit +} + +// Capture is a specific capture instance. +type Capture struct { + recordings map[string][]*CapturedRecording + recordingsLock sync.RWMutex +} + +// Snapshot returns a copy of all metrics recorded, keyed by name. +func (c *Capture) Snapshot() map[string][]*CapturedRecording { + c.recordingsLock.RLock() + defer c.recordingsLock.RUnlock() + ret := make(map[string][]*CapturedRecording, len(c.recordings)) + for k, v := range c.recordings { + recs := make([]*CapturedRecording, len(v)) + copy(recs, v) + ret[k] = recs + } + return ret +} + +func (c *Capture) record(name string, r *CapturedRecording) { + c.recordingsLock.Lock() + defer c.recordingsLock.Unlock() + c.recordings[name] = append(c.recordings[name], r) +} + +// CaptureHandler is a [metrics.Handler] that captures each metric recording. +type CaptureHandler struct { + tags []metrics.Tag + captures map[*Capture]struct{} + capturesLock *sync.RWMutex +} + +var _ metrics.Handler = (*CaptureHandler)(nil) + +// NewCaptureHandler creates a new [metrics.Handler] that captures. +func NewCaptureHandler() *CaptureHandler { + return &CaptureHandler{ + captures: map[*Capture]struct{}{}, + capturesLock: &sync.RWMutex{}, + } +} + +// StartCapture returns a started capture. StopCapture should be called on +// complete. +func (c *CaptureHandler) StartCapture() *Capture { + capture := &Capture{recordings: map[string][]*CapturedRecording{}} + c.capturesLock.Lock() + defer c.capturesLock.Unlock() + c.captures[capture] = struct{}{} + return capture +} + +// StopCapture stops capturing metrics for the given capture instance. +func (c *CaptureHandler) StopCapture(capture *Capture) { + c.capturesLock.Lock() + defer c.capturesLock.Unlock() + delete(c.captures, capture) +} + +// WithTags implements [metrics.Handler.WithTags]. +func (c *CaptureHandler) WithTags(tags ...metrics.Tag) metrics.Handler { + return &CaptureHandler{ + tags: append(append(make([]metrics.Tag, 0, len(c.tags)+len(tags)), c.tags...), tags...), + captures: c.captures, + capturesLock: c.capturesLock, + } +} + +func (c *CaptureHandler) record(name string, v any, unit metrics.MetricUnit, tags ...metrics.Tag) { + rec := &CapturedRecording{Value: v, Tags: make(map[string]string, len(c.tags)+len(tags)), Unit: unit} + for _, tag := range c.tags { + rec.Tags[tag.Key()] = tag.Value() + } + for _, tag := range tags { + rec.Tags[tag.Key()] = tag.Value() + } + c.capturesLock.RLock() + defer c.capturesLock.RUnlock() + for c := range c.captures { + c.record(name, rec) + } +} + +// Counter implements [metrics.Handler.Counter]. +func (c *CaptureHandler) Counter(name string) metrics.CounterIface { + return metrics.CounterFunc(func(v int64, tags ...metrics.Tag) { c.record(name, v, "", tags...) }) +} + +// Gauge implements [metrics.Handler.Gauge]. +func (c *CaptureHandler) Gauge(name string) metrics.GaugeIface { + return metrics.GaugeFunc(func(v float64, tags ...metrics.Tag) { c.record(name, v, "", tags...) }) +} + +// Timer implements [metrics.Handler.Timer]. +func (c *CaptureHandler) Timer(name string) metrics.TimerIface { + return metrics.TimerFunc(func(v time.Duration, tags ...metrics.Tag) { c.record(name, v, "", tags...) }) +} + +// Histogram implements [metrics.Handler.Histogram]. +func (c *CaptureHandler) Histogram(name string, unit metrics.MetricUnit) metrics.HistogramIface { + return metrics.HistogramFunc(func(v int64, tags ...metrics.Tag) { c.record(name, v, unit, tags...) }) +} + +// Stop implements [metrics.Handler.Stop]. +func (*CaptureHandler) Stop(log.Logger) {} diff --git a/common/rpc/encryption/fixedTLSConfigProvider.go b/common/rpc/encryption/fixedTLSConfigProvider.go new file mode 100644 index 00000000000..eef3b79298b --- /dev/null +++ b/common/rpc/encryption/fixedTLSConfigProvider.go @@ -0,0 +1,79 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package encryption + +import ( + "crypto/tls" + "time" +) + +// FixedTLSConfigProvider is a [TLSConfigProvider] that is for fixed sets of TLS +// configs. This is usually only used for testing. + +type FixedTLSConfigProvider struct { + InternodeServerConfig *tls.Config + InternodeClientConfig *tls.Config + FrontendServerConfig *tls.Config + FrontendClientConfig *tls.Config + RemoteClusterClientConfigs map[string]*tls.Config + CertExpirationChecker CertExpirationChecker +} + +var _ TLSConfigProvider = (*FixedTLSConfigProvider)(nil) + +// GetInternodeServerConfig implements [TLSConfigProvider.GetInternodeServerConfig]. +func (f *FixedTLSConfigProvider) GetInternodeServerConfig() (*tls.Config, error) { + return f.InternodeServerConfig, nil +} + +// GetInternodeClientConfig implements [TLSConfigProvider.GetInternodeClientConfig]. +func (f *FixedTLSConfigProvider) GetInternodeClientConfig() (*tls.Config, error) { + return f.InternodeClientConfig, nil +} + +// GetFrontendServerConfig implements [TLSConfigProvider.GetFrontendServerConfig]. +func (f *FixedTLSConfigProvider) GetFrontendServerConfig() (*tls.Config, error) { + return f.FrontendServerConfig, nil +} + +// GetFrontendClientConfig implements [TLSConfigProvider.GetFrontendClientConfig]. +func (f *FixedTLSConfigProvider) GetFrontendClientConfig() (*tls.Config, error) { + return f.FrontendClientConfig, nil +} + +// GetRemoteClusterClientConfig implements [TLSConfigProvider.GetRemoteClusterClientConfig]. +func (f *FixedTLSConfigProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) { + return f.RemoteClusterClientConfigs[hostname], nil +} + +// GetExpiringCerts implements [TLSConfigProvider.GetExpiringCerts]. +func (f *FixedTLSConfigProvider) GetExpiringCerts( + timeWindow time.Duration, +) (expiring CertExpirationMap, expired CertExpirationMap, err error) { + if f.CertExpirationChecker != nil { + return f.CertExpirationChecker.GetExpiringCerts(timeWindow) + } + return nil, nil, nil +} diff --git a/common/rpc/grpc.go b/common/rpc/grpc.go index 9d97bf54014..8db9d490cd1 100644 --- a/common/rpc/grpc.go +++ b/common/rpc/grpc.go @@ -52,6 +52,10 @@ const ( // MaxBackoffDelay is a maximum interval between reconnect attempts. MaxBackoffDelay = 10 * time.Second + // MaxHTTPAPIRequestBytes is the maximum number of bytes an HTTP API request + // can have. This is currently set to the max gRPC request size. + MaxHTTPAPIRequestBytes = 4 * 1024 * 1024 + // minConnectTimeout is the minimum amount of time we are willing to give a connection to complete. minConnectTimeout = 20 * time.Second diff --git a/config/development-cass-archival.yaml b/config/development-cass-archival.yaml index 5e3d585b144..e015c4f1e78 100644 --- a/config/development-cass-archival.yaml +++ b/config/development-cass-archival.yaml @@ -35,6 +35,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-cass-es.yaml b/config/development-cass-es.yaml index a1835ed4527..d51fbcf4808 100644 --- a/config/development-cass-es.yaml +++ b/config/development-cass-es.yaml @@ -42,6 +42,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-cass-s3.yaml b/config/development-cass-s3.yaml index e840704478d..dbc709008ad 100644 --- a/config/development-cass-s3.yaml +++ b/config/development-cass-s3.yaml @@ -33,6 +33,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-cass.yaml b/config/development-cass.yaml index 7498e8a9f73..c3e39ba6220 100644 --- a/config/development-cass.yaml +++ b/config/development-cass.yaml @@ -59,6 +59,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-cluster-a.yaml b/config/development-cluster-a.yaml index 05dfbe53aff..4e08ba0fc0f 100644 --- a/config/development-cluster-a.yaml +++ b/config/development-cluster-a.yaml @@ -42,6 +42,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-mysql-es.yaml b/config/development-mysql-es.yaml index 4d72f0c004a..82569f069e7 100644 --- a/config/development-mysql-es.yaml +++ b/config/development-mysql-es.yaml @@ -48,6 +48,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-mysql.yaml b/config/development-mysql.yaml index 81a25d20c12..8572b893d2d 100644 --- a/config/development-mysql.yaml +++ b/config/development-mysql.yaml @@ -50,6 +50,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-mysql8.yaml b/config/development-mysql8.yaml index 4056022b4d0..d61cdf0d146 100644 --- a/config/development-mysql8.yaml +++ b/config/development-mysql8.yaml @@ -50,6 +50,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-postgres-es.yaml b/config/development-postgres-es.yaml index e9665866fad..b92259ec2cf 100644 --- a/config/development-postgres-es.yaml +++ b/config/development-postgres-es.yaml @@ -50,6 +50,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-postgres.yaml b/config/development-postgres.yaml index 5f6ffcf8e66..1ad87606696 100644 --- a/config/development-postgres.yaml +++ b/config/development-postgres.yaml @@ -50,6 +50,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-postgres12.yaml b/config/development-postgres12.yaml index feef286603e..64093eb85de 100644 --- a/config/development-postgres12.yaml +++ b/config/development-postgres12.yaml @@ -50,6 +50,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-sqlite-file.yaml b/config/development-sqlite-file.yaml index 5b8de8a89ec..d50fc9dfd51 100644 --- a/config/development-sqlite-file.yaml +++ b/config/development-sqlite-file.yaml @@ -74,6 +74,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/config/development-sqlite.yaml b/config/development-sqlite.yaml index 2014521ed8a..e73d6c19231 100644 --- a/config/development-sqlite.yaml +++ b/config/development-sqlite.yaml @@ -70,6 +70,7 @@ services: grpcPort: 7233 membershipPort: 6933 bindOnLocalHost: true + httpPort: 7243 matching: rpc: diff --git a/docker/config_template.yaml b/docker/config_template.yaml index 0d90c0c9746..76ae88f182d 100644 --- a/docker/config_template.yaml +++ b/docker/config_template.yaml @@ -304,6 +304,7 @@ services: grpcPort: {{ $temporalGrpcPort }} membershipPort: {{ default .Env.FRONTEND_MEMBERSHIP_PORT "6933" }} bindOnIP: {{ default .Env.BIND_ON_IP "127.0.0.1" }} + httpPort: {{ default .Env.FRONTEND_HTTP_PORT "7243" }} {{- if .Env.USE_INTERNAL_FRONTEND }} internal-frontend: diff --git a/go.mod b/go.mod index 870e588833e..f46955c069a 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( go.opentelemetry.io/otel/metric v1.16.0 go.opentelemetry.io/otel/sdk v1.16.0 go.opentelemetry.io/otel/sdk/metric v0.39.0 - go.temporal.io/api v1.23.1-0.20230809151511-056e78321730 + go.temporal.io/api v1.23.1-0.20230818163044-7f76d854ed02 go.temporal.io/sdk v1.23.1 go.temporal.io/version v0.3.0 go.uber.org/atomic v1.11.0 @@ -65,7 +65,7 @@ require ( ) require ( - cloud.google.com/go v0.110.6 // indirect + cloud.google.com/go v0.110.7 // indirect cloud.google.com/go/compute v1.23.0 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v1.1.1 // indirect @@ -90,6 +90,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -131,9 +132,9 @@ require ( golang.org/x/tools v0.10.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230807174057-1744710a1577 // indirect + google.golang.org/genproto v0.0.0-20230815205213-6bfd019c3878 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230815205213-6bfd019c3878 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230815205213-6bfd019c3878 // indirect google.golang.org/protobuf v1.31.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect lukechampine.com/uint128 v1.3.0 // indirect diff --git a/go.sum b/go.sum index 97b4ad01fd0..401acd9df53 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,9 @@ cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= cloud.google.com/go v0.110.2/go.mod h1:k04UEeEtb6ZBRTv3dZz4CeJC3jKGxyhl0sAiVVquxiw= cloud.google.com/go v0.110.4/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= -cloud.google.com/go v0.110.6 h1:8uYAkj3YHTP/1iwReuHPxLSbdcyc+dSBbzFMrVwDR6Q= cloud.google.com/go v0.110.6/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= +cloud.google.com/go v0.110.7 h1:rJyC7nWRg2jWGZ4wSJ5nY65GTdYJkg0cd/uXb+ACI6o= +cloud.google.com/go v0.110.7/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -281,6 +282,7 @@ cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1 cloud.google.com/go/datastore v1.10.0/go.mod h1:PC5UzAmDEkAmkfaknstTYbNpgE49HAgW2J1gcgUfmdM= cloud.google.com/go/datastore v1.11.0/go.mod h1:TvGxBIHCS50u8jzG+AW/ppf87v1of8nwzFNgEZU1D3c= cloud.google.com/go/datastore v1.12.0/go.mod h1:KjdB88W897MRITkvWWJrg2OUtrR5XVj1EoLgSp6/N70= +cloud.google.com/go/datastore v1.12.1/go.mod h1:KjdB88W897MRITkvWWJrg2OUtrR5XVj1EoLgSp6/N70= cloud.google.com/go/datastore v1.13.0/go.mod h1:KjdB88W897MRITkvWWJrg2OUtrR5XVj1EoLgSp6/N70= cloud.google.com/go/datastream v1.2.0/go.mod h1:i/uTP8/fZwgATHS/XFu0TcNUhuA0twZxxQ3EyCUQMwo= cloud.google.com/go/datastream v1.3.0/go.mod h1:cqlOX8xlyYF/uxhiKn6Hbv6WjwPPuI9W2M9SAXwaLLQ= @@ -345,6 +347,7 @@ cloud.google.com/go/filestore v1.6.0/go.mod h1:di5unNuss/qfZTw2U9nhFqo8/ZDSc466d cloud.google.com/go/filestore v1.7.1/go.mod h1:y10jsorq40JJnjR/lQ8AfFbbcGlw3g+Dp8oN7i7FjV4= cloud.google.com/go/firestore v1.9.0/go.mod h1:HMkjKHNTtRyZNiMzu7YAsLr9K3X2udY2AMwDaMEQiiE= cloud.google.com/go/firestore v1.11.0/go.mod h1:b38dKhgzlmNNGTNZZwe7ZRFEuRab1Hay3/DBsIGKKy4= +cloud.google.com/go/firestore v1.12.0/go.mod h1:b38dKhgzlmNNGTNZZwe7ZRFEuRab1Hay3/DBsIGKKy4= cloud.google.com/go/functions v1.6.0/go.mod h1:3H1UA3qiIPRWD7PeZKLvHZ9SaQhR26XIJcC0A5GbvAk= cloud.google.com/go/functions v1.7.0/go.mod h1:+d+QBcWM+RsrgZfV9xo6KfA1GlzJfxcfZcRPEhDDfzg= cloud.google.com/go/functions v1.8.0/go.mod h1:RTZ4/HsQjIqIYP9a9YPbU+QFoQsAlYgrwOXJWHn1POY= @@ -446,6 +449,7 @@ cloud.google.com/go/managedidentities v1.6.1/go.mod h1:h/irGhTN2SkZ64F43tfGPMbHn cloud.google.com/go/maps v0.1.0/go.mod h1:BQM97WGyfw9FWEmQMpZ5T6cpovXXSd1cGmFma94eubI= cloud.google.com/go/maps v0.6.0/go.mod h1:o6DAMMfb+aINHz/p/jbcY+mYeXBoZoxTfdSQ8VAJaCw= cloud.google.com/go/maps v0.7.0/go.mod h1:3GnvVl3cqeSvgMcpRlQidXsPYuDGQ8naBis7MVzpXsY= +cloud.google.com/go/maps v1.3.0/go.mod h1:6mWTUv+WhnOwAgjVsSW2QPPECmW+s3PcRyOa9vgG/5s= cloud.google.com/go/maps v1.4.0/go.mod h1:6mWTUv+WhnOwAgjVsSW2QPPECmW+s3PcRyOa9vgG/5s= cloud.google.com/go/mediatranslation v0.5.0/go.mod h1:jGPUhGTybqsPQn91pNXw0xVHfuJ3leR1wj37oU3y1f4= cloud.google.com/go/mediatranslation v0.6.0/go.mod h1:hHdBCTYNigsBxshbznuIMFNe5QXEowAuNmmC7h8pu5w= @@ -1040,6 +1044,7 @@ github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3/go.mod h1:o//XUCC/F+yRGJoPO/VU0GSB0f8Nhgmxx0VIRUvaC0w= @@ -1301,8 +1306,8 @@ go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI go.opentelemetry.io/proto/otlp v0.20.0 h1:BLOA1cZBAGSbRiNuGCCKiFrCdYB7deeHDeD1SueyOfA= go.opentelemetry.io/proto/otlp v0.20.0/go.mod h1:3QgjzPALBIv9pcknj2EXGPXjYPFdUh/RQfF8Lz3+Vnw= go.temporal.io/api v1.21.0/go.mod h1:xlsUEakkN2vU2/WV7e5NqMG4N93nfuNfvbXdaXUpU8w= -go.temporal.io/api v1.23.1-0.20230809151511-056e78321730 h1:Lr4NMirHUq1JHREc3nj7ZD2/Y8U1rziBcBBzdZ4DRMw= -go.temporal.io/api v1.23.1-0.20230809151511-056e78321730/go.mod h1:mnxpCCu4HOu5/G1DJCNGQ689dAoHyDQSiWsSf34rzSI= +go.temporal.io/api v1.23.1-0.20230818163044-7f76d854ed02 h1:SAAlC1ExGDnjNC5CT5y99cCavUDmpIjduHb7luPgdak= +go.temporal.io/api v1.23.1-0.20230818163044-7f76d854ed02/go.mod h1:4ackgCMjQHMpJYr1UQ6Tr/nknIqFkJ6dZ/SZsGv+St0= go.temporal.io/sdk v1.23.1 h1:HzOaw5+f6QgDW/HH1jzwgupII7nVz+fzxFPjmFJqKiQ= go.temporal.io/sdk v1.23.1/go.mod h1:S7vWxU01lGcCny0sWx03bkkYw4VtVrpzeqBTn2A6y+E= go.temporal.io/version v0.3.0 h1:dMrei9l9NyHt8nG6EB8vAwDLLTwx2SvRyucCSumAiig= @@ -1954,15 +1959,20 @@ google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFl google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= google.golang.org/genproto v0.0.0-20230629202037-9506855d4529/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= google.golang.org/genproto v0.0.0-20230706204954-ccb25ca9f130/go.mod h1:O9kGHb51iE/nOGvQaDUuadVYqovW56s5emA88lQnj6Y= -google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 h1:L6iMMGrtzgHsWofoFcihmDEMYeDR9KN/ThbPWGrh++g= +google.golang.org/genproto v0.0.0-20230726155614-23370e0ffb3e/go.mod h1:0ggbjUrZYpy1q+ANUS30SEoGZ53cdfwtbuG7Ptgy108= google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5/go.mod h1:oH/ZOT02u4kWEp7oYBGYFFkCdKS/uYR9Z7+0/xuuFp8= +google.golang.org/genproto v0.0.0-20230815205213-6bfd019c3878 h1:Iveh6tGCJkHAjJgEqUQYGDGgbwmhjoAOz8kO/ajxefY= +google.golang.org/genproto v0.0.0-20230815205213-6bfd019c3878/go.mod h1:yZTlhN0tQnXo3h00fuXNCxJdLdIdnVFVBaRJ5LWBbw4= google.golang.org/genproto/googleapis/api v0.0.0-20230525234020-1aefcd67740a/go.mod h1:ts19tUU+Z0ZShN1y3aPyq2+O3d5FUNNgT6FtOzmrNn8= google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/api v0.0.0-20230526203410-71b5a4ffd15e/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/api v0.0.0-20230629202037-9506855d4529/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e h1:z3vDksarJxsAKM5dmEGv0GHwE2hKJ096wZra71Vs4sw= +google.golang.org/genproto/googleapis/api v0.0.0-20230706204954-ccb25ca9f130/go.mod h1:mPBs5jNgx2GuQGvFwUvVKqtn6HsUw9nP64BedgvqEsQ= google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= +google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5/go.mod h1:5DZzOUPCLYL3mNkQ0ms0F3EuUNZ7py1Bqeq6sxzI7/Q= +google.golang.org/genproto/googleapis/api v0.0.0-20230815205213-6bfd019c3878 h1:WGq4lvB/mlicysM/dUT3SBvijH4D3sm/Ny1A4wmt2CI= +google.golang.org/genproto/googleapis/api v0.0.0-20230815205213-6bfd019c3878/go.mod h1:KjSP20unUpOx5kyQUFa7k4OJg0qeJ7DEZflGDu2p6Bk= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:ylj+BE99M198VPbBh6A8d9n3w8fChvyLK3wwBOjXBFA= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234015-3fc162c6f38a/go.mod h1:xURIpW9ES5+/GZhnV6beoEtxQrnkRGIfP5VQG2tCBLc= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= @@ -1971,8 +1981,9 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go. google.golang.org/genproto/googleapis/rpc v0.0.0-20230629202037-9506855d4529/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130/go.mod h1:8mL13HKkDa+IuJ8yruA3ci0q+0vsUz4m//+ottjwS5o= google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230807174057-1744710a1577 h1:wukfNtZmZUurLN/atp2hiIeTKn7QJWIQdHzqmsOnAOk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230807174057-1744710a1577/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5/go.mod h1:zBEcrKX2ZOcEkHWxBPAIvYUWOKKMIhYcmNiUIu2ji3I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230815205213-6bfd019c3878 h1:lv6/DhyiFFGsmzxbsUUTOkN29II+zeWHxvT8Lpdxsv0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230815205213-6bfd019c3878/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 6d8b59c1ee2..d85fafb7a83 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -58,6 +58,7 @@ import ( "go.temporal.io/server/common/resolver" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/rpc" + "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/rpc/interceptor" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" @@ -87,12 +88,13 @@ var Module = fx.Options( fx.Provide(ThrottledLoggerRpsFnProvider), fx.Provide(PersistenceRateLimitingParamsProvider), fx.Provide(FEReplicatorNamespaceReplicationQueueProvider), - fx.Provide(func(so []grpc.ServerOption) *grpc.Server { return grpc.NewServer(so...) }), + fx.Provide(func(so GrpcServerOptions) *grpc.Server { return grpc.NewServer(so.Options...) }), fx.Provide(HandlerProvider), fx.Provide(AdminHandlerProvider), fx.Provide(OperatorHandlerProvider), fx.Provide(NewVersionChecker), fx.Provide(ServiceResolverProvider), + fx.Provide(HTTPAPIServerProvider), fx.Provide(NewServiceProvider), fx.Invoke(ServiceLifetimeHooks), ) @@ -101,6 +103,7 @@ func NewServiceProvider( serviceConfig *Config, server *grpc.Server, healthServer *health.Server, + httpAPIServer *HTTPAPIServer, handler Handler, adminHandler *AdminHandler, operatorHandler *OperatorHandlerImpl, @@ -116,6 +119,7 @@ func NewServiceProvider( serviceConfig, server, healthServer, + httpAPIServer, handler, adminHandler, operatorHandler, @@ -129,6 +133,13 @@ func NewServiceProvider( ) } +// GrpcServerOptions are the options to build the frontend gRPC server along +// with the interceptors that are already set in the options. +type GrpcServerOptions struct { + Options []grpc.ServerOption + UnaryInterceptors []grpc.UnaryServerInterceptor +} + func GrpcServerOptionsProvider( logger log.Logger, serviceConfig *Config, @@ -150,7 +161,7 @@ func GrpcServerOptionsProvider( audienceGetter authorization.JWTAudienceMapper, customInterceptors []grpc.UnaryServerInterceptor, metricsHandler metrics.Handler, -) []grpc.ServerOption { +) GrpcServerOptions { kep := keepalive.EnforcementPolicy{ MinTime: serviceConfig.KeepAliveMinTime(), PermitWithoutStream: serviceConfig.KeepAlivePermitWithoutStream(), @@ -209,13 +220,14 @@ func GrpcServerOptionsProvider( telemetryInterceptor.StreamIntercept, } - return append( + grpcServerOptions = append( grpcServerOptions, grpc.KeepaliveParams(kp), grpc.KeepaliveEnforcementPolicy(kep), grpc.ChainUnaryInterceptor(unaryInterceptors...), grpc.ChainStreamInterceptor(streamInterceptor...), ) + return GrpcServerOptions{Options: grpcServerOptions, UnaryInterceptors: unaryInterceptors} } func ConfigProvider( @@ -603,6 +615,42 @@ func HandlerProvider( return wfHandler } +// HTTPAPIServerProvider provides an HTTP API server if enabled or nil +// otherwise. +func HTTPAPIServerProvider( + cfg *config.Config, + serviceName primitives.ServiceName, + serviceConfig *Config, + grpcListener net.Listener, + tlsConfigProvider encryption.TLSConfigProvider, + handler Handler, + grpcServerOptions GrpcServerOptions, + metricsHandler metrics.Handler, + namespaceRegistry namespace.Registry, + logger log.Logger, +) (*HTTPAPIServer, error) { + // If the service is not the frontend service, HTTP API is disabled + if serviceName != primitives.FrontendService { + return nil, nil + } + // If HTTP API port is 0, it is disabled + rpcConfig := cfg.Services[string(serviceName)].RPC + if rpcConfig.HTTPPort == 0 { + return nil, nil + } + return NewHTTPAPIServer( + serviceConfig, + rpcConfig, + grpcListener, + tlsConfigProvider, + handler, + grpcServerOptions.UnaryInterceptors, + metricsHandler, + namespaceRegistry, + logger, + ) +} + func ServiceLifetimeHooks(lc fx.Lifecycle, svc *Service) { lc.Append(fx.StartStopHook(svc.Start, svc.Stop)) } diff --git a/service/frontend/http_api_server.go b/service/frontend/http_api_server.go new file mode 100644 index 00000000000..35407cc9e57 --- /dev/null +++ b/service/frontend/http_api_server.go @@ -0,0 +1,494 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package frontend + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "reflect" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/gogo/status" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "go.temporal.io/api/proxy" + "go.temporal.io/api/serviceerror" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/common/config" + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/rpc" + "go.temporal.io/server/common/rpc/encryption" + "go.temporal.io/server/common/rpc/interceptor" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// HTTPAPIServer is an HTTP API server that forwards requests to gRPC via the +// gRPC interceptors. +type HTTPAPIServer struct { + server http.Server + listener net.Listener + logger log.Logger + serveMux *runtime.ServeMux + stopped chan struct{} + matchAdditionalHeaders map[string]bool +} + +var defaultForwardedHeaders = []string{ + "Authorization-Extras", + "X-Forwarded-For", + http.CanonicalHeaderKey(headers.ClientNameHeaderName), + http.CanonicalHeaderKey(headers.ClientVersionHeaderName), +} + +type httpRemoteAddrContextKey struct{} + +var ( + errHTTPGRPCListenerNotTCP = errors.New("must use TCP for gRPC listener to support HTTP API") + errHTTPGRPCStreamNotSupported = errors.New("stream not supported") +) + +// NewHTTPAPIServer creates an [HTTPAPIServer]. +func NewHTTPAPIServer( + serviceConfig *Config, + rpcConfig config.RPC, + grpcListener net.Listener, + tlsConfigProvider encryption.TLSConfigProvider, + handler Handler, + interceptors []grpc.UnaryServerInterceptor, + metricsHandler metrics.Handler, + namespaceRegistry namespace.Registry, + logger log.Logger, +) (*HTTPAPIServer, error) { + // Create a TCP listener the same as the frontend one but with different port + tcpAddrRef, _ := grpcListener.Addr().(*net.TCPAddr) + if tcpAddrRef == nil { + return nil, errHTTPGRPCListenerNotTCP + } + tcpAddr := *tcpAddrRef + tcpAddr.Port = rpcConfig.HTTPPort + var listener net.Listener + var err error + if listener, err = net.ListenTCP("tcp", &tcpAddr); err != nil { + return nil, fmt.Errorf("failed listening for HTTP API on %v: %w", &tcpAddr, err) + } + // Close the listener if anything else in this function fails + success := false + defer func() { + if !success { + _ = listener.Close() + } + }() + + // Wrap the listener in a TLS listener if there is any TLS config + if tlsConfigProvider != nil { + if tlsConfig, err := tlsConfigProvider.GetFrontendServerConfig(); err != nil { + return nil, fmt.Errorf("failed getting TLS config for HTTP API: %w", err) + } else if tlsConfig != nil { + listener = tls.NewListener(listener, tlsConfig) + } + } + + h := &HTTPAPIServer{ + listener: listener, + logger: logger, + stopped: make(chan struct{}), + } + + // Build 4 possible marshalers in order based on content type + opts := []runtime.ServeMuxOption{ + runtime.WithMarshalerOption("application/json+pretty+no-payload-shorthand", h.newMarshaler(" ", true)), + runtime.WithMarshalerOption("application/json+no-payload-shorthand", h.newMarshaler("", true)), + runtime.WithMarshalerOption("application/json+pretty", h.newMarshaler(" ", false)), + runtime.WithMarshalerOption(runtime.MIMEWildcard, h.newMarshaler("", false)), + } + + // Set Temporal service error handler + opts = append(opts, runtime.WithProtoErrorHandler(h.errorHandler)) + + // Match headers w/ default + h.matchAdditionalHeaders = map[string]bool{} + for _, v := range defaultForwardedHeaders { + h.matchAdditionalHeaders[v] = true + } + for _, v := range rpcConfig.HTTPAdditionalForwardedHeaders { + h.matchAdditionalHeaders[http.CanonicalHeaderKey(v)] = true + } + opts = append(opts, runtime.WithIncomingHeaderMatcher(h.incomingHeaderMatcher)) + + // Create inline client connection + clientConn := newInlineClientConn( + map[string]any{"temporal.api.workflowservice.v1.WorkflowService": handler}, + interceptors, + metricsHandler, + namespaceRegistry, + ) + + // Create serve mux + h.serveMux = runtime.NewServeMux(opts...) + err = workflowservice.RegisterWorkflowServiceHandlerClient( + context.Background(), + h.serveMux, + workflowservice.NewWorkflowServiceClient(clientConn), + ) + if err != nil { + return nil, fmt.Errorf("failed registering HTTP API handler: %w", err) + } + // Set the handler as our function that wraps serve mux + h.server.Handler = http.HandlerFunc(h.serveHTTP) + + // Put the remote address on the context + h.server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, httpRemoteAddrContextKey{}, c) + } + + // We want to set ReadTimeout and WriteTimeout as max idle (and IdleTimeout + // defaults to ReadTimeout) to ensure that a connection cannot hang over that + // amount of time. + h.server.ReadTimeout = serviceConfig.KeepAliveMaxConnectionIdle() + h.server.WriteTimeout = serviceConfig.KeepAliveMaxConnectionIdle() + + success = true + return h, nil +} + +// Serve serves the HTTP API and does not return until there is a serve error or +// GracefulStop completes. Upon graceful stop, this will return nil. If an error +// is returned, the message is clear that it came from the HTTP API server. +func (h *HTTPAPIServer) Serve() error { + err := h.server.Serve(h.listener) + // If the error is for close, we have to wait for the shutdown to complete and + // we don't consider it an error + if errors.Is(err, http.ErrServerClosed) { + <-h.stopped + err = nil + } + // Wrap the error to be clearer it's from the HTTP API + if err != nil { + return fmt.Errorf("HTTP API serve failed: %w", err) + } + return nil +} + +// GracefulStop stops the HTTP server. This will first attempt a graceful stop +// with a drain time, then will hard-stop. This will not return until stopped. +func (h *HTTPAPIServer) GracefulStop(gracefulDrainTime time.Duration) { + // We try a graceful stop for the amount of time we can drain, then we do a + // hard stop + shutdownCtx, cancel := context.WithTimeout(context.Background(), gracefulDrainTime) + defer cancel() + // We intentionally ignore this error, we're gonna stop at this point no + // matter what. This closes the listener too. + _ = h.server.Shutdown(shutdownCtx) + _ = h.server.Close() + close(h.stopped) +} + +func (h *HTTPAPIServer) serveHTTP(w http.ResponseWriter, r *http.Request) { + // Limit the request body to max gRPC size. This is hardcoded to 4MB at the + // moment using gRPC's default at + // https://github.com/grpc/grpc-go/blob/0673105ebcb956e8bf50b96e28209ab7845a65ad/server.go#L58 + // which is what the constant is set as at the time of this comment. + r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxHTTPAPIRequestBytes) + + h.logger.Debug( + "HTTP API call", + tag.NewStringTag("http-method", r.Method), + tag.NewAnyTag("http-url", r.URL), + ) + + // Need to change the accept header based on whether pretty and/or + // noPayloadShorthand are present + var acceptHeaderSuffix string + if _, ok := r.URL.Query()["pretty"]; ok { + acceptHeaderSuffix += "+pretty" + } + if _, ok := r.URL.Query()["noPayloadShorthand"]; ok { + acceptHeaderSuffix += "+no-payload-shorthand" + } + if acceptHeaderSuffix != "" { + r.Header.Set("Accept", "application/json"+acceptHeaderSuffix) + } + + // Put the TLS info on the peer context + if r.TLS != nil { + var addr net.Addr + if conn, _ := r.Context().Value(httpRemoteAddrContextKey{}).(net.Conn); conn != nil { + addr = conn.RemoteAddr() + } + r = r.WithContext(peer.NewContext(r.Context(), &peer.Peer{ + Addr: addr, + AuthInfo: credentials.TLSInfo{ + State: *r.TLS, + CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}, + }, + })) + } + + // Call gRPC gateway mux + h.serveMux.ServeHTTP(w, r) +} + +func (h *HTTPAPIServer) errorHandler( + ctx context.Context, + mux *runtime.ServeMux, + marshaler runtime.Marshaler, + w http.ResponseWriter, + r *http.Request, + err error, +) { + // Convert the error using serviceerror. The result does not conform to Google + // gRPC status directly (it conforms to gogo gRPC status), but Err() does + // based on internal code reading. However, Err() uses Google proto Any + // which our marshaler is not expecting. So instead we are embedding similar + // logic to runtime.DefaultHTTPProtoErrorHandler in here but with gogo + // support. We don't implement custom content type marshaler or trailers at + // this time. + + s := serviceerror.ToStatus(err) + w.Header().Set("Content-Type", marshaler.ContentType()) + + buf, merr := marshaler.Marshal(s.Proto()) + if merr != nil { + h.logger.Warn("Failed to marshal error message", tag.Error(merr)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"code": 13, "message": "failed to marshal error message"}`)) + return + } + + w.WriteHeader(runtime.HTTPStatusFromCode(s.Code())) + _, _ = w.Write(buf) +} + +func (h *HTTPAPIServer) newMarshaler(indent string, disablePayloadShorthand bool) runtime.Marshaler { + marshalOpts := proxy.JSONPBMarshalerOptions{ + Indent: indent, + DisablePayloadShorthand: disablePayloadShorthand, + } + unmarshalOpts := proxy.JSONPBUnmarshalerOptions{DisablePayloadShorthand: disablePayloadShorthand} + if m, err := proxy.NewJSONPBMarshaler(marshalOpts); err != nil { + panic(err) + } else if u, err := proxy.NewJSONPBUnmarshaler(unmarshalOpts); err != nil { + panic(err) + } else { + return proxy.NewGRPCGatewayJSONPBMarshaler(m, u) + } +} + +func (h *HTTPAPIServer) incomingHeaderMatcher(headerName string) (string, bool) { + // Try ours before falling back to default + if h.matchAdditionalHeaders[headerName] { + return headerName, true + } + return runtime.DefaultHeaderMatcher(headerName) +} + +// inlineClientConn is a [grpc.ClientConnInterface] implementation that forwards +// requests directly to gRPC via interceptors. This implementation moves all +// outgoing metadata to incoming and takes resulting outgoing metadata and sets +// as header. But which headers to use and TLS peer context and such are +// expected to be handled by the caller. +type inlineClientConn struct { + methods map[string]*serviceMethod + interceptor grpc.UnaryServerInterceptor + requestsCounter metrics.CounterIface + namespaceRegistry namespace.Registry +} + +var _ grpc.ClientConnInterface = (*inlineClientConn)(nil) + +type serviceMethod struct { + info grpc.UnaryServerInfo + handler grpc.UnaryHandler +} + +var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() +var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem() +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +func newInlineClientConn( + servers map[string]any, + interceptors []grpc.UnaryServerInterceptor, + metricsHandler metrics.Handler, + namespaceRegistry namespace.Registry, +) *inlineClientConn { + // Create the set of methods via reflection. We currently accept the overhead + // of reflection compared to having to custom generate gateway code. + methods := map[string]*serviceMethod{} + for qualifiedServerName, server := range servers { + serverVal := reflect.ValueOf(server) + for i := 0; i < serverVal.Type().NumMethod(); i++ { + reflectMethod := serverVal.Type().Method(i) + // We intentionally look this up by name to not assume method indexes line + // up from type to value + methodVal := serverVal.MethodByName(reflectMethod.Name) + // We assume the methods we want only accept a context + request and only + // return a response + error. We also assume the method name matches the + // RPC name. + methodType := methodVal.Type() + validRPCMethod := methodType.Kind() == reflect.Func && + methodType.NumIn() == 2 && + methodType.NumOut() == 2 && + methodType.In(0) == contextType && + methodType.In(1).Implements(protoMessageType) && + methodType.Out(0).Implements(protoMessageType) && + methodType.Out(1) == errorType + if !validRPCMethod { + continue + } + fullMethod := "/" + qualifiedServerName + "/" + reflectMethod.Name + methods[fullMethod] = &serviceMethod{ + info: grpc.UnaryServerInfo{Server: server, FullMethod: fullMethod}, + handler: func(ctx context.Context, req interface{}) (interface{}, error) { + ret := methodVal.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(req)}) + err, _ := ret[1].Interface().(error) + return ret[0].Interface(), err + }, + } + } + } + + return &inlineClientConn{ + methods: methods, + interceptor: chainUnaryServerInterceptors(interceptors), + requestsCounter: metricsHandler.Counter(metrics.HTTPServiceRequests.GetMetricName()), + namespaceRegistry: namespaceRegistry, + } +} + +func (i *inlineClientConn) Invoke( + ctx context.Context, + method string, + args any, + reply any, + opts ...grpc.CallOption, +) error { + // Move outgoing metadata to incoming and set new outgoing metadata + md, _ := metadata.FromOutgoingContext(ctx) + // Set the client and version headers if not already set + if len(md[headers.ClientNameHeaderName]) == 0 { + md.Set(headers.ClientNameHeaderName, headers.ClientNameServerHTTP) + } + if len(md[headers.ClientVersionHeaderName]) == 0 { + md.Set(headers.ClientVersionHeaderName, headers.ServerVersion) + } + ctx = metadata.NewIncomingContext(ctx, md) + outgoingMD := metadata.MD{} + ctx = metadata.NewOutgoingContext(ctx, outgoingMD) + + // Get the method. Should never fail, but we check anyways + serviceMethod := i.methods[method] + if serviceMethod == nil { + return status.Error(codes.NotFound, "call not found") + } + + // Add metric + var namespaceTag metrics.Tag + if namespaceName := interceptor.MustGetNamespaceName(i.namespaceRegistry, args); namespaceName != "" { + namespaceTag = metrics.NamespaceTag(namespaceName.String()) + } else { + namespaceTag = metrics.NamespaceUnknownTag() + } + i.requestsCounter.Record(1, metrics.OperationTag(method), namespaceTag) + + // Invoke + var resp any + var err error + if i.interceptor == nil { + resp, err = serviceMethod.handler(ctx, args) + } else { + resp, err = i.interceptor(ctx, args, &serviceMethod.info, serviceMethod.handler) + } + + // Find the header call option and set response headers. We accept that if + // somewhere internally the metadata was replaced instead of appended to, this + // does not work. + for _, opt := range opts { + if callOpt, ok := opt.(grpc.HeaderCallOption); ok { + *callOpt.HeaderAddr = outgoingMD + } + } + + // Merge the response proto onto the wanted reply if non-nil + if respProto, _ := resp.(proto.Message); respProto != nil { + proto.Merge(reply.(proto.Message), respProto) + } + + return err +} + +func (*inlineClientConn) NewStream( + context.Context, + *grpc.StreamDesc, + string, + ...grpc.CallOption, +) (grpc.ClientStream, error) { + return nil, errHTTPGRPCStreamNotSupported +} + +// Mostly taken from https://github.com/grpc/grpc-go/blob/v1.56.1/server.go#L1124-L1158 +// with slight modifications. +func chainUnaryServerInterceptors(interceptors []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + switch len(interceptors) { + case 0: + return nil + case 1: + return interceptors[0] + default: + return chainUnaryInterceptors(interceptors) + } +} + +func chainUnaryInterceptors(interceptors []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + } +} + +func getChainUnaryHandler( + interceptors []grpc.UnaryServerInterceptor, + curr int, + info *grpc.UnaryServerInfo, + finalHandler grpc.UnaryHandler, +) grpc.UnaryHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, req interface{}) (interface{}, error) { + return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) + } +} diff --git a/service/frontend/service.go b/service/frontend/service.go index 07f5e16c3fa..ccb4b668f99 100644 --- a/service/frontend/service.go +++ b/service/frontend/service.go @@ -28,15 +28,11 @@ import ( "math/rand" "net" "os" + "sync" "time" "go.temporal.io/api/operatorservice/v1" "go.temporal.io/api/workflowservice/v1" - "google.golang.org/grpc" - "google.golang.org/grpc/health" - healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" - "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" @@ -49,6 +45,10 @@ import ( "go.temporal.io/server/common/persistence/visibility" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/util" + "google.golang.org/grpc" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/reflection" ) // Config represents configuration for frontend service @@ -288,6 +288,7 @@ type Service struct { versionChecker *VersionChecker visibilityManager manager.VisibilityManager server *grpc.Server + httpAPIServer *HTTPAPIServer logger log.Logger grpcListener net.Listener @@ -300,6 +301,7 @@ func NewService( serviceConfig *Config, server *grpc.Server, healthServer *health.Server, + httpAPIServer *HTTPAPIServer, handler Handler, adminHandler *AdminHandler, operatorHandler *OperatorHandlerImpl, @@ -315,6 +317,7 @@ func NewService( config: serviceConfig, server: server, healthServer: healthServer, + httpAPIServer: httpAPIServer, handler: handler, adminHandler: adminHandler, operatorHandler: operatorHandler, @@ -355,6 +358,14 @@ func (s *Service) Start() { } }() + if s.httpAPIServer != nil { + go func() { + if err := s.httpAPIServer.Serve(); err != nil { + s.logger.Fatal("Failed to serve HTTP API server", tag.Error(err)) + } + }() + } + go s.membershipMonitor.Start() } @@ -383,12 +394,26 @@ func (s *Service) Stop() { s.visibilityManager.Close() s.logger.Info("ShutdownHandler: Draining traffic") - t := time.AfterFunc(requestDrainTime, func() { - s.logger.Info("ShutdownHandler: Drain time expired, stopping all traffic") - s.server.Stop() - }) - s.server.GracefulStop() - t.Stop() + // Gracefully stop gRPC server and HTTP API server concurrently + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + t := time.AfterFunc(requestDrainTime, func() { + s.logger.Info("ShutdownHandler: Drain time expired, stopping all traffic") + s.server.Stop() + }) + s.server.GracefulStop() + t.Stop() + }() + if s.httpAPIServer != nil { + wg.Add(1) + go func() { + defer wg.Done() + s.httpAPIServer.GracefulStop(requestDrainTime) + }() + } + wg.Wait() if s.metricsHandler != nil { s.metricsHandler.Stop(s.logger) diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 7462031e38e..19dbf0cc3df 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -369,7 +369,9 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request * } if request.GetRequestId() == "" { - return nil, errRequestIDNotSet + // For easy direct API use, we default the request ID here but expect all + // SDKs and other auto-retrying clients to set it + request.RequestId = uuid.New() } if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() { diff --git a/service/frontend/workflow_handler_test.go b/service/frontend/workflow_handler_test.go index e1d6007094b..a1d3711ed32 100644 --- a/service/frontend/workflow_handler_test.go +++ b/service/frontend/workflow_handler_test.go @@ -361,33 +361,6 @@ func (s *workflowHandlerSuite) TestPollForTask_Failed_ContextTimeoutTooShort() { s.Equal(common.ErrContextTimeoutTooShort, err) } -func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_RequestIdNotSet() { - config := s.newConfig() - config.RPS = dc.GetIntPropertyFn(10) - wh := s.getWorkflowHandler(config) - - startWorkflowExecutionRequest := &workflowservice.StartWorkflowExecutionRequest{ - Namespace: "test-namespace", - WorkflowId: "workflow-id", - WorkflowType: &commonpb.WorkflowType{ - Name: "workflow-type", - }, - TaskQueue: &taskqueuepb.TaskQueue{ - Name: "task-queue", - }, - WorkflowTaskTimeout: timestamp.DurationPtr(1 * time.Second), - RetryPolicy: &commonpb.RetryPolicy{ - InitialInterval: timestamp.DurationPtr(1 * time.Second), - BackoffCoefficient: 2, - MaximumInterval: timestamp.DurationPtr(2 * time.Second), - MaximumAttempts: 1, - }, - } - _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - s.Error(err) - s.Equal(errRequestIDNotSet, err) -} - func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_StartRequestNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) diff --git a/tests/flag.go b/tests/flag.go index eca3ddb45bf..b34e5ef7f0d 100644 --- a/tests/flag.go +++ b/tests/flag.go @@ -29,6 +29,7 @@ import "flag" // TestFlags contains the feature flags for integration tests var TestFlags struct { FrontendAddr string + FrontendHTTPAddr string PersistenceType string PersistenceDriver string TestClusterConfigFile string @@ -37,6 +38,7 @@ var TestFlags struct { func init() { flag.StringVar(&TestFlags.FrontendAddr, "frontendAddress", "", "host:port for temporal frontend service") + flag.StringVar(&TestFlags.FrontendHTTPAddr, "frontendHttpAddress", "", "host:port for temporal frontend HTTP service (only applies when frontendAddress set)") flag.StringVar(&TestFlags.PersistenceType, "persistenceType", "sql", "type of persistence - [nosql or sql]") flag.StringVar(&TestFlags.PersistenceDriver, "persistenceDriver", "sqlite", "driver of nosql / sql- [cassandra, mysql, postgresql, sqlite]") flag.StringVar(&TestFlags.TestClusterConfigFile, "TestClusterConfigFile", "", "test cluster config file location") diff --git a/tests/http_api_test.go b/tests/http_api_test.go new file mode 100644 index 00000000000..c7823fa41ac --- /dev/null +++ b/tests/http_api_test.go @@ -0,0 +1,261 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tests + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "sync" + + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/metrics" + "google.golang.org/grpc/metadata" +) + +type SomeJSONStruct struct { + SomeField string `json:"someField"` +} + +func (s *clientIntegrationSuite) TestHTTPAPIBasics() { + if s.httpAPIAddress == "" { + s.T().Skip("HTTP API server not enabled") + } + // Create basic workflow that can answer queries, get signals, etc + workflowFn := func(ctx workflow.Context, arg *SomeJSONStruct) (*SomeJSONStruct, error) { + // Query that just returns query arg + err := workflow.SetQueryHandler(ctx, "some-query", func(queryArg *SomeJSONStruct) (*SomeJSONStruct, error) { + return queryArg, nil + }) + if err != nil { + return nil, err + } + // Wait for signal to complete + var done bool + sel := workflow.NewSelector(ctx) + sel.AddReceive(workflow.GetSignalChannel(ctx, "some-signal"), func(ch workflow.ReceiveChannel, _ bool) { + var signalArg SomeJSONStruct + ch.Receive(ctx, &signalArg) + if signalArg.SomeField != "signal-arg" { + panic("invalid signal arg") + } + done = true + }) + for !done { + sel.Select(ctx) + } + return arg, nil + } + s.worker.RegisterWorkflowWithOptions(workflowFn, workflow.RegisterOptions{Name: "http-basic-workflow"}) + + // Capture metrics + capture := s.testCluster.host.captureMetricsHandler.StartCapture() + defer s.testCluster.host.captureMetricsHandler.StopCapture(capture) + + // Start + workflowID := s.randomizeStr("wf") + _, respBody := s.httpPost(http.StatusOK, "/api/v1/namespaces/"+s.namespace+"/workflows/"+workflowID, `{ + "workflowType": { "name": "http-basic-workflow" }, + "taskQueue": { "name": "`+s.taskQueue+`" }, + "input": [{ "someField": "workflow-arg" }] + }`) + var startResp struct { + RunID string `json:"runId"` + } + s.Require().NoError(json.Unmarshal(respBody, &startResp)) + + // Check that there is a an HTTP call metric with the proper tags/value. We + // can't test overall counts because the metrics handler is shared across + // concurrently executing tests. + var found bool + for _, metric := range capture.Snapshot()[metrics.HTTPServiceRequests.GetMetricName()] { + found = + metric.Tags[metrics.OperationTagName] == "/temporal.api.workflowservice.v1.WorkflowService/StartWorkflowExecution" && + metric.Tags["namespace"] == s.namespace && + metric.Value == int64(1) + if found { + break + } + } + s.Require().True(found) + + // Confirm already exists error with details and proper code + _, respBody = s.httpPost(http.StatusConflict, "/api/v1/namespaces/"+s.namespace+"/workflows/"+workflowID, `{ + "workflowType": { "name": "http-basic-workflow" }, + "taskQueue": { "name": "`+s.taskQueue+`" }, + "input": [{ "someField": "workflow-arg" }], + "requestId": "`+s.randomizeStr("req")+`" + }`) + var errResp struct { + Message string `json:"message"` + Details []struct { + RunID string `json:"runId"` + } `json:"details"` + } + s.Require().NoError(json.Unmarshal(respBody, &errResp)) + s.Require().Contains(errResp.Message, "already running") + s.Require().Equal(startResp.RunID, errResp.Details[0].RunID) + + // Query + _, respBody = s.httpPost( + http.StatusOK, + "/api/v1/namespaces/"+s.namespace+"/workflows/"+workflowID+"/query/some-query", + `{ "query": { "queryArgs": [{ "someField": "query-arg" }] } }`, + ) + var queryResp struct { + QueryResult json.RawMessage `json:"queryResult"` + } + s.Require().NoError(json.Unmarshal(respBody, &queryResp)) + s.Require().JSONEq(`[{ "someField": "query-arg" }]`, string(queryResp.QueryResult)) + + // Signal which also completes the workflow + s.httpPost( + http.StatusOK, + "/api/v1/namespaces/"+s.namespace+"/workflows/"+workflowID+"/signal/some-signal", + `{ "input": [{ "someField": "signal-arg" }] }`, + ) + + // Confirm workflow complete + _, respBody = s.httpGet( + http.StatusOK, + // Our version of gRPC gateway only supports integer enums in queries :-( + "/api/v1/namespaces/"+s.namespace+"/workflows/"+workflowID+"/history?historyEventFilterType=2", + ) + var histResp struct { + History struct { + Events []struct { + EventType string `json:"eventType"` + WorkflowExecutionCompletedEventAttributes struct { + Result json.RawMessage `json:"result"` + } `json:"workflowExecutionCompletedEventAttributes"` + } `json:"events"` + } `json:"history"` + } + s.Require().NoError(json.Unmarshal(respBody, &histResp)) + s.Require().Equal("WorkflowExecutionCompleted", histResp.History.Events[0].EventType) + s.Require().JSONEq( + `[{ "someField": "workflow-arg" }]`, + string(histResp.History.Events[0].WorkflowExecutionCompletedEventAttributes.Result), + ) + +} + +func (s *clientIntegrationSuite) TestHTTPAPIHeaders() { + if s.httpAPIAddress == "" { + s.T().Skip("HTTP API server not enabled") + } + // Make a claim mapper and authorizer that capture info + var lastInfo *authorization.AuthInfo + var listWorkflowMetadata metadata.MD + var callbackLock sync.RWMutex + s.testCluster.host.SetOnGetClaims(func(info *authorization.AuthInfo) (*authorization.Claims, error) { + callbackLock.Lock() + defer callbackLock.Unlock() + if info != nil { + lastInfo = info + } + return &authorization.Claims{System: authorization.RoleAdmin}, nil + }) + s.testCluster.host.SetOnAuthorize(func( + ctx context.Context, + caller *authorization.Claims, + target *authorization.CallTarget, + ) (authorization.Result, error) { + callbackLock.Lock() + defer callbackLock.Unlock() + if target.APIName == "/temporal.api.workflowservice.v1.WorkflowService/ListWorkflowExecutions" { + listWorkflowMetadata, _ = metadata.FromIncomingContext(ctx) + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil + }) + + // Make a simple list call that we don't care about the result + req, err := http.NewRequest("GET", "/api/v1/namespaces/"+s.namespace+"/workflows", nil) + s.Require().NoError(err) + req.Header.Set("Authorization", "my-auth-token") + req.Header.Set("X-Forwarded-For", "1.2.3.4:5678") + // The header is set to forward deep in the onebox config + req.Header.Set("This-Header-Forwarded", "some-value") + req.Header.Set("This-Header-Not-Forwarded", "some-value") + s.httpRequest(http.StatusOK, req) + + // Confirm the claims got my auth token + callbackLock.RLock() + defer callbackLock.RUnlock() + s.Require().Equal("my-auth-token", lastInfo.AuthToken) + + // Check headers + s.Require().Equal("my-auth-token", listWorkflowMetadata["authorization"][0]) + s.Require().Contains(listWorkflowMetadata["x-forwarded-for"][0], "1.2.3.4:5678") + s.Require().Equal("some-value", listWorkflowMetadata["this-header-forwarded"][0]) + s.Require().NotContains(listWorkflowMetadata, "this-header-not-forwarded") + s.Require().Equal(headers.ClientNameServerHTTP, listWorkflowMetadata[headers.ClientNameHeaderName][0]) + s.Require().Equal(headers.ServerVersion, listWorkflowMetadata[headers.ClientVersionHeaderName][0]) +} + +func (s *clientIntegrationSuite) TestHTTPAPIPretty() { + if s.httpAPIAddress == "" { + s.T().Skip("HTTP API server not enabled") + } + // Make a call to system info normal, confirm no newline, then ask for pretty + // and confirm newlines + _, b := s.httpGet(http.StatusOK, "/api/v1/system-info") + s.Require().NotContains(b, byte('\n')) + _, b = s.httpGet(http.StatusOK, "/api/v1/system-info?pretty") + s.Require().Contains(b, byte('\n')) +} + +func (s *clientIntegrationSuite) httpGet(expectedStatus int, url string) (*http.Response, []byte) { + req, err := http.NewRequest("GET", url, nil) + s.Require().NoError(err) + return s.httpRequest(expectedStatus, req) +} + +func (s *clientIntegrationSuite) httpPost(expectedStatus int, url string, jsonBody string) (*http.Response, []byte) { + req, err := http.NewRequest("POST", url, strings.NewReader(jsonBody)) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/json") + return s.httpRequest(expectedStatus, req) +} + +func (s *clientIntegrationSuite) httpRequest(expectedStatus int, req *http.Request) (*http.Response, []byte) { + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + } + if req.URL.Host == "" { + req.URL.Host = s.httpAPIAddress + } + resp, err := http.DefaultClient.Do(req) + s.Require().NoError(err) + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + s.Require().NoError(err) + s.Require().Equal(expectedStatus, resp.StatusCode, "Bad status, body: %s", body) + return resp, body +} diff --git a/tests/integrationbase.go b/tests/integrationbase.go index dfd41f5de32..73a6fd71a82 100644 --- a/tests/integrationbase.go +++ b/tests/integrationbase.go @@ -69,6 +69,7 @@ type ( engine FrontendClient adminClient AdminClient operatorClient operatorservice.OperatorServiceClient + httpAPIAddress string Logger log.Logger namespace string foreignNamespace string @@ -96,6 +97,7 @@ func (s *IntegrationBase) setupSuite(defaultClusterConfigFile string) { s.engine = NewFrontendClient(connection) s.adminClient = NewAdminClient(connection) s.operatorClient = operatorservice.NewOperatorServiceClient(connection) + s.httpAPIAddress = TestFlags.FrontendHTTPAddr } else { s.Logger.Info("Running integration test against test cluster") cluster, err := NewCluster(clusterConfig, s.Logger) @@ -104,6 +106,7 @@ func (s *IntegrationBase) setupSuite(defaultClusterConfigFile string) { s.engine = s.testCluster.GetFrontendClient() s.adminClient = s.testCluster.GetAdminClient() s.operatorClient = s.testCluster.GetOperatorClient() + s.httpAPIAddress = cluster.host.FrontendHTTPAddress() } s.namespace = s.randomizeStr("integration-test-namespace") diff --git a/tests/onebox.go b/tests/onebox.go index 39e6cd21f04..fd039a44dac 100644 --- a/tests/onebox.go +++ b/tests/onebox.go @@ -26,9 +26,11 @@ package tests import ( "context" + "crypto/tls" "encoding/json" "fmt" "net" + "strconv" "sync" "time" @@ -57,6 +59,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" persistenceClient "go.temporal.io/server/common/persistence/client" @@ -66,6 +69,7 @@ import ( "go.temporal.io/server/common/resolver" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/rpc" + "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/service/frontend" @@ -121,6 +125,12 @@ type ( mockAdminClient map[string]adminservice.AdminServiceClient namespaceReplicationTaskExecutor namespace.ReplicationTaskExecutor spanExporters []otelsdktrace.SpanExporter + tlsConfigProvider *encryption.FixedTLSConfigProvider + captureMetricsHandler *metricstest.CaptureHandler + + onGetClaims func(*authorization.AuthInfo) (*authorization.Claims, error) + onAuthorize func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) + callbackLock sync.RWMutex // Must be used for above callbacks } // HistoryConfig contains configs for history service @@ -158,6 +168,8 @@ type ( NamespaceReplicationTaskExecutor namespace.ReplicationTaskExecutor SpanExporters []otelsdktrace.SpanExporter DynamicConfigOverrides map[dynamicconfig.Key]interface{} + TLSConfigProvider *encryption.FixedTLSConfigProvider + CaptureMetricsHandler *metricstest.CaptureHandler } listenHostPort string @@ -190,6 +202,8 @@ func newTemporal(params *TemporalParams) *temporalImpl { mockAdminClient: params.MockAdminClient, namespaceReplicationTaskExecutor: params.NamespaceReplicationTaskExecutor, spanExporters: params.SpanExporters, + tlsConfigProvider: params.TLSConfigProvider, + captureMetricsHandler: params.CaptureMetricsHandler, dcClient: testDCClient, } impl.overrideHistoryDynamicConfig(testDCClient) @@ -277,6 +291,21 @@ func (c *temporalImpl) FrontendGRPCAddress() string { } } +func (c *temporalImpl) FrontendHTTPAddress() string { + host, port := c.FrontendHTTPHostPort() + return net.JoinHostPort(host, strconv.Itoa(port)) +} + +func (c *temporalImpl) FrontendHTTPHostPort() (string, int) { + if host, port, err := net.SplitHostPort(c.FrontendGRPCAddress()); err != nil { + panic(fmt.Errorf("Invalid gRPC frontend address: %w", err)) + } else if portNum, err := strconv.Atoi(port); err != nil { + panic(fmt.Errorf("Invalid gRPC frontend port: %w", err)) + } else { + return host, portNum + 10 + } +} + func (c *temporalImpl) HistoryServiceAddress() []string { var hosts []string var startPort int @@ -376,11 +405,12 @@ func (c *temporalImpl) startFrontend(hosts map[primitives.ServiceName][]string, persistenceConfig, serviceName, ), + fx.Provide(c.frontendConfigProvider), fx.Provide(func() listenHostPort { return listenHostPort(c.FrontendGRPCAddress()) }), fx.Provide(func() config.DCRedirectionPolicy { return config.DCRedirectionPolicy{} }), fx.Provide(func() log.ThrottledLogger { return c.logger }), fx.Provide(func() resource.NamespaceLogger { return c.logger }), - fx.Provide(newRPCFactoryImpl), + fx.Provide(c.newRPCFactory), fx.Provide(func() membership.Monitor { return newSimpleMonitor(hosts) }), @@ -391,10 +421,10 @@ func (c *temporalImpl) startFrontend(hosts map[primitives.ServiceName][]string, fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), fx.Provide(sdkClientFactoryProvider), - fx.Provide(func() metrics.Handler { return metrics.NoopMetricsHandler }), + fx.Provide(c.GetMetricsHandler), fx.Provide(func() []grpc.UnaryServerInterceptor { return nil }), - fx.Provide(func() authorization.Authorizer { return nil }), - fx.Provide(func() authorization.ClaimMapper { return nil }), + fx.Provide(func() authorization.Authorizer { return c }), + fx.Provide(func() authorization.ClaimMapper { return c }), fx.Provide(func() authorization.JWTAudienceMapper { return nil }), fx.Provide(func() client.FactoryProvider { return client.NewFactoryProvider() }), fx.Provide(func() searchattribute.Mapper { return nil }), @@ -408,6 +438,7 @@ func (c *temporalImpl) startFrontend(hosts map[primitives.ServiceName][]string, fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), + fx.Provide(c.GetTLSConfigProvider), fx.Supply(c.spanExporters), temporal.ServiceTracingModule, frontend.Module, @@ -470,11 +501,11 @@ func (c *temporalImpl) startHistory( persistenceConfig, serviceName, ), - fx.Provide(func() metrics.Handler { return metrics.NoopMetricsHandler }), + fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(grpcPort) }), fx.Provide(func() config.DCRedirectionPolicy { return config.DCRedirectionPolicy{} }), fx.Provide(func() log.ThrottledLogger { return c.logger }), - fx.Provide(newRPCFactoryImpl), + fx.Provide(c.newRPCFactory), fx.Provide(func() membership.Monitor { return newSimpleMonitor(hosts) }), @@ -497,6 +528,7 @@ func (c *temporalImpl) startHistory( fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), + fx.Provide(c.GetTLSConfigProvider), fx.Provide(workflow.NewTaskGeneratorProvider), fx.Supply(c.spanExporters), temporal.ServiceTracingModule, @@ -566,10 +598,10 @@ func (c *temporalImpl) startMatching(hosts map[primitives.ServiceName][]string, persistenceConfig, serviceName, ), - fx.Provide(func() metrics.Handler { return metrics.NoopMetricsHandler }), + fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(c.MatchingGRPCServiceAddress()) }), fx.Provide(func() log.ThrottledLogger { return c.logger }), - fx.Provide(newRPCFactoryImpl), + fx.Provide(c.newRPCFactory), fx.Provide(func() membership.Monitor { return newSimpleMonitor(hosts) }), @@ -587,6 +619,7 @@ func (c *temporalImpl) startMatching(hosts map[primitives.ServiceName][]string, fx.Provide(func() dynamicconfig.Client { return c.dcClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), + fx.Provide(c.GetTLSConfigProvider), fx.Provide(func() log.Logger { return c.logger }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Supply(c.spanExporters), @@ -658,11 +691,11 @@ func (c *temporalImpl) startWorker(hosts map[primitives.ServiceName][]string, st persistenceConfig, serviceName, ), - fx.Provide(func() metrics.Handler { return metrics.NoopMetricsHandler }), + fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(c.WorkerGRPCServiceAddress()) }), fx.Provide(func() config.DCRedirectionPolicy { return config.DCRedirectionPolicy{} }), fx.Provide(func() log.ThrottledLogger { return c.logger }), - fx.Provide(newRPCFactoryImpl), + fx.Provide(c.newRPCFactory), fx.Provide(func() membership.Monitor { return newSimpleMonitor(hosts) }), @@ -683,6 +716,7 @@ func (c *temporalImpl) startWorker(hosts map[primitives.ServiceName][]string, st fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), + fx.Provide(c.GetTLSConfigProvider), fx.Supply(c.spanExporters), temporal.ServiceTracingModule, worker.Module, @@ -718,6 +752,37 @@ func (c *temporalImpl) GetExecutionManager() persistence.ExecutionManager { return c.executionManager } +func (c *temporalImpl) GetTLSConfigProvider() encryption.TLSConfigProvider { + // If we just return this directly, the interface will be non-nil but the + // pointer will be nil + if c.tlsConfigProvider != nil { + return c.tlsConfigProvider + } + return nil +} + +func (c *temporalImpl) GetMetricsHandler() metrics.Handler { + if c.captureMetricsHandler != nil { + return c.captureMetricsHandler + } + return metrics.NoopMetricsHandler +} + +func (c *temporalImpl) frontendConfigProvider() *config.Config { + // Set HTTP port and a test HTTP forwarded header + _, httpPort := c.FrontendHTTPHostPort() + return &config.Config{ + Services: map[string]config.Service{ + string(primitives.FrontendService): { + RPC: config.RPC{ + HTTPPort: httpPort, + HTTPAdditionalForwardedHeaders: []string{"this-header-forwarded"}, + }, + }, + }, + } +} + func (c *temporalImpl) overrideHistoryDynamicConfig(client *dcClient) { client.OverrideValue(dynamicconfig.ReplicationTaskProcessorStartWait, time.Nanosecond) @@ -748,6 +813,76 @@ func (c *temporalImpl) overrideHistoryDynamicConfig(client *dcClient) { client.OverrideValue(dynamicconfig.VisibilityProcessorUpdateAckInterval, 1*time.Second) } +func (c *temporalImpl) newRPCFactory( + sn primitives.ServiceName, + grpcHostPort listenHostPort, + logger log.Logger, + grpcResolver membership.GRPCResolver, + tlsConfigProvider encryption.TLSConfigProvider, +) (common.RPCFactory, error) { + host, portStr, err := net.SplitHostPort(string(grpcHostPort)) + if err != nil { + return nil, fmt.Errorf("failed parsing host:port: %w", err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + var frontendTLSConfig *tls.Config + if tlsConfigProvider != nil { + if frontendTLSConfig, err = tlsConfigProvider.GetFrontendClientConfig(); err != nil { + return nil, fmt.Errorf("failed getting client TLS config: %w", err) + } + } + return rpc.NewFactory( + &config.RPC{BindOnIP: host, GRPCPort: port}, + sn, + logger, + tlsConfigProvider, + grpcResolver.MakeURL(primitives.FrontendService), + frontendTLSConfig, + nil, + ), nil +} + +func (c *temporalImpl) SetOnGetClaims(fn func(*authorization.AuthInfo) (*authorization.Claims, error)) { + c.callbackLock.Lock() + c.onGetClaims = fn + c.callbackLock.Unlock() +} + +func (c *temporalImpl) GetClaims(authInfo *authorization.AuthInfo) (*authorization.Claims, error) { + c.callbackLock.RLock() + onGetClaims := c.onGetClaims + c.callbackLock.RUnlock() + if onGetClaims != nil { + return onGetClaims(authInfo) + } + return &authorization.Claims{System: authorization.RoleAdmin}, nil +} + +func (c *temporalImpl) SetOnAuthorize( + fn func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error), +) { + c.callbackLock.Lock() + c.onAuthorize = fn + c.callbackLock.Unlock() +} + +func (c *temporalImpl) Authorize( + ctx context.Context, + caller *authorization.Claims, + target *authorization.CallTarget, +) (authorization.Result, error) { + c.callbackLock.RLock() + onAuthorize := c.onAuthorize + c.callbackLock.RUnlock() + if onAuthorize != nil { + return onAuthorize(ctx, caller, target) + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil +} + // copyPersistenceConfig makes a deepcopy of persistence config. // This is just a temp fix for the race condition of persistence config. // The race condition happens because all the services are using the same datastore map in the config. @@ -775,89 +910,24 @@ func sdkClientFactoryProvider( metricsHandler metrics.Handler, logger log.Logger, dc *dynamicconfig.Collection, + tlsConfigProvider encryption.TLSConfigProvider, ) sdk.ClientFactory { + var tlsConfig *tls.Config + if tlsConfigProvider != nil { + var err error + if tlsConfig, err = tlsConfigProvider.GetFrontendClientConfig(); err != nil { + panic(err) + } + } return sdk.NewClientFactory( resolver.MakeURL(primitives.FrontendService), - nil, + tlsConfig, metricsHandler, logger, dc.GetIntProperty(dynamicconfig.WorkerStickyCacheSize, 0), ) } -type rpcFactoryImpl struct { - serviceName primitives.ServiceName - grpcHostPort string - logger log.Logger - frontendURL string - - sync.RWMutex - listener net.Listener -} - -func (c *rpcFactoryImpl) GetFrontendGRPCServerOptions() ([]grpc.ServerOption, error) { - return nil, nil -} - -func (c *rpcFactoryImpl) GetInternodeGRPCServerOptions() ([]grpc.ServerOption, error) { - return nil, nil -} - -func (c *rpcFactoryImpl) CreateRemoteFrontendGRPCConnection(hostName string) *grpc.ClientConn { - return c.CreateGRPCConnection(hostName) -} - -func (c *rpcFactoryImpl) CreateLocalFrontendGRPCConnection() *grpc.ClientConn { - return c.CreateGRPCConnection(c.frontendURL) -} - -func (c *rpcFactoryImpl) CreateInternodeGRPCConnection(hostName string) *grpc.ClientConn { - return c.CreateGRPCConnection(hostName) -} - -func newRPCFactoryImpl(sn primitives.ServiceName, grpcHostPort listenHostPort, logger log.Logger, resolver membership.GRPCResolver) common.RPCFactory { - return &rpcFactoryImpl{ - serviceName: sn, - grpcHostPort: string(grpcHostPort), - logger: logger, - frontendURL: resolver.MakeURL(primitives.FrontendService), - } -} - -func (c *rpcFactoryImpl) GetGRPCListener() net.Listener { - c.RLock() - if c.listener != nil { - c.RUnlock() - return c.listener - } - c.RUnlock() - - c.Lock() - defer c.Unlock() - - if c.listener == nil { - var err error - c.listener, err = net.Listen("tcp", c.grpcHostPort) - if err != nil { - c.logger.Fatal("Failed create gRPC listener", tag.Error(err), tag.Service(c.serviceName), tag.Address(c.grpcHostPort)) - } - - c.logger.Info("Created gRPC listener", tag.Service(c.serviceName), tag.Address(c.grpcHostPort)) - } - - return c.listener -} - -// CreateGRPCConnection creates connection for gRPC calls -func (c *rpcFactoryImpl) CreateGRPCConnection(hostName string) *grpc.ClientConn { - connection, err := rpc.Dial(hostName, nil, c.logger) - if err != nil { - c.logger.Fatal("Failed to create gRPC connection", tag.Error(err)) - } - - return connection -} - func newSimpleHostInfoProvider(serviceName primitives.ServiceName, hosts map[primitives.ServiceName][]string) membership.HostInfoProvider { hostInfo := membership.NewHostInfoFromAddress(hosts[serviceName][0]) return membership.NewHostInfoProvider(hostInfo) diff --git a/tests/test_cluster.go b/tests/test_cluster.go index baa8fcf4e35..20e88515647 100644 --- a/tests/test_cluster.go +++ b/tests/test_cluster.go @@ -26,6 +26,9 @@ package tests import ( "context" + "crypto/tls" + "crypto/x509" + "errors" "fmt" "os" "path" @@ -45,6 +48,7 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" persistencetests "go.temporal.io/server/common/persistence/persistence-tests" @@ -53,6 +57,7 @@ import ( "go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite" esclient "go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client" "go.temporal.io/server/common/pprof" + "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/tests/testutils" ) @@ -89,6 +94,8 @@ type ( MockAdminClient map[string]adminservice.AdminServiceClient FaultInjection config.FaultInjection `yaml:"faultinjection"` DynamicConfigOverrides map[dynamicconfig.Key]interface{} + GenerateMTLS bool + EnableMetricsCapture bool } // WorkerConfig is the config for enabling/disabling Temporal worker @@ -100,13 +107,13 @@ type ( ) const ( - defaultPageSize = 5 - pprofTestPort = 7000 + defaultPageSize = 5 + pprofTestPort = 7000 + tlsCertCommonName = "my-common-name" ) // NewCluster creates and sets up the test cluster func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, error) { - clusterMetadataConfig := cluster.NewTestClusterMetadataConfig( options.ClusterMetadata.EnableGlobalNamespace, options.IsMasterCluster, @@ -224,6 +231,13 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er return nil, err } + var tlsConfigProvider *encryption.FixedTLSConfigProvider + if options.GenerateMTLS { + if tlsConfigProvider, err = createFixedTLSConfigProvider(); err != nil { + return nil, err + } + } + temporalParams := &TemporalParams{ ClusterMetadataConfig: clusterMetadataConfig, PersistenceConfig: pConfig, @@ -244,6 +258,11 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er MockAdminClient: options.MockAdminClient, NamespaceReplicationTaskExecutor: namespace.NewReplicationTaskExecutor(options.ClusterMetadata.CurrentClusterName, testBase.MetadataManager, logger), DynamicConfigOverrides: options.DynamicConfigOverrides, + TLSConfigProvider: tlsConfigProvider, + } + + if options.EnableMetricsCapture { + temporalParams.CaptureMetricsHandler = metricstest.NewCaptureHandler() } err = newPProfInitializerImpl(logger, pprofTestPort).Start() @@ -448,3 +467,52 @@ func (tc *TestCluster) GetExecutionManager() persistence.ExecutionManager { func (tc *TestCluster) GetHost() *temporalImpl { return tc.host } + +var errCannotAddCACertToPool = errors.New("failed adding CA to pool") + +func createFixedTLSConfigProvider() (*encryption.FixedTLSConfigProvider, error) { + // We use the existing cert generation utilities even though they use slow + // RSA and use disk unnecessarily + tempDir, err := os.MkdirTemp("", "") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempDir) + + certChain, err := testutils.GenerateTestChain(tempDir, tlsCertCommonName) + if err != nil { + return nil, err + } + + // Due to how mTLS is built in the server, we have to reuse the CA for server + // and client, therefore we might as well reuse the cert too + + tlsCert, err := tls.LoadX509KeyPair(certChain.CertPubFile, certChain.CertKeyFile) + if err != nil { + return nil, err + } + caCertPool := x509.NewCertPool() + if caCertBytes, err := os.ReadFile(certChain.CaPubFile); err != nil { + return nil, err + } else if !caCertPool.AppendCertsFromPEM(caCertBytes) { + return nil, errCannotAddCACertToPool + } + + serverTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + clientTLSConfig := &tls.Config{ + ServerName: tlsCertCommonName, + Certificates: []tls.Certificate{tlsCert}, + RootCAs: caCertPool, + } + + return &encryption.FixedTLSConfigProvider{ + InternodeServerConfig: serverTLSConfig, + InternodeClientConfig: clientTLSConfig, + FrontendServerConfig: serverTLSConfig, + FrontendClientConfig: clientTLSConfig, + }, nil +} diff --git a/tests/testdata/clientintegrationtestcluster.yaml b/tests/testdata/clientintegrationtestcluster.yaml index 37e33cd9829..f5015220892 100644 --- a/tests/testdata/clientintegrationtestcluster.yaml +++ b/tests/testdata/clientintegrationtestcluster.yaml @@ -1,4 +1,5 @@ enablearchival: false +enablemetricscapture: true clusterno: 0 historyconfig: numhistoryshards: 4 diff --git a/tests/testdata/tls_integration_test_cluster.yaml b/tests/testdata/tls_integration_test_cluster.yaml new file mode 100644 index 00000000000..0f5dd16b1d8 --- /dev/null +++ b/tests/testdata/tls_integration_test_cluster.yaml @@ -0,0 +1,8 @@ +historyconfig: + numhistoryshards: 4 + numhistoryhosts: 1 +workerconfig: + enablearchiver: false + enablereplicator: false + startworkeranyway: true +generatemtls: true \ No newline at end of file diff --git a/tests/testutils/certificate.go b/tests/testutils/certificate.go index 0189c1bebf4..1208959721b 100644 --- a/tests/testutils/certificate.go +++ b/tests/testutils/certificate.go @@ -63,11 +63,12 @@ func generateSelfSignedX509CA(commonName string, extUsage []x509.ExtKeyUsage, ke if ip.IsLoopback() { template.DNSNames = []string{"localhost"} } + } else { + template.DNSNames = []string{commonName} } if strings.ToLower(commonName) == "localhost" { template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)} - template.DNSNames = []string{"localhost"} } privateKey, err := rsa.GenerateKey(rand.Reader, keyLengthBits) @@ -116,11 +117,12 @@ func generateServerX509UsingCAAndSerialNumber(commonName string, serialNumber in if ip.IsLoopback() { template.DNSNames = []string{"localhost"} } + } else { + template.DNSNames = []string{commonName} } if strings.ToLower(commonName) == "localhost" { template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)} - template.DNSNames = []string{"localhost"} } privateKey, err := rsa.GenerateKey(rand.Reader, 4096) diff --git a/tests/tls_test.go b/tests/tls_test.go new file mode 100644 index 00000000000..da638ea7e32 --- /dev/null +++ b/tests/tls_test.go @@ -0,0 +1,152 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tests + +import ( + "context" + "flag" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.temporal.io/api/workflowservice/v1" + sdkclient "go.temporal.io/sdk/client" + "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/rpc" +) + +type tlsIntegrationSuite struct { + IntegrationBase + hostPort string + sdkClient sdkclient.Client +} + +func TestTLSIntegrationSuite(t *testing.T) { + flag.Parse() + suite.Run(t, new(tlsIntegrationSuite)) +} + +func (s *tlsIntegrationSuite) SetupSuite() { + s.setupSuite("testdata/tls_integration_test_cluster.yaml") + s.hostPort = "127.0.0.1:7134" + if TestFlags.FrontendAddr != "" { + s.hostPort = TestFlags.FrontendAddr + } +} + +func (s *tlsIntegrationSuite) TearDownSuite() { + s.tearDownSuite() +} + +func (s *tlsIntegrationSuite) SetupTest() { + var err error + s.sdkClient, err = sdkclient.Dial(sdkclient.Options{ + HostPort: s.hostPort, + Namespace: s.namespace, + ConnectionOptions: sdkclient.ConnectionOptions{ + TLS: s.testCluster.host.tlsConfigProvider.FrontendClientConfig, + }, + }) + if err != nil { + s.Logger.Fatal("Error when creating SDK client", tag.Error(err)) + } +} + +func (s *tlsIntegrationSuite) TearDownTest() { + s.sdkClient.Close() +} + +func (s *tlsIntegrationSuite) TestGRPCMTLS() { + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(time.Minute) + defer cancel() + + // Track auth info + calls := s.trackAuthInfoByCall() + + // Make a list-open call + _, _ = s.sdkClient.ListOpenWorkflow(ctx, &workflowservice.ListOpenWorkflowExecutionsRequest{}) + + // Confirm auth info as expected + authInfo, ok := calls.Load("/temporal.api.workflowservice.v1.WorkflowService/ListOpenWorkflowExecutions") + s.Require().True(ok) + s.Require().Equal(tlsCertCommonName, authInfo.(*authorization.AuthInfo).TLSSubject.CommonName) +} + +func (s *tlsIntegrationSuite) TestHTTPMTLS() { + if s.httpAPIAddress == "" { + s.T().Skip("HTTP API server not enabled") + } + // Track auth info + calls := s.trackAuthInfoByCall() + + // Confirm non-HTTPS call is rejected with 400 + resp, err := http.Get("http://" + s.httpAPIAddress + "/api/v1/namespaces/" + s.namespace + "/workflows") + s.Require().NoError(err) + s.Require().Equal(http.StatusBadRequest, resp.StatusCode) + + // Create HTTP client with TLS config + httpClient := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: s.testCluster.host.tlsConfigProvider.FrontendClientConfig, + }, + } + + // Make a list call + req, err := http.NewRequest("GET", "https://"+s.httpAPIAddress+"/api/v1/namespaces/"+s.namespace+"/workflows", nil) + s.Require().NoError(err) + resp, err = httpClient.Do(req) + s.Require().NoError(err) + s.Require().Equal(http.StatusOK, resp.StatusCode) + + // Confirm auth info as expected + authInfo, ok := calls.Load("/temporal.api.workflowservice.v1.WorkflowService/ListWorkflowExecutions") + s.Require().True(ok) + s.Require().Equal(tlsCertCommonName, authInfo.(*authorization.AuthInfo).TLSSubject.CommonName) +} + +func (s *tlsIntegrationSuite) trackAuthInfoByCall() *sync.Map { + var calls sync.Map + // Put auth info on claim, then use authorizer to set on the map by call + s.testCluster.host.SetOnGetClaims(func(authInfo *authorization.AuthInfo) (*authorization.Claims, error) { + return &authorization.Claims{ + System: authorization.RoleAdmin, + Extensions: authInfo, + }, nil + }) + s.testCluster.host.SetOnAuthorize(func( + ctx context.Context, + caller *authorization.Claims, + target *authorization.CallTarget, + ) (authorization.Result, error) { + if authInfo, _ := caller.Extensions.(*authorization.AuthInfo); authInfo != nil { + calls.Store(target.APIName, authInfo) + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil + }) + return &calls +}