From 5c381cfbf3e4b7ce040ed8511a1fae1a78a0014b Mon Sep 17 00:00:00 2001 From: Eleftheria Stein-Kousathana Date: Tue, 2 Apr 2024 00:28:12 +0200 Subject: [PATCH] Make provider an optional filter (#2871) Ref #2522 --- cmd/cli/app/artifact/artifact.go | 3 +- cmd/cli/app/repo/repo.go | 3 +- cmd/server/app/webhook_update.go | 5 +- database/query/repositories.sql | 10 +- internal/controlplane/common.go | 15 +- internal/controlplane/handlers_artifacts.go | 17 +- internal/controlplane/handlers_oauth.go | 25 ++- internal/controlplane/handlers_profile.go | 4 +- .../controlplane/handlers_repositories.go | 162 +++++++++++------ .../handlers_repositories_test.go | 61 ++++--- internal/db/repositories.sql.go | 38 ++-- internal/db/repositories_test.go | 10 +- internal/providers/github/common.go | 13 ++ internal/reconcilers/run_profile.go | 5 +- internal/repositories/github/mock/service.go | 44 +++-- internal/repositories/github/service.go | 73 ++++---- internal/repositories/github/service_test.go | 166 +++++++++++------- 17 files changed, 409 insertions(+), 245 deletions(-) diff --git a/cmd/cli/app/artifact/artifact.go b/cmd/cli/app/artifact/artifact.go index 521c0a414b..44034af403 100644 --- a/cmd/cli/app/artifact/artifact.go +++ b/cmd/cli/app/artifact/artifact.go @@ -20,7 +20,6 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/minder/cmd/cli/app" - ghclient "github.com/stacklok/minder/internal/providers/github/oauth" ) // ArtifactCmd is the artifact subcommand @@ -36,6 +35,6 @@ var ArtifactCmd = &cobra.Command{ func init() { app.RootCmd.AddCommand(ArtifactCmd) // Flags for all subcommands - ArtifactCmd.PersistentFlags().StringP("provider", "p", ghclient.Github, "Name of the provider, i.e. github") + ArtifactCmd.PersistentFlags().StringP("provider", "p", "", "Name of the provider, i.e. github") ArtifactCmd.PersistentFlags().StringP("project", "j", "", "ID of the project") } diff --git a/cmd/cli/app/repo/repo.go b/cmd/cli/app/repo/repo.go index 74b0d4c334..1a87398c9f 100644 --- a/cmd/cli/app/repo/repo.go +++ b/cmd/cli/app/repo/repo.go @@ -20,7 +20,6 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/minder/cmd/cli/app" - ghclient "github.com/stacklok/minder/internal/providers/github/oauth" ) // RepoCmd is the root command for the repo subcommands @@ -36,6 +35,6 @@ var RepoCmd = &cobra.Command{ func init() { app.RootCmd.AddCommand(RepoCmd) // Flags for all subcommands - RepoCmd.PersistentFlags().StringP("provider", "p", ghclient.Github, "Name of the provider, i.e. github") + RepoCmd.PersistentFlags().StringP("provider", "p", "", "Name of the provider, i.e. github") RepoCmd.PersistentFlags().StringP("project", "j", "", "ID of the project") } diff --git a/cmd/server/app/webhook_update.go b/cmd/server/app/webhook_update.go index 5c249fa832..66ad38ccf3 100644 --- a/cmd/server/app/webhook_update.go +++ b/cmd/server/app/webhook_update.go @@ -141,7 +141,10 @@ func updateGithubWebhooks( ) error { repos, err := store.ListRegisteredRepositoriesByProjectIDAndProvider(ctx, db.ListRegisteredRepositoriesByProjectIDAndProviderParams{ - Provider: provider.Name, + Provider: sql.NullString{ + String: provider.Name, + Valid: true, + }, ProjectID: provider.ProjectID, }) if err != nil { diff --git a/database/query/repositories.sql b/database/query/repositories.sql index e47655ff7f..91f28dc1a9 100644 --- a/database/query/repositories.sql +++ b/database/query/repositories.sql @@ -20,7 +20,9 @@ INSERT INTO repositories ( SELECT * FROM repositories WHERE repo_id = $1; -- name: GetRepositoryByRepoName :one -SELECT * FROM repositories WHERE provider = $1 AND repo_owner = $2 AND repo_name = $3 AND project_id = $4; +SELECT * FROM repositories + WHERE repo_owner = $1 AND repo_name = $2 AND project_id = $3 + AND lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL; -- avoid using this, where possible use GetRepositoryByIDAndProject instead -- name: GetRepositoryByID :one @@ -31,14 +33,16 @@ SELECT * FROM repositories WHERE id = $1 AND project_id = $2; -- name: ListRepositoriesByProjectID :many SELECT * FROM repositories -WHERE provider = $1 AND project_id = $2 +WHERE project_id = $1 AND (repo_id >= sqlc.narg('repo_id') OR sqlc.narg('repo_id') IS NULL) + AND lower(provider) = lower(COALESCE(sqlc.narg('provider'), provider)::text) ORDER BY project_id, provider, repo_id LIMIT sqlc.narg('limit')::bigint; -- name: ListRegisteredRepositoriesByProjectIDAndProvider :many SELECT * FROM repositories -WHERE provider = $1 AND project_id = $2 AND webhook_id IS NOT NULL +WHERE project_id = $1 AND webhook_id IS NOT NULL + AND lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL ORDER BY repo_name; -- name: DeleteRepository :exec diff --git a/internal/controlplane/common.go b/internal/controlplane/common.go index cf4a5b354f..e5e2a450cf 100644 --- a/internal/controlplane/common.go +++ b/internal/controlplane/common.go @@ -27,10 +27,13 @@ import ( "google.golang.org/grpc/status" "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/providers/github/oauth" "github.com/stacklok/minder/internal/util" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) +const defaultProvider = oauth.Github + // HasProtoContext is an interface that can be implemented by a request type HasProtoContext interface { GetContext() *pb.Context @@ -60,10 +63,10 @@ func filteredResultNotFoundError(name sql.NullString, trait db.NullProviderType) func getProviderFromRequestOrDefault( ctx context.Context, store db.Store, - in HasProtoContext, + providerName string, projectId uuid.UUID, ) (db.Provider, error) { - name := getNameFilterParam(in.GetContext()) + name := getNameFilterParam(providerName) providers, err := findProvider(ctx, name, db.NullProviderType{}, projectId, store) if err != nil { return db.Provider{}, err @@ -79,7 +82,7 @@ func getProvidersByTrait( projectId uuid.UUID, trait db.ProviderType, ) ([]db.Provider, error) { - name := getNameFilterParam(in.GetContext()) + name := getNameFilterParam(in.GetContext().GetProvider()) t := db.NullProviderType{ProviderType: trait, Valid: true} providers, err := findProvider(ctx, name, t, projectId, store) if err != nil { @@ -120,10 +123,10 @@ func findProvider( } // getNameFilterParam allows us to build a name filter for our provider queries -func getNameFilterParam(in *pb.Context) sql.NullString { +func getNameFilterParam(providerName string) sql.NullString { return sql.NullString{ - String: in.GetProvider(), - Valid: in.GetProvider() != "", + String: providerName, + Valid: providerName != "", } } diff --git a/internal/controlplane/handlers_artifacts.go b/internal/controlplane/handlers_artifacts.go index 41301a6b6d..edf349544c 100644 --- a/internal/controlplane/handlers_artifacts.go +++ b/internal/controlplane/handlers_artifacts.go @@ -39,24 +39,20 @@ import ( func (s *Server) ListArtifacts(ctx context.Context, in *pb.ListArtifactsRequest) (*pb.ListArtifactsResponse, error) { entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID - - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID) - if err != nil { - return nil, providerError(err) - } + providerName := entityCtx.Provider.Name artifactFilter, err := parseArtifactListFrom(s.store, in.From) if err != nil { return nil, fmt.Errorf("failed to parse artifact list from: %w", err) } - results, err := artifactFilter.listArtifacts(ctx, provider.Name, projectID) + results, err := artifactFilter.listArtifacts(ctx, providerName, projectID) if err != nil { return nil, fmt.Errorf("failed to list artifacts: %w", err) } // Telemetry logging - logger.BusinessRecord(ctx).Provider = provider.Name + logger.BusinessRecord(ctx).Provider = providerName logger.BusinessRecord(ctx).Project = projectID return &pb.ListArtifactsResponse{Results: results}, nil @@ -73,9 +69,10 @@ func (s *Server) GetArtifactByName(ctx context.Context, in *pb.GetArtifactByName entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID + providerFilter := getNameFilterParam(entityCtx.Provider.Name) repo, err := s.store.GetRepositoryByRepoName(ctx, db.GetRepositoryByRepoNameParams{ - Provider: in.GetContext().GetProvider(), + Provider: providerFilter, RepoOwner: nameParts[0], RepoName: nameParts[1], ProjectID: projectID, @@ -241,8 +238,10 @@ func (filter *artifactListFilter) listArtifacts(ctx context.Context, provider st func artifactListRepoFilter( ctx context.Context, store db.Store, provider string, projectID uuid.UUID, repoSlubList []string, ) ([]*db.Repository, error) { + providerFilter := getNameFilterParam(provider) + repositories, err := store.ListRegisteredRepositoriesByProjectIDAndProvider(ctx, - db.ListRegisteredRepositoriesByProjectIDAndProviderParams{Provider: provider, ProjectID: projectID}) + db.ListRegisteredRepositoriesByProjectIDAndProviderParams{Provider: providerFilter, ProjectID: projectID}) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, status.Errorf(codes.NotFound, "repositories not found") diff --git a/internal/controlplane/handlers_oauth.go b/internal/controlplane/handlers_oauth.go index 1f3c3a1961..891e84b1a2 100644 --- a/internal/controlplane/handlers_oauth.go +++ b/internal/controlplane/handlers_oauth.go @@ -41,13 +41,10 @@ import ( "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/logger" "github.com/stacklok/minder/internal/providers" - "github.com/stacklok/minder/internal/providers/github/oauth" "github.com/stacklok/minder/internal/util" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) -const defaultProvider = oauth.Github - // GetAuthorizationURL returns the URL to redirect the user to for authorization // and the state to be used for the callback. It accepts a provider string // and a boolean indicating whether the client is a CLI or web client @@ -384,8 +381,17 @@ func (s *Server) StoreProviderToken(ctx context.Context, in *pb.StoreProviderTokenRequest) (*pb.StoreProviderTokenResponse, error) { entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name + + if providerName == "" { + return nil, status.Errorf(codes.InvalidArgument, "provider name is required") + } - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID) + provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ + // We don't check parent projects here because subprojects should not update the credentials of their parent + Projects: []uuid.UUID{projectID}, + Name: providerName, + }) if err != nil { return nil, providerError(err) } @@ -450,8 +456,17 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context, in *pb.VerifyProviderTokenFromRequest) (*pb.VerifyProviderTokenFromResponse, error) { entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID) + if providerName == "" { + return nil, status.Errorf(codes.InvalidArgument, "provider name is required") + } + + provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ + // We don't check parent projects here because subprojects should not check the credentials of their parent + Projects: []uuid.UUID{projectID}, + Name: providerName, + }) if err != nil { return nil, providerError(err) } diff --git a/internal/controlplane/handlers_profile.go b/internal/controlplane/handlers_profile.go index 02ce236483..9e01fc0cc1 100644 --- a/internal/controlplane/handlers_profile.go +++ b/internal/controlplane/handlers_profile.go @@ -58,7 +58,7 @@ func (s *Server) CreateProfile(ctx context.Context, } // TODO: This will be removed once we decouple providers from profiles - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, entityCtx.Project.ID) + provider, err := getProviderFromRequestOrDefault(ctx, s.store, entityCtx.Provider.Name, entityCtx.Project.ID) if err != nil { return nil, providerError(err) } @@ -634,7 +634,7 @@ func (s *Server) UpdateProfile(ctx context.Context, } // TODO: This will be removed once we decouple providers from profiles - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, entityCtx.Project.ID) + provider, err := getProviderFromRequestOrDefault(ctx, s.store, entityCtx.Provider.Name, entityCtx.Project.ID) if err != nil { return nil, providerError(err) } diff --git a/internal/controlplane/handlers_repositories.go b/internal/controlplane/handlers_repositories.go index 0611a3a185..0bc1f53ff6 100644 --- a/internal/controlplane/handlers_repositories.go +++ b/internal/controlplane/handlers_repositories.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" "errors" + "fmt" + "slices" "strings" "github.com/google/uuid" @@ -51,10 +53,7 @@ func (s *Server) RegisterRepository( in *pb.RegisterRepositoryRequest, ) (*pb.RegisterRepositoryResponse, error) { projectID := getProjectID(ctx) - provider, client, err := s.getProviderAndClient(ctx, projectID, in) - if err != nil { - return nil, err - } + providerName := getProviderName(ctx) // Validate that the Repository struct in the request githubRepo := in.GetRepository() @@ -72,7 +71,17 @@ func (s *Server) RegisterRepository( Logger() ctx = l.WithContext(ctx) - newRepo, err := s.repos.CreateRepository(ctx, client, provider, projectID, githubRepo.GetOwner(), githubRepo.GetName()) + provider, err := s.inferProviderByOwner(ctx, githubRepo.GetOwner(), projectID, providerName) + if err != nil { + return nil, status.Errorf(codes.Internal, "cannot get provider: %v", err) + } + + client, err := s.getClientForProvider(ctx, provider) + if err != nil { + return nil, err + } + + newRepo, err := s.repos.CreateRepository(ctx, client, &provider, projectID, githubRepo.GetOwner(), githubRepo.GetName()) if err != nil { if errors.Is(err, ghrepo.ErrPrivateRepoForbidden) { return nil, util.UserVisibleError(codes.InvalidArgument, err.Error()) @@ -97,11 +106,9 @@ func (s *Server) ListRepositories(ctx context.Context, in *pb.ListRepositoriesRequest) (*pb.ListRepositoriesResponse, error) { entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID) - if err != nil { - return nil, providerError(err) - } + providerFilter := getNameFilterParam(providerName) reqRepoCursor, err := cursorutil.NewRepoCursor(in.GetCursor()) if err != nil { @@ -109,7 +116,7 @@ func (s *Server) ListRepositories(ctx context.Context, } repoId := sql.NullInt64{} - if reqRepoCursor.ProjectId == projectID.String() && reqRepoCursor.Provider == provider.Name { + if reqRepoCursor.ProjectId == projectID.String() && reqRepoCursor.Provider == providerName { repoId = sql.NullInt64{Valid: true, Int64: reqRepoCursor.RepoId} } @@ -123,7 +130,7 @@ func (s *Server) ListRepositories(ctx context.Context, } repos, err := s.store.ListRepositoriesByProjectID(ctx, db.ListRepositoriesByProjectIDParams{ - Provider: provider.Name, + Provider: providerFilter, ProjectID: projectID, RepoID: repoId, Limit: limit, @@ -153,7 +160,7 @@ func (s *Server) ListRepositories(ctx context.Context, lastRepo := repos[len(repos)-1] respRepoCursor = &cursorutil.RepoCursor{ ProjectId: projectID.String(), - Provider: provider.Name, + Provider: providerName, RepoId: lastRepo.RepoID, } @@ -165,7 +172,7 @@ func (s *Server) ListRepositories(ctx context.Context, resp.Cursor = respRepoCursor.String() // Telemetry logging - logger.BusinessRecord(ctx).Provider = provider.Name + logger.BusinessRecord(ctx).Provider = providerName logger.BusinessRecord(ctx).Project = projectID return &resp, nil @@ -220,14 +227,10 @@ func (s *Server) GetRepositoryByName(ctx context.Context, entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID - - provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID) - if err != nil { - return nil, providerError(err) - } + providerFilter := getNameFilterParam(entityCtx.Provider.Name) repo, err := s.store.GetRepositoryByRepoName(ctx, db.GetRepositoryByRepoNameParams{ - Provider: provider.Name, + Provider: providerFilter, RepoOwner: fragments[0], RepoName: fragments[1], ProjectID: projectID, @@ -264,8 +267,10 @@ func (s *Server) DeleteRepositoryById( return nil, util.UserVisibleError(codes.InvalidArgument, "invalid repository ID") } - err = s.deleteRepository(ctx, in, func(client v1.GitHub, projectID uuid.UUID, _ string) error { - return s.repos.DeleteRepositoryByID(ctx, client, projectID, parsedRepositoryID) + projectID := getProjectID(ctx) + + err = s.deleteRepository(ctx, func() (db.Repository, error) { + return s.repos.GetRepositoryById(ctx, parsedRepositoryID, projectID) }) if err != nil { return nil, err @@ -288,8 +293,11 @@ func (s *Server) DeleteRepositoryByName( return nil, util.UserVisibleError(codes.InvalidArgument, "invalid repository name, needs to have the format: owner/name") } - err := s.deleteRepository(ctx, in, func(client v1.GitHub, projectID uuid.UUID, providerName string) error { - return s.repos.DeleteRepositoryByName(ctx, client, projectID, providerName, fragments[0], fragments[1]) + projectID := getProjectID(ctx) + providerName := getProviderName(ctx) + + err := s.deleteRepository(ctx, func() (db.Repository, error) { + return s.repos.GetRepositoryByName(ctx, fragments[0], fragments[1], projectID, providerName) }) if err != nil { return nil, err @@ -401,19 +409,72 @@ func (s *Server) listRemoteRepositoriesForProvider( return results, nil } -// TODO: this probably can probably be used elsewhere -// returns project ID and github client - this flow of code exists in multiple -// places -func (s *Server) getProviderAndClient( +// covers the common logic for the two varieties of repo deletion +func (s *Server) deleteRepository( ctx context.Context, - projectID uuid.UUID, - request HasProtoContext, -) (*db.Provider, v1.GitHub, error) { - provider, err := getProviderFromRequestOrDefault(ctx, s.store, request, projectID) + repoQueryMethod func() (db.Repository, error), +) error { + projectID := getProjectID(ctx) + + repo, err := repoQueryMethod() + if errors.Is(err, sql.ErrNoRows) { + return status.Errorf(codes.NotFound, "repository not found") + } else if err != nil { + return status.Errorf(codes.Internal, "unexpected error fetching repo: %v", err) + } + + provider, err := s.findProviderByName(ctx, repo.Provider, projectID) if err != nil { - return nil, nil, providerError(err) + return status.Errorf(codes.Internal, "cannot get provider: %v", err) } + client, err := s.getClientForProvider(ctx, provider) + if err != nil { + return status.Errorf(codes.Internal, "cannot get client for provider: %v", err) + } + + err = s.repos.DeleteRepository(ctx, client, &repo) + if err != nil { + return status.Errorf(codes.Internal, "unexpected error deleting repo: %v", err) + } + return nil +} + +// inferProviderByOwner returns the provider to use for a given repo owner +func (s *Server) inferProviderByOwner(ctx context.Context, owner string, projectID uuid.UUID, providerName string, +) (db.Provider, error) { + if providerName != "" { + return s.findProviderByName(ctx, providerName, projectID) + } + opts, err := getProvidersByTrait(ctx, s.store, nil, projectID, db.ProviderTypeGithub) + if err != nil { + return db.Provider{}, fmt.Errorf("error getting providers: %v", err) + } + + slices.SortFunc(opts, func(a, b db.Provider) int { + // Sort GitHub OAuth provider after all GitHub App providers + if a.Class.ProviderClass == db.ProviderClassGithub && b.Class.ProviderClass == db.ProviderClassGithubApp { + return 1 + } + if a.Class.ProviderClass == db.ProviderClassGithubApp && b.Class.ProviderClass == db.ProviderClassGithub { + return -1 + } + return 0 + }) + + for _, prov := range opts { + if github.CanHandleOwner(ctx, prov, owner) { + return prov, nil + } + } + + return db.Provider{}, fmt.Errorf("no providers can handle repo owned by %s", owner) +} + +func (s *Server) getClientForProvider( + ctx context.Context, + provider db.Provider, +) (v1.GitHub, error) { pbOpts := []providers.ProviderBuilderOption{ providers.WithProviderMetrics(s.provMt), providers.WithRestClientCache(s.restClientCache), @@ -421,40 +482,39 @@ func (s *Server) getProviderAndClient( p, err := providers.GetProviderBuilder(ctx, provider, s.store, s.cryptoEngine, &s.cfg.Provider, pbOpts...) if err != nil { - return nil, nil, status.Errorf(codes.Internal, "cannot get provider builder: %v", err) + return nil, status.Errorf(codes.Internal, "cannot get provider builder: %v", err) } client, err := p.GetGitHub() if err != nil { - return nil, nil, status.Errorf(codes.Internal, "error creating github provider: %v", err) + return nil, status.Errorf(codes.Internal, "error creating github provider: %v", err) } - return &provider, client, nil + return client, nil } -// covers the common logic for the two varieties of repo deletion -func (s *Server) deleteRepository( - ctx context.Context, - request HasProtoContext, - deletionMethod func(v1.GitHub, uuid.UUID, string) error, -) error { - projectID := getProjectID(ctx) - provider, client, err := s.getProviderAndClient(ctx, projectID, request) +func (s *Server) findProviderByName(ctx context.Context, providerName string, projectID uuid.UUID) (db.Provider, error) { + ph, err := s.store.GetParentProjects(ctx, projectID) if err != nil { - return err + return db.Provider{}, fmt.Errorf("error getting project hierarchy: %v", err) } - err = deletionMethod(client, projectID, provider.Name) - if errors.Is(err, sql.ErrNoRows) { - return status.Errorf(codes.NotFound, "repository not found") - } else if err != nil { - return status.Errorf(codes.Internal, "unexpected error deleting repo: %v", err) + provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ + Name: providerName, + Projects: ph, + }) + if err != nil { + return db.Provider{}, fmt.Errorf("error getting provider: %v", err) } - - return nil + return provider, nil } func getProjectID(ctx context.Context) uuid.UUID { entityCtx := engine.EntityFromContext(ctx) return entityCtx.Project.ID } + +func getProviderName(ctx context.Context) string { + entityCtx := engine.EntityFromContext(ctx) + return entityCtx.Provider.Name +} diff --git a/internal/controlplane/handlers_repositories_test.go b/internal/controlplane/handlers_repositories_test.go index efd72bc8dc..60f8b5be21 100644 --- a/internal/controlplane/handlers_repositories_test.go +++ b/internal/controlplane/handlers_repositories_test.go @@ -55,7 +55,7 @@ func TestServer_RegisterRepository(t *testing.T) { RepoOwner: repoOwner, RepoName: repoName, ProviderFails: true, - ExpectedError: "cannot retrieve providers", + ExpectedError: "error getting provider", }, { Name: "Repo creation fails when repo name is missing", @@ -145,10 +145,11 @@ func TestServer_DeleteRepository(t *testing.T) { ExpectedError string }{ { - Name: "deletion fails when provider cannot be found", - RepoName: repoOwnerAndName, - ProviderFails: true, - ExpectedError: "cannot retrieve providers", + Name: "deletion fails when provider cannot be found", + RepoName: repoOwnerAndName, + RepoServiceSetup: newRepoService(withSuccessfulGetRepoByName), + ProviderFails: true, + ExpectedError: "error getting provider", }, { Name: "delete by name fails when name is malformed", @@ -163,30 +164,30 @@ func TestServer_DeleteRepository(t *testing.T) { { Name: "deletion fails when repo is not found", RepoName: repoOwnerAndName, - RepoServiceSetup: newRepoService(withFailedDeleteByName(sql.ErrNoRows)), + RepoServiceSetup: newRepoService(withFailedGetRepoByName(sql.ErrNoRows)), ExpectedError: "repository not found", }, { Name: "deletion fails when repo service returns error", RepoName: repoOwnerAndName, - RepoServiceSetup: newRepoService(withFailedDeleteByName(errDefault)), + RepoServiceSetup: newRepoService(withSuccessfulGetRepoByName, withFailedDelete(errDefault)), ExpectedError: "unexpected error deleting repo", }, { Name: "delete by ID fails when repo service returns error", RepoID: repoID, - RepoServiceSetup: newRepoService(withFailedDeleteByID(errDefault)), + RepoServiceSetup: newRepoService(withSuccessfulGetRepoById, withFailedDelete(errDefault)), ExpectedError: "unexpected error deleting repo", }, { Name: "delete by name succeeds", RepoName: repoOwnerAndName, - RepoServiceSetup: newRepoService(withSuccessfulDeleteByName), + RepoServiceSetup: newRepoService(withSuccessfulGetRepoByName, withSuccessfulDelete), }, { Name: "delete by ID succeeds", RepoID: repoID, - RepoServiceSetup: newRepoService(withSuccessfulDeleteByID), + RepoServiceSetup: newRepoService(withSuccessfulGetRepoById, withSuccessfulDelete), }, } @@ -291,35 +292,39 @@ func withFailedCreate(err error) func(repoServiceMock) { } } -func withSuccessfulDeleteByName(mock repoServiceMock) { - withFailedDeleteByName(nil)(mock) +func withSuccessfulDelete(mock repoServiceMock) { + withFailedDelete(nil)(mock) } -func withFailedDeleteByName(err error) func(repoServiceMock) { +func withFailedDelete(err error) func(repoServiceMock) { return func(mock repoServiceMock) { mock.EXPECT(). - DeleteRepositoryByName(gomock.Any(), gomock.Any(), projectID, gomock.Any(), repoOwner, repoName). + DeleteRepository(gomock.Any(), gomock.Any(), gomock.Any()). Return(err) } } -func withSuccessfulDeleteByID(mock repoServiceMock) { - withFailedDeleteByID(nil)(mock) +func withSuccessfulGetRepoById(mock repoServiceMock) { + mock.EXPECT(). + GetRepositoryById(gomock.Any(), gomock.Any(), gomock.Any()). + Return(db.Repository{}, nil) } -func withFailedDeleteByID(err error) func(repoServiceMock) { +func withSuccessfulGetRepoByName(mock repoServiceMock) { + mock.EXPECT(). + GetRepositoryByName(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(db.Repository{}, nil) +} + +func withFailedGetRepoByName(err error) func(repoServiceMock) { return func(mock repoServiceMock) { mock.EXPECT(). - DeleteRepositoryByID(gomock.Any(), gomock.Any(), projectID, gomock.Any()). - Return(err) + GetRepositoryByName(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(db.Repository{}, err) } } -func createServer( - ctrl *gomock.Controller, - repoServiceSetup repoMockBuilder, - providerFails bool, -) *Server { +func createServer(ctrl *gomock.Controller, repoServiceSetup repoMockBuilder, providerFails bool) *Server { var svc ghrepo.RepositoryService if repoServiceSetup != nil { svc = repoServiceSetup(ctrl) @@ -348,12 +353,12 @@ func createServer( if providerFails { store.EXPECT(). - FindProviders(gomock.Any(), gomock.Any()). - Return(nil, errDefault) + GetProviderByName(gomock.Any(), gomock.Any()). + Return(db.Provider{}, errDefault) } else { store.EXPECT(). - FindProviders(gomock.Any(), gomock.Any()). - Return([]db.Provider{provider}, nil).AnyTimes() + GetProviderByName(gomock.Any(), gomock.Any()). + Return(provider, nil).AnyTimes() store.EXPECT(). GetAccessTokenByProjectID(gomock.Any(), gomock.Any()). Return(db.ProviderAccessToken{ diff --git a/internal/db/repositories.sql.go b/internal/db/repositories.sql.go index 3c63d9d051..703f2f2c71 100644 --- a/internal/db/repositories.sql.go +++ b/internal/db/repositories.sql.go @@ -203,22 +203,24 @@ func (q *Queries) GetRepositoryByRepoID(ctx context.Context, repoID int64) (Repo } const getRepositoryByRepoName = `-- name: GetRepositoryByRepoName :one -SELECT id, provider, project_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, created_at, updated_at, default_branch, license, provider_id FROM repositories WHERE provider = $1 AND repo_owner = $2 AND repo_name = $3 AND project_id = $4 +SELECT id, provider, project_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, created_at, updated_at, default_branch, license, provider_id FROM repositories + WHERE repo_owner = $1 AND repo_name = $2 AND project_id = $3 + AND lower(provider) = lower($4::text) OR $4::text IS NULL ` type GetRepositoryByRepoNameParams struct { - Provider string `json:"provider"` - RepoOwner string `json:"repo_owner"` - RepoName string `json:"repo_name"` - ProjectID uuid.UUID `json:"project_id"` + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + ProjectID uuid.UUID `json:"project_id"` + Provider sql.NullString `json:"provider"` } func (q *Queries) GetRepositoryByRepoName(ctx context.Context, arg GetRepositoryByRepoNameParams) (Repository, error) { row := q.db.QueryRowContext(ctx, getRepositoryByRepoName, - arg.Provider, arg.RepoOwner, arg.RepoName, arg.ProjectID, + arg.Provider, ) var i Repository err := row.Scan( @@ -245,17 +247,18 @@ func (q *Queries) GetRepositoryByRepoName(ctx context.Context, arg GetRepository const listRegisteredRepositoriesByProjectIDAndProvider = `-- name: ListRegisteredRepositoriesByProjectIDAndProvider :many SELECT id, provider, project_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, created_at, updated_at, default_branch, license, provider_id FROM repositories -WHERE provider = $1 AND project_id = $2 AND webhook_id IS NOT NULL +WHERE project_id = $1 AND webhook_id IS NOT NULL + AND lower(provider) = lower($2::text) OR $2::text IS NULL ORDER BY repo_name ` type ListRegisteredRepositoriesByProjectIDAndProviderParams struct { - Provider string `json:"provider"` - ProjectID uuid.UUID `json:"project_id"` + ProjectID uuid.UUID `json:"project_id"` + Provider sql.NullString `json:"provider"` } func (q *Queries) ListRegisteredRepositoriesByProjectIDAndProvider(ctx context.Context, arg ListRegisteredRepositoriesByProjectIDAndProviderParams) ([]Repository, error) { - rows, err := q.db.QueryContext(ctx, listRegisteredRepositoriesByProjectIDAndProvider, arg.Provider, arg.ProjectID) + rows, err := q.db.QueryContext(ctx, listRegisteredRepositoriesByProjectIDAndProvider, arg.ProjectID, arg.Provider) if err != nil { return nil, err } @@ -297,24 +300,25 @@ func (q *Queries) ListRegisteredRepositoriesByProjectIDAndProvider(ctx context.C const listRepositoriesByProjectID = `-- name: ListRepositoriesByProjectID :many SELECT id, provider, project_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, created_at, updated_at, default_branch, license, provider_id FROM repositories -WHERE provider = $1 AND project_id = $2 - AND (repo_id >= $3 OR $3 IS NULL) +WHERE project_id = $1 + AND (repo_id >= $2 OR $2 IS NULL) + AND lower(provider) = lower(COALESCE($3, provider)::text) ORDER BY project_id, provider, repo_id LIMIT $4::bigint ` type ListRepositoriesByProjectIDParams struct { - Provider string `json:"provider"` - ProjectID uuid.UUID `json:"project_id"` - RepoID sql.NullInt64 `json:"repo_id"` - Limit sql.NullInt64 `json:"limit"` + ProjectID uuid.UUID `json:"project_id"` + RepoID sql.NullInt64 `json:"repo_id"` + Provider sql.NullString `json:"provider"` + Limit sql.NullInt64 `json:"limit"` } func (q *Queries) ListRepositoriesByProjectID(ctx context.Context, arg ListRepositoriesByProjectIDParams) ([]Repository, error) { rows, err := q.db.QueryContext(ctx, listRepositoriesByProjectID, - arg.Provider, arg.ProjectID, arg.RepoID, + arg.Provider, arg.Limit, ) if err != nil { diff --git a/internal/db/repositories_test.go b/internal/db/repositories_test.go index 6cc2ff55cc..c5c8464f4a 100644 --- a/internal/db/repositories_test.go +++ b/internal/db/repositories_test.go @@ -130,7 +130,10 @@ func TestGetRepositoryByRepoName(t *testing.T) { repo1 := createRandomRepository(t, project.ID, prov) repo2, err := testQueries.GetRepositoryByRepoName(context.Background(), GetRepositoryByRepoNameParams{ - Provider: repo1.Provider, + Provider: sql.NullString{ + String: repo1.Provider, + Valid: true, + }, RepoOwner: repo1.RepoOwner, RepoName: repo1.RepoName, ProjectID: project.ID, @@ -168,7 +171,10 @@ func TestListRepositoriesByProjectID(t *testing.T) { } arg := ListRepositoriesByProjectIDParams{ - Provider: prov.Name, + Provider: sql.NullString{ + String: prov.Name, + Valid: true, + }, ProjectID: project.ID, } diff --git a/internal/providers/github/common.go b/internal/providers/github/common.go index 49a7a52cee..396bc4a7e8 100644 --- a/internal/providers/github/common.go +++ b/internal/providers/github/common.go @@ -814,3 +814,16 @@ func IsMinderHook(hook *github.Hook, hostURL string) (bool, error) { return false, nil } + +// CanHandleOwner checks if the GitHub provider has the right credentials to handle the owner +func CanHandleOwner(_ context.Context, prov db.Provider, owner string) bool { + // TODO: this is fragile and does not handle organization renames, in the future we can make sure the credential + // has admin permissions on the owner + if prov.Name == fmt.Sprintf("%s-%s", db.ProviderClassGithubApp, owner) { + return true + } + if prov.Class.ProviderClass == db.ProviderClassGithub { + return true + } + return false +} diff --git a/internal/reconcilers/run_profile.go b/internal/reconcilers/run_profile.go index de3a15d195..b3b8c2215f 100644 --- a/internal/reconcilers/run_profile.go +++ b/internal/reconcilers/run_profile.go @@ -127,7 +127,10 @@ func (r *Reconciler) publishProfileInitEvents( ) error { dbrepos, err := r.store.ListRegisteredRepositoriesByProjectIDAndProvider(ctx, db.ListRegisteredRepositoriesByProjectIDAndProviderParams{ - Provider: ectx.Provider.Name, + Provider: sql.NullString{ + String: ectx.Provider.Name, + Valid: ectx.Provider.Name != "", + }, ProjectID: ectx.Project.ID, }) if err != nil { diff --git a/internal/repositories/github/mock/service.go b/internal/repositories/github/mock/service.go index 345e52082a..da45274a98 100644 --- a/internal/repositories/github/mock/service.go +++ b/internal/repositories/github/mock/service.go @@ -58,30 +58,46 @@ func (mr *MockRepositoryServiceMockRecorder) CreateRepository(arg0, arg1, arg2, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRepository", reflect.TypeOf((*MockRepositoryService)(nil).CreateRepository), arg0, arg1, arg2, arg3, arg4, arg5) } -// DeleteRepositoryByID mocks base method. -func (m *MockRepositoryService) DeleteRepositoryByID(arg0 context.Context, arg1 clients.GitHubRepoClient, arg2, arg3 uuid.UUID) error { +// DeleteRepository mocks base method. +func (m *MockRepositoryService) DeleteRepository(arg0 context.Context, arg1 clients.GitHubRepoClient, arg2 *db.Repository) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRepositoryByID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "DeleteRepository", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// DeleteRepositoryByID indicates an expected call of DeleteRepositoryByID. -func (mr *MockRepositoryServiceMockRecorder) DeleteRepositoryByID(arg0, arg1, arg2, arg3 any) *gomock.Call { +// DeleteRepository indicates an expected call of DeleteRepository. +func (mr *MockRepositoryServiceMockRecorder) DeleteRepository(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRepositoryByID", reflect.TypeOf((*MockRepositoryService)(nil).DeleteRepositoryByID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRepository", reflect.TypeOf((*MockRepositoryService)(nil).DeleteRepository), arg0, arg1, arg2) } -// DeleteRepositoryByName mocks base method. -func (m *MockRepositoryService) DeleteRepositoryByName(arg0 context.Context, arg1 clients.GitHubRepoClient, arg2 uuid.UUID, arg3, arg4, arg5 string) error { +// GetRepositoryById mocks base method. +func (m *MockRepositoryService) GetRepositoryById(arg0 context.Context, arg1, arg2 uuid.UUID) (db.Repository, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRepositoryByName", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "GetRepositoryById", arg0, arg1, arg2) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepositoryById indicates an expected call of GetRepositoryById. +func (mr *MockRepositoryServiceMockRecorder) GetRepositoryById(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryById", reflect.TypeOf((*MockRepositoryService)(nil).GetRepositoryById), arg0, arg1, arg2) +} + +// GetRepositoryByName mocks base method. +func (m *MockRepositoryService) GetRepositoryByName(arg0 context.Context, arg1, arg2 string, arg3 uuid.UUID, arg4 string) (db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepositoryByName", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// DeleteRepositoryByName indicates an expected call of DeleteRepositoryByName. -func (mr *MockRepositoryServiceMockRecorder) DeleteRepositoryByName(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { +// GetRepositoryByName indicates an expected call of GetRepositoryByName. +func (mr *MockRepositoryServiceMockRecorder) GetRepositoryByName(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRepositoryByName", reflect.TypeOf((*MockRepositoryService)(nil).DeleteRepositoryByName), arg0, arg1, arg2, arg3, arg4, arg5) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryByName", reflect.TypeOf((*MockRepositoryService)(nil).GetRepositoryByName), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/repositories/github/service.go b/internal/repositories/github/service.go index ccb2571e3d..dab648ac55 100644 --- a/internal/repositories/github/service.go +++ b/internal/repositories/github/service.go @@ -49,24 +49,22 @@ type RepositoryService interface { repoName string, repoOwner string, ) (*pb.Repository, error) - // DeleteRepositoryByName removes the webhook and deletes the repo from the - // database. The repo is identified by its name and project. - DeleteRepositoryByName( + // DeleteRepository removes the webhook and deletes the repo from the database. + DeleteRepository( ctx context.Context, client ghclient.GitHubRepoClient, - projectID uuid.UUID, - providerName string, - repoOwner string, - repoName string, + repo *db.Repository, ) error - // DeleteRepositoryByID removes the webhook and deletes the repo from the - // database. The repo is identified by its database ID and project. - DeleteRepositoryByID( + // GetRepositoryById retrieves a repository by its ID and project. + GetRepositoryById(ctx context.Context, repositoryID uuid.UUID, projectID uuid.UUID) (db.Repository, error) + // GetRepositoryByName retrieves a repository by its name, owner, project and provider (if specified). + GetRepositoryByName( ctx context.Context, - client ghclient.GitHubRepoClient, + repoOwner string, + repoName string, projectID uuid.UUID, - repoID uuid.UUID, - ) error + providerName string, + ) (db.Repository, error) } var ( @@ -155,45 +153,38 @@ func (r *repositoryService) CreateRepository( return pbRepo, nil } -func (r *repositoryService) DeleteRepositoryByName( +func (r *repositoryService) GetRepositoryById( ctx context.Context, - client ghclient.GitHubRepoClient, + repositoryID uuid.UUID, projectID uuid.UUID, - providerName string, - repoOwner string, - repoName string, -) error { - params := db.GetRepositoryByRepoNameParams{ - Provider: providerName, - RepoOwner: repoOwner, - RepoName: repoName, +) (db.Repository, error) { + return r.store.GetRepositoryByIDAndProject(ctx, db.GetRepositoryByIDAndProjectParams{ + ID: repositoryID, ProjectID: projectID, - } - repo, err := r.store.GetRepositoryByRepoName(ctx, params) - if err != nil { - return err - } - return r.deleteRepository(ctx, client, &repo) + }) } -func (r *repositoryService) DeleteRepositoryByID( +func (r *repositoryService) GetRepositoryByName( ctx context.Context, - client ghclient.GitHubRepoClient, + repoOwner string, + repoName string, projectID uuid.UUID, - repoID uuid.UUID, -) error { - repo, err := r.store.GetRepositoryByIDAndProject(ctx, db.GetRepositoryByIDAndProjectParams{ - ID: repoID, + providerName string, +) (db.Repository, error) { + providerFilter := sql.NullString{ + String: providerName, + Valid: providerName != "", + } + params := db.GetRepositoryByRepoNameParams{ + Provider: providerFilter, + RepoOwner: repoOwner, + RepoName: repoName, ProjectID: projectID, - }) - if err != nil { - log.Printf("error is %v", err) - return err } - return r.deleteRepository(ctx, client, &repo) + return r.store.GetRepositoryByRepoName(ctx, params) } -func (r *repositoryService) deleteRepository(ctx context.Context, client ghclient.GitHubRepoClient, repo *db.Repository) error { +func (r *repositoryService) DeleteRepository(ctx context.Context, client ghclient.GitHubRepoClient, repo *db.Repository) error { var err error // Cleanup any webhook we created for this project diff --git a/internal/repositories/github/service_test.go b/internal/repositories/github/service_test.go index 7c334822e6..6bbd6ce613 100644 --- a/internal/repositories/github/service_test.go +++ b/internal/repositories/github/service_test.go @@ -142,51 +142,62 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { ShouldSucceed bool }{ { - Name: "Delete by ID fails when repo does not exist in DB", - DBSetup: newDBMock(withNotFoundByID), - DeleteType: ByID, - }, - { - Name: "Delete by Name fails when repo does not exist in DB", - DBSetup: newDBMock(withNotFoundByName), - DeleteType: ByName, - }, - { - Name: "Delete by ID fails when webhook cannot be deleted", - DBSetup: newDBMock(withRepoByID), - WebhookSetup: newWebhookMock(withFailedWebhookDelete), - DeleteType: ByID, - }, - { - Name: "Delete by Name fails when webhook cannot be deleted", - DBSetup: newDBMock(withRepoByName), + Name: "Delete fails when webhook cannot be deleted", WebhookSetup: newWebhookMock(withFailedWebhookDelete), - DeleteType: ByName, }, { Name: "Delete by ID fails when repo cannot be deleted from DB", - DBSetup: newDBMock(withRepoByID, withFailedDelete), + DBSetup: newDBMock(withFailedDelete), WebhookSetup: newWebhookMock(withSuccessfulWebhookDelete), - DeleteType: ByID, }, { - Name: "Delete by Name fails when repo cannot be deleted from DB", - DBSetup: newDBMock(withRepoByName, withFailedDelete), - WebhookSetup: newWebhookMock(withSuccessfulWebhookDelete), - DeleteType: ByName, - }, - { - Name: "Delete by ID succeeds", - DBSetup: newDBMock(withRepoByID, withSuccessfulDelete), + Name: "Delete succeeds", + DBSetup: newDBMock(withSuccessfulDelete), WebhookSetup: newWebhookMock(withSuccessfulWebhookDelete), - DeleteType: ByID, ShouldSucceed: true, }, + } + + for i := range scenarios { + scenario := scenarios[i] + t.Run(scenario.Name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + + // For the purposes of this test, we do not need to attach any + // mock behaviour to the client. We can leave it as a nil pointer. + var ghClient clients.GitHubRepoClient + + svc := createService(ctrl, scenario.WebhookSetup, scenario.DBSetup, false) + err := svc.DeleteRepository(ctx, ghClient, &dbRepo) + + if scenario.ShouldSucceed { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestRepositoryService_GetRepositoryById(t *testing.T) { + t.Parallel() + + scenarios := []struct { + Name string + DBSetup dbMockBuilder + DeleteType DeleteCallType + ShouldSucceed bool + }{ + { + Name: "Get by ID fails when DB call fails", + DBSetup: newDBMock(withFailedGetById), + }, { - Name: "Delete by Name succeeds", - DBSetup: newDBMock(withRepoByName, withSuccessfulDelete), - WebhookSetup: newWebhookMock(withSuccessfulWebhookDelete), - DeleteType: ByName, + Name: "Get by ID succeeds", + DBSetup: newDBMock(withSuccessfulGetById), ShouldSucceed: true, }, } @@ -201,18 +212,52 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { // For the purposes of this test, we do not need to attach any // mock behaviour to the client. We can leave it as a nil pointer. - var ghClient clients.GitHubRepoClient - svc := createService(ctrl, scenario.WebhookSetup, scenario.DBSetup, false) - var err error - switch scenario.DeleteType { - case ByID: - err = svc.DeleteRepositoryByID(ctx, ghClient, projectID, repoID) - case ByName: - err = svc.DeleteRepositoryByName(ctx, ghClient, projectID, providerName, repoOwner, repoName) - default: - t.Fatalf("Unknown deletion type: %d", scenario.DeleteType) + svc := createService(ctrl, newWebhookMock(), scenario.DBSetup, false) + _, err := svc.GetRepositoryById(ctx, repoID, projectID) + + if scenario.ShouldSucceed { + require.NoError(t, err) + } else { + require.Error(t, err) } + }) + } +} + +func TestRepositoryService_GetRepositoryByName(t *testing.T) { + t.Parallel() + + scenarios := []struct { + Name string + DBSetup dbMockBuilder + DeleteType DeleteCallType + ShouldSucceed bool + }{ + { + Name: "Get by name fails when DB call fails", + DBSetup: newDBMock(withFailedGetByName), + }, + { + Name: "Get by name succeeds", + DBSetup: newDBMock(withSuccessfulGetByName), + ShouldSucceed: true, + }, + } + + for i := range scenarios { + scenario := scenarios[i] + t.Run(scenario.Name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + + // For the purposes of this test, we do not need to attach any + // mock behaviour to the client. We can leave it as a nil pointer. + + svc := createService(ctrl, newWebhookMock(), scenario.DBSetup, false) + _, err := svc.GetRepositoryByName(ctx, repoOwner, repoName, projectID, providerName) if scenario.ShouldSucceed { require.NoError(t, err) @@ -346,41 +391,40 @@ func withFailedWebhookDelete(mock whMock) { Return(errDefault) } -func withNotFoundByID(mock dbMock) { +func withFailedDelete(mock dbMock) { mock.EXPECT(). - GetRepositoryByIDAndProject(gomock.Any(), gomock.Any()). - Return(db.Repository{}, errDefault) + DeleteRepository(gomock.Any(), gomock.Eq(repoID)). + Return(errDefault) } -func withNotFoundByName(mock dbMock) { +func withSuccessfulDelete(mock dbMock) { mock.EXPECT(). - // TODO: do we want more strict assertions here? - GetRepositoryByRepoName(gomock.Any(), gomock.Any()). - Return(db.Repository{}, errDefault) + DeleteRepository(gomock.Any(), gomock.Eq(repoID)). + Return(nil) } -func withRepoByID(mock dbMock) { +func withFailedGetById(mock dbMock) { mock.EXPECT(). GetRepositoryByIDAndProject(gomock.Any(), gomock.Any()). - Return(dbRepo, nil) + Return(db.Repository{}, errDefault) } -func withRepoByName(mock dbMock) { +func withSuccessfulGetById(mock dbMock) { mock.EXPECT(). - GetRepositoryByRepoName(gomock.Any(), gomock.Any()). + GetRepositoryByIDAndProject(gomock.Any(), gomock.Any()). Return(dbRepo, nil) } -func withFailedDelete(mock dbMock) { +func withFailedGetByName(mock dbMock) { mock.EXPECT(). - DeleteRepository(gomock.Any(), gomock.Eq(repoID)). - Return(errDefault) + GetRepositoryByRepoName(gomock.Any(), gomock.Any()). + Return(db.Repository{}, errDefault) } -func withSuccessfulDelete(mock dbMock) { +func withSuccessfulGetByName(mock dbMock) { mock.EXPECT(). - DeleteRepository(gomock.Any(), gomock.Eq(repoID)). - Return(nil) + GetRepositoryByRepoName(gomock.Any(), gomock.Any()). + Return(dbRepo, nil) } func withFailedCreate(mock dbMock) {