Skip to content

Commit

Permalink
add single license scanner instance
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>
  • Loading branch information
wagoodman committed Oct 18, 2024
1 parent 56dbb34 commit 7431129
Show file tree
Hide file tree
Showing 16 changed files with 275 additions and 148 deletions.
18 changes: 18 additions & 0 deletions internal/licenses/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package licenses

import (
"context"
)

type licenseScannerKey struct{}

func SetContextLicenseScanner(ctx context.Context, s Scanner) context.Context {
return context.WithValue(ctx, licenseScannerKey{}, s)
}

func ContextLicenseScanner(ctx context.Context) Scanner {
if s, ok := ctx.Value(licenseScannerKey{}).(Scanner); ok {
return s
}
return NewDefaultScanner()
}
45 changes: 0 additions & 45 deletions internal/licenses/parser.go

This file was deleted.

66 changes: 66 additions & 0 deletions internal/licenses/scanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package licenses

import (
"context"
"io"

"github.com/google/licensecheck"

"github.com/anchore/syft/internal/log"
)

type Scanner interface {
IdentifyLicenseIDs(context.Context, io.Reader) ([]string, error)
}

var _ Scanner = (*scanner)(nil)

type scanner struct {
coverageThreshold float64 // between 0 and 100
scanner func([]byte) licensecheck.Coverage
}

// NewDefaultScanner returns a scanner that uses a new instance of the default licensecheck package scanner.
func NewDefaultScanner() Scanner {
s, err := licensecheck.NewScanner(licensecheck.BuiltinLicenses())
if err != nil {
log.WithFields("error", err).Trace("unable to create default license scanner")
s = nil
}
return &scanner{
coverageThreshold: 75, // determined by experimentation
scanner: s.Scan,
}
}

// StaticScanner returns a scanner that uses the built-in license scanner from the licensecheck package.
// THIS IS ONLY MEANT FOR TEST CODE, NOT PRODUCTION CODE.
func StaticScanner() Scanner {
return &scanner{
coverageThreshold: 75, // determined by experimentation
scanner: licensecheck.Scan,
}
}

func (s scanner) IdentifyLicenseIDs(_ context.Context, reader io.Reader) ([]string, error) {
if s.scanner == nil {
return nil, nil
}

content, err := io.ReadAll(reader)
if err != nil {
return nil, err
}

cov := s.scanner(content)
if cov.Percent < s.coverageThreshold {
// unknown or no licenses here?
return nil, nil
}

var ids []string
for _, m := range cov.Match {
ids = append(ids, m.ID)
}
return ids, nil
}
28 changes: 28 additions & 0 deletions internal/licenses/search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package licenses

import (
"context"

"github.com/anchore/syft/syft/file"
"github.com/anchore/syft/syft/license"
"github.com/anchore/syft/syft/pkg"
)

// Search scans the contents of a license file to attempt to determine the type of license it is
func Search(ctx context.Context, scanner Scanner, reader file.LocationReadCloser) (licenses []pkg.License, err error) {
licenses = make([]pkg.License, 0)

ids, err := scanner.IdentifyLicenseIDs(ctx, reader)
if err != nil {
return nil, err
}

for _, id := range ids {
lic := pkg.NewLicenseFromLocations(id, reader.Location)
lic.Type = license.Concluded

licenses = append(licenses, lic)
}

return licenses, nil
}
4 changes: 4 additions & 0 deletions syft/create_sbom.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/scylladb/go-set/strset"

"github.com/anchore/syft/internal/bus"
"github.com/anchore/syft/internal/licenses"
"github.com/anchore/syft/internal/sbomsync"
"github.com/anchore/syft/internal/task"
"github.com/anchore/syft/syft/artifact"
Expand Down Expand Up @@ -60,6 +61,9 @@ func CreateSBOM(ctx context.Context, src source.Source, cfg *CreateSBOMConfig) (
},
}

