Skip to content

Commit

Permalink
fix(websocket): authorization was not working for clients
Browse files Browse the repository at this point in the history
Signed-off-by: Rodney Osodo <socials@rodneyosodo.com>
  • Loading branch information
rodneyosodo committed Feb 14, 2025
1 parent d12c628 commit ddc40fd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 83 deletions.
4 changes: 4 additions & 0 deletions ws/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package ws
import (
"context"
"fmt"
"strings"

grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
Expand Down Expand Up @@ -93,6 +94,9 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, chanID stri
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: clientKey,
}
if strings.HasPrefix(clientKey, "Client") {
authnReq.ClientSecret = extractClientSecret(clientKey)
}
authnRes, err := svc.clients.Authenticate(ctx, authnReq)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
Expand Down
115 changes: 32 additions & 83 deletions ws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
token = string(s.Password)
}

return h.authAccess(ctx, token, *topic, connections.Publish)
_, _, err := h.authAccess(ctx, token, *topic, connections.Publish)

return err
}

// AuthSubscribe is called on device publish,
Expand All @@ -111,16 +113,8 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
return errMissingTopicSub
}

var token string
switch {
case strings.HasPrefix(string(s.Password), "Client"):
token = strings.ReplaceAll(string(s.Password), "Client ", "")
default:
token = string(s.Password)
}

for _, topic := range *topics {
if err := h.authAccess(ctx, token, topic, connections.Subscribe); err != nil {
if _, _, err := h.authAccess(ctx, string(s.Password), topic, connections.Subscribe); err != nil {
return err
}
}
Expand All @@ -139,7 +133,6 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
if !ok {
return errors.Wrap(errFailedPublish, errClientNotInitialized)
}
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))

if len(*payload) == 0 {
return errFailedMessagePublish
Expand All @@ -160,41 +153,9 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(errFailedParseSubtopic, err)
}

var clientID, clientType string
switch {
case strings.HasPrefix(string(s.Password), "Client"):
clientKey := extractClientSecret(string(s.Password))
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: clientKey})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.Authenticated {
return svcerr.ErrAuthentication
}
clientType = policies.ClientType
clientID = authnRes.GetId()
default:
token := string(s.Password)
authnSession, err := h.authn.Authenticate(ctx, extractBearerToken(token))
if err != nil {
return err
}
clientType = policies.UserType
clientID = authnSession.DomainUserID
}

ar := &grpcChannelsV1.AuthzReq{
Type: uint32(connections.Publish),
ClientId: clientID,
ClientType: clientType,
ChannelId: chanID,
}
res, err := h.channels.Authorize(ctx, ar)
clientID, clientType, err := h.authAccess(ctx, string(s.Password), *topic, connections.Publish)
if err != nil {
return err
}
if !res.GetAuthorized() {
return svcerr.ErrAuthorization
return errors.Wrap(errFailedPublish, err)
}

msg := messaging.Message{
Expand All @@ -213,6 +174,8 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(errFailedPublishToMsgBroker, err)
}

h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))

return nil
}

Expand Down Expand Up @@ -242,38 +205,33 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}

func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) error {
var clientID, clientType string
switch {
case strings.HasPrefix(token, "Client"):
clientKey := extractClientSecret(token)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: clientKey})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.Authenticated {
return svcerr.ErrAuthentication
}
clientType = policies.ClientType
clientID = authnRes.GetId()
default:
authnSession, err := h.authn.Authenticate(ctx, extractBearerToken(token))
if err != nil {
return err
}
clientType = policies.UserType
clientID = authnSession.DomainUserID
func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, error) {
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: token,
}
if strings.HasPrefix(token, "Client") {
authnReq.ClientSecret = extractClientSecret(token)
}

authnRes, err := h.clients.Authenticate(ctx, authnReq)
if err != nil {
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.GetAuthenticated() {
return "", "", svcerr.ErrAuthentication
}
clientType := policies.ClientType
clientID := authnRes.GetId()

// Topics are in the format:
// channels/<channel_id>/messages/<subtopic>/.../ct/<content_type>
if !channelRegExp.MatchString(topic) {
return errMalformedTopic
return "", "", errMalformedTopic
}

channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 1 {
return errMalformedTopic
return "", "", errMalformedTopic
}

chanID := channelParts[1]
Expand All @@ -286,13 +244,13 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return errors.Wrap(svcerr.ErrAuthorization, err)
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
}

return nil
return clientID, clientType, nil
}

func parseSubtopic(subtopic string) (string, error) {
Expand Down Expand Up @@ -325,19 +283,10 @@ func parseSubtopic(subtopic string) (string, error) {
}

// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned.
func extractClientSecret(topic string) string {
if !strings.HasPrefix(topic, apiutil.ClientPrefix) {
return ""
}

return strings.TrimPrefix(topic, apiutil.ClientPrefix)
}

// extractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned.
func extractBearerToken(token string) string {
if !strings.HasPrefix(token, apiutil.BearerPrefix) {
func extractClientSecret(token string) string {
if !strings.HasPrefix(token, apiutil.ClientPrefix) {
return ""
}

return strings.TrimPrefix(token, apiutil.BearerPrefix)
return strings.TrimPrefix(token, apiutil.ClientPrefix)
}

0 comments on commit ddc40fd

Please # to comment.