From ddc40fd7c1315e289d4f7a7e1ad8167f645c06f0 Mon Sep 17 00:00:00 2001 From: Rodney Osodo Date: Tue, 11 Feb 2025 15:04:16 +0300 Subject: [PATCH] fix(websocket): authorization was not working for clients Signed-off-by: Rodney Osodo --- ws/adapter.go | 4 ++ ws/handler.go | 115 ++++++++++++++------------------------------------ 2 files changed, 36 insertions(+), 83 deletions(-) diff --git a/ws/adapter.go b/ws/adapter.go index 02c4cfe39e..0517cbbe83 100644 --- a/ws/adapter.go +++ b/ws/adapter.go @@ -6,6 +6,7 @@ package ws import ( "context" "fmt" + "strings" grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" @@ -93,6 +94,9 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, chanID stri authnReq := &grpcClientsV1.AuthnReq{ ClientSecret: clientKey, } + if strings.HasPrefix(clientKey, "Client") { + authnReq.ClientSecret = extractClientSecret(clientKey) + } authnRes, err := svc.clients.Authenticate(ctx, authnReq) if err != nil { return "", errors.Wrap(svcerr.ErrAuthentication, err) diff --git a/ws/handler.go b/ws/handler.go index 238011b28d..422e6fec4f 100644 --- a/ws/handler.go +++ b/ws/handler.go @@ -97,7 +97,9 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt token = string(s.Password) } - return h.authAccess(ctx, token, *topic, connections.Publish) + _, _, err := h.authAccess(ctx, token, *topic, connections.Publish) + + return err } // AuthSubscribe is called on device publish, @@ -111,16 +113,8 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { return errMissingTopicSub } - var token string - switch { - case strings.HasPrefix(string(s.Password), "Client"): - token = strings.ReplaceAll(string(s.Password), "Client ", "") - default: - token = string(s.Password) - } - for _, topic := range *topics { - if err := h.authAccess(ctx, token, topic, connections.Subscribe); err != nil { + if _, _, err := h.authAccess(ctx, string(s.Password), topic, connections.Subscribe); err != nil { return err } } @@ -139,7 +133,6 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e if !ok { return errors.Wrap(errFailedPublish, errClientNotInitialized) } - h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) if len(*payload) == 0 { return errFailedMessagePublish @@ -160,41 +153,9 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e return errors.Wrap(errFailedParseSubtopic, err) } - var clientID, clientType string - switch { - case strings.HasPrefix(string(s.Password), "Client"): - clientKey := extractClientSecret(string(s.Password)) - authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: clientKey}) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !authnRes.Authenticated { - return svcerr.ErrAuthentication - } - clientType = policies.ClientType - clientID = authnRes.GetId() - default: - token := string(s.Password) - authnSession, err := h.authn.Authenticate(ctx, extractBearerToken(token)) - if err != nil { - return err - } - clientType = policies.UserType - clientID = authnSession.DomainUserID - } - - ar := &grpcChannelsV1.AuthzReq{ - Type: uint32(connections.Publish), - ClientId: clientID, - ClientType: clientType, - ChannelId: chanID, - } - res, err := h.channels.Authorize(ctx, ar) + clientID, clientType, err := h.authAccess(ctx, string(s.Password), *topic, connections.Publish) if err != nil { - return err - } - if !res.GetAuthorized() { - return svcerr.ErrAuthorization + return errors.Wrap(errFailedPublish, err) } msg := messaging.Message{ @@ -213,6 +174,8 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e return errors.Wrap(errFailedPublishToMsgBroker, err) } + h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) + return nil } @@ -242,38 +205,33 @@ func (h *handler) Disconnect(ctx context.Context) error { return nil } -func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) error { - var clientID, clientType string - switch { - case strings.HasPrefix(token, "Client"): - clientKey := extractClientSecret(token) - authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: clientKey}) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !authnRes.Authenticated { - return svcerr.ErrAuthentication - } - clientType = policies.ClientType - clientID = authnRes.GetId() - default: - authnSession, err := h.authn.Authenticate(ctx, extractBearerToken(token)) - if err != nil { - return err - } - clientType = policies.UserType - clientID = authnSession.DomainUserID +func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, error) { + authnReq := &grpcClientsV1.AuthnReq{ + ClientSecret: token, } + if strings.HasPrefix(token, "Client") { + authnReq.ClientSecret = extractClientSecret(token) + } + + authnRes, err := h.clients.Authenticate(ctx, authnReq) + if err != nil { + return "", "", errors.Wrap(svcerr.ErrAuthentication, err) + } + if !authnRes.GetAuthenticated() { + return "", "", svcerr.ErrAuthentication + } + clientType := policies.ClientType + clientID := authnRes.GetId() // Topics are in the format: // channels//messages//.../ct/ if !channelRegExp.MatchString(topic) { - return errMalformedTopic + return "", "", errMalformedTopic } channelParts := channelRegExp.FindStringSubmatch(topic) if len(channelParts) < 1 { - return errMalformedTopic + return "", "", errMalformedTopic } chanID := channelParts[1] @@ -286,13 +244,13 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c } res, err := h.channels.Authorize(ctx, ar) if err != nil { - return errors.Wrap(svcerr.ErrAuthorization, err) + return "", "", errors.Wrap(svcerr.ErrAuthorization, err) } if !res.GetAuthorized() { - return errors.Wrap(svcerr.ErrAuthorization, err) + return "", "", errors.Wrap(svcerr.ErrAuthorization, err) } - return nil + return clientID, clientType, nil } func parseSubtopic(subtopic string) (string, error) { @@ -325,19 +283,10 @@ func parseSubtopic(subtopic string) (string, error) { } // extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned. -func extractClientSecret(topic string) string { - if !strings.HasPrefix(topic, apiutil.ClientPrefix) { - return "" - } - - return strings.TrimPrefix(topic, apiutil.ClientPrefix) -} - -// extractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned. -func extractBearerToken(token string) string { - if !strings.HasPrefix(token, apiutil.BearerPrefix) { +func extractClientSecret(token string) string { + if !strings.HasPrefix(token, apiutil.ClientPrefix) { return "" } - return strings.TrimPrefix(token, apiutil.BearerPrefix) + return strings.TrimPrefix(token, apiutil.ClientPrefix) }