// inject a single license scanner for all package cataloging tasks into context
ctx = licenses.SetContextLicenseScanner(ctx, licenses.NewDefaultScanner())

catalogingProgress := monitorCatalogingTask(src.ID(), taskGroups)
packageCatalogingProgress := monitorPackageCatalogingTask()

Expand Down
4 changes: 2 additions & 2 deletions syft/pkg/cataloger/generic/cataloger.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (c *Cataloger) Catalog(ctx context.Context, resolver file.Resolver) ([]pkg.
var relationships []artifact.Relationship
var errs error

logger := log.Nested("cataloger", c.upstreamCataloger)
lgr := log.Nested("cataloger", c.upstreamCataloger)

env := Environment{
// TODO: consider passing into the cataloger, this would affect the cataloger interface (and all implementations). This can be deferred until later.
Expand All @@ -166,7 +166,7 @@ func (c *Cataloger) Catalog(ctx context.Context, resolver file.Resolver) ([]pkg.

log.WithFields("path", location.RealPath).Trace("parsing file contents")

discoveredPackages, discoveredRelationships, err := invokeParser(ctx, resolver, location, logger, parser, &env)
discoveredPackages, discoveredRelationships, err := invokeParser(ctx, resolver, location, lgr, parser, &env)
if err != nil {
// parsers may return errors and valid packages / relationships
errs = unknown.Append(errs, location, err)
Expand Down
30 changes: 16 additions & 14 deletions syft/pkg/cataloger/golang/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package golang
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"io/fs"
Expand Down Expand Up @@ -79,9 +80,9 @@ func remotesForModule(proxies []string, noProxy []string, module string) []strin
return proxies
}

func (c *goLicenseResolver) getLicenses(resolver file.Resolver, moduleName, moduleVersion string) ([]pkg.License, error) {
func (c *goLicenseResolver) getLicenses(ctx context.Context, scanner licenses.Scanner, resolver file.Resolver, moduleName, moduleVersion string) ([]pkg.License, error) {
// search the scan target first, ignoring local and remote sources
goLicenses, err := c.findLicensesInSource(resolver,
goLicenses, err := c.findLicensesInSource(ctx, scanner, resolver,
fmt.Sprintf(`**/go/pkg/mod/%s@%s/*`, processCaps(moduleName), moduleVersion),
)
if err != nil || len(goLicenses) > 0 {
Expand All @@ -90,21 +91,21 @@ func (c *goLicenseResolver) getLicenses(resolver file.Resolver, moduleName, modu

// look in the local host mod directory...
if c.opts.SearchLocalModCacheLicenses {
goLicenses, err = c.getLicensesFromLocal(moduleName, moduleVersion)
goLicenses, err = c.getLicensesFromLocal(ctx, scanner, moduleName, moduleVersion)
if err != nil || len(goLicenses) > 0 {
return toPkgLicenses(goLicenses), err
}
}

// download from remote sources
if c.opts.SearchRemoteLicenses {
goLicenses, err = c.getLicensesFromRemote(moduleName, moduleVersion)
goLicenses, err = c.getLicensesFromRemote(ctx, scanner, moduleName, moduleVersion)
}

return toPkgLicenses(goLicenses), err
}

func (c *goLicenseResolver) getLicensesFromLocal(moduleName, moduleVersion string) ([]goLicense, error) {
func (c *goLicenseResolver) getLicensesFromLocal(ctx context.Context, scanner licenses.Scanner, moduleName, moduleVersion string) ([]goLicense, error) {
if c.localModCacheDir == nil {
return nil, nil
}
Expand All @@ -120,10 +121,10 @@ func (c *goLicenseResolver) getLicensesFromLocal(moduleName, moduleVersion strin
// if we're running against a directory on the filesystem, it may not include the
// user's homedir / GOPATH, so we defer to using the localModCacheResolver
// we use $GOPATH/pkg/mod to avoid leaking information about the user's system
return c.findLicensesInFS("file://$GOPATH/pkg/mod/"+subdir+"/", dir)
return c.findLicensesInFS(ctx, scanner, "file://$GOPATH/pkg/mod/"+subdir+"/", dir)
}

func (c *goLicenseResolver) getLicensesFromRemote(moduleName, moduleVersion string) ([]goLicense, error) {
func (c *goLicenseResolver) getLicensesFromRemote(ctx context.Context, scanner licenses.Scanner, moduleName, moduleVersion string) ([]goLicense, error) {
return c.licenseCache.Resolve(fmt.Sprintf("%s/%s", moduleName, moduleVersion), func() ([]goLicense, error) {
proxies := remotesForModule(c.opts.Proxies, c.opts.NoProxy, moduleName)

Expand All @@ -132,11 +133,11 @@ func (c *goLicenseResolver) getLicensesFromRemote(moduleName, moduleVersion stri
return nil, err
}

return c.findLicensesInFS(urlPrefix, fsys)
return c.findLicensesInFS(ctx, scanner, urlPrefix, fsys)
})
}

func (c *goLicenseResolver) findLicensesInFS(urlPrefix string, fsys fs.FS) ([]goLicense, error) {
func (c *goLicenseResolver) findLicensesInFS(ctx context.Context, scanner licenses.Scanner, urlPrefix string, fsys fs.FS) ([]goLicense, error) {
var out []goLicense
err := fs.WalkDir(fsys, ".", func(filePath string, d fs.DirEntry, err error) error {
if err != nil {
Expand All @@ -156,7 +157,8 @@ func (c *goLicenseResolver) findLicensesInFS(urlPrefix string, fsys fs.FS) ([]go
return nil
}
defer internal.CloseAndLogError(rdr, filePath)
parsed, err := licenses.Parse(rdr, file.NewLocation(filePath))

parsed, err := licenses.Search(ctx, scanner, file.NewLocationReadCloser(file.NewLocation(filePath), rdr))
if err != nil {
log.Debugf("error parsing license file %s: %v", filePath, err)
return nil
Expand All @@ -174,15 +176,15 @@ func (c *goLicenseResolver) findLicensesInFS(urlPrefix string, fsys fs.FS) ([]go
return out, err
}

func (c *goLicenseResolver) findLicensesInSource(resolver file.Resolver, globMatch string) ([]goLicense, error) {
func (c *goLicenseResolver) findLicensesInSource(ctx context.Context, scanner licenses.Scanner, resolver file.Resolver, globMatch string) ([]goLicense, error) {
var out []goLicense
locations, err := resolver.FilesByGlob(globMatch)
if err != nil {
return nil, err
}

for _, l := range locations {
parsed, err := c.parseLicenseFromLocation(l, resolver)
parsed, err := c.parseLicenseFromLocation(ctx, scanner, l, resolver)
if err != nil {
return nil, err
}
Expand All @@ -200,7 +202,7 @@ func (c *goLicenseResolver) findLicensesInSource(resolver file.Resolver, globMat
return out, nil
}

func (c *goLicenseResolver) parseLicenseFromLocation(l file.Location, resolver file.Resolver) ([]goLicense, error) {
func (c *goLicenseResolver) parseLicenseFromLocation(ctx context.Context, scanner licenses.Scanner, l file.Location, resolver file.Resolver) ([]goLicense, error) {
var out []goLicense
fileName := path.Base(l.RealPath)
if c.lowerLicenseFileNames.Has(strings.ToLower(fileName)) {
Expand All @@ -209,7 +211,7 @@ func (c *goLicenseResolver) parseLicenseFromLocation(l file.Location, resolver f
return nil, err
}
defer internal.CloseAndLogError(contents, l.RealPath)
parsed, err := licenses.Parse(contents, l)
parsed, err := licenses.Search(ctx, scanner, file.NewLocationReadCloser(l, contents))
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7431129

Please # to comment.