diff --git a/internal/auth/sasl.go b/internal/auth/sasl.go index 06417ce0..2103ab7e 100644 --- a/internal/auth/sasl.go +++ b/internal/auth/sasl.go @@ -81,6 +81,9 @@ func (s *SASLAuth) CreateSASL(mech string, remoteAddr net.Addr, successCb func(i if identity == "" { identity = username } + if identity != username { + return ErrInvalidAuthCred + } err := s.AuthPlain(username, password) if err != nil { diff --git a/internal/auth/sasl_test.go b/internal/auth/sasl_test.go index 625f0bb1..fcc193aa 100644 --- a/internal/auth/sasl_test.go +++ b/internal/auth/sasl_test.go @@ -75,13 +75,13 @@ func TestCreateSASL(t *testing.T) { t.Run("PLAIN with authorization identity", func(t *testing.T) { srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(id string) error { - if id != "user1a" { + if id != "user1" { t.Fatal("Wrong authorization identity passed:", id) } return nil }) - _, _, err := srv.Next([]byte("user1a\x00user1\x00aa")) + _, _, err := srv.Next([]byte("user1\x00user1\x00aa")) if err != nil { t.Error("Unexpected error:", err) }