diff --git a/binaries.py b/binaries.py index 0b817f63..d54220f1 100644 --- a/binaries.py +++ b/binaries.py @@ -8,15 +8,21 @@ import subprocess import sys from collections import OrderedDict +from typing import Generator, Tuple - -class environment: +class Environment: + """ + Environment class to handle the build and distribution process for different operating systems. + """ WIN = "win" LINUX = "linux" MACOS = "macos" - def __init__(self): + def __init__(self) -> None: + """ + Initialize the environment based on the BINARY_OS environment variable. + """ os_mapping = { "windows-latest": self.WIN, "ubuntu-20.04": self.LINUX, @@ -25,7 +31,13 @@ def __init__(self): self.os = os_mapping[os.getenv("BINARY_OS")] @property - def python(self): + def python(self) -> Generator[Tuple[int, str], None, None]: + """ + Generator to yield the architecture and corresponding Python executable path. + + Yields: + Generator[Tuple[int, str], None, None]: Architecture and Python executable path. + """ for arch, python in self.PYTHON_BINARIES[self.os].items(): yield arch, python @@ -49,11 +61,15 @@ def python(self): } } - def run(self, command): - """Runs the given command via subprocess.check_output. + def run(self, command: str) -> None: + """ + Runs the given command via subprocess.run. - Exits with -1 if the command wasn't successfull. + Args: + command (str): The command to run. + Exits: + Exits with -1 if the command wasn't successful. """ try: print(f"RUNNING: {command}") @@ -68,7 +84,7 @@ def run(self, command): print(e.output and e.output.decode('utf-8')) sys.exit(-1) - def install(self): + def install(self) -> None: """ Install required dependencies """ @@ -76,8 +92,10 @@ def install(self): self.run(f"{python} -m pip install pyinstaller") self.run(f"{python} -m pip install -r test_requirements.txt") - def dist(self): - """Runs Pyinstaller producing a binary for every platform arch.""" + def dist(self) -> None: + """ + Runs PyInstaller to produce a binary for every platform architecture. + """ for arch, python in self.python: # Build the binary @@ -102,9 +120,9 @@ def dist(self): else: self.run(f"cp {binary_path} {artifact_path}") - def test(self): + def test(self) -> None: """ - Runs tests for every available arch on the current platform. + Runs tests for every available architecture on the current platform. """ for arch, python in self.python: self.run(f"{python} -m pytest --log-level=DEBUG") @@ -116,7 +134,7 @@ def test(self): print("usage: binaries.py [install|test|dist]") sys.exit(-1) - env = environment() + env = Environment() # Runs the command in sys.argv[1] (install|test|dist) getattr(env, sys.argv[1])() diff --git a/safety/alerts/__init__.py b/safety/alerts/__init__.py index cf5d4f04..0304f182 100644 --- a/safety/alerts/__init__.py +++ b/safety/alerts/__init__.py @@ -1,7 +1,7 @@ import logging import sys import json -from typing import Any +from typing import Any, IO import click from dataclasses import dataclass @@ -16,6 +16,15 @@ @dataclass class Alert: + """ + Data class for storing alert details. + + Attributes: + report (Any): The report data. + key (str): The API key for the safetycli.com vulnerability database. + policy (Any): The policy data. + requirements_files (Any): The requirements files data. + """ report: Any key: str policy: Any = None @@ -29,7 +38,16 @@ class Alert: @click.option("--policy-file", type=SafetyPolicyFile(), default='.safety-policy.yml', help="Define the policy file to be used") @click.pass_context -def alert(ctx, check_report, policy_file, key): +def alert(ctx: click.Context, check_report: IO[str], policy_file: SafetyPolicyFile, key: str) -> None: + """ + Command for processing the Safety Check JSON report. + + Args: + ctx (click.Context): The Click context object. + check_report (IO[str]): The file containing the JSON report. + policy_file (SafetyPolicyFile): The policy file to be used. + key (str): The API key for the safetycli.com vulnerability database. + """ LOG.info('alert started') LOG.info(f'check_report is using stdin: {check_report == sys.stdin}') @@ -48,5 +66,6 @@ def alert(ctx, check_report, policy_file, key): ctx.obj = Alert(report=safety_report, policy=policy_file if policy_file else {}, key=key) +# Adding subcommands for GitHub integration alert.add_command(github.github_pr) alert.add_command(github.github_issue) diff --git a/safety/alerts/github.py b/safety/alerts/github.py index ddf6b66d..e480df8e 100644 --- a/safety/alerts/github.py +++ b/safety/alerts/github.py @@ -2,6 +2,7 @@ import logging import re import sys +from typing import Any, Optional import click @@ -18,12 +19,27 @@ LOG = logging.getLogger(__name__) -def create_branch(repo, base_branch, new_branch): +def create_branch(repo: Any, base_branch: str, new_branch: str) -> None: + """ + Create a new branch in the given GitHub repository. + + Args: + repo (Any): The GitHub repository object. + base_branch (str): The name of the base branch. + new_branch (str): The name of the new branch to create. + """ ref = repo.get_git_ref("heads/" + base_branch) repo.create_git_ref(ref="refs/heads/" + new_branch, sha=ref.object.sha) -def delete_branch(repo, branch): +def delete_branch(repo: Any, branch: str) -> None: + """ + Delete a branch from the given GitHub repository. + + Args: + repo (Any): The GitHub repository object. + branch (str): The name of the branch to delete. + """ ref = repo.get_git_ref(f"heads/{branch}") ref.delete() @@ -33,16 +49,23 @@ def delete_branch(repo, branch): @click.option('--base-url', help='Optional custom Base URL, if you\'re using GitHub enterprise', default=None) @click.pass_obj @utils.require_files_report -def github_pr(obj, repo, token, base_url): +def github_pr(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None: """ Create a GitHub PR to fix any vulnerabilities using Safety's remediation data. - This is usally run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results. + This is usually run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results. + + Args: + obj (Any): The Click context object containing report data. + repo (str): The GitHub repository path. + token (str): The GitHub Access Token. + base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable. """ if pygithub is None: click.secho("pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]", fg='red') sys.exit(1) + # Load alert configurations from the policy alert = obj.policy.get('alert', {}) or {} security = alert.get('security', {}) or {} config_pr = security.get('github-pr', {}) or {} @@ -55,6 +78,7 @@ def github_pr(obj, repo, token, base_url): ignore_cvss_severity_below = config_pr.get('ignore-cvss-severity-below', 0) ignore_cvss_unknown_severity = config_pr.get('ignore-cvss-unknown-severity', False) + # Authenticate with GitHub gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {})) repo_name = repo repo = gh.get_repo(repo) @@ -65,8 +89,11 @@ def github_pr(obj, repo, token, base_url): # to assuming we're running under an action self_user = "web-flow" + # Collect all remediations from the report req_remediations = list(itertools.chain.from_iterable( rem.get('requirements', {}).values() for pkg_name, rem in obj.report['remediations'].items())) + + # Get all open pull requests for the repository pulls = repo.get_pulls(state='open', sort='created', base=repo.default_branch) pending_updates = set( [f"{canonicalize_name(req_rem['requirement']['name'])}{req_rem['requirement']['specifier']}" for req_rem in req_remediations]) @@ -74,6 +101,7 @@ def github_pr(obj, repo, token, base_url): created = 0 # TODO: Refactor this loop into a fn to iterate over remediations nicely + # Iterate over all requirements files and process each remediation for name, contents in obj.requirements_files.items(): raw_contents = contents contents = contents.decode('utf-8') # TODO - encoding? @@ -84,14 +112,17 @@ def github_pr(obj, repo, token, base_url): pkg_canonical_name: str = canonicalize_name(pkg) analyzed_spec: str = remediation['requirement']['specifier'] + # Skip remediations without a recommended version if remediation['recommended_version'] is None: LOG.debug(f"The GitHub PR alerter only currently supports remediations that have a recommended_version: {pkg}") continue # We have a single remediation that can have multiple vulnerabilities + # Find all vulnerabilities associated with the remediation vulns = [x for x in obj.report['vulnerabilities'] if x['package_name'] == pkg_canonical_name and x['analyzed_requirement']['specifier'] == analyzed_spec] + # Skip if all vulnerabilities have unknown severity and the ignore flag is set if ignore_cvss_unknown_severity and all(x['severity'] is None for x in vulns): LOG.debug("All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set.") continue @@ -101,6 +132,7 @@ def github_pr(obj, repo, token, base_url): if vuln['severity'] is not None: highest_base_score = max(highest_base_score, (vuln['severity'].get('cvssv3', {}) or {}).get('base_score', 10)) + # Skip if none of the vulnerabilities meet the severity threshold if ignore_cvss_severity_below: at_least_one_match = False for vuln in vulns: @@ -116,6 +148,7 @@ def github_pr(obj, repo, token, base_url): for parsed_req in parsed_req_file.requirements: specs = SpecifierSet('>=0') if parsed_req.specs == SpecifierSet('') else parsed_req.specs + # Check if the requirement matches the remediation if canonicalize_name(parsed_req.name) == pkg_canonical_name and str(specs) == analyzed_spec: updated_contents = parsed_req.update_version(contents, remediation['recommended_version']) pending_updates.discard(f"{pkg_canonical_name}{analyzed_spec}") @@ -131,6 +164,7 @@ def github_pr(obj, repo, token, base_url): # 5. An existing PR exists, but it's not needed anymore (perhaps we've been updated to a later version) # 6. No existing PRs exist, but a branch does exist (perhaps the PR was closed but a stale branch left behind) # In any case, we only act if we've been the only committer to the branch. + # Handle various cases for existing pull requests for pr in pulls: if not pr.head.ref.startswith(branch_prefix): continue @@ -142,6 +176,7 @@ def github_pr(obj, repo, token, base_url): _, pr_pkg, pr_spec, pr_ver = pr.head.ref.split('/') except ValueError: # It's possible that something weird has manually been done, so skip that + # Skip invalid branch names LOG.debug('Found an invalid branch name on an open PR, that matches our prefix. Skipping.') continue @@ -150,7 +185,7 @@ def github_pr(obj, repo, token, base_url): if pr_pkg != pkg_canonical_name: continue - # Case 4 + # Case 4: An up-to-date PR exists if pr_pkg == pkg_canonical_name and pr_spec == analyzed_spec and pr_ver == \ remediation['recommended_version'] and pr.mergeable: LOG.debug(f"An up to date PR #{pr.number} for {pkg} was found, no action will be taken.") @@ -162,7 +197,7 @@ def github_pr(obj, repo, token, base_url): LOG.debug(f"There are other committers on the PR #{pr.number} for {pkg}. No further action will be taken.") continue - # Case 2 + # Case 2: An existing PR is out of date if pr_pkg == pkg_canonical_name and pr_spec == analyzed_spec and pr_ver != \ remediation['recommended_version']: LOG.debug(f"Closing stale PR #{pr.number} for {pkg} as a newer recommended version became") @@ -171,7 +206,7 @@ def github_pr(obj, repo, token, base_url): pr.edit(state='closed') delete_branch(repo, pr.head.ref) - # Case 3 + # Case 3: An existing PR is not mergeable if not pr.mergeable: LOG.debug(f"Closing PR #{pr.number} for {pkg} as it has become unmergable and we were the only committer") @@ -179,13 +214,16 @@ def github_pr(obj, repo, token, base_url): pr.edit(state='closed') delete_branch(repo, pr.head.ref) + # Skip if no changes were made if updated_contents == contents: LOG.debug(f"Couldn't update {pkg} to {remediation['recommended_version']}") continue + # Skip creation if indicated if skip_create: continue + # Create a new branch and commit the changes try: create_branch(repo, repo.default_branch, new_branch) except pygithub.GithubException as e: @@ -224,6 +262,7 @@ def github_pr(obj, repo, token, base_url): created += 1 + # Add assignees and labels to the PR for assignee in assignees: pr.add_to_assignees(assignee) @@ -250,11 +289,17 @@ def github_pr(obj, repo, token, base_url): @click.option('--base-url', help='Optional custom Base URL, if you\'re using GitHub enterprise', default=None) @click.pass_obj @utils.require_files_report # TODO: For now, it can be removed in the future to support env scans. -def github_issue(obj, repo, token, base_url): +def github_issue(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None: """ Create a GitHub Issue for any vulnerabilities found using PyUp's remediation data. Normally, this is run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results. + + Args: + obj (Any): The Click context object containing report data. + repo (str): The GitHub repository path. + token (str): The GitHub Access Token. + base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable. """ LOG.info(f'github_issue') @@ -262,6 +307,7 @@ def github_issue(obj, repo, token, base_url): click.secho("pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]", fg='red') sys.exit(1) + # Load alert configurations from the policy alert = obj.policy.get('alert', {}) or {} security = alert.get('security', {}) or {} config_issue = security.get('github-issue', {}) or {} @@ -274,10 +320,12 @@ def github_issue(obj, repo, token, base_url): ignore_cvss_severity_below = config_issue.get('ignore-cvss-severity-below', 0) ignore_cvss_unknown_severity = config_issue.get('ignore-cvss-unknown-severity', False) + # Authenticate with GitHub gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {})) repo_name = repo repo = gh.get_repo(repo) + # Get all open issues for the repository issues = list(repo.get_issues(state='open', sort='created')) ISSUE_TITLE_REGEX = re.escape(issue_prefix) + r"Security Vulnerability in (.+)" req_remediations = list(itertools.chain.from_iterable( @@ -285,6 +333,7 @@ def github_issue(obj, repo, token, base_url): created = 0 + # Iterate over all requirements files and process each remediation for name, contents in obj.requirements_files.items(): raw_contents = contents contents = contents.decode('utf-8') # TODO - encoding? @@ -295,13 +344,16 @@ def github_issue(obj, repo, token, base_url): pkg_canonical_name: str = canonicalize_name(pkg) analyzed_spec: str = remediation['requirement']['specifier'] + # Skip remediations without a recommended version if remediation['recommended_version'] is None: LOG.debug(f"The GitHub Issue alerter only currently supports remediations that have a recommended_version: {pkg}") continue # We have a single remediation that can have multiple vulnerabilities + # Find all vulnerabilities associated with the remediation vulns = [x for x in obj.report['vulnerabilities'] if x['package_name'] == pkg_canonical_name and x['analyzed_requirement']['specifier'] == analyzed_spec] + # Skip if all vulnerabilities have unknown severity and the ignore flag is set if ignore_cvss_unknown_severity and all(x['severity'] is None for x in vulns): LOG.debug("All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set.") continue @@ -311,6 +363,7 @@ def github_issue(obj, repo, token, base_url): if vuln['severity'] is not None: highest_base_score = max(highest_base_score, (vuln['severity'].get('cvssv3', {}) or {}).get('base_score', 10)) + # Skip if none of the vulnerabilities meet the severity threshold if ignore_cvss_severity_below: at_least_one_match = False for vuln in vulns: @@ -337,15 +390,18 @@ def github_issue(obj, repo, token, base_url): break # For now, we just skip issues if they already exist - we don't try and update them. + # Skip if an issue already exists for this remediation if skip: LOG.debug( f"An issue already exists for {pkg}{analyzed_spec} - skipping") continue + # Create a new GitHub issue pr = repo.create_issue(title=issue_prefix + utils.generate_issue_title(pkg, remediation), body=utils.generate_issue_body(pkg, remediation, vulns, api_key=obj.key)) created += 1 LOG.debug(f"Created issue to update {pkg}") + # Add assignees and labels to the issue for assignee in assignees: pr.add_to_assignees(assignee) diff --git a/safety/alerts/requirements.py b/safety/alerts/requirements.py index 400b7050..a622233c 100644 --- a/safety/alerts/requirements.py +++ b/safety/alerts/requirements.py @@ -3,6 +3,7 @@ from packaging.version import parse as parse_version from packaging.specifiers import SpecifierSet import requests +from typing import Any, Optional, Generator, Tuple, List from datetime import datetime from dparse import parse, parser, updater, filetypes @@ -11,7 +12,15 @@ class RequirementFile(object): - def __init__(self, path, content, sha=None): + """ + Class representing a requirements file with its content and metadata. + + Attributes: + path (str): The file path. + content (str): The content of the file. + sha (Optional[str]): The SHA of the file. + """ + def __init__(self, path: str, content: str, sha: Optional[str] = None): self.path = path self.content = content self.sha = sha @@ -22,7 +31,7 @@ def __init__(self, path, content, sha=None): self.is_pipfile_lock = False self.is_setup_cfg = False - def __str__(self): + def __str__(self) -> str: return "RequirementFile(path='{path}', sha='{sha}', content='{content}')".format( path=self.path, content=self.content[:30] + "[truncated]" if len(self.content) > 30 else self.content, @@ -30,52 +39,109 @@ def __str__(self): ) @property - def is_valid(self): + def is_valid(self) -> Optional[bool]: + """ + Checks if the requirements file is valid by parsing it. + + Returns: + bool: True if the file is valid, False otherwise. + """ if self._is_valid is None: self._parse() return self._is_valid @property - def requirements(self): + def requirements(self) -> Optional[List]: + """ + Returns the list of requirements parsed from the file. + + Returns: + List: The list of requirements. + """ if not self._requirements: self._parse() return self._requirements @property - def other_files(self): + def other_files(self) -> Optional[List]: + """ + Returns the list of other files resolved from the requirements file. + + Returns: + List: The list of other files. + """ if not self._other_files: self._parse() return self._other_files @staticmethod - def parse_index_server(line): + def parse_index_server(line: str) -> Optional[str]: + """ + Parses the index server from a given line. + + Args: + line (str): The line to parse. + + Returns: + str: The parsed index server. + """ return parser.Parser.parse_index_server(line) - def _hash_parser(self, line): + def _hash_parser(self, line: str) -> Optional[Tuple[str, List[str]]]: + """ + Parses the hashes from a given line. + + Args: + line (str): The line to parse. + + Returns: + List: The list of parsed hashes. + """ return parser.Parser.parse_hashes(line) - def _parse_requirements_txt(self): + def _parse_requirements_txt(self) -> None: + """ + Parses the requirements.txt file format. + """ self.parse_dependencies(filetypes.requirements_txt) - def _parse_conda_yml(self): + def _parse_conda_yml(self) -> None: + """ + Parses the conda.yml file format. + """ self.parse_dependencies(filetypes.conda_yml) - def _parse_tox_ini(self): + def _parse_tox_ini(self) -> None: + """ + Parses the tox.ini file format. + """ self.parse_dependencies(filetypes.tox_ini) - def _parse_pipfile(self): + def _parse_pipfile(self) -> None: + """ + Parses the Pipfile format. + """ self.parse_dependencies(filetypes.pipfile) self.is_pipfile = True - def _parse_pipfile_lock(self): + def _parse_pipfile_lock(self) -> None: + """ + Parses the Pipfile.lock format. + """ self.parse_dependencies(filetypes.pipfile_lock) self.is_pipfile_lock = True - def _parse_setup_cfg(self): + def _parse_setup_cfg(self) -> None: + """ + Parses the setup.cfg format. + """ self.parse_dependencies(filetypes.setup_cfg) self.is_setup_cfg = True - def _parse(self): + def _parse(self) -> None: + """ + Parses the requirements file to extract dependencies and other files. + """ self._requirements, self._other_files = [], [] if self.path.endswith('.yml') or self.path.endswith(".yaml"): self._parse_conda_yml() @@ -91,7 +157,13 @@ def _parse(self): self._parse_requirements_txt() self._is_valid = len(self._requirements) > 0 or len(self._other_files) > 0 - def parse_dependencies(self, file_type): + def parse_dependencies(self, file_type: str) -> None: + """ + Parses the dependencies from the content based on the file type. + + Args: + file_type (str): The type of the file. + """ result = parse( self.content, path=self.path, @@ -118,17 +190,47 @@ def parse_dependencies(self, file_type): self._requirements.append(req) self._other_files = result.resolved_files - def iter_lines(self, lineno=0): + def iter_lines(self, lineno: int = 0) -> Generator[str, None, None]: + """ + Iterates over lines in the content starting from a specific line number. + + Args: + lineno (int): The line number to start from. + + Yields: + str: The next line in the content. + """ for line in self.content.splitlines()[lineno:]: yield line @classmethod - def resolve_file(cls, file_path, line): + def resolve_file(cls, file_path: str, line: str) -> str: + """ + Resolves a file path from a given line. + + Args: + file_path (str): The file path to resolve. + line (str): The line containing the file path. + + Returns: + str: The resolved file path. + """ return parser.Parser.resolve_file(file_path, line) class Requirement(object): - def __init__(self, name, specs, line, lineno, extras, file_type): + """ + Class representing a single requirement. + + Attributes: + name (str): The name of the requirement. + specs (SpecifierSet): The version specifiers for the requirement. + line (str): The line containing the requirement. + lineno (int): The line number of the requirement. + extras (List): The extras for the requirement. + file_type (str): The type of the file containing the requirement. + """ + def __init__(self, name: str, specs: SpecifierSet, line: str, lineno: int, extras: List, file_type: str): self.name = name self.key = name.lower() self.specs = specs @@ -149,6 +251,7 @@ def __init__(self, name, specs, line, lineno, extras, file_type): self._is_insecure = None self._changelog = None + # Convert compatible releases to a range of versions if len(self.specs._specs) == 1 and next(iter(self.specs._specs))._spec[0] == "~=": # convert compatible releases to something more easily consumed, # e.g. '~=1.2.3' is equivalent to '>=1.2.3,<1.3.0', while '~=1.2' @@ -161,43 +264,76 @@ def __init__(self, name, specs, line, lineno, extras, file_type): self.specs = SpecifierSet('>=%s,<%s' % (min_version, max_version)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, Requirement) and self.hashCmp == other.hashCmp ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __str__(self): + def __str__(self) -> str: return "Requirement.parse({line}, {lineno})".format(line=self.line, lineno=self.lineno) - def __repr__(self): + def __repr__(self) -> str: return self.__str__() @property - def is_pinned(self): + def is_pinned(self) -> bool: + """ + Checks if the requirement is pinned to a specific version. + + Returns: + bool: True if pinned, False otherwise. + """ if len(self.specs._specs) == 1 and next(iter(self.specs._specs))._spec[0] == "==": return True return False @property - def is_open_ranged(self): + def is_open_ranged(self) -> bool: + """ + Checks if the requirement has an open range of versions. + + Returns: + bool: True if open ranged, False otherwise. + """ if len(self.specs._specs) == 1 and next(iter(self.specs._specs))._spec[0] == ">=": return True return False @property - def is_ranged(self): + def is_ranged(self) -> bool: + """ + Checks if the requirement has a range of versions. + + Returns: + bool: True if ranged, False otherwise. + """ return len(self.specs._specs) >= 1 and not self.is_pinned @property - def is_loose(self): + def is_loose(self) -> bool: + """ + Checks if the requirement has no version specifiers. + + Returns: + bool: True if loose, False otherwise. + """ return len(self.specs._specs) == 0 @staticmethod - def convert_semver(version): + def convert_semver(version: str) -> dict: + """ + Converts a version string to a semantic version dictionary. + + Args: + version (str): The version string. + + Returns: + dict: The semantic version dictionary. + """ semver = {'major': 0, "minor": 0, "patch": 0} version = version.split(".") # don't be overly clever here. repitition makes it more readable and works exactly how @@ -211,7 +347,13 @@ def convert_semver(version): return semver @property - def can_update_semver(self): + def can_update_semver(self) -> bool: + """ + Checks if the requirement can be updated based on semantic versioning rules. + + Returns: + bool: True if it can be updated, False otherwise. + """ # return early if there's no update filter set if "pyup: update" not in self.line: return True @@ -229,6 +371,12 @@ def can_update_semver(self): @property def filter(self): + """ + Returns the filter for the requirement if specified. + + Returns: + Optional[SpecifierSet]: The filter specifier set, or None if not specified. + """ rqfilter = False if "rq.filter:" in self.line: rqfilter = self.line.split("rq.filter:")[1].strip().split("#")[0] @@ -255,7 +403,13 @@ def filter(self): return False @property - def version(self): + def version(self) -> Optional[str]: + """ + Returns the current version of the requirement. + + Returns: + Optional[str]: The current version, or None if not pinned. + """ if self.is_pinned: return next(iter(self.specs._specs))._spec[1] @@ -270,7 +424,16 @@ def version(self): prereleases=self.prereleases ) - def get_hashes(self, version): + def get_hashes(self, version: str) -> List: + """ + Retrieves the hashes for a specific version from PyPI. + + Args: + version (str): The version to retrieve hashes for. + + Returns: + List: A list of hashes for the specified version. + """ r = requests.get('https://pypi.org/pypi/{name}/{version}/json'.format( name=self.key, version=version @@ -284,7 +447,18 @@ def get_hashes(self, version): hashes.append({"hash": sha256, "method": "sha256"}) return hashes - def update_version(self, content, version, update_hashes=True): + def update_version(self, content: str, version: str, update_hashes: bool = True) -> str: + """ + Updates the version of the requirement in the content. + + Args: + content (str): The original content. + version (str): The new version to update to. + update_hashes (bool): Whether to update the hashes as well. + + Returns: + str: The updated content. + """ if self.file_type == filetypes.tox_ini: updater_class = updater.ToxINIUpdater elif self.file_type == filetypes.conda_yml: @@ -322,7 +496,18 @@ def update_version(self, content, version, update_hashes=True): ) @classmethod - def parse(cls, s, lineno, file_type=filetypes.requirements_txt): + def parse(cls, s: str, lineno: int, file_type: str = filetypes.requirements_txt) -> 'Requirement': + """ + Parses a requirement from a line of text. + + Args: + s (str): The line of text. + lineno (int): The line number. + file_type (str): The type of the file containing the requirement. + + Returns: + Requirement: The parsed requirement. + """ # setuptools requires a space before the comment. If this isn't the case, add it. if "\t#" in s: parsed, = parse_requirements(s.replace("\t#", "\t #")) diff --git a/safety/alerts/utils.py b/safety/alerts/utils.py index 9fd8b15d..ccfbcc34 100644 --- a/safety/alerts/utils.py +++ b/safety/alerts/utils.py @@ -1,9 +1,8 @@ import hashlib import os import sys - from functools import wraps -from typing import Optional +from typing import Optional, List, Dict, Any from packaging.version import parse as parse_version from packaging.specifiers import SpecifierSet @@ -25,7 +24,16 @@ import requests -def highest_base_score(vulns): +def highest_base_score(vulns: List[Dict[str, Any]]) -> float: + """ + Calculates the highest CVSS base score from a list of vulnerabilities. + + Args: + vulns (List[Dict[str, Any]]): The list of vulnerabilities. + + Returns: + float: The highest CVSS base score. + """ highest_base_score = 0 for vuln in vulns: if vuln['severity'] is not None: @@ -34,15 +42,44 @@ def highest_base_score(vulns): return highest_base_score -def generate_branch_name(pkg: str, remediation): +def generate_branch_name(pkg: str, remediation: Dict[str, Any]) -> str: + """ + Generates a branch name for a given package and remediation. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + + Returns: + str: The generated branch name. + """ return f"{pkg}/{remediation['requirement']['specifier']}/{remediation['recommended_version']}" -def generate_issue_title(pkg, remediation): +def generate_issue_title(pkg: str, remediation: Dict[str, Any]) -> str: + """ + Generates an issue title for a given package and remediation. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + + Returns: + str: The generated issue title. + """ return f"Security Vulnerability in {pkg}{remediation['requirement']['specifier']}" -def get_hint(remediation): +def get_hint(remediation: Dict[str, Any]) -> str: + """ + Generates a hint for a given remediation. + + Args: + remediation (Dict[str, Any]): The remediation data. + + Returns: + str: The generated hint. + """ pinned = is_pinned_requirement(SpecifierSet(remediation['requirement']['specifier'])) hint = '' @@ -54,13 +91,36 @@ def get_hint(remediation): return hint -def generate_title(pkg, remediation, vulns): +def generate_title(pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]]) -> str: + """ + Generates a title for a pull request or issue. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + vulns (List[Dict[str, Any]]): The list of vulnerabilities. + + Returns: + str: The generated title. + """ suffix = "y" if len(vulns) == 1 else "ies" from_dependency = remediation['version'] if remediation['version'] else remediation['requirement']['specifier'] return f"Update {pkg} from {from_dependency} to {remediation['recommended_version']} to fix {len(vulns)} vulnerabilit{suffix}" -def generate_body(pkg, remediation, vulns, *, api_key): +def generate_body(pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str) -> Optional[str]: + """ + Generates the body content for a pull request. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + vulns (List[Dict[str, Any]]): The list of vulnerabilities. + api_key (str): The API key for fetching changelog data. + + Returns: + str: The generated body content. + """ changelog = fetch_changelog(pkg, remediation['version'], remediation['recommended_version'], api_key=api_key, from_spec=remediation.get('requirement', {}).get('specifier', None)) @@ -84,7 +144,19 @@ def generate_body(pkg, remediation, vulns, *, api_key): return template.render(context) -def generate_issue_body(pkg, remediation, vulns, *, api_key): +def generate_issue_body(pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str) -> Optional[str]: + """ + Generates the body content for an issue. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + vulns (List[Dict[str, Any]]): The list of vulnerabilities. + api_key (str): The API key for fetching changelog data. + + Returns: + str: The generated body content. + """ changelog = fetch_changelog(pkg, remediation['version'], remediation['recommended_version'], api_key=api_key, from_spec=remediation.get('requirement', {}).get('specifier', None)) @@ -108,17 +180,49 @@ def generate_issue_body(pkg, remediation, vulns, *, api_key): return template.render(context) -def generate_commit_message(pkg, remediation): +def generate_commit_message(pkg: str, remediation: Dict[str, Any]) -> str: + """ + Generates a commit message for a given package and remediation. + + Args: + pkg (str): The package name. + remediation (Dict[str, Any]): The remediation data. + + Returns: + str: The generated commit message. + """ from_dependency = remediation['version'] if remediation['version'] else remediation['requirement']['specifier'] return f"Update {pkg} from {from_dependency} to {remediation['recommended_version']}" -def git_sha1(raw_contents): +def git_sha1(raw_contents: bytes) -> str: + """ + Calculates the SHA-1 hash of the given raw contents. + + Args: + raw_contents (bytes): The raw contents to hash. + + Returns: + str: The SHA-1 hash. + """ return hashlib.sha1(b"blob " + str(len(raw_contents)).encode('ascii') + b"\0" + raw_contents).hexdigest() -def fetch_changelog(package, from_version: Optional[str], to_version: str, *, api_key, from_spec=None): +def fetch_changelog(package: str, from_version: Optional[str], to_version: str, *, api_key: str, from_spec: Optional[str] = None) -> Dict[str, Any]: + """ + Fetches the changelog for a package from a specified version to another version. + + Args: + package (str): The package name. + from_version (Optional[str]): The starting version. + to_version (str): The ending version. + api_key (str): The API key for fetching changelog data. + from_spec (Optional[str]): The specifier for the starting version. + + Returns: + Dict[str, Any]: The fetched changelog data. + """ to_version = parse_version(to_version) if from_version: @@ -154,6 +258,15 @@ def fetch_changelog(package, from_version: Optional[str], to_version: str, *, ap def cvss3_score_to_label(score: float) -> Optional[str]: + """ + Converts a CVSS v3 score to a severity label. + + Args: + score (float): The CVSS v3 score. + + Returns: + Optional[str]: The severity label. + """ if 0.1 <= score <= 3.9: return 'low' elif 4.0 <= score <= 6.9: @@ -168,7 +281,18 @@ def cvss3_score_to_label(score: float) -> Optional[str]: def require_files_report(func): @wraps(func) - def inner(obj, *args, **kwargs): + def inner(obj: Any, *args: Any, **kwargs: Any) -> Any: + """ + Decorator that ensures a report is generated against a file. + + Args: + obj (Any): The object containing the report. + *args (Any): Additional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + Any: The result of the decorated function. + """ if obj.report['report_meta']['scan_target'] != "files": click.secho("This report was generated against an environment, but this alert command requires " "a scan report that was generated against a file. To learn more about the " diff --git a/safety/auth/cli.py b/safety/auth/cli.py index 187f4878..8a0a2c1a 100644 --- a/safety/auth/cli.py +++ b/safety/auth/cli.py @@ -40,26 +40,37 @@ CMD_LOGOUT_NAME = "logout" DEFAULT_CMD = CMD_LOGIN_NAME -@auth_app.callback(invoke_without_command=True, - cls=SafetyCLISubGroup, - help=CLI_AUTH_COMMAND_HELP, +@auth_app.callback(invoke_without_command=True, + cls=SafetyCLISubGroup, + help=CLI_AUTH_COMMAND_HELP, epilog=DEFAULT_EPILOG, - context_settings={"allow_extra_args": True, + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @pass_safety_cli_obj -def auth(ctx: typer.Context): +def auth(ctx: typer.Context) -> None: """ - Authenticate Safety CLI with your account + Authenticate Safety CLI with your account. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('auth started') + # If no subcommand is invoked, forward to the default command if not ctx.invoked_subcommand: - default_command = get_command_for(name=DEFAULT_CMD, + default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app) return ctx.forward(default_command) -def fail_if_authenticated(ctx, with_msg: str): +def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None: + """ + Exits the command if the user is already authenticated. + + Args: + ctx (typer.Context): The Typer context object. + with_msg (str): The message to display if authenticated. + """ info = get_auth_info(ctx) if info: @@ -72,10 +83,26 @@ def fail_if_authenticated(ctx, with_msg: str): sys.exit(0) def render_email_note(auth: Auth) -> str: + """ + Renders a note indicating whether email verification is required. + + Args: + auth (Auth): The Auth object. + + Returns: + str: The rendered email note. + """ return "" if auth.email_verified else "[red](email verification required)[/red]" def render_successful_login(auth: Auth, - organization: Optional[str] = None): + organization: Optional[str] = None) -> None: + """ + Renders a message indicating a successful login. + + Args: + auth (Auth): The Auth object. + organization (Optional[str]): The organization name. + """ DEFAULT = "--" name = auth.name if auth.name else DEFAULT email = auth.email if auth.email else DEFAULT @@ -88,46 +115,52 @@ def render_successful_login(auth: Auth, details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"] if organization: - details.insert(0, + details.insert(0, "[green][bold]Organization:[/bold] " \ f"{organization}[green]") for msg in details: - console.print(Padding(msg, (0, 0, 0, 1)), emoji=True) + console.print(Padding(msg, (0, 0, 0, 1)), emoji=True) @auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP) -def login(ctx: typer.Context, headless: bool = False): +def login(ctx: typer.Context, headless: bool = False) -> None: """ Authenticate Safety CLI with your safetycli.com account using your default browser. + + Args: + ctx (typer.Context): The Typer context object. + headless (bool): Whether to run in headless mode. """ LOG.info('login started') + # Check if the user is already authenticated fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED) console.print() - + info = None - + brief_msg: str = "Redirecting your browser to log in; once authenticated, " \ "return here to start using Safety" - - if ctx.obj.auth.org: + + if ctx.obj.auth.org: console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \ "organization.") - + if headless: brief_msg = "Running in headless mode. Please copy and open the following URL in a browser" - + # Get authorization data and generate the authorization URL uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, organization=ctx.obj.auth.org, headless=headless) click.secho(brief_msg) click.echo() + # Process the browser callback to complete the authentication info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless) - + if info: if info.get("email", None): @@ -161,21 +194,25 @@ def login(ctx: typer.Context, headless: bool = False): msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \ "for more information[/red]" - + console.print(msg, emoji=True) @auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP) -def logout(ctx: typer.Context): +def logout(ctx: typer.Context) -> None: """ Log out of your current session. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('logout started') id_token = get_token('id_token') - + msg = MSG_NON_AUTHENTICATED - + if id_token: + # Clean the session if an ID token is found if clean_session(ctx.obj.auth.client): msg = MSG_LOGOUT_DONE else: @@ -190,10 +227,15 @@ def logout(ctx: typer.Context): "authentication is made.") @click.option("--login-timeout", "-w", type=int, default=600, help="Max time allowed to wait for an authentication.") -def status(ctx: typer.Context, ensure_auth: bool = False, - login_timeout: int = 600): +def status(ctx: typer.Context, ensure_auth: bool = False, + login_timeout: int = 600) -> None: """ Display Safety CLI's current authentication status. + + Args: + ctx (typer.Context): The Typer context object. + ensure_auth (bool): Whether to keep running until authentication is made. + login_timeout (int): Max time allowed to wait for authentication. """ LOG.info('status started') current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -211,21 +253,22 @@ def status(ctx: typer.Context, ensure_auth: bool = False, verified = is_email_verified(info) email_status = " [red](email not verified)[/red]" if not verified else "" - console.print(f'[green]Authenticated as {info["email"]}[/green]{email_status}') + console.print(f'[green]Authenticated as {info["email"]}[/green]{email_status}') elif ensure_auth: console.print('Safety is not authenticated. Launching default browser to log in') console.print() uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, organization=ctx.obj.auth.org, ensure_auth=ensure_auth) - - info = process_browser_callback(uri, initial_state=initial_state, - timeout=login_timeout, ctx=ctx) - + + # Process the browser callback to complete the authentication + info = process_browser_callback(uri, initial_state=initial_state, + timeout=login_timeout, ctx=ctx) + if not info: console.print(f'[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]') sys.exit(1) - + organization = None if ctx.obj.auth.org and ctx.obj.auth.org.name: organization = ctx.obj.auth.org.name @@ -238,21 +281,27 @@ def status(ctx: typer.Context, ensure_auth: bool = False, @auth_app.command(name=CMD_REGISTER_NAME) -def register(ctx: typer.Context): +def register(ctx: typer.Context) -> None: """ Create a new user account for the safetycli.com service. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('register started') + # Check if the user is already authenticated fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED) + # Get authorization data and generate the registration URL uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, sign_up=True) - + console.print("Redirecting your browser to register for a free account. Once registered, return here to start using Safety.") console.print() + # Process the browser callback to complete the registration info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx) @@ -261,4 +310,4 @@ def register(ctx: typer.Context): console.print() else: console.print('[red]Unable to register in this time, try again.[/red]') - + diff --git a/safety/auth/cli_utils.py b/safety/auth/cli_utils.py index 098bde30..cc0cca43 100644 --- a/safety/auth/cli_utils.py +++ b/safety/auth/cli_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional +from typing import Dict, Optional, Tuple, Any, Callable import click @@ -21,9 +21,20 @@ LOG = logging.getLogger(__name__) -def build_client_session(api_key=None, proxies=None, headers=None): - kwargs = {} +def build_client_session(api_key: Optional[str] = None, proxies: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None) -> Tuple[SafetyAuthSession, Dict[str, Any]]: + """ + Builds and configures the client session for authentication. + Args: + api_key (Optional[str]): The API key for authentication. + proxies (Optional[Dict[str, str]]): Proxy configuration. + headers (Optional[Dict[str, str]]): Additional headers. + + Returns: + Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration. + """ + + kwargs = {} target_proxies = proxies # Global proxy defined in the config.ini @@ -31,21 +42,21 @@ def build_client_session(api_key=None, proxies=None, headers=None): if not proxies: target_proxies = proxy_config - + def update_token(tokens, **kwargs): save_auth_config(access_token=tokens['access_token'], id_token=tokens['id_token'], refresh_token=tokens['refresh_token']) load_auth_session(click_ctx=click.get_current_context(silent=True)) - client_session = SafetyAuthSession(client_id=CLIENT_ID, + client_session = SafetyAuthSession(client_id=CLIENT_ID, code_challenge_method='S256', - redirect_uri=get_redirect_url(), + redirect_uri=get_redirect_url(), update_token=update_token, scope='openid email profile offline_access', **kwargs) - + client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter()) - + client_session.proxy_required = proxy_required client_session.proxy_timeout = proxy_timeout client_session.proxies = target_proxies @@ -57,7 +68,7 @@ def update_token(tokens, **kwargs): LOG.debug('Unable to load the openID config: %s', e) openid_config = {} - client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", + client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", None) if api_key: @@ -70,7 +81,13 @@ def update_token(tokens, **kwargs): return client_session, openid_config -def load_auth_session(click_ctx): +def load_auth_session(click_ctx: click.Context) -> None: + """ + Loads the authentication session from the context. + + Args: + click_ctx (click.Context): The Click context object. + """ if not click_ctx: LOG.warn("Click context is needed to be able to load the Auth data.") return @@ -94,63 +111,103 @@ def load_auth_session(click_ctx): print(e) clean_session(client) -def proxy_options(func): +def proxy_options(func: Callable) -> Callable: """ + Decorator that defines proxy options for Click commands. + Options defined per command, this will override the proxy settings defined in the config.ini file. + + Args: + func (Callable): The Click command function. + + Returns: + Callable: The wrapped Click command function with proxy options. """ - func = click.option("--proxy-protocol", + func = click.option("--proxy-protocol", type=click.Choice(['http', 'https']), default='https', cls=DependentOption, required_options=['proxy_host'], help=CLI_PROXY_PROTOCOL_HELP)(func) - func = click.option("--proxy-port", multiple=False, type=int, default=80, - cls=DependentOption, required_options=['proxy_host'], + func = click.option("--proxy-port", multiple=False, type=int, default=80, + cls=DependentOption, required_options=['proxy_host'], help=CLI_PROXY_PORT_HELP)(func) - func = click.option("--proxy-host", multiple=False, type=str, default=None, + func = click.option("--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP)(func) return func -def auth_options(stage=True): +def auth_options(stage: bool = True) -> Callable: + """ + Decorator that defines authentication options for Click commands. - def decorator(func): + Args: + stage (bool): Whether to include the stage option. + + Returns: + Callable: The decorator function. + """ + def decorator(func: Callable) -> Callable: func = click.option("--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP)(func) if stage: - func = click.option("--stage", default=None, envvar="SAFETY_STAGE", + func = click.option("--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP)(func) - + return func - + return decorator -def inject_session(func): +def inject_session(func: Callable) -> Callable: """ + Decorator that injects a session object into Click commands. + Builds the session object to be used in each command. + + Args: + func (Callable): The Click command function. + + Returns: + Callable: The wrapped Click command function with session injection. """ @wraps(func) - def inner(ctx, proxy_protocol: Optional[str] = None, + def inner(ctx: click.Context, proxy_protocol: Optional[str] = None, proxy_host: Optional[str] = None, proxy_port: Optional[str] = None, - key: Optional[str] = None, - stage: Optional[Stage] = None, *args, **kwargs): - + key: Optional[str] = None, + stage: Optional[Stage] = None, *args, **kwargs) -> Any: + """ + Inner function that performs the session injection. + + Args: + ctx (click.Context): The Click context object. + proxy_protocol (Optional[str]): The proxy protocol. + proxy_host (Optional[str]): The proxy host. + proxy_port (Optional[int]): The proxy port. + key (Optional[str]): The API key. + stage (Optional[Stage]): The stage. + *args (Any): Additional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + Any: The result of the decorated function. + """ + if ctx.invoked_subcommand == "configure": return - + org: Optional[Organization] = get_organization() - + if not stage: host_stage = get_host_config(key_name="stage") stage = host_stage if host_stage else Stage.development - proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, + proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, proxy_host, proxy_port) - client_session, openid_config = build_client_session(api_key=key, + client_session, openid_config = build_client_session(api_key=key, proxies=proxy_config) keys = get_keys(client_session, openid_config) diff --git a/safety/auth/main.py b/safety/auth/main.py index c96542a2..fdb9c225 100644 --- a/safety/auth/main.py +++ b/safety/auth/main.py @@ -1,25 +1,36 @@ import configparser -import json from typing import Any, Dict, Optional, Tuple, Union -from urllib.parse import urlencode from authlib.oidc.core import CodeIDToken from authlib.jose import jwt from authlib.jose.errors import ExpiredTokenError from safety.auth.models import Organization -from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH +from safety.auth.constants import CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH from safety.constants import CONFIG -from safety.errors import NotVerifiedEmailError from safety.scan.util import Stage from safety.util import get_proxy_dict def get_authorization_data(client, code_verifier: str, - organization: Optional[Organization] = None, + organization: Optional[Organization] = None, sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]: - + """ + Generate the authorization URL for the authentication process. + + Args: + client: The authentication client. + code_verifier (str): The code verifier for the PKCE flow. + organization (Optional[Organization]): The organization to authenticate with. + sign_up (bool): Whether the URL is for sign-up. + ensure_auth (bool): Whether to ensure authentication. + headless (bool): Whether to run in headless mode. + + Returns: + Tuple[str, str]: The authorization URL and initial state. + """ + kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless} if organization: kwargs['organization'] = organization.id @@ -29,12 +40,33 @@ def get_authorization_data(client, code_verifier: str, **kwargs) def get_logout_url(id_token: str) -> str: + """ + Generate the logout URL. + + Args: + id_token (str): The ID token. + + Returns: + str: The logout URL. + """ return f'{CLI_AUTH_LOGOUT}?id_token={id_token}' def get_redirect_url() -> str: + """ + Get the redirect URL for the authentication callback. + + Returns: + str: The redirect URL. + """ return CLI_CALLBACK def get_organization() -> Optional[Organization]: + """ + Retrieve the organization configuration. + + Returns: + Optional[Organization]: The organization object, or None if not configured. + """ config = configparser.ConfigParser() config.read(CONFIG) @@ -53,9 +85,18 @@ def get_organization() -> Optional[Organization]: return org -def get_auth_info(ctx): +def get_auth_info(ctx) -> Optional[Dict]: + """ + Retrieve the authentication information. + + Args: + ctx: The context object containing authentication data. + + Returns: + Optional[Dict]: The authentication information, or None if not authenticated. + """ from safety.auth.utils import is_email_verified - + info = None if ctx.obj.auth.client.token: try: @@ -67,7 +108,7 @@ def get_auth_info(ctx): verified = is_email_verified(user_info) if verified: - # refresh only if needed + # refresh only if needed raise ExpiredTokenError except ExpiredTokenError as e: @@ -80,10 +121,21 @@ def get_auth_info(ctx): clean_session(ctx.obj.auth.client) except Exception as _g: clean_session(ctx.obj.auth.client) - + return info -def get_token_data(token, keys, silent_if_expired=False) -> Optional[Dict]: +def get_token_data(token: str, keys: Any, silent_if_expired: bool = False) -> Optional[Dict]: + """ + Decode and validate the token data. + + Args: + token (str): The token to decode. + keys (Any): The keys to use for decoding. + silent_if_expired (bool): Whether to silently ignore expired tokens. + + Returns: + Optional[Dict]: The decoded token data, or None if invalid. + """ claims = jwt.decode(token, keys, claims_cls=CodeIDToken) try: claims.validate() @@ -93,10 +145,18 @@ def get_token_data(token, keys, silent_if_expired=False) -> Optional[Dict]: return claims -def get_token(name='access_token') -> Optional[str]: +def get_token(name: str = 'access_token') -> Optional[str]: """" + Retrieve a token from the local authentication configuration. + This returns tokens saved in the local auth configuration. There are two types of tokens: access_token and id_token + + Args: + name (str): The name of the token to retrieve. + + Returns: + Optional[str]: The token value, or None if not found. """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) @@ -108,13 +168,22 @@ def get_token(name='access_token') -> Optional[str]: return None -def get_host_config(key_name) -> Optional[Any]: +def get_host_config(key_name: str) -> Optional[Any]: + """ + Retrieve a configuration value from the host configuration. + + Args: + key_name (str): The name of the configuration key. + + Returns: + Optional[Any]: The configuration value, or None if not found. + """ config = configparser.ConfigParser() config.read(CONFIG) if not config.has_section("host"): return None - + host_section = dict(config.items("host")) if key_name in host_section: @@ -128,8 +197,19 @@ def get_host_config(key_name) -> Optional[Any]: return None -def str_to_bool(s): - """Convert a string to a boolean value.""" +def str_to_bool(s: str) -> bool: + """ + Convert a string to a boolean value. + + Args: + s (str): The string to convert. + + Returns: + bool: The converted boolean value. + + Raises: + ValueError: If the string cannot be converted. + """ if s.lower() == 'true' or s == '1': return True elif s.lower() == 'false' or s == '0': @@ -137,7 +217,13 @@ def str_to_bool(s): else: raise ValueError(f"Cannot convert '{s}' to a boolean value.") -def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: +def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]: + """ + Retrieve the proxy configuration. + + Returns: + Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required. + """ config = configparser.ConfigParser() config.read(CONFIG) @@ -151,7 +237,7 @@ def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: if proxy: try: - proxy_dictionary = get_proxy_dict(proxy['protocol'], proxy['host'], + proxy_dictionary = get_proxy_dict(proxy['protocol'], proxy['host'], proxy['port']) required = str_to_bool(proxy["required"]) timeout = proxy["timeout"] @@ -160,7 +246,16 @@ def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: return proxy_dictionary, timeout, required -def clean_session(client): +def clean_session(client) -> bool: + """ + Clean the authentication session. + + Args: + client: The authentication client. + + Returns: + bool: Always returns True. + """ config = configparser.ConfigParser() config['auth'] = {'access_token': '', 'id_token': '', 'refresh_token':''} @@ -171,11 +266,19 @@ def clean_session(client): return True -def save_auth_config(access_token=None, id_token=None, refresh_token=None): +def save_auth_config(access_token: Optional[str] = None, id_token: Optional[str] = None, refresh_token: Optional[str] = None) -> None: + """ + Save the authentication configuration. + + Args: + access_token (Optional[str]): The access token. + id_token (Optional[str]): The ID token. + refresh_token (Optional[str]): The refresh token. + """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) - config['auth'] = {'access_token': access_token, 'id_token': id_token, + config['auth'] = {'access_token': access_token, 'id_token': id_token, 'refresh_token': refresh_token} - + with open(AUTH_CONFIG_USER, 'w') as configfile: config.write(configfile) diff --git a/safety/auth/models.py b/safety/auth/models.py index 9966e90c..3312dedc 100644 --- a/safety/auth/models.py +++ b/safety/auth/models.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import os -from typing import Any, Optional +from typing import Any, Optional, Dict from authlib.integrations.base_client import BaseOAuth @@ -11,7 +11,13 @@ class Organization: id: str name: str - def to_dict(self): + def to_dict(self) -> Dict: + """ + Convert the Organization instance to a dictionary. + + Returns: + dict: The dictionary representation of the organization. + """ return {'id': self.id, 'name': self.name} @dataclass @@ -27,6 +33,12 @@ class Auth: email_verified: bool = False def is_valid(self) -> bool: + """ + Check if the authentication information is valid. + + Returns: + bool: True if valid, False otherwise. + """ if os.getenv("SAFETY_DB_DIR"): return True @@ -38,7 +50,13 @@ def is_valid(self) -> bool: return bool(self.client.token and self.email_verified) - def refresh_from(self, info): + def refresh_from(self, info: Dict) -> None: + """ + Refresh the authentication information from the provided info. + + Args: + info (dict): The information to refresh from. + """ from safety.auth.utils import is_email_verified self.name = info.get("name") @@ -46,9 +64,24 @@ def refresh_from(self, info): self.email_verified = is_email_verified(info) class XAPIKeyAuth(BaseOAuth): - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: + """ + Initialize the XAPIKeyAuth instance. + + Args: + api_key (str): The API key to use for authentication. + """ self.api_key = api_key - def __call__(self, r): + def __call__(self, r: Any) -> Any: + """ + Add the API key to the request headers. + + Args: + r (Any): The request object. + + Returns: + Any: The modified request object. + """ r.headers['X-API-Key'] = self.api_key return r diff --git a/safety/auth/server.py b/safety/auth/server.py index 3559c6eb..9003df9a 100644 --- a/safety/auth/server.py +++ b/safety/auth/server.py @@ -4,7 +4,7 @@ import socket import sys import time -from typing import Any, Optional +from typing import Any, Optional, Dict, Tuple import urllib.parse import threading import click @@ -14,14 +14,18 @@ from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST from safety.auth.main import save_auth_config -from authlib.integrations.base_client.errors import OAuthError from rich.prompt import Prompt LOG = logging.getLogger(__name__) -def find_available_port(): - """Find an available port on localhost""" +def find_available_port() -> Optional[int]: + """ + Find an available port on localhost within the dynamic port range (49152-65536). + + Returns: + Optional[int]: An available port number, or None if no ports are available. + """ # Dynamic ports IANA port_range = range(49152, 65536) @@ -36,7 +40,23 @@ def find_available_port(): return None -def auth_process(code: str, state: str, initial_state: str, code_verifier, client): +def auth_process(code: str, state: str, initial_state: str, code_verifier: str, client: Any) -> Any: + """ + Process the authentication callback and exchange the authorization code for tokens. + + Args: + code (str): The authorization code. + state (str): The state parameter from the callback. + initial_state (str): The initial state parameter. + code_verifier (str): The code verifier for PKCE. + client (Any): The OAuth client. + + Returns: + Any: The user information. + + Raises: + SystemExit: If there is an error during authentication. + """ err = None if initial_state is None or initial_state != state: @@ -51,15 +71,15 @@ def auth_process(code: str, state: str, initial_state: str, code_verifier, clien if err: click.secho(f'Error: {err}', fg='red') sys.exit(1) - + try: tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token', code_verifier=code_verifier, client_id=client.client_id, grant_type='authorization_code', code=code) - save_auth_config(access_token=tokens['access_token'], - id_token=tokens['id_token'], + save_auth_config(access_token=tokens['access_token'], + id_token=tokens['id_token'], refresh_token=tokens['refresh_token']) return client.fetch_user_info() @@ -68,20 +88,32 @@ def auth_process(code: str, state: str, initial_state: str, code_verifier, clien sys.exit(1) class CallbackHandler(http.server.BaseHTTPRequestHandler): - def auth(self, code: str, state: str, err, error_description): + def auth(self, code: str, state: str, err: str, error_description: str) -> None: + """ + Handle the authentication callback. + + Args: + code (str): The authorization code. + state (str): The state parameter. + err (str): The error message, if any. + error_description (str): The error description, if any. + """ initial_state = self.server.initial_state ctx = self.server.ctx - result = auth_process(code=code, - state=state, - initial_state=initial_state, + result = auth_process(code=code, + state=state, + initial_state=initial_state, code_verifier=ctx.obj.auth.code_verifier, client=ctx.obj.auth.client) - + self.server.callback = result self.do_redirect(location=CLI_AUTH_SUCCESS, params={}) - def logout(self): + def logout(self) -> None: + """ + Handle the logout callback. + """ ctx = self.server.ctx uri = CLI_LOGOUT_SUCCESS @@ -90,7 +122,10 @@ def logout(self): self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={}) - def do_GET(self): + def do_GET(self) -> None: + """ + Handle GET requests. + """ query = urllib.parse.urlparse(self.path).query params = urllib.parse.parse_qs(query) callback_type: Optional[str] = None @@ -111,22 +146,56 @@ def do_GET(self): state = params.get('state', [''])[0] err = params.get('error', [''])[0] error_description = params.get('error_description', [''])[0] - + self.auth(code=code, state=state, err=err, error_description=error_description) - def do_redirect(self, location, params): + def do_redirect(self, location: str, params: Dict) -> None: + """ + Redirect the client to the specified location. + + Args: + location (str): The URL to redirect to. + params (dict): Additional parameters for the redirection. + """ self.send_response(301) self.send_header('Location', location) self.end_headers() - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: + """ + Log an arbitrary message. + + Args: + format (str): The format string. + args (Any): Arguments for the format string. + """ LOG.info(format % args) -def process_browser_callback(uri, **kwargs) -> Any: +def process_browser_callback(uri: str, **kwargs: Any) -> Any: + """ + Process the browser callback for authentication. + + Args: + uri (str): The authorization URL. + **kwargs (Any): Additional keyword arguments. + + Returns: + Any: The user information. + + Raises: + SystemExit: If there is an error during the process. + """ class ThreadedHTTPServer(http.server.HTTPServer): - def __init__(self, server_address, RequestHandlerClass): + def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None: + """ + Initialize the ThreadedHTTPServer. + + Args: + server_address (Tuple): The server address as a tuple (host, port). + RequestHandlerClass (Any): The request handler class. + """ super().__init__(server_address, RequestHandlerClass) self.initial_state = None self.ctx = None @@ -134,6 +203,9 @@ def __init__(self, server_address, RequestHandlerClass): self.timeout_reached = False def handle_timeout(self) -> None: + """ + Handle server timeout. + """ self.timeout_reached = True return super().handle_timeout() @@ -142,7 +214,7 @@ def handle_timeout(self) -> None: if not PORT: click.secho("No available ports.") sys.exit(1) - + try: headless = kwargs.get("headless", False) initial_state = kwargs.get("initial_state", None) @@ -152,6 +224,7 @@ def handle_timeout(self) -> None: if not headless: + # Start a threaded HTTP server to handle the callback server = ThreadedHTTPServer((HOST, PORT), CallbackHandler) server.initial_state = initial_state server.timeout = kwargs.get("timeout", 600) @@ -159,14 +232,14 @@ def handle_timeout(self) -> None: server_thread = threading.Thread(target=server.handle_request) server_thread.start() message = f"If the browser does not automatically open in 5 seconds, " \ - "copy and paste this url into your browser:" + "copy and paste this url into your browser:" target = uri if headless else f"{uri}&port={PORT}" console.print(f"{message} [link={target}]{target}[/link]") console.print() if headless: - + # Handle the headless mode where user manually provides the response exchange_data = None while not exchange_data: auth_code_text = Prompt.ask("Paste the response here", default=None, console=console) @@ -175,17 +248,17 @@ def handle_timeout(self) -> None: state = exchange_data["state"] code = exchange_data["code"] except Exception as e: - code = state = None + code = state = None - return auth_process(code=code, - state=state, - initial_state=initial_state, + return auth_process(code=code, + state=state, + initial_state=initial_state, code_verifier=ctx.obj.auth.code_verifier, client=ctx.obj.auth.client) else: - + # Wait for the browser authentication in non-headless mode wait_msg = "waiting for browser authentication" - + with console.status(wait_msg, spinner="bouncingBar"): time.sleep(2) click.launch(target) diff --git a/safety/auth/utils.py b/safety/auth/utils.py index f0a93ec7..f41e8a58 100644 --- a/safety/auth/utils.py +++ b/safety/auth/utils.py @@ -1,10 +1,11 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, Dict, Callable, Tuple from authlib.integrations.requests_client import OAuth2Session from authlib.integrations.base_client.errors import OAuthError import requests from requests.adapters import HTTPAdapter + from safety.auth.constants import AUTH_SERVER_URL, CLAIM_EMAIL_VERIFIED_API, \ CLAIM_EMAIL_VERIFIED_AUTH_SERVER from safety.auth.main import get_auth_info, get_token_data @@ -21,17 +22,45 @@ LOG = logging.getLogger(__name__) -def get_keys(client_session, openid_config): +def get_keys(client_session: OAuth2Session, openid_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Retrieve the keys from the OpenID configuration. + + Args: + client_session (OAuth2Session): The OAuth2 session. + openid_config (Dict[str, Any]): The OpenID configuration. + + Returns: + Optional[Dict[str, Any]]: The keys, if available. + """ if "jwks_uri" in openid_config: return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() return None -def is_email_verified(info) -> bool: +def is_email_verified(info: Dict[str, Any]) -> Optional[bool]: + """ + Check if the email is verified. + + Args: + info (Dict[str, Any]): The user information. + + Returns: + bool: True if the email is verified, False otherwise. + """ return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(CLAIM_EMAIL_VERIFIED_AUTH_SERVER) -def parse_response(func): +def parse_response(func: Callable) -> Callable: + """ + Decorator to parse the response from an HTTP request. + + Args: + func (Callable): The function to wrap. + + Returns: + Callable: The wrapped function. + """ def wrapper(*args, **kwargs): try: r = func(*args, **kwargs) @@ -46,12 +75,12 @@ def wrapper(*args, **kwargs): raise e if r.status_code == 403: - raise InvalidCredentialError(credential="Failed authentication.", + raise InvalidCredentialError(credential="Failed authentication.", reason=r.text) if r.status_code == 429: raise TooManyRequestsError(reason=r.text) - + if r.status_code >= 400 and r.status_code < 500: error_code = None try: @@ -72,43 +101,87 @@ def wrapper(*args, **kwargs): data = r.json() except json.JSONDecodeError as e: raise SafetyError(message=f"Bad JSON response: {e}") - + return data return wrapper class SafetyAuthSession(OAuth2Session): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Initialize the SafetyAuthSession. + + Args: + *args (Any): Positional arguments for the parent class. + **kwargs (Any): Keyword arguments for the parent class. + """ super().__init__(*args, **kwargs) self.proxy_required: bool = False self.proxy_timeout: Optional[int] = None self.api_key = None def get_credential(self) -> Optional[str]: + """ + Get the current authentication credential. + + Returns: + Optional[str]: The API key, token, or None. + """ if self.api_key: return self.api_key - + if self.token: return SafetyContext().account - + return None - + def is_using_auth_credentials(self) -> bool: - """This does NOT check if the client is authenticated""" + """ + Check if the session is using authentication credentials. + + This does NOT check if the client is authenticated. + + Returns: + bool: True if using authentication credentials, False otherwise. + """ return self.get_authentication_type() != AuthenticationType.none def get_authentication_type(self) -> AuthenticationType: + """ + Get the type of authentication being used. + + Returns: + AuthenticationType: The type of authentication. + """ if self.api_key: return AuthenticationType.api_key - + if self.token: return AuthenticationType.token - + return AuthenticationType.none - def request(self, method, url, withhold_token=False, auth=None, bearer=True, **kwargs): - """Use the right auth parameter for Safety supported auth types""" + def request(self, method: str, url: str, withhold_token: bool = False, auth: Optional[Tuple] = None, bearer: bool = True, **kwargs: Any) -> requests.Response: + """ + Make an HTTP request with the appropriate authentication. + + Use the right auth parameter for Safety supported auth types. + + Args: + method (str): The HTTP method. + url (str): The URL to request. + withhold_token (bool): Whether to withhold the token. + auth (Optional[Tuple]): The authentication tuple. + bearer (bool): Whether to use bearer authentication. + **kwargs (Any): Additional keyword arguments. + + Returns: + requests.Response: The HTTP response. + + Raises: + Exception: If the request fails. + """ # By default use the token_auth TIMEOUT_KEYWARD = "timeout" func_timeout = kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT @@ -119,7 +192,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k kwargs["headers"] = key_header else: kwargs["headers"]["X-Api-Key"] = self.api_key - + if not self.token or not bearer: # Fallback to no token auth auth = () @@ -128,7 +201,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k # Override proxies if self.proxies: kwargs['proxies'] = self.proxies - + if self.proxy_timeout: kwargs['timeout'] = int(self.proxy_timeout) / 1000 @@ -140,7 +213,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k 'method': method, 'url': url, 'withhold_token': withhold_token, - 'auth': auth, + 'auth': auth, } params.update(kwargs) @@ -160,13 +233,19 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k LOG.debug(message) if message not in [a['message'] for a in SafetyContext.local_announcements]: SafetyContext.local_announcements.append({'message': message, 'type': 'warning', 'local': True}) - + return request_func(**params) - + raise e @parse_response def fetch_user_info(self) -> Any: + """ + Fetch user information from the authorization server. + + Returns: + Any: The user information. + """ USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo" r = self.get( @@ -176,13 +255,26 @@ def fetch_user_info(self) -> Any: return r @parse_response - def check_project(self, scan_stage: str, safety_source: str, + def check_project(self, scan_stage: str, safety_source: str, project_slug: Optional[str] = None, git_origin: Optional[str] = None, project_slug_source: Optional[str] = None) -> Any: - - data = {"scan_stage": scan_stage, "safety_source": safety_source, - "project_slug": project_slug, - "project_slug_source": project_slug_source, + """ + Check project information. + + Args: + scan_stage (str): The scan stage. + safety_source (str): The safety source. + project_slug (Optional[str]): The project slug. + git_origin (Optional[str]): The git origin. + project_slug_source (Optional[str]): The project slug source. + + Returns: + Any: The project information. + """ + + data = {"scan_stage": scan_stage, "safety_source": safety_source, + "project_slug": project_slug, + "project_slug_source": project_slug_source, "git_origin": git_origin} r = self.post( @@ -191,80 +283,138 @@ def check_project(self, scan_stage: str, safety_source: str, ) return r - + @parse_response def project(self, project_id: str) -> Any: + """ + Get project information. + + Args: + project_id (str): The project ID. + + Returns: + Any: The project information. + """ data = {"project": project_id} - r = self.get( + return self.get( url=PLATFORM_API_PROJECT_ENDPOINT, params=data ) - return r - @parse_response def download_policy(self, project_id: Optional[str], stage: Stage, branch: Optional[str]) -> Any: + """ + Download the project policy. + + Args: + project_id (Optional[str]): The project ID. + stage (Stage): The stage. + branch (Optional[str]): The branch. + + Returns: + Any: The policy data. + """ data = {"project": project_id, "stage": STAGE_ID_MAPPING[stage], "branch": branch} - r = self.get( + return self.get( url=PLATFORM_API_POLICY_ENDPOINT, params=data ) - return r - + @parse_response def project_scan_request(self, project_id: str) -> Any: + """ + Request a project scan. + + Args: + project_id (str): The project ID. + + Returns: + Any: The scan request result. + """ data = {"project_id": project_id} - r = self.post( + return self.post( url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data ) - return r - + @parse_response def upload_report(self, json_report: str) -> Any: + """ + Upload a scan report. + + Args: + json_report (str): The JSON report. + + Returns: + Any: The upload result. + """ headers = { "Content-Type": "application/json" - } + } - r = self.post( + return self.post( url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT, data=json_report, headers=headers ) - return r - + @parse_response - def check_updates(self, version: int, safety_version=None, - python_version=None, - os_type=None, - os_release=None, - os_description=None) -> Any: - data = {"version": version, + def check_updates(self, version: int, safety_version: Optional[str] = None, python_version: Optional[str] = None, os_type: Optional[str] = None, os_release: Optional[str] = None, os_description: Optional[str] = None) -> Any: + """ + Check for updates. + + Args: + version (int): The version. + safety_version (Optional[str]): The Safety version. + python_version (Optional[str]): The Python version. + os_type (Optional[str]): The OS type. + os_release (Optional[str]): The OS release. + os_description (Optional[str]): The OS description. + + Returns: + Any: The update check result. + """ + data = {"version": version, "safety_version": safety_version, "python_version": python_version, "os_type": os_type, "os_release": os_release, "os_description": os_description} - r = self.get( + return self.get( url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data ) - return r @parse_response def initialize_scan(self) -> Any: + """ + Initialize a scan. + + Returns: + Any: The initialization result. + """ return self.get(url=PLATFORM_API_INITIALIZE_SCAN_ENDPOINT, timeout=2) class S3PresignedAdapter(HTTPAdapter): - def send(self, request, **kwargs): + def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response: + """ + Send a request, removing the Authorization header. + + Args: + request (requests.PreparedRequest): The prepared request. + **kwargs (Any): Additional keyword arguments. + + Returns: + requests.Response: The response. + """ request.headers.pop("Authorization", None) return super().send(request, **kwargs) diff --git a/safety/cli_util.py b/safety/cli_util.py index 02692396..a478b819 100644 --- a/safety/cli_util.py +++ b/safety/cli_util.py @@ -19,15 +19,25 @@ LOG = logging.getLogger(__name__) -def get_command_for(name:str, typer_instance: typer.Typer): +def get_command_for(name: str, typer_instance: typer.Typer) -> click.Command: + """ + Retrieve a command by name from a Typer instance. + + Args: + name (str): The name of the command. + typer_instance (typer.Typer): The Typer instance. + + Returns: + click.Command: The found command. + """ single_command = next( - (command - for command in typer_instance.registered_commands + (command + for command in typer_instance.registered_commands if command.name == name), None) if not single_command: raise ValueError("Unable to find the command name.") - + single_command.context_settings = typer_instance.info.context_settings click_command = typer.main.get_command_from_info( single_command, @@ -44,7 +54,7 @@ def get_command_for(name:str, typer_instance: typer.Typer): def pass_safety_cli_obj(func): """ - Make sure the SafetyCLI object exists for a command. + Decorator to ensure the SafetyCLI object exists for a command. """ @wraps(func) def inner(ctx, *args, **kwargs): @@ -58,8 +68,16 @@ def inner(ctx, *args, **kwargs): return inner -def pretty_format_help(obj: Union[click.Command, click.Group], +def pretty_format_help(obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode) -> None: + """ + Format and print help text in a pretty format. + + Args: + obj (Union[click.Command, click.Group]): The Click command or group. + ctx (click.Context): The Click context. + markup_mode (MarkupMode): The markup mode. + """ from typer.rich_utils import _print_options_panel, _get_rich_console, \ _get_help_text, highlighter, STYLE_HELPTEXT, STYLE_USAGE_COMMAND, _print_commands_panel, \ _RICH_HELP_PANEL_NAME, ARGUMENTS_PANEL_TITLE, OPTIONS_PANEL_TITLE, \ @@ -77,7 +95,7 @@ def pretty_format_help(obj: Union[click.Command, click.Group], if obj.help: console.print() - # Print with some padding + # Print with some padding console.print( Padding( Align(_get_help_text(obj=obj, markup_mode=markup_mode), pad=False), @@ -118,7 +136,7 @@ def pretty_format_help(obj: Union[click.Command, click.Group], commands=commands, markup_mode=markup_mode, console=console, - ) + ) panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list) panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list) @@ -189,7 +207,7 @@ def pretty_format_help(obj: Union[click.Command, click.Group], ctx=ctx.parent, markup_mode=markup_mode, console=console, - ) + ) # Epilogue if we have it if obj.epilog: @@ -205,7 +223,16 @@ def print_main_command_panels(*, name: str, commands: List[click.Command], markup_mode: MarkupMode, - console): + console) -> None: + """ + Print the main command panels. + + Args: + name (str): The name of the panel. + commands (List[click.Command]): List of commands to display. + markup_mode (MarkupMode): The markup mode. + console: The Rich console. + """ from rich import box from rich.table import Table from rich.text import Text @@ -242,7 +269,7 @@ def print_main_command_panels(*, commands_table.add_column(style="bold cyan", no_wrap=True, width=column_width, max_width=column_width) commands_table.add_column(width=console_width - column_width) - + rows = [] for command in commands: @@ -274,6 +301,14 @@ def print_main_command_panels(*, # The help output for the main safety root command: `safety --help` def format_main_help(obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode) -> None: + """ + Format the main help output for the safety root command. + + Args: + obj (Union[click.Command, click.Group]): The Click command or group. + ctx (click.Context): The Click context. + markup_mode (MarkupMode): The markup mode. + """ from typer.rich_utils import _print_options_panel, _get_rich_console, \ _get_help_text, highlighter, STYLE_USAGE_COMMAND, _print_commands_panel, \ _RICH_HELP_PANEL_NAME, ARGUMENTS_PANEL_TITLE, OPTIONS_PANEL_TITLE, \ @@ -401,7 +436,15 @@ def format_main_help(obj: Union[click.Command, click.Group], console.print(Padding(Align(epilogue_text, pad=False), 1)) -def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): +def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context) -> None: + """ + Handle the process when the authentication status is not ready. + + Args: + console: The Rich console. + auth (Auth): The Auth object. + ctx (typer.Context): The Typer context. + """ from safety_schemas.models import Stage from rich.prompt import Confirm, Prompt @@ -410,10 +453,10 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): if auth.stage is Stage.development: console.print() if auth.org: - confirmed = Confirm.ask(MSG_NO_AUTHD_DEV_STG_ORG_PROMPT, choices=["Y", "N", "y", "n"], - show_choices=False, show_default=False, + confirmed = Confirm.ask(MSG_NO_AUTHD_DEV_STG_ORG_PROMPT, choices=["Y", "N", "y", "n"], + show_choices=False, show_default=False, default=True, console=console) - + if not confirmed: sys.exit(0) @@ -425,10 +468,10 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): console.print(MSG_NO_AUTHD_DEV_STG) console.print() choices = ["L", "R", "l", "r"] - next_command = Prompt.ask(MSG_NO_AUTHD_DEV_STG_PROMPT, default=None, - choices=choices, show_choices=False, + next_command = Prompt.ask(MSG_NO_AUTHD_DEV_STG_PROMPT, default=None, + choices=choices, show_choices=False, console=console) - + from safety.auth.cli import auth_app login_command = get_command_for(name='login', typer_instance=auth_app) @@ -436,8 +479,8 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): typer_instance=auth_app) if next_command is None or next_command.lower() not in choices: sys.exit(0) - - console.print() + + console.print() if next_command.lower() == "r": ctx.invoke(register_command) else: @@ -448,12 +491,12 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): else: if not auth.org: console.print(MSG_NO_AUTHD_CICD_PROD_STG_ORG.format(LOGIN_URL=CLI_AUTH)) - + else: console.print(MSG_NO_AUTHD_CICD_PROD_STG) console.print( MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL.format( - LOGIN_URL=CLI_AUTH, + LOGIN_URL=CLI_AUTH, SIGNUP_URL=f"{CLI_AUTH}/?sign_up=True")) sys.exit(1) @@ -466,16 +509,43 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context): sys.exit(1) class UtilityCommandMixin: - def __init__(self, *args, **kwargs): + """ + Mixin to add utility command functionality. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Initialize the UtilityCommandMixin. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ self.utility_command = kwargs.pop('utility_command', False) super().__init__(*args, **kwargs) class SafetyCLISubGroup(UtilityCommandMixin, TyperGroup): + """ + Custom TyperGroup with additional functionality for Safety CLI. + """ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format help message with rich formatting. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode) - def format_usage(self, ctx, formatter) -> None: + def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format usage message. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ command_path = ctx.command_path pieces = self.collect_usage_pieces(ctx) main_group = ctx.parent @@ -488,15 +558,41 @@ def command( self, *args: Any, **kwargs: Any, - ): + ) -> click.Command: + """ + Create a new command. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + click.Command: The created command. + """ super().command(*args, **kwargs) class SafetyCLICommand(UtilityCommandMixin, TyperCommand): - + """ + Custom TyperCommand with additional functionality for Safety CLI. + """ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format help message with rich formatting. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode) - def format_usage(self, ctx, formatter) -> None: + def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format usage message. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ command_path = ctx.command_path pieces = self.collect_usage_pieces(ctx) main_group = ctx.parent @@ -507,13 +603,35 @@ def format_usage(self, ctx, formatter) -> None: class SafetyCLIUtilityCommand(TyperCommand): - def __init__(self, *args, **kwargs): + """ + Custom TyperCommand designated as a utility command. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Initialize the SafetyCLIUtilityCommand. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ self.utility_command = True super().__init__(*args, **kwargs) class SafetyCLILegacyGroup(UtilityCommandMixin, click.Group): + """ + Custom Click Group to handle legacy command-line arguments. + """ def parse_legacy_args(self, args: List[str]) -> Tuple[Optional[Dict[str, str]], Optional[str]]: + """ + Parse legacy command-line arguments for proxy settings and keys. + + Args: + args (List[str]): List of command-line arguments. + + Returns: + Tuple[Optional[Dict[str, str]], Optional[str]]: Parsed proxy options and key. + """ options = { 'proxy_protocol': 'https', 'proxy_port': 80, @@ -534,7 +652,13 @@ def parse_legacy_args(self, args: List[str]) -> Tuple[Optional[Dict[str, str]], proxy = options if options['proxy_host'] else None return proxy, key - def invoke(self, ctx): + def invoke(self, ctx: click.Context) -> None: + """ + Invoke the command, handling legacy arguments. + + Args: + ctx (click.Context): Click context. + """ args = ctx.args # Workaround for legacy check options, that now are global options @@ -544,33 +668,59 @@ def invoke(self, ctx): proxy_options, key = self.parse_legacy_args(args) if proxy_options: ctx.params.update(proxy_options) - + if key: ctx.params.update({"key": key}) # Now, invoke the original behavior super(SafetyCLILegacyGroup, self).invoke(ctx) - - def format_help(self, ctx, formatter) -> None: + + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format help message with rich formatting. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ # The main `safety --help` if self.name == "cli": format_main_help(self, ctx, markup_mode="rich") # All other help outputs - else: + else: pretty_format_help(self, ctx, markup_mode="rich") class SafetyCLILegacyCommand(UtilityCommandMixin, click.Command): - def format_help(self, ctx, formatter) -> None: + """ + Custom Click Command to handle legacy command-line arguments. + """ + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """ + Format help message with rich formatting. + + Args: + ctx (click.Context): Click context. + formatter (click.HelpFormatter): Click help formatter. + """ pretty_format_help(self, ctx, markup_mode="rich") def handle_cmd_exception(func): + """ + Decorator to handle exceptions in command functions. + + Args: + func: The command function to wrap. + + Returns: + The wrapped function. + """ @wraps(func) def inner(ctx, output: Optional[ScanOutput], *args, **kwargs): if output: kwargs.update({"output": output}) - + if output is ScanOutput.NONE: return func(ctx, *args, **kwargs) @@ -584,6 +734,6 @@ def inner(ctx, output: Optional[ScanOutput], *args, **kwargs): except Exception as e: LOG.exception('Unexpected Exception happened: %s', e) exception = e if isinstance(e, SafetyException) else SafetyException(info=e) - output_exception(exception, exit_code_output=True) + output_exception(exception, exit_code_output=True) return inner \ No newline at end of file diff --git a/safety/constants.py b/safety/constants.py index 4afbd57b..4e234d9d 100644 --- a/safety/constants.py +++ b/safety/constants.py @@ -15,6 +15,12 @@ DIR_NAME = ".safety" def get_system_dir() -> Path: + """ + Get the system directory for the safety configuration. + + Returns: + Path: The system directory path. + """ import os import sys raw_dir = os.getenv("SAFETY_SYSTEM_CONFIG_PATH") @@ -34,6 +40,12 @@ def get_system_dir() -> Path: def get_user_dir() -> Path: + """ + Get the user directory for the safety configuration. + + Returns: + Path: The user directory path. + """ path = Path("~", DIR_NAME).expanduser() return path @@ -66,6 +78,15 @@ class URLSettings(Enum): def get_config_setting(name: str) -> Optional[str]: + """ + Get the configuration setting from the config file or defaults. + + Args: + name (str): The name of the setting to retrieve. + + Returns: + Optional[str]: The value of the setting if found, otherwise None. + """ config = configparser.ConfigParser() config.read(CONFIG) diff --git a/safety/errors.py b/safety/errors.py index c0ae314b..65a55652 100644 --- a/safety/errors.py +++ b/safety/errors.py @@ -14,83 +14,178 @@ class SafetyException(Exception): - - def __init__(self, message="Unhandled exception happened: {info}", info=""): + """ + Base exception for Safety CLI errors. + + Args: + message (str): The error message template. + info (str): Additional information to include in the error message. + """ + def __init__(self, message: str = "Unhandled exception happened: {info}", info: str = ""): self.message = message.format(info=info) super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this exception. + + Returns: + int: The exit code. + """ return EXIT_CODE_FAILURE class SafetyError(Exception): - - def __init__(self, message="Unhandled Safety generic error", error_code=None): + """ + Generic Safety CLI error. + + Args: + message (str): The error message. + error_code (Optional[int]): The error code. + """ + def __init__(self, message: str = "Unhandled Safety generic error", error_code: Optional[int] = None): self.message = message self.error_code = error_code super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_FAILURE class MalformedDatabase(SafetyError): - - def __init__(self, reason=None, fetched_from="server", - message="Sorry, something went wrong.\n" + - "Safety CLI can not read the data fetched from {fetched_from} because is malformed.\n"): + """ + Error raised when the vulnerability database is malformed. + + Args: + reason (Optional[str]): The reason for the error. + fetched_from (str): The source of the fetched data. + message (str): The error message template. + """ + def __init__(self, reason: Optional[str] = None, fetched_from: str = "server", + message: str = "Sorry, something went wrong.\n" + "Safety CLI cannot read the data fetched from {fetched_from} because it is malformed.\n"): + info = f"Reason, {reason}" if reason else "" info = "Reason, {reason}".format(reason=reason) self.message = message.format(fetched_from=fetched_from) + (info if reason else "") super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_MALFORMED_DB class DatabaseFetchError(SafetyError): + """ + Error raised when the vulnerability database cannot be fetched. - def __init__(self, message="Unable to load vulnerability database"): + Args: + message (str): The error message. + """ + def __init__(self, message: str = "Unable to load vulnerability database"): self.message = message super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB class InvalidProvidedReportError(SafetyError): - - def __init__(self, message="Unable to apply fix: the report needs to be generated from a file. " - "Environment isn't supported yet."): + """ + Error raised when the provided report is invalid for applying fixes. + + Args: + message (str): The error message. + """ + def __init__(self, message: str = "Unable to apply fix: the report needs to be generated from a file. " + "Environment isn't supported yet."): self.message = message super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_INVALID_PROVIDED_REPORT class InvalidRequirementError(SafetyError): - def __init__(self, message="Unable to parse the requirement: {line}", line=""): + """ + Error raised when a requirement is invalid. + + Args: + message (str): The error message template. + line (str): The invalid requirement line. + """ + def __init__(self, message: str = "Unable to parse the requirement: {line}", line: str = ""): self.message = message.format(line=line) super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_INVALID_REQUIREMENT class DatabaseFileNotFoundError(DatabaseFetchError): - - def __init__(self, db=None, message="Unable to find vulnerability database in {db}"): + """ + Error raised when the vulnerability database file is not found. + + Args: + db (Optional[str]): The database file path. + message (str): The error message template. + """ + def __init__(self, db: Optional[str] = None, message: str = "Unable to find vulnerability database in {db}"): self.db = db self.message = message.format(db=db) super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB class InvalidCredentialError(DatabaseFetchError): - - def __init__(self, credential: Optional[str] = None, message="Your authentication credential{credential}is invalid. See {link}.", reason=None): + """ + Error raised when authentication credentials are invalid. + + Args: + credential (Optional[str]): The invalid credential. + message (str): The error message template. + reason (Optional[str]): The reason for the error. + """ + + def __init__(self, credential: Optional[str] = None, + message: str = "Your authentication credential{credential}is invalid. See {link}.", + reason: Optional[str] = None): self.credential = credential self.link = 'https://docs.safetycli.com/safety-docs/support/invalid-api-key-error' self.message = message.format(credential=f" '{self.credential}' ", link=self.link) if self.credential else message.format(credential=' ', link=self.link) @@ -98,48 +193,96 @@ def __init__(self, credential: Optional[str] = None, message="Your authenticatio self.message = self.message + (info if reason else "") super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_INVALID_API_KEY class NotVerifiedEmailError(SafetyError): - def __init__(self, message="email is not verified"): + """ + Error raised when the user's email is not verified. + + Args: + message (str): The error message. + """ + def __init__(self, message: str = "email is not verified"): self.message = message super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_EMAIL_NOT_VERIFIED class TooManyRequestsError(DatabaseFetchError): - - def __init__(self, reason=None, - message="Too many requests."): + """ + Error raised when too many requests are made to the server. + + Args: + reason (Optional[str]): The reason for the error. + message (str): The error message template. + """ + def __init__(self, reason: Optional[str] = None, + message: str = "Too many requests."): info = f" Reason: {reason}" self.message = message + (info if reason else "") super().__init__(self.message) - def get_exit_code(self): + def get_exit_code(self) -> int: + """ + Get the exit code associated with this error. + + Returns: + int: The exit code. + """ return EXIT_CODE_TOO_MANY_REQUESTS class NetworkConnectionError(DatabaseFetchError): + """ + Error raised when there is a network connection issue. - def __init__(self, message="Check your network connection, unable to reach the server."): + Args: + message (str): The error message. + """ + + def __init__(self, message: str = "Check your network connection, unable to reach the server."): self.message = message super().__init__(self.message) class RequestTimeoutError(DatabaseFetchError): + """ + Error raised when a request times out. - def __init__(self, message="Check your network connection, the request timed out."): + Args: + message (str): The error message. + """ + def __init__(self, message: str = "Check your network connection, the request timed out."): self.message = message super().__init__(self.message) class ServerError(DatabaseFetchError): - - def __init__(self, reason=None, - message="Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" + - "Our engineers are working quickly to resolve the issue."): + """ + Error raised when there is a server issue. + + Args: + reason (Optional[str]): The reason for the error. + message (str): The error message template. + """ + def __init__(self, reason: Optional[str] = None, + message: str = "Sorry, something went wrong.\n" + "Safety CLI cannot connect to the server.\n" + "Our engineers are working quickly to resolve the issue."): info = f" Reason: {reason}" self.message = message + (info if reason else "") super().__init__(self.message) diff --git a/safety/formatter.py b/safety/formatter.py index 9e39de9c..f70b4de0 100644 --- a/safety/formatter.py +++ b/safety/formatter.py @@ -1,5 +1,6 @@ import logging from abc import ABCMeta, abstractmethod +from typing import Any, Dict, List, Tuple, Union, Optional NOT_IMPLEMENTED = "You should implement this." @@ -8,45 +9,75 @@ class FormatterAPI: """ - Strategy Abstract class, with all the render methods that the concrete implementations should support + Strategy Abstract class, with all the render methods that the concrete implementations should support. """ __metaclass__ = ABCMeta - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """ - Dummy + Dummy initializer for the FormatterAPI class. """ pass @abstractmethod - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): + def render_vulnerabilities(self, announcements: List[Dict[str, Any]], vulnerabilities: List[Dict[str, Any]], remediations: Dict[str, Any], full: bool, packages: List[Dict[str, Any]], fixes: Tuple = ()) -> Optional[str]: + """ + Render the vulnerabilities report. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + vulnerabilities (List[Dict[str, Any]]): List of vulnerabilities. + remediations (Dict[str, Any]): Dictionary of remediations. + full (bool): Whether to render a full report. + packages (List[Dict[str, Any]]): List of packages. + fixes (Tuple, optional): Tuple of fixes. Defaults to (). + + Returns: + Optional[str]: Rendered vulnerabilities report. + """ raise NotImplementedError(NOT_IMPLEMENTED) # pragma: no cover @abstractmethod - def render_licenses(self, announcements, licenses): - raise NotImplementedError(NOT_IMPLEMENTED) # pragma: no cover + def render_licenses(self, announcements: List[Dict[str, Any]], licenses: List[Dict[str, Any]]) -> Optional[str]: + """ + Render the licenses report. - @abstractmethod - def render_announcements(self, announcements): + Args: + announcements (List[Dict[str, Any]]): List of announcements. + licenses (List[Dict[str, Any]]): List of licenses. + + Returns: + Optional[str]: Rendered licenses report. + """ raise NotImplementedError(NOT_IMPLEMENTED) # pragma: no cover + @abstractmethod + def render_announcements(self, announcements: List[Dict[str, Any]]) -> Optional[str]: + """ + Render the announcements. -class SafetyFormatter(FormatterAPI): + Args: + announcements (List[Dict[str, Any]]): List of announcements. - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): - LOG.info('Safety is going to render_vulnerabilities with format: %s', self.format) - return self.format.render_vulnerabilities(announcements, vulnerabilities, remediations, full, packages, fixes) + Returns: + Optional[str]: Rendered announcements. + """ + raise NotImplementedError(NOT_IMPLEMENTED) # pragma: no cover - def render_licenses(self, announcements, licenses): - LOG.info('Safety is going to render_licenses with format: %s', self.format) - return self.format.render_licenses(announcements, licenses) - def render_announcements(self, announcements): - LOG.info('Safety is going to render_announcements with format: %s', self.format) - return self.format.render_announcements(announcements) +class SafetyFormatter(FormatterAPI): + """ + Formatter class that implements the FormatterAPI to render reports in various formats. + """ + def __init__(self, output: str, **kwargs: Any) -> None: + """ + Initialize the SafetyFormatter with the specified output format. - def __init__(self, output, **kwargs): + Args: + output (str): The output format (e.g., 'json', 'html', 'bare', 'text'). + **kwargs: Additional keyword arguments. + """ from safety.formatters.screen import ScreenReport from safety.formatters.text import TextReport from safety.formatters.json import JsonReport @@ -63,3 +94,48 @@ def __init__(self, output, **kwargs): self.format = BareReport(**kwargs) elif output == 'text': self.format = TextReport(**kwargs) + + def render_vulnerabilities(self, announcements: List[Dict[str, Any]], vulnerabilities: List[Dict[str, Any]], remediations: Dict[str, Any], full: bool, packages: List[Dict[str, Any]], fixes: Tuple = ()) -> Optional[str]: + """ + Render the vulnerabilities report. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + vulnerabilities (List[Dict[str, Any]]): List of vulnerabilities. + remediations (Dict[str, Any]): Dictionary of remediations. + full (bool): Whether to render a full report. + packages (List[Dict[str, Any]]): List of packages. + fixes (Tuple, optional): Tuple of fixes. Defaults to (). + + Returns: + Optional[str]: Rendered vulnerabilities report. + """ + LOG.info('Safety is going to render_vulnerabilities with format: %s', self.format) + return self.format.render_vulnerabilities(announcements, vulnerabilities, remediations, full, packages, fixes) + + def render_licenses(self, announcements: List[Dict[str, Any]], licenses: List[Dict[str, Any]]) -> Optional[str]: + """ + Render the licenses report. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + licenses (List[Dict[str, Any]]): List of licenses. + + Returns: + Optional[str]: Rendered licenses report. + """ + LOG.info('Safety is going to render_licenses with format: %s', self.format) + return self.format.render_licenses(announcements, licenses) + + def render_announcements(self, announcements: List[Dict[str, Any]]): + """ + Render the announcements. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + + Returns: + Optional[str]: Rendered announcements. + """ + LOG.info('Safety is going to render_announcements with format: %s', self.format) + return self.format.render_announcements(announcements) diff --git a/safety/formatters/bare.py b/safety/formatters/bare.py index b1e38066..e1ba7b4b 100644 --- a/safety/formatters/bare.py +++ b/safety/formatters/bare.py @@ -1,15 +1,32 @@ from collections import namedtuple +from typing import List, Dict, Any, Optional, Tuple from safety.formatter import FormatterAPI from safety.util import get_basic_announcements class BareReport(FormatterAPI): - """Bare report, for command line tools""" + """ + Bare report, for command line tools. + """ - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): - parsed_announcements = [] + def render_vulnerabilities(self, announcements: List[Dict[str, Any]], vulnerabilities: List[Any], + remediations: Any, full: bool, packages: List[Any], fixes: Tuple = ()) -> str: + """ + Renders vulnerabilities in a bare format. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + vulnerabilities (List[Any]): List of vulnerabilities. + remediations (Any): Remediation data. + full (bool): Flag indicating full output. + packages (List[Any]): List of packages. + fixes (Tuple, optional): Tuple of fixes. + Returns: + str: Rendered vulnerabilities. + """ + parsed_announcements = [] Announcement = namedtuple("Announcement", ["name"]) for announcement in get_basic_announcements(announcements, include_local=False): @@ -21,7 +38,17 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f return " ".join(announcements_to_render + affected_packages) - def render_licenses(self, announcements, packages_licenses): + def render_licenses(self, announcements: List[Dict[str, Any]], packages_licenses: List[Dict[str, Any]]) -> str: + """ + Renders licenses in a bare format. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + packages_licenses (List[Dict[str, Any]]): List of package licenses. + + Returns: + str: Rendered licenses. + """ parsed_announcements = [] for announcement in get_basic_announcements(announcements): @@ -34,5 +61,11 @@ def render_licenses(self, announcements, packages_licenses): sorted_licenses = sorted(licenses) return " ".join(announcements_to_render + sorted_licenses) - def render_announcements(self, announcements): + def render_announcements(self, announcements: List[Dict[str, Any]]) -> None: + """ + Renders announcements in a bare format. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + """ print('render_announcements bare') diff --git a/safety/formatters/html.py b/safety/formatters/html.py index ade31056..bfefb1f4 100644 --- a/safety/formatters/html.py +++ b/safety/formatters/html.py @@ -1,4 +1,6 @@ import logging +from typing import List, Dict, Tuple, Optional + from safety.formatter import FormatterAPI from safety.formatters.json import build_json_report @@ -9,9 +11,26 @@ class HTMLReport(FormatterAPI): - """HTML report, for when the output is input for something else""" - - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): + """ + HTML report formatter for when the output is input for something else. + """ + + def render_vulnerabilities(self, announcements: List[Dict], vulnerabilities: List[Dict], remediations: Dict, + full: bool, packages: List[Dict], fixes: Tuple = ()) -> Optional[str]: + """ + Renders vulnerabilities in HTML format. + + Args: + announcements (List[Dict]): List of announcements. + vulnerabilities (List[Dict]): List of vulnerabilities. + remediations (Dict): Remediation data. + full (bool): Flag indicating full output. + packages (List[Dict]): List of packages. + fixes (Tuple, optional): Tuple of fixes. + + Returns: + str: Rendered HTML vulnerabilities report. + """ LOG.debug( f'HTML Output, Rendering {len(vulnerabilities)} vulnerabilities, {len(remediations)} package ' f'remediations with full_report: {full}') @@ -19,8 +38,21 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f return parse_html(kwargs={"json_data": report}) - def render_licenses(self, announcements, licenses): + def render_licenses(self, announcements: List[Dict], licenses: List[Dict]) -> None: + """ + Renders licenses in HTML format. + + Args: + announcements (List[Dict]): List of announcements. + licenses (List[Dict]): List of licenses. + """ pass - def render_announcements(self, announcements): + def render_announcements(self, announcements: List[Dict]) -> None: + """ + Renders announcements in HTML format. + + Args: + announcements (List[Dict]): List of announcements. + """ pass diff --git a/safety/formatters/json.py b/safety/formatters/json.py index 4e0edfc0..0b4ae8ba 100644 --- a/safety/formatters/json.py +++ b/safety/formatters/json.py @@ -1,10 +1,7 @@ import logging - import json as json_parser from collections import defaultdict -from typing import Iterable - -from requests.models import PreparedRequest +from typing import Iterable, List, Dict, Any from safety.formatter import FormatterAPI from safety.formatters.schemas import VulnerabilitySchemaV05 @@ -16,7 +13,20 @@ LOG = logging.getLogger(__name__) -def build_json_report(announcements, vulnerabilities, remediations, packages): +def build_json_report(announcements: List[Dict], vulnerabilities: List[Dict], remediations: Dict[str, Any], + packages: List[Any]) -> Dict[str, Any]: + """ + Build a JSON report for vulnerabilities, remediations, and packages. + + Args: + announcements (List[Dict]): List of announcements. + vulnerabilities (List[Dict]): List of vulnerabilities. + remediations (Dict[str, Any]): Remediation data. + packages (List[Any]): List of packages. + + Returns: + Dict[str, Any]: JSON report. + """ vulns_ignored = [vuln.to_dict() for vuln in vulnerabilities if vuln.ignored] vulns = [vuln.to_dict() for vuln in vulnerabilities if not vuln.ignored] @@ -56,10 +66,31 @@ class JsonReport(FormatterAPI): VERSIONS = ("0.5", "1.1") def __init__(self, version="1.1", **kwargs): + """ + Initialize JsonReport with the specified version. + + Args: + version (str): Report version. + """ super().__init__(**kwargs) self.version: str = version if version in self.VERSIONS else "1.1" - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): + def render_vulnerabilities(self, announcements: List[Dict], vulnerabilities: List[Dict], + remediations: Dict[str, Any], full: bool, packages: List[Any], fixes: Iterable = ()) -> str: + """ + Render vulnerabilities in JSON format. + + Args: + announcements (List[Dict]): List of announcements. + vulnerabilities (List[Dict]): List of vulnerabilities. + remediations (Dict[str, Any]): Remediation data. + full (bool): Flag indicating full output. + packages (List[Any]): List of packages. + fixes (Iterable, optional): Iterable of fixes. + + Returns: + str: Rendered JSON vulnerabilities report. + """ if self.version == '0.5': return json_parser.dumps(VulnerabilitySchemaV05().dump(obj=vulnerabilities, many=True), indent=4) @@ -72,7 +103,17 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f return json_parser.dumps(template, indent=4, cls=SafetyEncoder) - def render_licenses(self, announcements, licenses): + def render_licenses(self, announcements: List[Dict], licenses: List[Dict]) -> str: + """ + Render licenses in JSON format. + + Args: + announcements (List[Dict]): List of announcements. + licenses (List[Dict]): List of licenses. + + Returns: + str: Rendered JSON licenses report. + """ unique_license_types = set([lic['license'] for lic in licenses]) report = get_report_brief_info(as_dict=True, report_type=2, licenses_found=len(unique_license_types)) @@ -84,10 +125,29 @@ def render_licenses(self, announcements, licenses): return json_parser.dumps(template, indent=4) - def render_announcements(self, announcements): + def render_announcements(self, announcements: List[Dict]) -> str: + """ + Render announcements in JSON format. + + Args: + announcements (List[Dict]): List of announcements. + + Returns: + str: Rendered JSON announcements. + """ return json_parser.dumps({"announcements": get_basic_announcements(announcements)}, indent=4) - def __render_fixes(self, scan_template, fixes: Iterable): + def __render_fixes(self, scan_template: Dict[str, Any], fixes: Iterable) -> Dict[str, Any]: + """ + Render fixes and update the scan template with remediations information. + + Args: + scan_template (Dict[str, Any]): Initial scan template. + fixes (Iterable): Iterable of fixes. + + Returns: + Dict[str, Any]: Updated scan template with remediations. + """ applied = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) skipped = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) diff --git a/safety/formatters/schemas/common.py b/safety/formatters/schemas/common.py index 5220323e..3c7e0aae 100644 --- a/safety/formatters/schemas/common.py +++ b/safety/formatters/schemas/common.py @@ -1,4 +1,4 @@ -from typing import Dict, Generic, TypeVar +from typing import Dict, Generic, TypeVar, Generator from pydantic import BaseModel as PydanticBaseModel from pydantic import Extra @@ -9,10 +9,13 @@ class BaseModel(PydanticBaseModel): + """ + Base model that extends Pydantic's BaseModel with additional configurations. + """ class Config: - arbitrary_types_allowed = True - max_anystr_length = 50 - validate_assignment = True + arbitrary_types_allowed: bool = True + max_anystr_length: int = 50 + validate_assignment: bool = True extra = Extra.forbid @@ -21,15 +24,36 @@ class Config: class ConstrainedDict(Generic[KeyType, ValueType]): - def __init__(self, v: Dict[KeyType, ValueType]): + """ + A constrained dictionary that validates its length based on a specified limit. + """ + def __init__(self, v: Dict[KeyType, ValueType]) -> None: + """ + Initialize the ConstrainedDict. + + Args: + v (Dict[KeyType, ValueType]): The dictionary to constrain. + """ super().__init__() @classmethod - def __get_validators__(cls): + def __get_validators__(cls) -> Generator: yield cls.dict_length_validator @classmethod - def dict_length_validator(cls, v): + def dict_length_validator(cls, v: Dict[KeyType, ValueType]) -> Dict[KeyType, ValueType]: + """ + Validate the length of the dictionary. + + Args: + v (Dict[KeyType, ValueType]): The dictionary to validate. + + Returns: + Dict[KeyType, ValueType]: The validated dictionary. + + Raises: + DictMaxLengthError: If the dictionary exceeds the allowed length. + """ v = dict_validator(v) if len(v) > SCHEMA_DICT_ITEMS_COUNT_LIMIT: raise DictMaxLengthError(limit_value=SCHEMA_DICT_ITEMS_COUNT_LIMIT) diff --git a/safety/formatters/schemas/v3_0.py b/safety/formatters/schemas/v3_0.py index 87e0af94..3dfb3090 100644 --- a/safety/formatters/schemas/v3_0.py +++ b/safety/formatters/schemas/v3_0.py @@ -17,6 +17,21 @@ class Meta(BaseModel): + """ + Metadata for the scan report. + + Attributes: + scan_type (Literal["system-scan", "scan", "check"]): The type of scan. + scan_location (Path): The location of the scan. + logged_to_dashboard (bool): Whether the scan was logged to the dashboard. + authenticated (bool): Whether the scan was authenticated. + authentication_method (Literal["token", "api_key"]): The method of authentication. + local_database_path (Optional[Path]): The path to the local database. + safety_version (str): The version of the Safety tool used. + timestamp (datetime): The timestamp of the scan. + telemetry (Telemetry): Telemetry data related to the scan. + schema_version (str): The version of the schema used. + """ scan_type: Literal["system-scan", "scan", "check"] scan_location: Path logged_to_dashboard: bool @@ -30,6 +45,17 @@ class Meta(BaseModel): class Package(BaseModel): + """ + Information about a package and its vulnerabilities. + + Attributes: + requirements (ConstrainedDict[str, RequirementInfo]): The package requirements. + current_version (Optional[str]): The current version of the package. + vulnerabilities_found (Optional[int]): The number of vulnerabilities found. + recommended_version (Optional[str]): The recommended version of the package. + other_recommended_versions (List[str]): Other recommended versions of the package. + more_info_url (Optional[HttpUrl]): URL for more information about the package. + """ requirements: ConstrainedDict[str, RequirementInfo] current_version: Optional[str] vulnerabilities_found: Optional[int] @@ -39,34 +65,81 @@ class Package(BaseModel): class OSVulnerabilities(BaseModel): + """ + Information about OS vulnerabilities. + + Attributes: + packages (ConstrainedDict[str, Package]): Packages with vulnerabilities. + vulnerabilities (List[Vulnerability]): List of vulnerabilities. + """ packages: ConstrainedDict[str, Package] vulnerabilities: List[Vulnerability] = Field(..., max_items=100, unique_items=True) class EnvironmentFindings(BaseModel): + """ + Findings related to the environment. + + Attributes: + configuration (ConstrainedDict): Configuration details. + packages (ConstrainedDict[str, Package]): Packages found in the environment. + os_vulnerabilities (OSVulnerabilities): OS vulnerabilities found. + """ configuration: ConstrainedDict packages: ConstrainedDict[str, Package] os_vulnerabilities: OSVulnerabilities class Environment(BaseModel): + """ + Details about the environment being scanned. + + Attributes: + full_location (Path): The full path of the environment. + type (Literal["environment"]): The type of the environment. + findings (EnvironmentFindings): Findings related to the environment. + """ full_location: Path type: Literal["environment"] findings: EnvironmentFindings class DependencyVulnerabilities(BaseModel): + """ + Information about dependency vulnerabilities. + + Attributes: + packages (List[PackageShort]): List of packages with vulnerabilities. + vulnerabilities (List[Vulnerability]): List of vulnerabilities found. + """ packages: List[PackageShort] = Field(..., max_items=500, unique_items=True) vulnerabilities: List[Vulnerability] = Field(..., max_items=100, unique_items=True) class FileFindings(BaseModel): + """ + Findings related to a file. + + Attributes: + configuration (ConstrainedDict): Configuration details. + packages (List[PackageShort]): List of packages found in the file. + dependency_vulnerabilities (DependencyVulnerabilities): Dependency vulnerabilities found. + """ configuration: ConstrainedDict packages: List[PackageShort] = Field(..., max_items=500, unique_items=True) dependency_vulnerabilities: DependencyVulnerabilities class Remediations(BaseModel): + """ + Remediations for vulnerabilities. + + Attributes: + configuration (ConstrainedDict): Configuration details. + packages (ConstrainedDict[str, Package]): Packages with remediations. + dependency_vulnerabilities (ConstrainedDict[str, Package]): Dependency vulnerabilities with remediations. + remediations_results (RemediationsResults): Results of the remediations. + """ configuration: ConstrainedDict packages: ConstrainedDict[str, Package] dependency_vulnerabilities: ConstrainedDict[str, Package] @@ -74,6 +147,17 @@ class Remediations(BaseModel): class File(BaseModel): + """ + Information about a scanned file. + + Attributes: + full_location (Path): The full path of the file. + type (str): The type of the file. + language (Literal["python"]): The programming language of the file. + format (str): The format of the file. + findings (FileFindings): Findings related to the file. + remediations (Remediations): Remediations for the file. + """ full_location: Path type: str language: Literal["python"] @@ -83,6 +167,13 @@ class File(BaseModel): class Results(BaseModel): + """ + The results of a scan. + + Attributes: + environments (List[ConstrainedDict[Path, Environment]]): List of environments scanned. + files (List[ConstrainedDict[str, File]]): List of files scanned. + """ environments: List[ConstrainedDict[Path, Environment]] = Field( [], max_items=100, unique_items=True ) @@ -92,6 +183,16 @@ class Results(BaseModel): class Project(Results): + """ + Information about a project being scanned. + + Attributes: + id (Optional[int]): The project ID. + location (Path): The location of the project. + policy (Optional[Path]): The policy file for the project. + policy_source (Optional[Literal["local", "cloud"]]): The source of the policy. + git (Union[GitInfo, NoGit]): Git information related to the project. + """ id: Optional[int] location: Path policy: Optional[Path] @@ -100,6 +201,14 @@ class Project(Results): class ScanReportV30(BaseModel): + """ + The scan report. + + Attributes: + meta (Meta): Metadata about the scan. + results (Union[Results, Dict]): The results of the scan. + projects (Union[Project, Dict]): Projects involved in the scan. + """ meta: Meta results: Results | Dict = {} projects: Project | Dict = {} \ No newline at end of file diff --git a/safety/formatters/schemas/zero_five.py b/safety/formatters/schemas/zero_five.py index 578e89cf..d9bed276 100644 --- a/safety/formatters/schemas/zero_five.py +++ b/safety/formatters/schemas/zero_five.py @@ -1,9 +1,17 @@ -from typing import Optional +from typing import Optional, List, Any, Dict, Tuple from marshmallow import Schema, fields as fields_, post_dump class CVSSv2(Schema): + """ + Schema for CVSSv2 data. + + Attributes: + base_score (fields_.Int): Base score of the CVSSv2. + impact_score (fields_.Int): Impact score of the CVSSv2. + vector_string (fields_.Str): Vector string of the CVSSv2. + """ base_score = fields_.Int() impact_score = fields_.Int() vector_string = fields_.Str() @@ -13,6 +21,15 @@ class Meta: class CVSSv3(Schema): + """ + Schema for CVSSv3 data. + + Attributes: + base_score (fields_.Int): Base score of the CVSSv3. + base_severity (fields_.Str): Base severity of the CVSSv3. + impact_score (fields_.Int): Impact score of the CVSSv3. + vector_string (fields_.Str): Vector string of the CVSSv3. + """ base_score = fields_.Int() base_severity = fields_.Str() impact_score = fields_.Int() @@ -24,7 +41,16 @@ class Meta: class VulnerabilitySchemaV05(Schema): """ - Legacy JSON report used in Safety 1.10.3 + Legacy JSON report schema used in Safety 1.10.3. + + Attributes: + package_name (fields_.Str): Name of the vulnerable package. + vulnerable_spec (fields_.Str): Vulnerable specification of the package. + version (fields_.Str): Version of the package. + advisory (fields_.Str): Advisory details for the vulnerability. + vulnerability_id (fields_.Str): ID of the vulnerability. + cvssv2 (Optional[CVSSv2]): CVSSv2 details of the vulnerability. + cvssv3 (Optional[CVSSv3]): CVSSv3 details of the vulnerability. """ package_name = fields_.Str() @@ -39,6 +65,17 @@ class Meta: ordered = True @post_dump(pass_many=True) - def wrap_with_envelope(self, data, many, **kwargs): + def wrap_with_envelope(self, data: List[Dict[str, Any]], many: bool, **kwargs: Any) -> List[Tuple]: + """ + Wraps the dumped data with an envelope. + + Args: + data (List[Dict[str, Any]]): The data to be wrapped. + many (bool): Indicates if multiple objects are being dumped. + **kwargs (Any): Additional keyword arguments. + + Returns: + List[Tuple]: The wrapped data. + """ return [tuple(d.values()) for d in data] diff --git a/safety/formatters/screen.py b/safety/formatters/screen.py index 2d6a1848..61892e4f 100644 --- a/safety/formatters/screen.py +++ b/safety/formatters/screen.py @@ -7,7 +7,8 @@ build_primary_announcement, get_specifier_range_info, format_unpinned_vulnerabilities from safety.util import get_primary_announcement, get_basic_announcements, get_terminal_size, \ is_ignore_unpinned_mode - +from collections import defaultdict +from typing import List, Dict, Any, Tuple class ScreenReport(FormatterAPI): DIVIDER_SECTIONS = '+' + '=' * (get_terminal_size().columns - 2) + '+' @@ -29,7 +30,16 @@ class ScreenReport(FormatterAPI): ANNOUNCEMENTS_HEADING = format_long_text(click.style('ANNOUNCEMENTS', bold=True)) - def __build_announcements_section(self, announcements): + def __build_announcements_section(self, announcements: List[Dict]) -> List[str]: + """ + Build the announcements section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + + Returns: + List[str]: Formatted announcements section. + """ announcements_section = [] basic_announcements = get_basic_announcements(announcements) @@ -41,7 +51,22 @@ def __build_announcements_section(self, announcements): return announcements_section - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): + def render_vulnerabilities(self, announcements: List[Dict], vulnerabilities: List[Dict], remediations: Dict[str, Any], + full: bool, packages: List[Dict], fixes: Tuple = ()) -> str: + """ + Render the vulnerabilities section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + vulnerabilities (List[Dict]): List of vulnerability dictionaries. + remediations (Dict[str, Any]): Remediation data. + full (bool): Flag indicating full report. + packages (List[Dict]): List of package dictionaries. + fixes (Tuple, optional): Iterable of fixes. + + Returns: + str: Rendered vulnerabilities report. + """ announcements_section = self.__build_announcements_section(announcements) primary_announcement = get_primary_announcement(announcements) remediation_section = build_remediation_section(remediations) @@ -56,7 +81,6 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f ignored = {} total_ignored = 0 - from collections import defaultdict unpinned_packages = defaultdict(list) styled_vulns = [] @@ -110,7 +134,17 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f end_content ) - def render_licenses(self, announcements, licenses): + def render_licenses(self, announcements: List[Dict], licenses: List[Dict]) -> str: + """ + Render the licenses section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + licenses (List[Dict]): List of license dictionaries. + + Returns: + str: Rendered licenses report. + """ unique_license_types = set([lic['license'] for lic in licenses]) report_brief_section = build_report_brief_section(primary_announcement=get_primary_announcement(announcements), @@ -149,7 +183,16 @@ def render_licenses(self, announcements, licenses): self.DIVIDER_SECTIONS] ) - def render_announcements(self, announcements): + def render_announcements(self, announcements: List[Dict]) -> List[str]: + """ + Render the announcements section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + + Returns: + str: Rendered announcements section. + """ return self.__build_announcements_section(announcements) diff --git a/safety/formatters/text.py b/safety/formatters/text.py index 51625809..0d7a7ff6 100644 --- a/safety/formatters/text.py +++ b/safety/formatters/text.py @@ -8,6 +8,7 @@ build_primary_announcement, format_unpinned_vulnerabilities from safety.util import get_primary_announcement, get_basic_announcements, is_ignore_unpinned_mode, \ get_remediations_count +from typing import List, Dict, Tuple, Any class TextReport(FormatterAPI): @@ -30,7 +31,16 @@ class TextReport(FormatterAPI): """ + SMALL_DIVIDER_SECTIONS - def __build_announcements_section(self, announcements): + def __build_announcements_section(self, announcements: List[Dict]) -> List[str]: + """ + Build the announcements section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + + Returns: + List[str]: Formatted announcements section. + """ announcements_table = [] basic_announcements = get_basic_announcements(announcements) @@ -43,7 +53,25 @@ def __build_announcements_section(self, announcements): return announcements_table - def render_vulnerabilities(self, announcements, vulnerabilities, remediations, full, packages, fixes=()): + def render_vulnerabilities( + self, announcements: List[Dict], vulnerabilities: List[Dict], + remediations: Dict[str, Any], full: bool, packages: List[Dict], + fixes: Tuple = () + ) -> str: + """ + Render the vulnerabilities section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + vulnerabilities (List[Dict]): List of vulnerability dictionaries. + remediations (Dict[str, Any]): Remediation data. + full (bool): Flag indicating full report. + packages (List[Dict]): List of package dictionaries. + fixes (Tuple, optional): Iterable of fixes. + + Returns: + str: Rendered vulnerabilities report. + """ primary_announcement = get_primary_announcement(announcements) remediation_section = [click.unstyle(rem) for rem in build_remediation_section(remediations, columns=80)] end_content = [] @@ -109,7 +137,17 @@ def render_vulnerabilities(self, announcements, vulnerabilities, remediations, f table ) - def render_licenses(self, announcements, licenses): + def render_licenses(self, announcements: List[Dict], licenses: List[Dict]) -> str: + """ + Render the licenses section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + licenses (List[Dict]): List of license dictionaries. + + Returns: + str: Rendered licenses report. + """ unique_license_types = set([lic['license'] for lic in licenses]) report_brief_section = click.unstyle( @@ -145,7 +183,16 @@ def render_licenses(self, announcements, licenses): return "\n".join(table) - def render_announcements(self, announcements): + def render_announcements(self, announcements: List[Dict]) -> str: + """ + Render the announcements section of the report. + + Args: + announcements (List[Dict]): List of announcement dictionaries. + + Returns: + str: Rendered announcements section. + """ rows = self.__build_announcements_section(announcements) rows.insert(0, self.SMALL_DIVIDER_SECTIONS) return '\n'.join(rows) diff --git a/safety/models.py b/safety/models.py index 64c1932d..20a81887 100644 --- a/safety/models.py +++ b/safety/models.py @@ -2,6 +2,7 @@ from collections import namedtuple from dataclasses import dataclass, field from datetime import datetime +from typing import Any, List, Optional, Set, Tuple, Union, Dict from dparse.dependencies import Dependency from dparse import parse, filetypes @@ -25,8 +26,11 @@ class DictConverter(object): + """ + A class to convert objects to dictionaries. + """ - def to_dict(self, **kwargs): + def to_dict(self, **kwargs: Any) -> Dict: pass @@ -45,7 +49,20 @@ def to_dict(self, **kwargs): class SafetyRequirement(Requirement): - def __init__(self, requirement: [str, Dependency], found: Optional[str] = None) -> None: + """ + A subclass of Requirement that includes additional attributes and methods for safety requirements. + """ + def __init__(self, requirement: Union[str, Dependency], found: Optional[str] = None) -> None: + """ + Initialize a SafetyRequirement. + + Args: + requirement (Union[str, Dependency]): The requirement as a string or Dependency object. + found (Optional[str], optional): Where the requirement was found. Defaults to None. + + Raises: + InvalidRequirementError: If the requirement cannot be parsed. + """ dep = requirement if isinstance(requirement, str): @@ -75,10 +92,19 @@ def __init__(self, requirement: [str, Dependency], found: Optional[str] = None) self.raw = raw_line self.found = found - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return str(self) == str(other) - - def to_dict(self, **kwargs): + + def to_dict(self, **kwargs: Any) -> Dict: + """ + Convert the requirement to a dictionary. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + dict: The dictionary representation of the requirement. + """ specifier_obj = self.specifier if not "specifier_obj" in kwargs: specifier_obj = str(self.specifier) @@ -95,6 +121,15 @@ def to_dict(self, **kwargs): def is_pinned_requirement(spec: SpecifierSet) -> bool: + """ + Check if a requirement is pinned. + + Args: + spec (SpecifierSet): The specifier set of the requirement. + + Returns: + bool: True if the requirement is pinned, False otherwise. + """ if not spec or len(spec) != 1: return False @@ -105,6 +140,9 @@ def is_pinned_requirement(spec: SpecifierSet) -> bool: @dataclass class Package(DictConverter): + """ + A class representing a package. + """ name: str version: Optional[str] requirements: List[SafetyRequirement] @@ -116,16 +154,37 @@ class Package(DictConverter): latest_version: Optional[str] = None more_info_url: Optional[str] = None - def has_unpinned_req(self): + def has_unpinned_req(self) -> bool: + """ + Check if the package has unpinned requirements. + + Returns: + bool: True if there are unpinned requirements, False otherwise. + """ for req in self.requirements: if not is_pinned_requirement(req.specifier): return True return False - def get_unpinned_req(self): + def get_unpinned_req(self) -> filter: + """ + Get the unpinned requirements. + + Returns: + filter: A filter object with the unpinned requirements. + """ return filter(lambda r: not is_pinned_requirement(r.specifier), self.requirements) - def filter_by_supported_versions(self, versions: [str]) -> [str]: + def filter_by_supported_versions(self, versions: List[str]) -> List[str]: + """ + Filter the versions by supported versions. + + Args: + versions (List[str]): The list of versions. + + Returns: + List[str]: The list of supported versions. + """ allowed = [] for version in versions: @@ -137,13 +196,28 @@ def filter_by_supported_versions(self, versions: [str]) -> [str]: return allowed - def get_versions(self, db_full): + def get_versions(self, db_full: Dict) -> Set[str]: + """ + Get the versions from the database. + + Args: + db_full (Dict): The full database. + + Returns: + Set[str]: The set of versions. + """ pkg_meta = db_full.get('meta', {}).get('packages', {}).get(self.name, {}) versions = self.filter_by_supported_versions( pkg_meta.get("insecure_versions", []) + pkg_meta.get("secure_versions", [])) return set(versions) - def refresh_from(self, db_full): + def refresh_from(self, db_full: Dict) -> None: + """ + Refresh the package information from the database. + + Args: + db_full (Dict): The full database. + """ base_domain = db_full.get('meta', {}).get('base_domain') pkg_meta = db_full.get('meta', {}).get('packages', {}).get(canonicalize_name(self.name), {}) @@ -156,7 +230,16 @@ def refresh_from(self, db_full): self.update(kwargs) - def to_dict(self, **kwargs): + def to_dict(self, **kwargs: Any) -> Dict: + """ + Convert the package to a dictionary. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + dict: The dictionary representation of the package. + """ if kwargs.get('short_version', False): return { 'name': self.name, @@ -175,19 +258,37 @@ def to_dict(self, **kwargs): 'more_info_url': self.more_info_url } - def update(self, new): + def update(self, new: Dict) -> None: + """ + Update the package attributes with new values. + + Args: + new (Dict): The new attribute values. + """ for key, value in new.items(): if hasattr(self, key): setattr(self, key, value) class Announcement(announcement_nmt): + """ + A class representing an announcement. + """ pass class Remediation(remediation_nmt, DictConverter): + """ + A class representing a remediation. + """ + + def to_dict(self) -> Dict: + """ + Convert the remediation to a dictionary. - def to_dict(self): + Returns: + Dict: The dictionary representation of the remediation. + """ return {'package': self.Package.name, 'closest_secure_version': self.closest_secure_version, 'secure_versions': self.secure_versions, @@ -197,10 +298,13 @@ def to_dict(self): @dataclass class Fix: + """ + A class representing a fix. + """ dependency: Any = None previous_version: Any = None previous_spec: Optional[str] = None - other_options: [str] = field(default_factory=lambda: []) + other_options: List[str] = field(default_factory=lambda: []) updated_version: Any = None update_type: str = "" package: str = "" @@ -211,13 +315,31 @@ class Fix: class CVE(cve_nmt, DictConverter): + """ + A class representing a CVE. + """ - def to_dict(self): + def to_dict(self) -> Dict: + """ + Convert the CVE to a dictionary. + + Returns: + Dict: The dictionary representation of the CVE. + """ return {'name': self.name, 'cvssv2': self.cvssv2, 'cvssv3': self.cvssv3} class Severity(severity_nmt, DictConverter): - def to_dict(self): + """ + A class representing the severity of a vulnerability. + """ + def to_dict(self) -> Dict: + """ + Convert the severity to a dictionary. + + Returns: + Dict: The dictionary representation of the severity. + """ result = {'severity': {'source': self.source}} result['severity']['cvssv2'] = self.cvssv2 @@ -227,7 +349,19 @@ def to_dict(self): class SafetyEncoder(json.JSONEncoder): - def default(self, value): + """ + A custom JSON encoder for Safety related objects. + """ + def default(self, value: Any) -> Any: + """ + Override the default method to handle custom objects. + + Args: + value (Any): The value to encode. + + Returns: + Any: The encoded value. + """ if isinstance(value, SafetyRequirement): return value.to_dict() elif isinstance(value, Version) or (legacyType and isinstance(value, legacyType)): @@ -237,8 +371,17 @@ def default(self, value): class Vulnerability(vulnerability_nmt): + """ + A class representing a vulnerability. + """ + + def to_dict(self) -> Dict: + """ + Convert the vulnerability to a dictionary. - def to_dict(self): + Returns: + Dict: The dictionary representation of the vulnerability. + """ empty_list_if_none = ['fixed_versions', 'closest_versions_without_known_vulnerabilities', 'resources'] result = { } @@ -268,10 +411,22 @@ def to_dict(self): return result - def get_advisory(self): + def get_advisory(self) -> str: + """ + Get the advisory for the vulnerability. + + Returns: + str: The advisory text. + """ return self.advisory.replace('\r', '') if self.advisory else "No advisory found for this vulnerability." - - def to_model_dict(self): + + def to_model_dict(self) -> Dict: + """ + Convert the vulnerability to a dictionary for the model. + + Returns: + Dict: The dictionary representation of the vulnerability for the model. + """ try: affected_spec = next(iter(self.vulnerable_spec)) except Exception: @@ -285,7 +440,7 @@ def to_model_dict(self): } if self.ignored: - repr["ignored"] = {"reason": self.ignored_reason, + repr["ignored"] = {"reason": self.ignored_reason, "expires": self.ignored_expires} return repr @@ -293,6 +448,9 @@ def to_model_dict(self): @dataclass class Safety: + """ + A class representing Safety settings. + """ client: Any keys: Any @@ -302,6 +460,9 @@ class Safety: @dataclass class SafetyCLI: + """ + A class representing Safety CLI settings. + """ auth: Optional[Auth] = None telemetry: Optional[TelemetryModel] = None metadata: Optional[MetadataModel] = None diff --git a/safety/output_utils.py b/safety/output_utils.py index 5ad25c8e..0ebb928d 100644 --- a/safety/output_utils.py +++ b/safety/output_utils.py @@ -4,7 +4,7 @@ import os import textwrap from datetime import datetime -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, Any, Union import click @@ -20,8 +20,19 @@ LOG = logging.getLogger(__name__) -def build_announcements_section_content(announcements, columns=get_terminal_size().columns, indent: str = ' ' * 2, - sub_indent: str = ' ' * 4): +def build_announcements_section_content(announcements: List[Dict[str, Any]], columns: int = get_terminal_size().columns, indent: str = ' ' * 2, sub_indent: str = ' ' * 4) -> str: + """ + Build the content for the announcements section. + + Args: + announcements (List[Dict[str, Any]]): List of announcements. + columns (int, optional): Number of columns for formatting. Defaults to terminal size. + indent (str, optional): Indentation for the text. Defaults to ' ' * 2. + sub_indent (str, optional): Sub-indentation for the text. Defaults to ' ' * 4. + + Returns: + str: Formatted announcements section content. + """ section = '' for i, announcement in enumerate(announcements): @@ -42,11 +53,30 @@ def build_announcements_section_content(announcements, columns=get_terminal_size return section -def add_empty_line(): +def add_empty_line() -> str: + """ + Add an empty line. + + Returns: + str: Empty line. + """ return format_long_text('') -def style_lines(lines, columns, pre_processed_text='', start_line=' ' * 4, end_line=' ' * 4): +def style_lines(lines: List[Dict[str, Any]], columns: int, pre_processed_text: str = '', start_line: str = ' ' * 4, end_line: str = ' ' * 4) -> str: + """ + Style the lines with the specified format. + + Args: + lines (List[Dict[str, Any]]): List of lines to style. + columns (int): Number of columns for formatting. + pre_processed_text (str, optional): Pre-processed text. Defaults to ''. + start_line (str, optional): Starting line decorator. Defaults to ' ' * 4. + end_line (str, optional): Ending line decorator. Defaults to ' ' * 4. + + Returns: + str: Styled text. + """ styled_text = pre_processed_text for line in lines: @@ -74,7 +104,19 @@ def style_lines(lines, columns, pre_processed_text='', start_line=' ' * 4, end_l return styled_text -def format_vulnerability(vulnerability, full_mode, only_text=False, columns=get_terminal_size().columns): +def format_vulnerability(vulnerability: Any, full_mode: bool, only_text: bool = False, columns: int = get_terminal_size().columns) -> str: + """ + Format the vulnerability details. + + Args: + vulnerability (Any): The vulnerability object. + full_mode (bool): Whether to use full mode for formatting. + only_text (bool, optional): Whether to return only text without styling. Defaults to False. + columns (int, optional): Number of columns for formatting. Defaults to terminal size. + + Returns: + str: Formatted vulnerability details. + """ common_format = {'indent': 3, 'format': {'sub_indent': ' ' * 3, 'max_lines': None}} @@ -221,7 +263,18 @@ def format_vulnerability(vulnerability, full_mode, only_text=False, columns=get_ return click.unstyle(content) if only_text else content -def format_license(license, only_text=False, columns=get_terminal_size().columns): +def format_license(license: Dict[str, Any], only_text: bool = False, columns: int = get_terminal_size().columns) -> str: + """ + Format the license details. + + Args: + license (Dict[str, Any]): The license details. + only_text (bool, optional): Whether to return only text without styling. Defaults to False. + columns (int, optional): Number of columns for formatting. Defaults to terminal size. + + Returns: + str: Formatted license details. + """ to_print = [ {'words': [{'style': {'bold': True}, 'value': license['package']}, {'value': ' version {0} found using license '.format(license['version'])}, @@ -235,7 +288,16 @@ def format_license(license, only_text=False, columns=get_terminal_size().columns return click.unstyle(content) if only_text else content -def get_fix_hint_for_unpinned(remediation): +def get_fix_hint_for_unpinned(remediation: Dict[str, Any]) -> str: + """ + Get the fix hint for unpinned dependencies. + + Args: + remediation (Dict[str, Any]): The remediation details. + + Returns: + str: The fix hint. + """ secure_options: List[str] = [str(fix) for fix in remediation.get('other_recommended_versions', [])] fixes_hint = f'Version {remediation.get("recommended_version")} has no known vulnerabilities and falls' \ f' within your current specifier range.' @@ -248,12 +310,31 @@ def get_fix_hint_for_unpinned(remediation): return fixes_hint -def get_unpinned_hint(pkg: str): +def get_unpinned_hint(pkg: str) -> str: + """ + Get the hint for unpinned packages. + + Args: + pkg (str): The package name. + + Returns: + str: The hint for unpinned packages. + """ return f"We recommend either pinning {pkg} to one of the versions above or updating your " \ f"install specifier to ensure a vulnerable version cannot be installed." def get_specifier_range_info(style: bool = True, pin_hint: bool = False) -> str: + """ + Get the specifier range information. + + Args: + style (bool, optional): Whether to apply styling. Defaults to True. + pin_hint (bool, optional): Whether to include a pin hint. Defaults to False. + + Returns: + str: The specifier range information. + """ hint = '' if pin_hint: @@ -269,7 +350,18 @@ def get_specifier_range_info(style: bool = True, pin_hint: bool = False) -> str: return f'{msg} {link}' -def build_other_options_msg(fix_version: Optional[str], is_spec: bool, secure_options: [str]) -> str: +def build_other_options_msg(fix_version: Optional[str], is_spec: bool, secure_options: List[str]) -> str: + """ + Build the message for other secure options. + + Args: + fix_version (Optional[str]): The recommended fix version. + is_spec (bool): Whether the package is specified. + secure_options (List[str]): List of secure options. + + Returns: + str: The message for other secure options. + """ other_options_msg = '' raw_pre_other_options = '' outside = '' @@ -290,7 +382,19 @@ def build_other_options_msg(fix_version: Optional[str], is_spec: bool, secure_op return other_options_msg -def build_remediation_section(remediations, only_text=False, columns=get_terminal_size().columns, kwargs=None): +def build_remediation_section(remediations: Dict[str, Any], only_text: bool = False, columns: int = get_terminal_size().columns, kwargs: Optional[Dict[str, Any]] = None) -> List[str]: + """ + Build the remediation section content. + + Args: + remediations (Dict[str, Any]): The remediations details. + only_text (bool, optional): Whether to return only text without styling. Defaults to False. + columns (int, optional): Number of columns for formatting. Defaults to terminal size. + kwargs (Optional[Dict[str, Any]], optional): Additional arguments for formatting. Defaults to None. + + Returns: + List[str]: The remediation section content. + """ columns -= 2 indent = ' ' * 3 @@ -315,94 +419,94 @@ def build_remediation_section(remediations, only_text=False, columns=get_termina spec = rem['requirement'] is_spec = not version and spec secure_options: List[str] = [str(fix) for fix in rem.get('other_recommended_versions', [])] - + fix_version = None new_line = '\n' spec_info = [] - + vuln_word = 'vulnerability' pronoun_word = 'this' - + if rem['vulnerabilities_found'] > 1: vuln_word = 'vulnerabilities' pronoun_word = 'these' - + if rem.get('recommended_version', None): fix_version = str(rem.get('recommended_version')) - + other_options_msg = build_other_options_msg(fix_version=fix_version, is_spec=is_spec, secure_options=secure_options) - + spec_hint = '' - - if secure_options or fix_version and is_spec: + + if secure_options or fix_version and is_spec: raw_spec_info = get_unpinned_hint(pkg) - + spec_hint = f"{click.style(raw_spec_info, bold=True, fg='green')}" \ f" {get_specifier_range_info()}" - + if fix_version: fix_v: str = click.style(fix_version, bold=True) closest_msg = f'The closest version with no known vulnerabilities is {fix_v}' - + if is_spec: closest_msg = f'Version {fix_v} has no known vulnerabilities and falls within your current specifier ' \ f'range' - + raw_recommendation = f"We recommend updating to version {fix_version} of {pkg}." - + remediation_styled = click.style(f'{raw_recommendation} {other_options_msg}', bold=True, fg='green') - + # Spec case if is_spec: closest_msg += f'. {other_options_msg}' remediation_styled = spec_hint - + remediation_content = [ closest_msg, new_line, remediation_styled ] - + else: no_known_fix_msg = f'There is no known fix for {pronoun_word} {vuln_word}.' - + if is_spec and secure_options: no_known_fix_msg = f'There is no known fix for {pronoun_word} {vuln_word} in the current specified ' \ f'range ({spec}).' - + no_fix_msg_styled = f"{click.style(no_known_fix_msg, bold=True, fg='yellow')} " \ f"{click.style(other_options_msg, bold=True, fg='green')}" - + remediation_content = [new_line, no_fix_msg_styled] - + if spec_hint: remediation_content.extend([new_line, spec_hint]) - + # Pinned raw_rem_title = f"-> {pkg} version {version} was found, " \ f"which has {rem['vulnerabilities_found']} {vuln_word}" - + # Range if is_spec: # Spec remediation copy raw_rem_title = f"-> {pkg} with install specifier {spec} was found, " \ f"which has {rem['vulnerabilities_found']} {vuln_word}" - + remediation_title = click.style(raw_rem_title, fg=RED, bold=True) content += new_line + format_long_text(remediation_title, **{**kwargs, **{'indent': '', 'sub_indent': ' ' * 3}}) + new_line - + pre_content = remediation_content + spec_info + [new_line, f"For more information about the {pkg} package and update " f"options, visit {rem['more_info_url']}", f'Always check for breaking changes when updating packages.', new_line] - + for i, element in enumerate(pre_content): content += format_long_text(element, **kwargs) - + if i + 1 < len(pre_content): content += '\n' @@ -429,7 +533,20 @@ def build_remediation_section(remediations, only_text=False, columns=get_termina return content -def get_final_brief(total_vulns_found, remediations, ignored, total_ignored, kwargs=None): +def get_final_brief(total_vulns_found: int, remediations: Dict[str, Any], ignored: Dict[str, Any], total_ignored: int, kwargs: Optional[Dict[str, Any]] = None) -> str: + """ + Get the final brief summary. + + Args: + total_vulns_found (int): Total vulnerabilities found. + remediations (Dict[str, Any]): Remediation details. + ignored (Dict[str, Any]): Ignored vulnerabilities details. + total_ignored (int): Total ignored vulnerabilities. + kwargs (Optional[Dict[str, Any]], optional): Additional arguments for formatting. Defaults to None. + + Returns: + str: Final brief summary. + """ if not kwargs: kwargs = {} @@ -451,7 +568,17 @@ def get_final_brief(total_vulns_found, remediations, ignored, total_ignored, kwa return format_long_text(raw_brief, start_line_decorator=' ', **kwargs) -def get_final_brief_license(licenses, kwargs=None): +def get_final_brief_license(licenses: List[str], kwargs: Optional[Dict[str, Any]] = None) -> str: + """ + Get the final brief summary for licenses. + + Args: + licenses (List[str]): List of licenses. + kwargs (Optional[Dict[str, Any]], optional): Additional arguments for formatting. Defaults to None. + + Returns: + str: Final brief summary for licenses. + """ if not kwargs: kwargs = {} @@ -463,7 +590,24 @@ def get_final_brief_license(licenses, kwargs=None): return format_long_text("{0}".format(licenses_text), start_line_decorator=' ', **kwargs) -def format_long_text(text, color='', columns=get_terminal_size().columns, start_line_decorator=' ', end_line_decorator=' ', max_lines=None, styling=None, indent='', sub_indent=''): +def format_long_text(text: str, color: str = '', columns: int = get_terminal_size().columns, start_line_decorator: str = ' ', end_line_decorator: str = ' ', max_lines: Optional[int] = None, styling: Optional[Dict[str, Any]] = None, indent: str = '', sub_indent: str = '') -> str: + """ + Format long text with wrapping and styling. + + Args: + text (str): The text to format. + color (str, optional): Color for the text. Defaults to ''. + columns (int, optional): Number of columns for formatting. Defaults to terminal size. + start_line_decorator (str, optional): Starting line decorator. Defaults to ' '. + end_line_decorator (str, optional): Ending line decorator. Defaults to ' '. + max_lines (Optional[int], optional): Maximum number of lines. Defaults to None. + styling (Optional[Dict[str, Any]], optional): Additional styling options. Defaults to None. + indent (str, optional): Indentation for the text. Defaults to ''. + sub_indent (str, optional): Sub-indentation for the text. Defaults to ''. + + Returns: + str: Formatted text. + """ if not styling: styling = {} @@ -492,7 +636,16 @@ def format_long_text(text, color='', columns=get_terminal_size().columns, start_ return "\n".join(formatted_lines) -def get_printable_list_of_scanned_items(scanning_target): +def get_printable_list_of_scanned_items(scanning_target: str) -> Tuple[List[Dict[str, Any]], List[str]]: + """ + Get a printable list of scanned items. + + Args: + scanning_target (str): The scanning target (environment, stdin, files, or file). + + Returns: + Tuple[List[Dict[str, Any]], List[str]]: Printable list of scanned items and scanned items data. + """ context = SafetyContext() result = [] @@ -538,7 +691,19 @@ def get_printable_list_of_scanned_items(scanning_target): REPORT_HEADING = format_long_text(click.style('REPORT', bold=True)) -def build_report_brief_section(columns=None, primary_announcement=None, report_type=1, **kwargs): +def build_report_brief_section(columns: Optional[int] = None, primary_announcement: Optional[Dict[str, Any]] = None, report_type: int = 1, **kwargs: Any) -> str: + """ + Build the brief section of the report. + + Args: + columns (Optional[int], optional): Number of columns for formatting. Defaults to None. + primary_announcement (Optional[Dict[str, Any]], optional): Primary announcement details. Defaults to None. + report_type (int, optional): Type of the report. Defaults to 1. + **kwargs: Additional arguments for formatting. + + Returns: + str: Brief section of the report. + """ if not columns: columns = get_terminal_size().columns @@ -571,7 +736,16 @@ def build_report_brief_section(columns=None, primary_announcement=None, report_t return "\n".join([add_empty_line(), REPORT_HEADING, add_empty_line(), '\n'.join(styled_brief_lines)]) -def build_report_for_review_vuln_report(as_dict=False): +def build_report_for_review_vuln_report(as_dict: bool = False) -> Union[Dict[str, Any], List[List[Dict[str, Any]]]]: + """ + Build the report for review vulnerability report. + + Args: + as_dict (bool, optional): Whether to return as a dictionary. Defaults to False. + + Returns: + Union[Dict[str, Any], List[List[Dict[str, Any]]]]: Review vulnerability report. + """ ctx = SafetyContext() report_from_file = ctx.review packages = ctx.packages @@ -620,7 +794,18 @@ def build_report_for_review_vuln_report(as_dict=False): return brief_info -def build_using_sentence(account, key, db): +def build_using_sentence(account: Optional[str], key: Optional[str], db: Optional[str]) -> List[Dict[str, Any]]: + """ + Build the sentence for the used components. + + Args: + account (Optional[str]): The account details. + key (Optional[str]): The API key. + db (Optional[str]): The database details. + + Returns: + List[Dict[str, Any]]: Sentence for the used components. + """ key_sentence = [] custom_integration = os.environ.get('SAFETY_CUSTOM_INTEGRATION', 'false').lower() == 'true' @@ -647,7 +832,16 @@ def build_using_sentence(account, key, db): return [{'style': False, 'value': 'Using '}] + key_sentence + database_sentence -def build_scanned_count_sentence(packages): +def build_scanned_count_sentence(packages: List[Package]) -> List[Dict[str, Any]]: + """ + Build the sentence for the scanned count. + + Args: + packages (List[Package]): List of packages. + + Returns: + List[Dict[str, Any]]: Sentence for the scanned count. + """ scanned_count = 'No packages found' if len(packages) >= 1: scanned_count = 'Found and scanned {0} {1}'.format(len(packages), @@ -656,7 +850,13 @@ def build_scanned_count_sentence(packages): return [{'style': True, 'value': scanned_count}] -def add_warnings_if_needed(brief_info): +def add_warnings_if_needed(brief_info: List[List[Dict[str, Any]]]): + """ + Add warnings to the brief info if needed. + + Args: + brief_info (List[List[Dict[str, Any]]]): Brief info details. + """ ctx = SafetyContext() warnings = [] @@ -673,7 +873,18 @@ def add_warnings_if_needed(brief_info): brief_info += [[{'style': False, 'value': ''}]] + warnings -def get_report_brief_info(as_dict=False, report_type=1, **kwargs): +def get_report_brief_info(as_dict: bool = False, report_type: int = 1, **kwargs: Any): + """ + Get the brief info of the report. + + Args: + as_dict (bool, optional): Whether to return as a dictionary. Defaults to False. + report_type (int, optional): Type of the report. Defaults to 1. + **kwargs: Additional arguments for the report. + + Returns: + Union[Dict[str, Any], List[List[Dict[str, Any]]]]: Brief info of the report. + """ LOG.info('get_report_brief_info: %s, %s, %s', as_dict, report_type, kwargs) context = SafetyContext() @@ -812,7 +1023,18 @@ def get_report_brief_info(as_dict=False, report_type=1, **kwargs): return brief_data if as_dict else brief_info -def build_primary_announcement(primary_announcement, columns=None, only_text=False): +def build_primary_announcement(primary_announcement, columns: Optional[int] = None, only_text: bool = False) -> str: + """ + Build the primary announcement section. + + Args: + primary_announcement (Dict[str, Any]): Primary announcement details. + columns (Optional[int], optional): Number of columns for formatting. Defaults to None. + only_text (bool, optional): Whether to return only text without styling. Defaults to False. + + Returns: + str: Primary announcement section. + """ lines = json.loads(primary_announcement.get('message')) for line in lines: @@ -829,15 +1051,37 @@ def build_primary_announcement(primary_announcement, columns=None, only_text=Fal return click.unstyle(message) if only_text else message -def is_using_api_key(): +def is_using_api_key() -> bool: + """ + Check if an API key is being used. + + Returns: + bool: True if using an API key, False otherwise. + """ return bool(SafetyContext().key) or bool(SafetyContext().account) -def is_using_a_safety_policy_file(): +def is_using_a_safety_policy_file() -> bool: + """ + Check if a safety policy file is being used. + + Returns: + bool: True if using a safety policy file, False otherwise. + """ return bool(SafetyContext().params.get('policy_file', None)) -def should_add_nl(output, found_vulns): +def should_add_nl(output: str, found_vulns: bool) -> bool: + """ + Determine if a newline should be added. + + Args: + output (str): The output format. + found_vulns (bool): Whether vulnerabilities were found. + + Returns: + bool: True if a newline should be added, False otherwise. + """ if output == 'bare' and not found_vulns: return False @@ -845,6 +1089,15 @@ def should_add_nl(output, found_vulns): def get_skip_reason(fix: Fix) -> str: + """ + Get the reason for skipping a fix. + + Args: + fix (Fix): The fix details. + + Returns: + str: The reason for skipping the fix. + """ range_msg = '' if not fix.updated_version and fix.other_options: @@ -859,15 +1112,43 @@ def get_skip_reason(fix: Fix) -> str: return reasons.get(fix.status, 'unknown.') -def get_applied_msg(fix, mode="auto") -> str: +def get_applied_msg(fix: Fix, mode: str = "auto") -> str: + """ + Get the message for an applied fix. + + Args: + fix (Fix): The fix details. + mode (str, optional): The mode of the fix. Defaults to "auto". + + Returns: + str: The message for the applied fix. + """ return f"{fix.package}{fix.previous_spec} has a {fix.update_type} version fix available: {mode} updating to =={fix.updated_version}." -def get_skipped_msg(fix) -> str: +def get_skipped_msg(fix: Fix) -> str: + """ + Get the message for a skipped fix. + + Args: + fix (Fix): The fix details. + + Returns: + str: The message for the skipped fix. + """ return f'{fix.package} remediation was skipped because {get_skip_reason(fix)}' -def get_fix_opt_used_msg(fix_options=None) -> str: +def get_fix_opt_used_msg(fix_options: Optional[List[str]] = None) -> str: + """ + Get the message for the fix options used. + + Args: + fix_options (Optional[List[str]], optional): The fix options. Defaults to None. + + Returns: + str: The message for the fix options used. + """ if not fix_options: fix_options = SafetyContext().params.get('auto_remediation_limit', []) @@ -883,7 +1164,18 @@ def get_fix_opt_used_msg(fix_options=None) -> str: return msg -def print_service(output: List[Tuple[str, Dict]], out_format: str, format_text: Optional[dict] = None): +def print_service(output: List[Tuple[str, Dict[str, Any]]], out_format: str, format_text: Optional[Dict[str, Any]] = None): + """ + Print the service output. + + Args: + output (List[Tuple[str, Dict[str, Any]]]): The output to print. + out_format (str): The output format. + format_text (Optional[Dict[str, Any]], optional): Additional text formatting options. Defaults to None. + + Raises: + ValueError: If the output format is not allowed. + """ formats = ['text', 'screen'] if out_format not in formats: @@ -904,7 +1196,22 @@ def print_service(output: List[Tuple[str, Dict]], out_format: str, format_text: click.echo(click.unstyle(line)) -def prompt_service(output: Tuple[str, Dict], out_format: str, format_text: Optional[dict] = None) -> bool: + +def prompt_service(output: Tuple[str, Dict[str, Any]], out_format: str, format_text: Optional[Dict[str, Any]] = None) -> bool: + """ + Prompt the user for input. + + Args: + output (Tuple[str, Dict[str, Any]]): The output to display. + out_format (str): The output format. + format_text (Optional[Dict[str, Any]], optional): Additional text formatting options. Defaults to None. + + Returns: + bool: The user response. + + Raises: + ValueError: If the output format is not allowed. + """ formats = ['text', 'screen'] if out_format not in formats: @@ -924,14 +1231,34 @@ def prompt_service(output: Tuple[str, Dict], out_format: str, format_text: Optio return click.prompt(msg) -def parse_html(*, kwargs, template='index.html'): +def parse_html(*, kwargs: Dict[str, Any], template: str = 'index.html') -> str: + """ + Parse HTML using Jinja2 templates. + + Args: + kwargs (Dict[str, Any]): The template variables. + template (str, optional): The template name. Defaults to 'index.html'. + + Returns: + str: The rendered HTML. + """ file_loader = PackageLoader('safety', 'templates') env = Environment(loader=file_loader) template = env.get_template(template) return template.render(**kwargs) -def format_unpinned_vulnerabilities(unpinned_packages, columns=None): +def format_unpinned_vulnerabilities(unpinned_packages: Dict[str, List[Any]], columns: Optional[int] = None) -> List[str]: + """ + Format unpinned vulnerabilities. + + Args: + unpinned_packages (Dict[str, List[Any]]): Unpinned packages and their vulnerabilities. + columns (Optional[int], optional): Number of columns for formatting. Defaults to None. + + Returns: + List[str]: Formatted unpinned vulnerabilities. + """ lines = [] if not unpinned_packages: diff --git a/safety/safety.py b/safety/safety.py index ae6d1fc5..7bb6ad7d 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -6,17 +6,15 @@ import logging import os from pathlib import Path -import random import sys import tempfile import time from collections import defaultdict from datetime import datetime -from typing import Dict, Optional, List, Any +from typing import Dict, Optional, List, Any, Union, Iterator import click import requests -from requests.models import PreparedRequest from packaging.specifiers import SpecifierSet from packaging.utils import canonicalize_name from packaging.version import parse as parse_version, Version @@ -43,6 +41,17 @@ def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False) -> Optional[Dict[str, Any]]: + """ + Retrieves the database from the cache if it is valid. + + Args: + db_name (str): The name of the database. + cache_valid_seconds (int): The validity period of the cache in seconds. + skip_time_verification (bool): Whether to skip time verification. + + Returns: + Optional[[Dict[str, Any]]: The cached database if available and valid, otherwise False. + """ cache_file_lock = f"{DB_CACHE_FILE}.lock" os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True) lock = FileLock(cache_file_lock, timeout=10) @@ -76,7 +85,14 @@ def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verific return None -def write_to_cache(db_name, data): +def write_to_cache(db_name: str, data: Dict[str, Any]) -> None: + """ + Writes the database to the cache. + + Args: + db_name (str): The name of the database. + data (Dict[str, Any]): The database data to be cached. + """ # cache is in: ~/safety/cache.json # and has the following form: # { @@ -122,8 +138,30 @@ def write_to_cache(db_name, data): LOG.debug('Safety updated the cache file for %s database.', db_name) -def fetch_database_url(session, mirror, db_name, cached, telemetry=True, - ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True): +def fetch_database_url( + session: requests.Session, + mirror: str, + db_name: str, + cached: int, + telemetry: bool = True, + ecosystem: Ecosystem = Ecosystem.PYTHON, + from_cache: bool = True +) -> Dict[str, Any]: + """ + Fetches the database from a URL. + + Args: + session (requests.Session): The requests session. + mirror (str): The URL of the mirror. + db_name (str): The name of the database. + cached (int): The cache validity in seconds. + telemetry (bool): Whether to include telemetry data. + ecosystem (Ecosystem): The ecosystem. + from_cache (bool): Whether to fetch from cache. + + Returns: + Dict[str, Any]: The fetched database. + """ headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value} if cached and from_cache: @@ -169,7 +207,16 @@ def fetch_database_url(session, mirror, db_name, cached, telemetry=True, return data -def fetch_policy(session): +def fetch_policy(session: requests.Session) -> Dict[str, Any]: + """ + Fetches the policy from the server. + + Args: + session (requests.Session): The requests session. + + Returns: + Dict[str, Any]: The fetched policy. + """ url = f"{DATA_API_BASE_URL}policy/" try: @@ -183,7 +230,18 @@ def fetch_policy(session): return {"safety_policy": "", "audit_and_monitor": False} -def post_results(session, safety_json, policy_file): +def post_results(session: requests.Session, safety_json: str, policy_file: str) -> Dict[str, Any]: + """ + Posts the scan results to the server. + + Args: + session (requests.Session): The requests session. + safety_json (str): The scan results in JSON format. + policy_file (str): The policy file. + + Returns: + Dict[str, Any]: The server response. + """ url = f"{DATA_API_BASE_URL}result/" # safety_json is in text form already. policy_file is a text YAML @@ -210,8 +268,19 @@ def post_results(session, safety_json, policy_file): return {} -def fetch_database_file(path: str, db_name: str, cached = 0, - ecosystem: Optional[Ecosystem] = None): +def fetch_database_file(path: str, db_name: str, cached: int = 0, ecosystem: Optional[Ecosystem] = None) -> Dict[str, Any]: + """ + Fetches the database from a local file. + + Args: + path (str): The path to the local file. + db_name (str): The name of the database. + cached (int): The cache validity in seconds. + ecosystem (Optional[Ecosystem]): The ecosystem. + + Returns: + Dict[str, Any]: The fetched database. + """ full_path = (Path(path) / (ecosystem.value if ecosystem else '') / db_name).expanduser().resolve() if not full_path.exists(): @@ -227,7 +296,16 @@ def fetch_database_file(path: str, db_name: str, cached = 0, return data -def is_valid_database(db) -> bool: +def is_valid_database(db: Dict[str, Any]) -> bool: + """ + Checks if the database is valid. + + Args: + db (Dict[str, Any]): The database. + + Returns: + bool: True if the database is valid, False otherwise. + """ try: if db['meta']['schema_version'] == JSON_SCHEMA_VERSION: return True @@ -237,9 +315,30 @@ def is_valid_database(db) -> bool: return False -def fetch_database(session, full=False, db=False, cached=0, telemetry=True, - ecosystem: Optional[Ecosystem] = None, from_cache=True): - +def fetch_database( + session: requests.Session, + full: bool = False, + db: Union[Optional[str], bool] = False, + cached: int = 0, + telemetry: bool = True, + ecosystem: Optional[Ecosystem] = None, + from_cache: bool = True +) -> Dict[str, Any]: + """ + Fetches the database from a mirror or a local file. + + Args: + session (requests.Session): The requests session. + full (bool): Whether to fetch the full database. + db (Optional[str]): The path to the local database file. + cached (int): The cache validity in seconds. + telemetry (bool): Whether to include telemetry data. + ecosystem (Optional[Ecosystem]): The ecosystem. + from_cache (bool): Whether to fetch from cache. + + Returns: + Dict[str, Any]: The fetched database. + """ if session.is_using_auth_credentials(): mirrors = API_MIRRORS elif db: @@ -268,14 +367,52 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True, raise DatabaseFetchError() -def get_vulnerabilities(pkg, spec, db): +def get_vulnerabilities(pkg: str, spec: str, db: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + """ + Retrieves vulnerabilities for a package from the database. + + Args: + pkg (str): The package name. + spec (str): The specifier set. + db (Dict[str, Any]): The database. + + Returns: + Iterator[Dict[str, Any]]: An iterator of vulnerabilities. + """ for entry in db['vulnerable_packages'][pkg]: for entry_spec in entry["specs"]: if entry_spec == spec: yield entry -def get_vulnerability_from(vuln_id, cve, data, specifier, db, name, pkg, ignore_vulns, affected): +def get_vulnerability_from( + vuln_id: str, + cve: Optional[CVE], + data: Dict[str, Any], + specifier: str, + db: Dict[str, Any], + name: str, + pkg: Package, + ignore_vulns: Dict[str, Any], + affected: SafetyRequirement +) -> Vulnerability: + """ + Constructs a Vulnerability object from the provided data. + + Args: + vuln_id (str): The vulnerability ID. + cve (Optional[CVE]): The CVE object. + data (Dict[str, Any]): The vulnerability data. + specifier (str): The specifier set. + db (Dict[str, Any]): The database. + name (str): The package name. + pkg (Package): The Package object. + ignore_vulns (Dict[str, Any]): The ignored vulnerabilities. + affected (SafetyRequirement): The affected requirement. + + Returns: + Vulnerability: The constructed Vulnerability object. + """ base_domain = db.get('meta', {}).get('base_domain') unpinned_ignored = ignore_vulns.get(vuln_id, {}).get('requirements', None) should_ignore = not unpinned_ignored or str(affected.specifier) in unpinned_ignored @@ -319,7 +456,17 @@ def get_vulnerability_from(vuln_id, cve, data, specifier, db, name, pkg, ignore_ ) -def get_cve_from(data, db_full): +def get_cve_from(data: Dict[str, Any], db_full: Dict[str, Any]) -> Optional[CVE]: + """ + Retrieves the CVE object from the provided data. + + Args: + data (Dict[str, Any]): The vulnerability data. + db_full (Dict[str, Any]): The full database. + + Returns: + Optional[CVE]: The CVE object if found, otherwise None. + """ try: xve_id: str = str( next(filter(lambda i: i.get('type', None) in ['cve', 'pve'], data.get('ids', []))).get('id', '')) @@ -334,7 +481,25 @@ def get_cve_from(data, db_full): cvssv3=cve_meta.get("cvssv3", None)) -def ignore_vuln_if_needed(pkg: Package, vuln_id, cve, ignore_vulns, ignore_severity_rules, req): +def ignore_vuln_if_needed( + pkg: Package, + vuln_id: str, + cve: Optional[CVE], + ignore_vulns: Dict[str, Any], + ignore_severity_rules: Dict[str, Any], + req: SafetyRequirement +) -> None: + """ + Determines if a vulnerability should be ignored based on severity rules and updates the ignore_vulns dictionary. + + Args: + pkg (Package): The package. + vuln_id (str): The vulnerability ID. + cve (Optional[CVE]): The CVE object. + ignore_vulns (Dict[str, Any]): The ignored vulnerabilities. + ignore_severity_rules (Dict[str, Any]): The severity rules for ignoring vulnerabilities. + req (SafetyRequirement): The affected requirement. + """ if not ignore_severity_rules: ignore_severity_rules = {} @@ -373,8 +538,18 @@ def ignore_vuln_if_needed(pkg: Package, vuln_id, cve, ignore_vulns, ignore_sever ignore_vulns[vuln_id] = {'reason': reason, 'expires': None, 'requirements': requirements} -def is_vulnerable(vulnerable_spec: SpecifierSet, requirement, package): +def is_vulnerable(vulnerable_spec: SpecifierSet, requirement: SafetyRequirement, package: Package) -> bool: + """ + Checks if a package version is vulnerable. + + Args: + vulnerable_spec (SpecifierSet): The specifier set for vulnerable versions. + requirement (SafetyRequirement): The package requirement. + package (Package): The package. + Returns: + bool: True if the package version is vulnerable, False otherwise. + """ if is_pinned_requirement(requirement.specifier): try: return vulnerable_spec.contains(next(iter(requirement.specifier)).version) @@ -392,8 +567,41 @@ def is_vulnerable(vulnerable_spec: SpecifierSet, requirement, package): @sync_safety_context -def check(*, session=None, packages=[], db_mirror=False, cached=0, ignore_vulns=None, ignore_severity_rules=None, proxy=None, - include_ignored=False, is_env_scan=True, telemetry=True, params=None, project=None): +def check( + *, + session: requests.Session, + packages: List[Package] = [], + db_mirror: Union[Optional[str], bool] = False, + cached: int = 0, + ignore_vulns: Optional[Dict[str, Any]] = None, + ignore_severity_rules: Optional[Dict[str, Any]] = None, + proxy: Optional[Dict[str, Any]] = None, + include_ignored: bool = False, + is_env_scan: bool = True, + telemetry: bool = True, + params: Optional[Dict[str, Any]] = None, + project: Optional[str] = None +) -> tuple: + """ + Performs a vulnerability check on the provided packages. + + Args: + session (requests.Session): The requests session. + packages (List[Package]): The list of packages to check. + db_mirror (Union[Optional[str], bool]): The database mirror. + cached (int): The cache validity in seconds. + ignore_vulns (Optional[Dict[str, Any]]): The ignored vulnerabilities. + ignore_severity_rules (Optional[Dict[str, Any]]): The severity rules for ignoring vulnerabilities. + proxy (Optional[Dict[str, Any]]): The proxy settings. + include_ignored (bool): Whether to include ignored vulnerabilities. + is_env_scan (bool): Whether it is an environment scan. + telemetry (bool): Whether to include telemetry data. + params (Optional[Dict[str, Any]]): Additional parameters. + project (Optional[str]): The project name. + + Returns: + tuple: A tuple containing the list of vulnerabilities and the full database. + """ SafetyContext().command = 'check' db = fetch_database(session, db=db_mirror, cached=cached, telemetry=telemetry) db_full = None @@ -457,8 +665,21 @@ def check(*, session=None, packages=[], db_mirror=False, cached=0, ignore_vulns= return vulnerabilities, db_full -def precompute_remediations(remediations, packages, vulns, secure_vulns_by_user): +def precompute_remediations( + remediations: Dict[str, Dict[str, Any]], + packages: Dict[str, Package], + vulns: List[Vulnerability], + secure_vulns_by_user: set +) -> None: + """ + Precomputes the remediations for the given vulnerabilities. + Args: + remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. + packages (Dict[str, Package]): The packages dictionary. + vulns (List[Vulnerability]): The list of vulnerabilities. + secure_vulns_by_user (set): The set of vulnerabilities secured by the user. + """ for vuln in vulns: if vuln.ignored and vuln.ignored_reason != IGNORE_UNPINNED_REQ_REASON: @@ -489,8 +710,23 @@ def precompute_remediations(remediations, packages, vulns, secure_vulns_by_user) 'more_info_url': vuln.pkg.more_info_url} -def get_closest_ver(versions, version, spec: SpecifierSet): - results = {'upper': None, 'lower': None} +def get_closest_ver( + versions: List[str], + version: Optional[str], + spec: SpecifierSet +) -> Dict[str, Optional[Union[str, Version]]]: + """ + Retrieves the closest versions for the given version and specifier set. + + Args: + versions (List[str]): The list of versions. + version (Optional[str]): The current version. + spec (SpecifierSet): The specifier set. + + Returns: + Dict[str, Optional[Union[str, Version]]]: The closest versions. + """ + results: Dict[str, Optional[Union[str, Version]]] = {'upper': None, 'lower': None} if (not version and not spec) or not versions: return results @@ -529,7 +765,22 @@ def get_closest_ver(versions, version, spec: SpecifierSet): return results -def compute_sec_ver_for_user(package: Package, secure_vulns_by_user, db_full): +def compute_sec_ver_for_user( + package: Package, + secure_vulns_by_user: set, + db_full: Dict[str, Any] +) -> List[str]: + """ + Computes the secure versions for the user. + + Args: + package (Package): The package. + secure_vulns_by_user (set): The set of vulnerabilities secured by the user. + db_full (Dict[str, Any]): The full database. + + Returns: + List[str]: The list of secure versions. + """ versions = package.get_versions(db_full) affected_versions = [] @@ -544,10 +795,22 @@ def compute_sec_ver_for_user(package: Package, secure_vulns_by_user, db_full): return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True) -def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_user, db_full): +def compute_sec_ver( + remediations: Dict[str, Dict[str, Any]], + packages: Dict[str, Package], + secure_vulns_by_user: set, + db_full: Dict[str, Any] +) -> None: """ - Compute the secure_versions and the closest_secure_version for each remediation using the affected_versions - of each no ignored vulnerability of the same package, there is only a remediation for each package. + Computes the secure versions and the closest secure version for each remediation. + + Uses the affected_versions of each no ignored vulnerability of the same package, there is only a remediation for each package. + + Args: + remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. + packages (Dict[str, Package]): The packages dictionary. + secure_vulns_by_user (set): The set of vulnerabilities secured by the user. + db_full (Dict[str, Any]): The full database. """ for pkg_name in remediations.keys(): pkg: Package = packages.get(pkg_name, None) @@ -597,7 +860,20 @@ def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_ target_version=recommended_version) -def calculate_remediations(vulns, db_full): +def calculate_remediations( + vulns: List[Vulnerability], + db_full: Dict[str, Any] +) -> Dict[str, Dict[str, Any]]: + """ + Calculates the remediations for the given vulnerabilities. + + Args: + vulns (List[Vulnerability]): The list of vulnerabilities. + db_full (Dict[str, Any]): The full database. + + Returns: + Dict[str, Dict[str, Any]]: The calculated remediations. + """ remediations = defaultdict(dict) package_metadata = {} secure_vulns_by_user = set() @@ -611,7 +887,22 @@ def calculate_remediations(vulns, db_full): return remediations -def should_apply_auto_fix(from_ver: Optional[Version], to_ver, allowed_automatic): +def should_apply_auto_fix( + from_ver: Optional[Version], + to_ver: Version, + allowed_automatic: List[str] +) -> bool: + """ + Determines if an automatic fix should be applied. + + Args: + from_ver (Optional[Version]): The current version. + to_ver (Version): The target version. + allowed_automatic (List[str]): The allowed automatic update types. + + Returns: + bool: True if an automatic fix should be applied, False otherwise. + """ if not from_ver: return False @@ -636,8 +927,17 @@ def should_apply_auto_fix(from_ver: Optional[Version], to_ver, allowed_automatic return False -def get_update_type(from_ver: Optional[Version], to_ver: Version): +def get_update_type(from_ver: Optional[Version], to_ver: Version) -> str: + """ + Determines the update type. + + Args: + from_ver (Optional[Version]): The current version. + to_ver (Version): The target version. + Returns: + str: The update type. + """ if not from_ver or (to_ver.major - from_ver.major) != 0: return 'major' @@ -647,14 +947,56 @@ def get_update_type(from_ver: Optional[Version], to_ver: Version): return 'patch' -def process_fixes(files, remediations, auto_remediation_limit, output, no_output=True, prompt=False): +def process_fixes( + files: List[str], + remediations: Dict[str, Dict[str, Any]], + auto_remediation_limit: List[str], + output: str, + no_output: bool = True, + prompt: bool = False +) -> List[Fix]: + """ + Processes the fixes for the given files and remediations. + + Args: + files (List[str]): The list of files. + remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. + auto_remediation_limit (List[str]): The automatic remediation limits. + output (str): The output format. + no_output (bool): Whether to suppress output. + prompt (bool): Whether to prompt for confirmation. + + Returns: + List[Fix]: The list of applied fixes. + """ req_remediations = itertools.chain.from_iterable(rem.values() for pkg_name, rem in remediations.items()) requirements = compute_fixes_per_requirements(files, req_remediations, auto_remediation_limit, prompt=prompt) fixes = apply_fixes(requirements, output, no_output, prompt) return fixes -def process_fixes_scan(file_to_fix, to_fix_spec, auto_remediation_limit, output, no_output=True, prompt=False): +def process_fixes_scan( + file_to_fix: SafetyPolicyFile, + to_fix_spec: List[SafetyRequirement], + auto_remediation_limit: List[str], + output: str, + no_output: bool = True, + prompt: bool = False +) -> List[Fix]: + """ + Processes the fixes for the given file and specifications in scan mode. + + Args: + file_to_fix (SafetyPolicyFile): The file to fix. + to_fix_spec (List[SafetyRequirement]): The specifications to fix. + auto_remediation_limit (List[str]): The automatic remediation limits. + output (str): The output format. + no_output (bool): Whether to suppress output. + prompt (bool): Whether to prompt for confirmation. + + Returns: + List[Fix]: The list of applied fixes. + """ to_fix_remediations = [] def get_remmediation_from(spec): @@ -707,7 +1049,24 @@ def get_remmediation_from(spec): return fixes -def compute_fixes_per_requirements(files, req_remediations, auto_remediation_limit, prompt=False): +def compute_fixes_per_requirements( + files: List[str], + req_remediations: Iterator[Dict[str, Any]], + auto_remediation_limit: List[str], + prompt: bool = False +) -> Dict[str, Any]: + """ + Computes the fixes per requirements. + + Args: + files (List[str]): The list of files. + req_remediations (Iterator[Dict[str, Any]]): The remediations iterator. + auto_remediation_limit (List[str]): The automatic remediation limits. + prompt (bool): Whether to prompt for confirmation. + + Returns: + Dict[str, Any]: The computed requirements with fixes. + """ requirements_files = get_requirements_content(files) from dparse.parser import parse, filetypes @@ -812,7 +1171,28 @@ def compute_fixes_per_requirements(files, req_remediations, auto_remediation_lim return requirements -def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto_remediation_limit=None): +def apply_fixes( + requirements: Dict[str, Any], + out_type: str, + no_output: bool, + prompt: bool, + scan_flow: bool = False, + auto_remediation_limit: List[str] = None +) -> List[Fix]: + """ + Applies the fixes to the requirements. + + Args: + requirements (Dict[str, Any]): The requirements with fixes. + out_type (str): The output format. + no_output (bool): Whether to suppress output. + prompt (bool): Whether to prompt for confirmation. + scan_flow (bool): Whether it is in scan flow mode. + auto_remediation_limit (List[str]): The automatic remediation limits. + + Returns: + List[Fix]: The list of applied fixes. + """ from dparse.updater import RequirementsTXTUpdater @@ -932,7 +1312,20 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto return skip + apply + confirm -def find_vulnerabilities_fixed(vulnerabilities: Dict, fixes) -> List[Vulnerability]: +def find_vulnerabilities_fixed( + vulnerabilities: Dict[str, Any], + fixes: List[Fix] +) -> List[Vulnerability]: + """ + Finds the vulnerabilities that have been fixed. + + Args: + vulnerabilities (Dict[str, Any]): The dictionary of vulnerabilities. + fixes (List[Fix]): The list of applied fixes. + + Returns: + List[Vulnerability]: The list of fixed vulnerabilities. + """ fixed_specs = set(fix.previous_spec for fix in fixes) if not fixed_specs: @@ -943,7 +1336,21 @@ def find_vulnerabilities_fixed(vulnerabilities: Dict, fixes) -> List[Vulnerabili @sync_safety_context -def review(*, report=None, params=None): +def review( + *, + report: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None +) -> tuple: + """ + Reviews the report and returns the vulnerabilities and remediations. + + Args: + report (Optional[Dict[str, Any]]): The report. + params (Optional[Dict[str, Any]]): Additional parameters. + + Returns: + tuple: A tuple containing the list of vulnerabilities, the remediations, and the found packages. + """ SafetyContext().command = 'review' vulnerable = [] vulnerabilities = report.get('vulnerabilities', []) + report.get('ignored_vulnerabilities', []) @@ -1009,8 +1416,25 @@ def review(*, report=None, params=None): @sync_safety_context -def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True): +def get_licenses( + *, + session: requests.Session, + db_mirror: Union[Optional[str], bool] = False, + cached: int = 0, + telemetry: bool = True +) -> Dict[str, Any]: + """ + Retrieves the licenses from the database. + + Args: + session (requests.Session): The requests session. + db_mirror (Union[Optional[str], bool]): The database mirror. + cached (int): The cache validity in seconds. + telemetry (bool): Whether to include telemetry data. + Returns: + Dict[str, Any]: The licenses dictionary. + """ if db_mirror: mirrors = [db_mirror] else: @@ -1030,10 +1454,22 @@ def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True): raise DatabaseFetchError() -def add_local_notifications(packages: List[Package], - ignore_unpinned_requirements: Optional[bool]) -> List[Dict[str, str]]: +def add_local_notifications( + packages: List[Package], + ignore_unpinned_requirements: Optional[bool] +) -> List[Dict[str, str]]: + """ + Adds local notifications for unpinned packages. + + Args: + packages (List[Package]): The list of packages. + ignore_unpinned_requirements (Optional[bool]): Whether to ignore unpinned requirements. + + Returns: + List[Dict[str, str]]: The list of notifications. + """ announcements = [] - unpinned_packages: [str] = [f"{pkg.name}" for pkg in packages if pkg.has_unpinned_req()] + unpinned_packages: List[str] = [f"{pkg.name}" for pkg in packages if pkg.has_unpinned_req()] if unpinned_packages and ignore_unpinned_requirements is not False: found = len(unpinned_packages) @@ -1060,7 +1496,22 @@ def add_local_notifications(packages: List[Package], return announcements -def get_announcements(session, telemetry=True, with_telemetry=None): +def get_announcements( + session: requests.Session, + telemetry: bool = True, + with_telemetry: Any = None +) -> List[Dict[str, str]]: + """ + Retrieves announcements from the server. + + Args: + session (requests.Session): The requests session. + telemetry (bool): Whether to include telemetry data. + with_telemetry (Optional[Dict[str, Any]]): The telemetry data. + + Returns: + List[Dict[str, str]]: The list of announcements. + """ LOG.info('Getting announcements') announcements = [] @@ -1112,8 +1563,18 @@ def get_announcements(session, telemetry=True, with_telemetry=None): return announcements -def get_packages(files=False, stdin=False): +def get_packages(files: Optional[List[str]] = None, stdin: bool = False) -> List[Package]: + """ + Retrieves the packages from the given files or standard input. + + Args: + files (Optional[List[str]]): The list of files. + stdin (bool): Whether to read from standard input. + + Returns: + List[Package]: The list of packages. + """ if files: return list(itertools.chain.from_iterable(read_requirements(f, resolve=True) for f in files)) @@ -1122,7 +1583,7 @@ def get_packages(files=False, stdin=False): import pkg_resources - def allowed_version(pkg: str, version: str): + def allowed_version(pkg: str, version: str) -> bool: try: parse_version(version) except Exception: @@ -1148,7 +1609,16 @@ def allowed_version(pkg: str, version: str): ] -def read_vulnerabilities(fh): +def read_vulnerabilities(fh: Any) -> Dict[str, Any]: + """ + Reads vulnerabilities from a file handle. + + Args: + fh (Any): The file handle. + + Returns: + Dict[str, Any]: The vulnerabilities data. + """ try: data = json.load(fh) except json.JSONDecodeError as e: @@ -1159,7 +1629,22 @@ def read_vulnerabilities(fh): return data -def get_server_policies(session, policy_file, proxy_dictionary: Dict): +def get_server_policies( + session: requests.Session, + policy_file: SafetyPolicyFile, + proxy_dictionary: Dict[str, str] +) -> tuple: + """ + Retrieves the server policies. + + Args: + session (requests.Session): The requests session. + policy_file (SafetyPolicyFile): The policy file. + proxy_dictionary (Dict[str, str]): The proxy dictionary. + + Returns: + tuple: A tuple containing the policy file and the audit and monitor flag. + """ if session.api_key: server_policies = fetch_policy(session) server_audit_and_monitor = server_policies["audit_and_monitor"] @@ -1186,7 +1671,19 @@ def get_server_policies(session, policy_file, proxy_dictionary: Dict): return policy_file, server_audit_and_monitor -def save_report(path: str, default_name: str, report: str): +def save_report( + path: str, + default_name: str, + report: str +) -> None: + """ + Saves the report to a file. + + Args: + path (str): The path to save the report. + default_name (str): The default name of the report file. + report (str): The report content. + """ if path: save_at = path diff --git a/safety/scan/command.py b/safety/scan/command.py index 433e8742..f272868b 100644 --- a/safety/scan/command.py +++ b/safety/scan/command.py @@ -1,5 +1,3 @@ - -from datetime import datetime from enum import Enum import itertools import logging @@ -45,16 +43,32 @@ class ScannableEcosystems(Enum): + """Enum representing scannable ecosystems.""" PYTHON = Ecosystem.PYTHON.value -def process_report(obj: Any, console: Console, report: ReportModel, output: str, - save_as: Optional[Tuple[str, Path]], **kwargs): - +def process_report( + obj: Any, console: Console, report: ReportModel, output: str, + save_as: Optional[Tuple[str, Path]], **kwargs +) -> Optional[str]: + """ + Processes and outputs the report based on the given parameters. + + Args: + obj (Any): The context object. + console (Console): The console object. + report (ReportModel): The report model. + output (str): The output format. + save_as (Optional[Tuple[str, Path]]): The save-as format and path. + kwargs: Additional keyword arguments. + + Returns: + Optional[str]: The URL of the report if uploaded, otherwise None. + """ wait_msg = "Processing report" with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status: json_format = report.as_v30().json() - + export_type, export_path = None, None if save_as: @@ -74,12 +88,12 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, spdx_version = None if export_type: spdx_version = export_type.version if export_type.version and ScanExport.is_format(export_type, ScanExport.SPDX) else None - + if not spdx_version and output: spdx_version = output.version if output.version and ScanOutput.is_format(output, ScanOutput.SPDX) else None spdx_format = render_scan_spdx(report, obj, spdx_version=spdx_version) - + if export_type is ScanExport.HTML or output is ScanOutput.HTML: html_format = render_scan_html(report, obj) @@ -89,7 +103,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, ScanExport.SPDX: spdx_format, ScanExport.SPDX_2_3: spdx_format, ScanExport.SPDX_2_2: spdx_format, - } + } output_format_mapping = { ScanOutput.JSON: json_format, @@ -106,7 +120,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, msg = f"Saving {export_type} report at: {export_path}" status.update(msg) LOG.debug(msg) - save_report_as(report.metadata.scan_type, export_type, Path(export_path), + save_report_as(report.metadata.scan_type, export_type, Path(export_path), report_to_export) report_url = None @@ -131,7 +145,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, f"[link]{project_url}[/link]") elif report.metadata.scan_type is ScanType.system_scan: lines.append(f"System scan report: [link]{report_url}[/link]") - + for line in lines: console.print(line, emoji=True) @@ -142,25 +156,30 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, if output is ScanOutput.JSON: kwargs = {"json": report_to_output} else: - kwargs = {"data": report_to_output} + kwargs = {"data": report_to_output} console.print_json(**kwargs) else: console.print(report_to_output) console.quiet = True - + return report_url -def generate_updates_arguments() -> list: - """Generates a list of file types and update limits for apply fixes.""" +def generate_updates_arguments() -> List: + """ + Generates a list of file types and update limits for apply fixes. + + Returns: + List: A list of file types and update limits. + """ fixes = [] limit_type = SecurityUpdates.UpdateLevel.PATCH - DEFAULT_FILE_TYPES = [FileType.REQUIREMENTS_TXT, FileType.PIPENV_LOCK, + DEFAULT_FILE_TYPES = [FileType.REQUIREMENTS_TXT, FileType.PIPENV_LOCK, FileType.POETRY_LOCK, FileType.VIRTUAL_ENVIRONMENT] fixes.extend([(default_file_type, limit_type) for default_file_type in DEFAULT_FILE_TYPES]) - + return fixes @@ -197,7 +216,7 @@ def scan(ctx: typer.Context, ] = ScanOutput.SCREEN, detailed_output: Annotated[bool, typer.Option("--detailed-output", - help=SCAN_DETAILED_OUTPUT, + help=SCAN_DETAILED_OUTPUT, show_default=False) ] = False, save_as: Annotated[Optional[Tuple[ScanExport, Path]], @@ -221,7 +240,7 @@ def scan(ctx: typer.Context, )] = None, apply_updates: Annotated[bool, typer.Option("--apply-fixes", - help=SCAN_APPLY_FIXES, + help=SCAN_APPLY_FIXES, show_default=False) ] = False ): @@ -229,10 +248,12 @@ def scan(ctx: typer.Context, Scans a project (defaulted to the current directory) for supply-chain security and configuration issues """ + # Generate update arguments if apply updates option is enabled fixes_target = [] if apply_updates: fixes_target = generate_updates_arguments() + # Ensure save_as params are correctly set if not all(save_as): ctx.params["save_as"] = None @@ -240,19 +261,21 @@ def scan(ctx: typer.Context, ecosystems = [Ecosystem(member.value) for member in list(ScannableEcosystems)] to_include = {file_type: paths for file_type, paths in ctx.obj.config.scan.include_files.items() if file_type.ecosystem in ecosystems} - file_finder = FileFinder(target=target, ecosystems=ecosystems, + # Initialize file finder + file_finder = FileFinder(target=target, ecosystems=ecosystems, max_level=ctx.obj.config.scan.max_depth, - exclude=ctx.obj.config.scan.ignore, + exclude=ctx.obj.config.scan.ignore, include_files=to_include, console=console) + # Download necessary assets for each handler for handler in file_finder.handlers: if handler.ecosystem: wait_msg = "Fetching Safety's vulnerability database..." with console.status(wait_msg, spinner=DEFAULT_SPINNER): handler.download_required_assets(ctx.obj.auth.client) - + # Start scanning the project directory wait_msg = "Scanning project directory" path = None @@ -260,7 +283,7 @@ def scan(ctx: typer.Context, with console.status(wait_msg, spinner=DEFAULT_SPINNER): path, file_paths = file_finder.search() - print_detected_ecosystems_section(console, file_paths, + print_detected_ecosystems_section(console, file_paths, include_safety_prjs=True) target_ecosystems = ", ".join([member.value for member in ecosystems]) @@ -274,7 +297,7 @@ def scan(ctx: typer.Context, count = 0 ignored = set() - + affected_count = 0 dependency_vuln_detected = False @@ -287,18 +310,21 @@ def scan(ctx: typer.Context, requirements_txt_found = False display_apply_fix_suggestion = False + # Process each file for dependencies and vulnerabilities with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status: - for path, analyzed_file in process_files(paths=file_paths, + for path, analyzed_file in process_files(paths=file_paths, config=config): count += len(analyzed_file.dependency_results.dependencies) + # Update exit code if vulnerabilities are found if exit_code == 0 and analyzed_file.dependency_results.failed: exit_code = EXIT_CODE_VULNERABILITIES_FOUND + # Handle ignored vulnerabilities for detailed output if detailed_output: vulns_ignored = analyzed_file.dependency_results.ignored_vulns_data \ .values() - ignored_vulns_data = itertools.chain(vulns_ignored, + ignored_vulns_data = itertools.chain(vulns_ignored, ignored_vulns_data) ignored.update(analyzed_file.dependency_results.ignored_vulns.keys()) @@ -309,7 +335,7 @@ def scan(ctx: typer.Context, def sort_vulns_by_score(vuln: Vulnerability) -> int: if vuln.severity and vuln.severity.cvssv3: return vuln.severity.cvssv3.get("base_score", 0) - + return 0 to_fix_spec = [] @@ -327,10 +353,10 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: for spec in affected_specifications: if file_matched_for_fix: to_fix_spec.append(spec) - + console.print() vulns_to_report = sorted( - [vuln for vuln in spec.vulnerabilities if not vuln.ignored], + [vuln for vuln in spec.vulnerabilities if not vuln.ignored], key=sort_vulns_by_score, reverse=True) @@ -346,14 +372,14 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True, overflow="crop") - + if detailed_output or vulns_found < 3: for vuln in vulns_to_report: - render_to_console(vuln, console, - rich_kwargs={"emoji": True, + render_to_console(vuln, console, + rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=detailed_output) - + lines = [] # Put remediation here @@ -381,16 +407,16 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: console.print(Padding(line, (0, 0, 0, 1)), emoji=True) console.print( - Padding(f"Learn more: [link]{spec.remediation.more_info_url}[/link]", - (0, 0, 0, 1)), emoji=True) + Padding(f"Learn more: [link]{spec.remediation.more_info_url}[/link]", + (0, 0, 0, 1)), emoji=True) else: console.print() console.print(f":white_check_mark: [file_title]{path.relative_to(target)}: No issues found.[/file_title]", emoji=True) if(ctx.obj.auth.stage == Stage.development - and analyzed_file.ecosystem == Ecosystem.PYTHON - and analyzed_file.file_type == FileType.REQUIREMENTS_TXT + and analyzed_file.ecosystem == Ecosystem.PYTHON + and analyzed_file.file_type == FileType.REQUIREMENTS_TXT and any(affected_specifications) and not apply_updates): display_apply_fix_suggestion = True @@ -405,12 +431,12 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: if file_matched_for_fix: to_fix_files.append((file, to_fix_spec)) - files.append(file) + files.append(file) if display_apply_fix_suggestion: console.print() print_fixes_section(console, requirements_txt_found, detailed_output) - + console.print() print_brief(console, ctx.obj.project, count, affected_count, fixes_count) @@ -418,18 +444,18 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: is_detailed_output=detailed_output, ignored_vulns_data=ignored_vulns_data) - + version = ctx.obj.schema metadata = ctx.obj.metadata telemetry = ctx.obj.telemetry ctx.obj.project.files = files report = ReportModel(version=version, - metadata=metadata, + metadata=metadata, telemetry=telemetry, files=[], projects=[ctx.obj.project]) - + report_url = process_report(ctx.obj, console, report, **{**ctx.params}) project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}" @@ -440,7 +466,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: no_output = output is not ScanOutput.SCREEN prompt = output is ScanOutput.SCREEN - + # TODO: rename that 'no_output' confusing name if not no_output: console.print() @@ -462,11 +488,11 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: if any(policy_limits): update_limits = [policy_limit.value for policy_limit in policy_limits] - - fixes = process_fixes_scan(file_to_fix, + + fixes = process_fixes_scan(file_to_fix, specs_to_fix, update_limits, output, no_output=no_output, prompt=prompt) - + if not no_output: console.print("-" * console.size.width) @@ -484,7 +510,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: @scan_system_app.command( cls=SafetyCLICommand, help=CLI_SYSTEM_SCAN_COMMAND_HELP, - options_metavar="[COMMAND-OPTIONS]", + options_metavar="[COMMAND-OPTIONS]", name=CMD_SYSTEM_NAME, epilog=DEFAULT_EPILOG) @handle_cmd_exception @inject_metadata @@ -521,7 +547,7 @@ def system_scan(ctx: typer.Context, typer.Option( help=SYSTEM_SCAN_OUTPUT_HELP, show_default=False) - ] = SystemScanOutput.SCREEN, + ] = SystemScanOutput.SCREEN, save_as: Annotated[Optional[Tuple[SystemScanExport, Path]], typer.Option( help=SYSTEM_SCAN_SAVE_AS_HELP, @@ -575,9 +601,9 @@ def system_scan(ctx: typer.Context, for file_type, paths in target_paths.items(): current = file_paths.get(file_type, set()) current.update(paths) - file_paths[file_type] = current + file_paths[file_type] = current - scan_project_command = get_command_for(name=CMD_PROJECT_NAME, + scan_project_command = get_command_for(name=CMD_PROJECT_NAME, typer_instance=scan_project_app) projects_dirs = set() @@ -587,12 +613,12 @@ def system_scan(ctx: typer.Context, with console.status(":mag:", spinner=DEFAULT_SPINNER) as status: # Handle projects first if FileType.SAFETY_PROJECT.value in file_paths.keys(): - projects_file_paths = file_paths[FileType.SAFETY_PROJECT.value] + projects_file_paths = file_paths[FileType.SAFETY_PROJECT.value] basic_params = ctx.params.copy() basic_params.pop("targets", None) prjs_console = Console(quiet=True) - + for project_path in projects_file_paths: projects_dirs.add(project_path.parent) project_dir = str(project_path.parent) @@ -607,7 +633,7 @@ def system_scan(ctx: typer.Context, if not project or not project.id: LOG.warn(f"{project_path} parsed but project id is not defined or valid.") continue - + if not ctx.obj.platform_enabled: msg = f"project found and skipped, navigate to `{project.project_path}` and scan this project with ‘safety scan’" console.print(f"{project.id}: {msg}") @@ -615,8 +641,8 @@ def system_scan(ctx: typer.Context, msg = f"Existing project found at {project_dir}" console.print(f"{project.id}: {msg}") - project_data[project.id] = {"path": project_dir, - "report_url": None, + project_data[project.id] = {"path": project_dir, + "report_url": None, "project_url": None, "failed_exception": None} @@ -642,7 +668,7 @@ def system_scan(ctx: typer.Context, "save_as": (None, None), "upload_request_id": upload_request_id, "local_policy": local_policy_file, "console": prjs_console} try: - # TODO: Refactor to avoid calling invoke, also, launch + # TODO: Refactor to avoid calling invoke, also, launch # this on background. console.print( Padding(f"Running safety scan for {project.id} project", @@ -660,7 +686,7 @@ def system_scan(ctx: typer.Context, (0, 0, 0, 1)), emoji=True) LOG.exception(f"Failed to run scan on project {project.id}, " \ f"Upload request ID: {upload_request_id}. Reason {e}") - + console.print() file_paths.pop(FileType.SAFETY_PROJECT.value, None) @@ -670,18 +696,18 @@ def system_scan(ctx: typer.Context, status.update(":mag: Finishing projects processing.") for k, f_paths in file_paths.items(): - file_paths[k] = {fp for fp in f_paths - if not should_exclude(excludes=projects_dirs, + file_paths[k] = {fp for fp in f_paths + if not should_exclude(excludes=projects_dirs, to_analyze=fp)} - + pkgs_count = 0 file_count = 0 venv_count = 0 for path, analyzed_file in process_files(paths=file_paths, config=config): status.update(f":mag: {path}") - files.append(FileModel(location=path, - file_type=analyzed_file.file_type, + files.append(FileModel(location=path, + file_type=analyzed_file.file_type, results=analyzed_file.dependency_results)) file_pkg_count = len(analyzed_file.dependency_results.dependencies) @@ -718,7 +744,7 @@ def system_scan(ctx: typer.Context, pkgs_count += file_pkg_count console.print(f":package: {file_pkg_count} {msg} in {path}", emoji=True) - + if affected_pkgs_count <= 0: msg = "No vulnerabilities found" else: @@ -738,7 +764,7 @@ def system_scan(ctx: typer.Context, telemetry=telemetry, files=files, projects=projects) - + console.print() total_count = sum([finder.file_count for finder in file_finders], 0) console.print(f"Searched {total_count:,} files for dependency security issues") @@ -749,16 +775,16 @@ def system_scan(ctx: typer.Context, console.print() proccessed = dict(filter( - lambda item: item[1]["report_url"] and item[1]["project_url"], + lambda item: item[1]["report_url"] and item[1]["project_url"], project_data.items())) - + if proccessed: run_word = "runs" if len(proccessed) == 1 else "run" console.print(f"Project {pluralize('scan', len(proccessed))} {run_word} on {len(proccessed)} existing {pluralize('project', len(proccessed))}:") for prj, data in proccessed.items(): console.print(f"[bold]{prj}[/bold] at {data['path']}") - for detail in [f"{prj} dashboard: {data['project_url']}"]: + for detail in [f"{prj} dashboard: {data['project_url']}"]: console.print(Padding(detail, (0, 0, 0, 1)), emoji=True, overflow="crop") process_report(ctx.obj, console, report, **{**ctx.params}) diff --git a/safety/scan/decorators.py b/safety/scan/decorators.py index 2f41a7c8..29c9e7c9 100644 --- a/safety/scan/decorators.py +++ b/safety/scan/decorators.py @@ -4,7 +4,7 @@ from pathlib import Path from random import randint import sys -from typing import List, Optional +from typing import Any, List, Optional from rich.padding import Padding from safety_schemas.models import ConfigModel, ProjectModel @@ -29,7 +29,10 @@ LOG = logging.getLogger(__name__) -def initialize_scan(ctx, console): +def initialize_scan(ctx: Any, console: Console) -> None: + """ + Initializes the scan by setting platform_enabled based on the response from the server. + """ data = None try: @@ -48,7 +51,7 @@ def initialize_scan(ctx, console): def scan_project_command_init(func): """ - Make general verifications before each scan command. + Decorator to make general verifications before each project scan command. """ @wraps(func) def inner(ctx, policy_file_path: Optional[Path], target: Path, @@ -62,7 +65,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, console.quiet = True if not ctx.obj.auth.is_valid(): - process_auth_status_not_ready(console=console, + process_auth_status_not_ready(console=console, auth=ctx.obj.auth, ctx=ctx) upload_request_id = kwargs.pop("upload_request_id", None) @@ -109,12 +112,12 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, cloud_policy = None if ctx.obj.platform_enabled: - cloud_policy = print_wait_policy_download(console, (download_policy, - {"session": session, + cloud_policy = print_wait_policy_download(console, (download_policy, + {"session": session, "project_id": ctx.obj.project.id, "stage": stage, "branch": branch})) - + ctx.obj.project.policy = resolve_policy(local_policy, cloud_policy) config = ctx.obj.project.policy.config \ if ctx.obj.project.policy and ctx.obj.project.policy.config \ @@ -145,10 +148,10 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"} else: details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"} - + if ctx.obj.project.id: details["Project"] = ctx.obj.project.id - + if ctx.obj.project.git: details[" Git branch"] = ctx.obj.project.git.branch @@ -156,7 +159,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, msg = "None, using Safety CLI default policies" - if ctx.obj.project.policy: + if ctx.obj.project.policy: if ctx.obj.project.policy.source is PolicySource.cloud: msg = f"fetched from Safety Platform, " \ "ignoring any local Safety CLI policy files" @@ -170,7 +173,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, for k,v in details.items(): console.print(f"[scan_meta_title]{k}[/scan_meta_title]: {v}") - + print_announcements(console=console, ctx=ctx) console.print() @@ -185,10 +188,10 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, def scan_system_command_init(func): """ - Make general verifications before each system scan command. + Decorator to make general verifications before each system scan command. """ @wraps(func) - def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], + def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], output: SystemScanOutput, console: Console = main_console, *args, **kwargs): ctx.obj.console = console @@ -198,8 +201,8 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], console.quiet = True if not ctx.obj.auth.is_valid(): - process_auth_status_not_ready(console=console, - auth=ctx.obj.auth, ctx=ctx) + process_auth_status_not_ready(console=console, + auth=ctx.obj.auth, ctx=ctx) initialize_scan(ctx, console) @@ -229,12 +232,12 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], ctx.obj.config = config - if not any(targets): + if not any(targets): if any(config.scan.system_targets): targets = [Path(t).expanduser().absolute() for t in config.scan.system_targets] else: targets = [Path("/")] - + ctx.obj.metadata.scan_locations = targets console.print() @@ -244,8 +247,8 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], details = {"Account": f"{ctx.obj.auth.name}, {ctx.obj.auth.email}", "Scan stage": ctx.obj.auth.stage} - - if ctx.obj.system_scan_policy: + + if ctx.obj.system_scan_policy: if ctx.obj.system_scan_policy.source is PolicySource.cloud: policy_type = "remote" else: @@ -259,9 +262,9 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], for k,v in details.items(): console.print(f"[bold]{k}[/bold]: {v}") - + if ctx.obj.system_scan_policy: - + dirs = [ign for ign in ctx.obj.config.scan.ignore if Path(ign).is_dir()] policy_details = [ @@ -273,17 +276,17 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], console.print( Padding(policy_detail, (0, 0, 0, 1)), emoji=True) - + print_announcements(console=console, ctx=ctx) console.print() - + kwargs.update({"targets": targets}) result = func(ctx, *args, **kwargs) return result - return inner - + return inner + def inject_metadata(func): """ @@ -304,7 +307,7 @@ def inner(ctx, *args, **kwargs): if not scan_type: raise SafetyException("Missing scan_type.") - + if scan_type is ScanType.scan: if not target: raise SafetyException("Missing target.") @@ -319,7 +322,7 @@ def inner(ctx, *args, **kwargs): telemetry=telemetry, schema_version=ReportSchemaVersion.v3_0 ) - + ctx.obj.schema = ReportSchemaVersion.v3_0 ctx.obj.metadata = metadata ctx.obj.telemetry = telemetry diff --git a/safety/scan/ecosystems/base.py b/safety/scan/ecosystems/base.py index 2d4e4820..4458040f 100644 --- a/safety/scan/ecosystems/base.py +++ b/safety/scan/ecosystems/base.py @@ -1,35 +1,58 @@ from abc import ABC, abstractmethod from typing import List -from safety_schemas.models import Ecosystem, FileType, ConfigModel, \ - DependencyResultModel +from safety_schemas.models import Ecosystem, FileType, ConfigModel, DependencyResultModel from typer import FileTextWrite NOT_IMPLEMENTED = "Not implemented funtion" class Inspectable(ABC): + """ + Abstract base class defining the interface for objects that can be inspected for dependencies. + """ @abstractmethod def inspect(self, config: ConfigModel) -> DependencyResultModel: + """ + Inspects the object and returns the result of the dependency analysis. + + Args: + config (ConfigModel): The configuration model for inspection. + + Returns: + DependencyResultModel: The result of the dependency inspection. + """ return NotImplementedError(NOT_IMPLEMENTED) - + class Remediable(ABC): + """ + Abstract base class defining the interface for objects that can be remediated. + """ @abstractmethod def remediate(self): + """ + Remediates the object to address any detected issues. + """ return NotImplementedError(NOT_IMPLEMENTED) - + class InspectableFile(Inspectable): - + """ + Represents an inspectable file within a specific ecosystem and file type. + """ + def __init__(self, file: FileTextWrite): + """ + Initializes an InspectableFile instance. + + Args: + file (FileTextWrite): The file to be inspected. + """ self.file = file self.ecosystem: Ecosystem self.file_type: FileType self.dependency_results: DependencyResultModel = \ DependencyResultModel(dependencies=[]) - - - diff --git a/safety/scan/ecosystems/python/dependencies.py b/safety/scan/ecosystems/python/dependencies.py index 51dfccf5..9be67c67 100644 --- a/safety/scan/ecosystems/python/dependencies.py +++ b/safety/scan/ecosystems/python/dependencies.py @@ -13,7 +13,18 @@ from packaging.utils import canonicalize_name -def get_closest_ver(versions, version, spec: SpecifierSet): +def get_closest_ver(versions: List[str], version: Optional[str], spec: SpecifierSet) -> dict: + """ + Gets the closest version to the specified version within a list of versions. + + Args: + versions (List[str]): The list of versions. + version (Optional[str]): The target version. + spec (SpecifierSet): The version specifier set. + + Returns: + dict: A dictionary containing the upper and lower closest versions. + """ results = {'upper': None, 'lower': None} if (not version and not spec) or not versions: @@ -54,6 +65,15 @@ def get_closest_ver(versions, version, spec: SpecifierSet): def is_pinned_requirement(spec: SpecifierSet) -> bool: + """ + Checks if a requirement is pinned. + + Args: + spec (SpecifierSet): The version specifier set. + + Returns: + bool: True if the requirement is pinned, False otherwise. + """ if not spec or len(spec) != 1: return False @@ -63,7 +83,16 @@ def is_pinned_requirement(spec: SpecifierSet) -> bool: or specifier.operator == '===' -def find_version(requirements): +def find_version(requirements: List[PythonSpecification]) -> Optional[str]: + """ + Finds the version of a requirement. + + Args: + requirements (List[PythonSpecification]): The list of requirements. + + Returns: + Optional[str]: The version if found, otherwise None. + """ ver = None if len(requirements) != 1: @@ -77,13 +106,32 @@ def find_version(requirements): return ver -def is_supported_by_parser(path): +def is_supported_by_parser(path: str) -> bool: + """ + Checks if the file path is supported by the parser. + + Args: + path (str): The file path. + + Returns: + bool: True if supported, False otherwise. + """ supported_types = (".txt", ".in", ".yml", ".ini", "Pipfile", "Pipfile.lock", "setup.cfg", "poetry.lock") return path.endswith(supported_types) -def parse_requirement(dep, found: Optional[str]) -> PythonSpecification: +def parse_requirement(dep: str, found: Optional[str]) -> PythonSpecification: + """ + Parses a requirement and creates a PythonSpecification object. + + Args: + dep (str): The dependency string. + found (Optional[str]): The found path. + + Returns: + PythonSpecification: The parsed requirement. + """ req = PythonSpecification(dep) req.found = Path(found).resolve() if found else None @@ -93,12 +141,16 @@ def parse_requirement(dep, found: Optional[str]) -> PythonSpecification: return req -def read_requirements(fh, resolve=True): +def read_requirements(fh, resolve: bool = True) -> Generator[PythonDependency, None, None]: """ - Reads requirements from a file like object and (optionally) from referenced files. - :param fh: file like object to read from - :param resolve: boolean. resolves referenced files. - :return: generator + Reads requirements from a file-like object and (optionally) from referenced files. + + Args: + fh: The file-like object to read from. + resolve (bool): Whether to resolve referenced files. + + Returns: + Generator[PythonDependency, None, None]: A generator of PythonDependency objects. """ is_temp_file = not hasattr(fh, 'name') path = None @@ -136,7 +188,17 @@ def read_requirements(fh, resolve=True): more_info_url=None) -def read_dependencies(fh, resolve=True): +def read_dependencies(fh, resolve: bool = True) -> Generator[PythonDependency, None, None]: + """ + Reads dependencies from a file-like object. + + Args: + fh: The file-like object to read from. + resolve (bool): Whether to resolve referenced files. + + Returns: + Generator[PythonDependency, None, None]: A generator of PythonDependency objects. + """ path = fh.name absolute_path = Path(path).resolve() found = absolute_path @@ -163,8 +225,16 @@ def read_dependencies(fh, resolve=True): latest_version_without_known_vulnerabilities=None, more_info_url=None) -def read_virtual_environment_dependencies(f: InspectableFile) \ - -> Generator[PythonDependency, None, None]: +def read_virtual_environment_dependencies(f: InspectableFile) -> Generator[PythonDependency, None, None]: + """ + Reads dependencies from a virtual environment. + + Args: + f (InspectableFile): The inspectable file representing the virtual environment. + + Returns: + Generator[PythonDependency, None, None]: A generator of PythonDependency objects. + """ env_path = Path(f.file.name).resolve().parent @@ -181,7 +251,7 @@ def read_virtual_environment_dependencies(f: InspectableFile) \ if not site_pkgs_path.resolve().exists(): # Unable to find packages for foo env return - + dep_paths = site_pkgs_path.glob("*/METADATA") for path in dep_paths: @@ -193,22 +263,31 @@ def read_virtual_environment_dependencies(f: InspectableFile) \ yield PythonDependency(name=dep_name, version=dep_version, specifications=[ - PythonSpecification(f"{dep_name}=={dep_version}", - found=site_pkgs_path)], + PythonSpecification(f"{dep_name}=={dep_version}", + found=site_pkgs_path)], found=site_pkgs_path, insecure_versions=[], - secure_versions=[], latest_version=None, + secure_versions=[], latest_version=None, latest_version_without_known_vulnerabilities=None, more_info_url=None) def get_dependencies(f: InspectableFile) -> List[PythonDependency]: + """ + Gets the dependencies for the given inspectable file. + + Args: + f (InspectableFile): The inspectable file. + + Returns: + List[PythonDependency]: A list of PythonDependency objects. + """ if not f.file_type: return [] - - if f.file_type in [FileType.REQUIREMENTS_TXT, FileType.POETRY_LOCK, + + if f.file_type in [FileType.REQUIREMENTS_TXT, FileType.POETRY_LOCK, FileType.PIPENV_LOCK]: return list(read_dependencies(f.file, resolve=True)) - + if f.file_type == FileType.VIRTUAL_ENVIRONMENT: return list(read_virtual_environment_dependencies(f)) diff --git a/safety/scan/ecosystems/python/main.py b/safety/scan/ecosystems/python/main.py index bd9353bf..b72f840a 100644 --- a/safety/scan/ecosystems/python/main.py +++ b/safety/scan/ecosystems/python/main.py @@ -29,11 +29,27 @@ LOG = logging.getLogger(__name__) -def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, - vuln_id: str, cve, ignore_vulns, - ignore_unpinned: bool, ignore_environment: bool, - specification: PythonSpecification, - ignore_severity: List[VulnerabilitySeverityLabels] = []): +def ignore_vuln_if_needed( + dependency: PythonDependency, file_type: FileType, + vuln_id: str, cve, ignore_vulns, + ignore_unpinned: bool, ignore_environment: bool, + specification: PythonSpecification, + ignore_severity: List[VulnerabilitySeverityLabels] = [] +) -> None: + """ + Ignores vulnerabilities based on the provided rules. + + Args: + dependency (PythonDependency): The Python dependency. + file_type (FileType): The type of the file. + vuln_id (str): The vulnerability ID. + cve: The CVE object. + ignore_vulns: The dictionary of ignored vulnerabilities. + ignore_unpinned (bool): Whether to ignore unpinned specifications. + ignore_environment (bool): Whether to ignore environment results. + specification (PythonSpecification): The specification. + ignore_severity (List[VulnerabilitySeverityLabels]): List of severity labels to ignore. + """ vuln_ignored: bool = vuln_id in ignore_vulns @@ -80,6 +96,17 @@ def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, def should_fail(config: ConfigModel, vulnerability: Vulnerability) -> bool: + """ + Determines if a vulnerability should cause a failure based on the configuration. + + Args: + config (ConfigModel): The configuration model. + vulnerability (Vulnerability): The vulnerability. + + Returns: + bool: True if the vulnerability should cause a failure, False otherwise. + """ + if not config.depedendency_vulnerability.fail_on.enabled: return False @@ -119,10 +146,27 @@ def should_fail(config: ConfigModel, vulnerability: Vulnerability) -> bool: ) -def get_vulnerability(vuln_id: str, cve, - data, specifier, - db, name, ignore_vulns: IgnoredItems, - affected: PythonSpecification) -> Vulnerability: +def get_vulnerability( + vuln_id: str, cve, data, specifier, + db, name, ignore_vulns: IgnoredItems, + affected: PythonSpecification +) -> Vulnerability: + """ + Creates a Vulnerability object from the given data. + + Args: + vuln_id (str): The vulnerability ID. + cve: The CVE object. + data: The vulnerability data. + specifier: The specifier set. + db: The database. + name: The package name. + ignore_vulns (IgnoredItems): The ignored vulnerabilities. + affected (PythonSpecification): The affected specification. + + Returns: + Vulnerability: The created Vulnerability object. + """ base_domain = db.get('meta', {}).get('base_domain') unpinned_ignored = ignore_vulns[vuln_id].specifications \ if vuln_id in ignore_vulns.keys() else None @@ -175,14 +219,31 @@ def get_vulnerability(vuln_id: str, cve, ) class PythonFile(InspectableFile, Remediable): + """ + A class representing a Python file that can be inspected for vulnerabilities and remediated. + """ def __init__(self, file_type: FileType, file: FileTextWrite) -> None: + """ + Initializes the PythonFile instance. + + Args: + file_type (FileType): The type of the file. + file (FileTextWrite): The file object. + """ super().__init__(file=file) self.ecosystem = file_type.ecosystem self.file_type = file_type def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependency], - config: ConfigModel): + config: ConfigModel) -> None: + """ + Finds vulnerabilities in the dependencies. + + Args: + dependencies (List[PythonDependency]): The list of dependencies. + config (ConfigModel): The configuration model. + """ ignored_vulns_data = {} ignore_vulns = {} \ if not config.depedendency_vulnerability.ignore_vulnerabilities \ @@ -287,7 +348,13 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc self.dependency_results.ignored_vulns = ignore_vulns self.dependency_results.ignored_vulns_data = ignored_vulns_data - def inspect(self, config: ConfigModel): + def inspect(self, config: ConfigModel) -> None: + """ + Inspects the file for vulnerabilities based on the given configuration. + + Args: + config (ConfigModel): The configuration model. + """ # We only support vulnerability checking for now dependencies = get_dependencies(self) @@ -299,7 +366,18 @@ def inspect(self, config: ConfigModel): config=config) def __get_secure_specifications_for_user__(self, dependency: PythonDependency, db_full, - secure_vulns_by_user=None): + secure_vulns_by_user=None) -> List[str]: + """ + Gets secure specifications for the user. + + Args: + dependency (PythonDependency): The Python dependency. + db_full: The full database. + secure_vulns_by_user: The set of secure vulnerabilities by user. + + Returns: + List[str]: The list of secure specifications. + """ if not db_full: return @@ -319,7 +397,10 @@ def __get_secure_specifications_for_user__(self, dependency: PythonDependency, d return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True) - def remediate(self): + def remediate(self) -> None: + """ + Remediates the vulnerabilities in the file. + """ db_full = get_from_cache(db_name="insecure_full.json", skip_time_verification=True) if not db_full: diff --git a/safety/scan/ecosystems/target.py b/safety/scan/ecosystems/target.py index 8a95d421..9bee71a9 100644 --- a/safety/scan/ecosystems/target.py +++ b/safety/scan/ecosystems/target.py @@ -8,32 +8,75 @@ class InspectableFileContext: - def __init__(self, file_path: Path, + """ + Context manager for handling the lifecycle of an inspectable file. + + This class ensures that the file is properly opened and closed, handling any + exceptions that may occur during the process. + """ + + def __init__(self, file_path: Path, file_type: FileType) -> None: + """ + Initializes the InspectableFileContext. + + Args: + file_path (Path): The path to the file. + file_type (FileType): The type of the file. + """ self.file_path = file_path self.inspectable_file = None self.file_type = file_type def __enter__(self): # TODO: Handle permission issue /Applications/... + """ + Enters the runtime context related to this object. + + Opens the file and creates the appropriate inspectable file object based on the file type. + + Returns: + The inspectable file object. + """ try: file: FileTextWrite = open(self.file_path, mode='r+') # type: ignore self.inspectable_file = TargetFile.create(file_type=self.file_type, file=file) except Exception as e: # TODO: Report this pass - + return self.inspectable_file def __exit__(self, exc_type, exc_value, traceback): + """ + Exits the runtime context related to this object. + + Ensures that the file is properly closed. + """ if self.inspectable_file: self.inspectable_file.file.close() class TargetFile(): + """ + Factory class for creating inspectable file objects based on the file type and ecosystem. + """ @classmethod def create(cls, file_type: FileType, file: FileTextWrite): + """ + Creates an inspectable file object based on the file type and ecosystem. + + Args: + file_type (FileType): The type of the file. + file (FileTextWrite): The file object. + + Returns: + An instance of the appropriate inspectable file class. + + Raises: + ValueError: If the ecosystem or file type is unsupported. + """ if file_type.ecosystem == Ecosystem.PYTHON: return PythonFile(file=file, file_type=file_type) - + raise ValueError("Unsupported ecosystem or file type: " \ f"{file_type.ecosystem}:{file_type.value}") diff --git a/safety/scan/finder/file_finder.py b/safety/scan/finder/file_finder.py index aaacd7ea..1aab8400 100644 --- a/safety/scan/finder/file_finder.py +++ b/safety/scan/finder/file_finder.py @@ -13,6 +13,16 @@ LOG = logging.getLogger(__name__) def should_exclude(excludes: Set[Path], to_analyze: Path) -> bool: + """ + Determines whether a given path should be excluded based on the provided exclusion set. + + Args: + excludes (Set[Path]): Set of paths to exclude. + to_analyze (Path): The path to analyze. + + Returns: + bool: True if the path should be excluded, False otherwise. + """ if not to_analyze.is_absolute(): to_analyze = to_analyze.resolve() @@ -27,7 +37,7 @@ def should_exclude(excludes: Set[Path], to_analyze: Path) -> bool: return True except ValueError: pass - + return False @@ -37,25 +47,46 @@ class FileFinder(): find depending on the language type. """ - def __init__(self, max_level: int, ecosystems: List[Ecosystem], target: Path, - console, live_status=None, - exclude: Optional[List[str]] = None, - include_files: Optional[Dict[FileType, List[Path]]] = None, - handlers: Optional[Set[FileHandler]] = None) -> None: + def __init__( + self, + max_level: int, + ecosystems: List[Ecosystem], + target: Path, + console, + live_status=None, + exclude: Optional[List[str]] = None, + include_files: Optional[Dict[FileType, List[Path]]] = None, + handlers: Optional[Set[FileHandler]] = None + ) -> None: + """ + Initializes the FileFinder with the specified parameters. + + Args: + max_level (int): Maximum directory depth to search. + ecosystems (List[Ecosystem]): List of ecosystems to consider. + target (Path): Target directory to search. + console: Console object for output. + live_status: Live status object for updates. + exclude (Optional[List[str]]): List of patterns to exclude from the search. + include_files (Optional[Dict[FileType, List[Path]]]): Dictionary of files to include in the search. + handlers (Optional[Set[FileHandler]]): Set of file handlers. + """ self.max_level = max_level self.target = target self.include_files = include_files + # If no handlers are provided, initialize them from the ecosystem mapping if not handlers: - handlers = set(ECOSYSTEM_HANDLER_MAPPING[ecosystem]() + handlers = set(ECOSYSTEM_HANDLER_MAPPING[ecosystem]() for ecosystem in ecosystems) - + self.handlers = handlers self.file_count = 0 self.exclude_dirs: Set[Path] = set() self.exclude_files: Set[Path] = set() exclude = [] if not exclude else exclude + # Populate the exclude_dirs and exclude_files sets based on the provided patterns for pattern in exclude: for path in Path(target).glob(pattern): if path.is_dir(): @@ -65,8 +96,18 @@ def __init__(self, max_level: int, ecosystems: List[Ecosystem], target: Path, self.console = console self.live_status = live_status - - def process_directory(self, dir_path, max_deep: Optional[int]=None) -> Tuple[str, Dict[str, Set[Path]]]: + + def process_directory(self, dir_path: str, max_deep: Optional[int] = None) -> Tuple[str, Dict[str, Set[Path]]]: + """ + Processes the specified directory to find files matching the handlers' criteria. + + Args: + dir_path (str): The directory path to process. + max_deep (Optional[int]): Maximum depth to search within the directory. + + Returns: + Tuple[str, Dict[str, Set[Path]]]: The directory path and a dictionary of file types and their corresponding paths. + """ files: Dict[str, Set[Path]] = {} level : int = 0 initial_depth = len(Path(dir_path).parts) - 1 @@ -77,22 +118,24 @@ def process_directory(self, dir_path, max_deep: Optional[int]=None) -> Tuple[str root_path = Path(root) current_depth = len(root_path.parts) - initial_depth + # Filter directories based on exclusion criteria dirs[:] = [d for d in dirs if not should_exclude(excludes=self.exclude_dirs, to_analyze=(root_path / Path(d)))] - if dirs: LOG.info(f"Directories to inspect -> {', '.join(dirs)}") - + LOG.info(f"Current -> {root}") if self.live_status: self.live_status.update(f":mag: Scanning {root}") + # Stop descending into directories if the maximum depth is reached if max_deep is not None and current_depth > max_deep: # Don't go deeper del dirs[:] + # Filter filenames based on exclusion criteria filenames[:] = [f for f in filenames if not should_exclude( - excludes=self.exclude_files, + excludes=self.exclude_files, to_analyze=Path(f))] self.file_count += len(filenames) @@ -111,4 +154,10 @@ def process_directory(self, dir_path, max_deep: Optional[int]=None) -> Tuple[str return dir_path, files def search(self) -> Tuple[str, Dict[str, Set[Path]]]: + """ + Initiates the search for files within the target directory. + + Returns: + Tuple[str, Dict[str, Set[Path]]]: The target directory and a dictionary of file types and their corresponding paths. + """ return self.process_directory(self.target, self.max_level) diff --git a/safety/scan/finder/handlers.py b/safety/scan/finder/handlers.py index 4e2f6966..80a3db6d 100644 --- a/safety/scan/finder/handlers.py +++ b/safety/scan/finder/handlers.py @@ -2,7 +2,7 @@ import os from pathlib import Path from types import MappingProxyType -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Optional, Tuple from safety_schemas.models import Ecosystem, FileType @@ -10,11 +10,26 @@ NOT_IMPLEMENTED = "You should implement this." class FileHandler(ABC): - + """ + Abstract base class for file handlers that define how to handle specific types of files + within an ecosystem. + """ + def __init__(self) -> None: self.ecosystem: Optional[Ecosystem] = None def can_handle(self, root: str, file_name: str, include_files: Dict[FileType, List[Path]]) -> Optional[FileType]: + """ + Determines if the handler can handle the given file based on its type and inclusion criteria. + + Args: + root (str): The root directory of the file. + file_name (str): The name of the file. + include_files (Dict[FileType, List[Path]]): Dictionary of file types and their paths to include. + + Returns: + Optional[FileType]: The type of the file if it can be handled, otherwise None. + """ # Keeping it simple for now if not self.ecosystem: @@ -28,54 +43,79 @@ def can_handle(self, root: str, file_name: str, include_files: Dict[FileType, Li return f_type # Let's compare by name only for now - # We can put heavier logic here, but for speed reasons, + # We can put heavier logic here, but for speed reasons, # right now is very basic, we will improve this later. # Custom matching per File Type if file_name.lower().endswith(f_type.value.lower()): return f_type - + return None - + @abstractmethod def download_required_assets(self, session) -> Dict[str, str]: + """ + Abstract method to download required assets for handling files. Should be implemented + by subclasses. + + Args: + session: The session object for making network requests. + + Returns: + Dict[str, str]: A dictionary of downloaded assets. + """ return NotImplementedError(NOT_IMPLEMENTED) class PythonFileHandler(FileHandler): + """ + Handler for Python files within the Python ecosystem. + """ # Example of a Python File Handler - + def __init__(self) -> None: super().__init__() self.ecosystem = Ecosystem.PYTHON - - def download_required_assets(self, session): + + def download_required_assets(self, session) -> None: + """ + Downloads the required assets for handling Python files, specifically the Safety database. + + Args: + session: The session object for making network requests. + """ from safety.safety import fetch_database - + SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR") db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR - + # Fetch both the full and partial Safety databases fetch_database(session=session, full=False, db=db, cached=True, - telemetry=True, ecosystem=Ecosystem.PYTHON, + telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False) - + fetch_database(session=session, full=True, db=db, cached=True, - telemetry=True, ecosystem=Ecosystem.PYTHON, + telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False) class SafetyProjectFileHandler(FileHandler): + """ + Handler for Safety project files within the Safety project ecosystem. + """ # Example of a Python File Handler - + def __init__(self) -> None: super().__init__() self.ecosystem = Ecosystem.SAFETY_PROJECT - - def download_required_assets(self, session): + + def download_required_assets(self, session) -> None: + """ + No required assets to download for Safety project files. + """ pass - +# Mapping of ecosystems to their corresponding file handlers ECOSYSTEM_HANDLER_MAPPING = MappingProxyType({ Ecosystem.PYTHON: PythonFileHandler, Ecosystem.SAFETY_PROJECT: SafetyProjectFileHandler, diff --git a/safety/scan/main.py b/safety/scan/main.py index 52508276..3b6d2e71 100644 --- a/safety/scan/main.py +++ b/safety/scan/main.py @@ -25,11 +25,20 @@ PROJECT_CONFIG_NAME = "name" -def download_policy(session: SafetyAuthSession, - project_id: str, - stage: Stage, - branch: Optional[str]) -> Optional[PolicyFileModel]: - result = session.download_policy(project_id=project_id, stage=stage, +def download_policy(session: SafetyAuthSession, project_id: str, stage: Stage, branch: Optional[str]) -> Optional[PolicyFileModel]: + """ + Downloads the policy file from the cloud for the given project and stage. + + Args: + session (SafetyAuthSession): SafetyAuthSession object for authentication. + project_id (str): The ID of the project. + stage (Stage): The stage of the project. + branch (Optional[str]): The branch of the project (optional). + + Returns: + Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None. + """ + result = session.download_policy(project_id=project_id, stage=stage, branch=branch) if result and "uuid" in result and result["uuid"]: @@ -62,28 +71,44 @@ def download_policy(session: SafetyAuthSession, source=PolicySource.cloud, location=None, config=config) - + return None def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel: + """ + Loads an unverified project from the configuration file located at the project root. + + Args: + project_root (Path): The root directory of the project. + + Returns: + UnverifiedProjectModel: An instance of UnverifiedProjectModel. + """ config = configparser.ConfigParser() project_path = project_root / PROJECT_CONFIG config.read(project_path) id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None) - name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) + name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) created = True if id: created = False - - return UnverifiedProjectModel(id=id, url_path=url, - name=name, project_path=project_path, + + return UnverifiedProjectModel(id=id, url_path=url, + name=name, project_path=project_path, created=created) -def save_project_info(project: ProjectModel, project_path: Path): +def save_project_info(project: ProjectModel, project_path: Path) -> None: + """ + Saves the project information to the configuration file. + + Args: + project (ProjectModel): The ProjectModel object containing project information. + project_path (Path): The path to the configuration file. + """ config = configparser.ConfigParser() config.read(project_path) @@ -95,12 +120,21 @@ def save_project_info(project: ProjectModel, project_path: Path): config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path if project.name: config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name - + with open(project_path, 'w') as configfile: - config.write(configfile) + config.write(configfile) def load_policy_file(path: Path) -> Optional[PolicyFileModel]: + """ + Loads a policy file from the specified path. + + Args: + path (Path): The path to the policy file. + + Returns: + Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None. + """ config = None if not path or not path.exists(): @@ -118,13 +152,21 @@ def load_policy_file(path: Path) -> Optional[PolicyFileModel]: LOG.error(f"Wrong YML file for policy file {path}.", exc_info=True) raise SafetyError(f"{err}, details: {e}") - return PolicyFileModel(id=str(path), source=PolicySource.local, + return PolicyFileModel(id=str(path), source=PolicySource.local, location=path, config=config) -def resolve_policy(local_policy: Optional[PolicyFileModel], - cloud_policy: Optional[PolicyFileModel]) \ - -> Optional[PolicyFileModel]: +def resolve_policy(local_policy: Optional[PolicyFileModel], cloud_policy: Optional[PolicyFileModel]) -> Optional[PolicyFileModel]: + """ + Resolves the policy to be used, preferring cloud policy over local policy. + + Args: + local_policy (Optional[PolicyFileModel]): The local policy file model (optional). + cloud_policy (Optional[PolicyFileModel]): The cloud policy file model (optional). + + Returns: + Optional[PolicyFileModel]: The resolved PolicyFileModel object. + """ policy = None if cloud_policy: @@ -135,20 +177,37 @@ def resolve_policy(local_policy: Optional[PolicyFileModel], return policy -def save_report_as(scan_type: ScanType, export_type: ScanExport, at: Path, report: Any): - tag = int(time.time()) +def save_report_as(scan_type: ScanType, export_type: ScanExport, at: Path, report: Any) -> None: + """ + Saves the scan report to the specified location. + + Args: + scan_type (ScanType): The type of scan. + export_type (ScanExport): The type of export. + at (Path): The path to save the report. + report (Any): The report content. + """ + tag = int(time.time()) + + if at.is_dir(): + at = at / Path( + f"{scan_type.value}-{export_type.get_default_file_name(tag=tag)}") + + with open(at, 'w+') as report_file: + report_file.write(report) - if at.is_dir(): - at = at / Path( - f"{scan_type.value}-{export_type.get_default_file_name(tag=tag)}") - with open(at, 'w+') as report_file: - report_file.write(report) +def process_files(paths: Dict[str, Set[Path]], config: Optional[ConfigModel] = None) -> Generator[Tuple[Path, InspectableFile], None, None]: + """ + Processes the files and yields each file path along with its inspectable file. + Args: + paths (Dict[str, Set[Path]]): A dictionary of file paths by file type. + config (Optional[ConfigModel]): The configuration model (optional). -def process_files(paths: Dict[str, Set[Path]], - config: Optional[ConfigModel] = None) -> \ - Generator[Tuple[Path, InspectableFile], None, None]: + Yields: + Tuple[Path, InspectableFile]: A tuple of file path and inspectable file. + """ if not config: config = ConfigModel() @@ -158,7 +217,7 @@ def process_files(paths: Dict[str, Set[Path]], continue for f_path in f_paths: with InspectableFileContext(f_path, file_type=file_type) as inspectable_file: - if inspectable_file and inspectable_file.file_type: + if inspectable_file and inspectable_file.file_type: inspectable_file.inspect(config=config) inspectable_file.remediate() yield f_path, inspectable_file diff --git a/safety/scan/models.py b/safety/scan/models.py index 54a32895..86ffc21a 100644 --- a/safety/scan/models.py +++ b/safety/scan/models.py @@ -5,10 +5,22 @@ from pydantic.dataclasses import dataclass class FormatMixin: + """ + Mixin class providing format-related utilities for Enum classes. + """ @classmethod - def is_format(cls, format_sub: Optional[Enum], format_instance: Enum): - """ Check if the value is a variant of the specified format. """ + def is_format(cls, format_sub: Optional[Enum], format_instance: Enum) -> bool: + """ + Check if the value is a variant of the specified format. + + Args: + format_sub (Optional[Enum]): The format to check. + format_instance (Enum): The instance of the format to compare against. + + Returns: + bool: True if the format matches, otherwise False. + """ if not format_sub: return False @@ -17,19 +29,27 @@ def is_format(cls, format_sub: Optional[Enum], format_instance: Enum): prefix = format_sub.value.split('@')[0] return prefix == format_instance.value - + @property - def version(self): - """ Return the version of the format. """ + def version(self) -> Optional[str]: + """ + Return the version of the format. + + Returns: + Optional[str]: The version of the format if available, otherwise None. + """ result = self.value.split('@') if len(result) == 2: return result[1] - + return None class ScanOutput(FormatMixin, str, Enum): + """ + Enum representing different scan output formats. + """ JSON = "json" SPDX = "spdx" SPDX_2_3 = "spdx@2.3" @@ -39,19 +59,36 @@ class ScanOutput(FormatMixin, str, Enum): SCREEN = "screen" NONE = "none" - def is_silent(self): + def is_silent(self) -> bool: + """ + Check if the output format is silent. + + Returns: + bool: True if the output format is silent, otherwise False. + """ return self in (ScanOutput.JSON, ScanOutput.SPDX, ScanOutput.SPDX_2_3, ScanOutput.SPDX_2_2, ScanOutput.HTML) class ScanExport(FormatMixin, str, Enum): + """ + Enum representing different scan export formats. + """ JSON = "json" SPDX = "spdx" SPDX_2_3 = "spdx@2.3" SPDX_2_2 = "spdx@2.2" - HTML = "html" + HTML = "html" + + def get_default_file_name(self, tag: int) -> str: + """ + Get the default file name for the export format. - def get_default_file_name(self, tag: int): - + Args: + tag (int): A unique tag to include in the file name. + + Returns: + str: The default file name. + """ if self is ScanExport.JSON: return f"safety-report-{tag}.json" elif self in [ScanExport.SPDX, ScanExport.SPDX_2_3, ScanExport.SPDX_2_2]: @@ -63,19 +100,34 @@ def get_default_file_name(self, tag: int): class SystemScanOutput(str, Enum): + """ + Enum representing different system scan output formats. + """ JSON = "json" SCREEN = "screen" - def is_silent(self): - return self in (SystemScanOutput.JSON,) + def is_silent(self) -> bool: + """ + Check if the output format is silent. + + Returns: + bool: True if the output format is silent, otherwise False. + """ + return self in (SystemScanOutput.JSON,) class SystemScanExport(str, Enum): + """ + Enum representing different system scan export formats. + """ JSON = "json" @dataclass class UnverifiedProjectModel(): + """ + Data class representing an unverified project model. + """ id: Optional[str] project_path: Path created: bool name: Optional[str] = None - url_path: Optional[str] = None + url_path: Optional[str] = None diff --git a/safety/scan/render.py b/safety/scan/render.py index f5f1da2f..a7f4e277 100644 --- a/safety/scan/render.py +++ b/safety/scan/render.py @@ -5,7 +5,7 @@ import logging from pathlib import Path import time -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Tuple from rich.prompt import Prompt from rich.text import Text from rich.console import Console @@ -28,6 +28,16 @@ import datetime def render_header(targets: List[Path], is_system_scan: bool) -> Text: + """ + Render the header text for the scan. + + Args: + targets (List[Path]): List of target paths for the scan. + is_system_scan (bool): Indicates if the scan is a system scan. + + Returns: + Text: Rendered header text. + """ version = get_safety_version() scan_datetime = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") @@ -38,14 +48,29 @@ def render_header(targets: List[Path], is_system_scan: bool) -> Text: return Text.from_markup( f"[bold]Safety[/bold] {version} {action}\n{scan_datetime}") -def print_header(console, targets: List[Path], is_system_scan: bool = False): +def print_header(console, targets: List[Path], is_system_scan: bool = False) -> None: + """ + Print the header for the scan. + + Args: + console (Console): The console for output. + targets (List[Path]): List of target paths for the scan. + is_system_scan (bool): Indicates if the scan is a system scan. + """ console.print(render_header(targets, is_system_scan), markup=True) -def print_announcements(console, ctx): +def print_announcements(console: Console, ctx: typer.Context): + """ + Print announcements from Safety. + + Args: + console (Console): The console for output. + ctx (typer.Context): The context of the Typer command. + """ colors = {"error": "red", "warning": "yellow", "info": "default"} - announcements = safety.get_announcements(ctx.obj.auth.client, - telemetry=ctx.obj.config.telemetry_enabled, + announcements = safety.get_announcements(ctx.obj.auth.client, + telemetry=ctx.obj.config.telemetry_enabled, with_telemetry=ctx.obj.telemetry) basic_announcements = get_basic_announcements(announcements, False) @@ -53,12 +78,19 @@ def print_announcements(console, ctx): console.print() console.print("[bold]Safety Announcements:[/bold]") console.print() - for announcement in announcements: + for announcement in announcements: color = colors.get(announcement.get('type', "info"), "default") console.print(f"[{color}]* {announcement.get('message')}[/{color}]") -def print_detected_ecosystems_section(console, file_paths: Dict[str, Set[Path]], - include_safety_prjs: bool = True): +def print_detected_ecosystems_section(console: Console, file_paths: Dict[str, Set[Path]], include_safety_prjs: bool = True) -> None: + """ + Print detected ecosystems section. + + Args: + console (Console): The console for output. + file_paths (Dict[str, Set[Path]]): Dictionary of file paths by type. + include_safety_prjs (bool): Whether to include safety projects. + """ detected: Dict[Ecosystem, Dict[FileType, int]] = {} for file_type_key, f_paths in file_paths.items(): @@ -75,24 +107,33 @@ def print_detected_ecosystems_section(console, file_paths: Dict[str, Set[Path]], brief = "Found " file_types = [] - + for f_type, count in f_type_count.items(): file_types.append(f"{count} {f_type.human_name(plural=count>1)}") - + if len(file_types) > 1: brief += ", ".join(file_types[:-1]) + " and " + file_types[-1] else: brief += file_types[0] - + msg = f"{ecosystem.name.replace('_', ' ').title()} detected. {brief}" - + console.print(msg) -def print_brief(console, project: ProjectModel, dependencies_count: int = 0, - affected_count: int = 0, fixes_count: int = 0): +def print_brief(console: Console, project: ProjectModel, dependencies_count: int = 0, affected_count: int = 0, fixes_count: int = 0) -> None: + """ + Print a brief summary of the scan results. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + dependencies_count (int): Number of dependencies tested. + affected_count (int): Number of security issues found. + fixes_count (int): Number of fixes suggested. + """ from ..util import pluralize - if project.policy: + if project.policy: if project.policy.source is PolicySource.cloud: policy_msg = f"policy fetched from Safety Platform" else: @@ -107,8 +148,16 @@ def print_brief(console, project: ProjectModel, dependencies_count: int = 0, f"issues using {policy_msg}") console.print( f"[number]{affected_count}[/number] security {pluralize('issue', affected_count)} found, [number]{fixes_count}[/number] {pluralize('fix', fixes_count)} suggested") - -def print_fixes_section(console, requirements_txt_found: bool = False, is_detailed_output: bool = False): + +def print_fixes_section(console: Console, requirements_txt_found: bool = False, is_detailed_output: bool = False) -> None: + """ + Print the section on applying fixes. + + Args: + console (Console): The console for output. + requirements_txt_found (bool): Indicates if a requirements.txt file was found. + is_detailed_output (bool): Indicates if detailed output is enabled. + """ console.print("-" * console.size.width) console.print("Apply Fixes") console.print("-" * console.size.width) @@ -131,8 +180,17 @@ def print_fixes_section(console, requirements_txt_found: bool = False, is_detail console.print("-" * console.size.width) -def print_ignore_details(console, project: ProjectModel, ignored, - is_detailed_output: bool = False, ignored_vulns_data = None): +def print_ignore_details(console: Console, project: ProjectModel, ignored: Set[str], is_detailed_output: bool = False, ignored_vulns_data: Optional[Dict[str, Vulnerability]] = None) -> None: + """ + Print details about ignored vulnerabilities. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + ignored (Set[str]): Set of ignored vulnerabilities. + is_detailed_output (bool): Indicates if detailed output is enabled. + ignored_vulns_data (Optional[Dict[str, Vulnerability]]): Data of ignored vulnerabilities. + """ from ..util import pluralize if is_detailed_output: @@ -146,7 +204,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, unpinned_ignored = {} unpinned_ignored_pkgs = set() environment_ignored = {} - environment_ignored_pkgs = set() + environment_ignored_pkgs = set() for vuln_data in ignored_vulns_data: code = IgnoreCodes(vuln_data.ignored_code) @@ -160,7 +218,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, unpinned_ignored_pkgs.add(vuln_data.package_name) elif code is IgnoreCodes.environment_dependency: environment_ignored[vuln_data.vulnerability_id] = vuln_data - environment_ignored_pkgs.add(vuln_data.package_name) + environment_ignored_pkgs.add(vuln_data.package_name) if manual_ignored: count = len(manual_ignored) @@ -168,7 +226,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, f"[number]{count}[/number] were manually ignored due to the project policy:") for vuln in manual_ignored.values(): render_to_console(vuln, console, - rich_kwargs={"emoji": True, "overflow": "crop"}, + rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=is_detailed_output) if cvss_severity_ignored: count = len(cvss_severity_ignored) @@ -197,7 +255,19 @@ def print_ignore_details(console, project: ProjectModel, ignored, "project policy)") -def print_wait_project_verification(console, project_id, closure, on_error_delay=1): +def print_wait_project_verification(console: Console, project_id: str, closure: Tuple[Any, Dict[str, Any]], on_error_delay: int = 1) -> Any: + """ + Print a waiting message while verifying a project. + + Args: + console (Console): The console for output. + project_id (str): The project ID. + closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call. + on_error_delay (int): Delay in seconds on error. + + Returns: + Any: The status of the project verification. + """ status = None wait_msg = f"Verifying project {project_id} with Safety Platform." @@ -215,10 +285,17 @@ def print_wait_project_verification(console, project_id, closure, on_error_delay if not status: wait_msg = f'Unable to verify "{project_id}". Starting again...' time.sleep(on_error_delay) - + return status -def print_project_info(console, project: ProjectModel): +def print_project_info(console: Console, project: ProjectModel): + """ + Print information about the project. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + """ config_msg = "loaded without policies or custom configuration." if project.policy: @@ -229,11 +306,21 @@ def print_project_info(console, project: ProjectModel): else: config_msg = " policies fetched " \ "from Safety Platform." - + msg = f"[bold]{project.id} project found[/bold] - {config_msg}" console.print(msg) -def print_wait_policy_download(console, closure) -> Optional[PolicyFileModel]: +def print_wait_policy_download(console: Console, closure: Tuple[Any, Dict[str, Any]]) -> Optional[PolicyFileModel]: + """ + Print a waiting message while downloading a policy from the cloud. + + Args: + console (Console): The console for output. + closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call. + + Returns: + Optional[PolicyFileModel]: The downloaded policy file model. + """ policy = None wait_msg = "Looking for a policy from cloud..." @@ -253,9 +340,19 @@ def print_wait_policy_download(console, closure) -> Optional[PolicyFileModel]: return policy -def prompt_project_id(console, stage: Stage, - prj_root_name: Optional[str], - do_not_exit=True) -> str: +def prompt_project_id(console: Console, stage: Stage, prj_root_name: Optional[str], do_not_exit: bool = True) -> Optional[str]: + """ + Prompt the user to set a project ID for the scan. + + Args: + console (Console): The console for output. + stage (Stage): The current stage. + prj_root_name (Optional[str]): The root name of the project. + do_not_exit (bool): Indicates if the function should not exit on failure. + + Returns: + Optional[str]: The project ID. + """ from safety.util import clean_project_id default_prj_id = clean_project_id(prj_root_name) if prj_root_name else None @@ -264,10 +361,10 @@ def prompt_project_id(console, stage: Stage, # Fail here console.print("The scan needs to be linked to a project.") raise typer.Exit(code=1) - + hint = "" - if default_prj_id: - hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" + if default_prj_id: + hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" prompt_text = f"Set a project id for this scan (no spaces).{hint}" def ask(): @@ -290,27 +387,55 @@ def ask(): return project_id -def prompt_link_project(console, prj_name: str, prj_admin_email: str) -> bool: +def prompt_link_project(console: Console, prj_name: str, prj_admin_email: str) -> bool: + """ + Prompt the user to link the scan with an existing project. + + Args: + console (Console): The console for output. + prj_name (str): The project name. + prj_admin_email (str): The project admin email. + + Returns: + bool: True if the user wants to link the scan, False otherwise. + """ console.print("[bold]Safety found an existing project with this name in your organization:[/bold]") - for detail in (f"[bold]Project name:[/bold] {prj_name}", + for detail in (f"[bold]Project name:[/bold] {prj_name}", f"[bold]Project admin:[/bold] {prj_admin_email}"): console.print(Padding(detail, (0, 0, 0, 2)), emoji=True) prompt_question = "Do you want to link this scan with this existing project?" - - answer = Prompt.ask(prompt=prompt_question, choices=["y", "n"], + + answer = Prompt.ask(prompt=prompt_question, choices=["y", "n"], default="y", show_default=True, console=console).lower() - + return answer == "y" -def render_to_console(cls: Vulnerability, console: Console, rich_kwargs, - detailed_output: bool = False): +def render_to_console(cls: Vulnerability, console: Console, rich_kwargs: Dict[str, Any], detailed_output: bool = False) -> None: + """ + Render a vulnerability to the console. + + Args: + cls (Vulnerability): The vulnerability instance. + console (Console): The console for output. + rich_kwargs (Dict[str, Any]): Additional arguments for rendering. + detailed_output (bool): Indicates if detailed output is enabled. + """ cls.__render__(console, detailed_output, rich_kwargs) -def get_render_console(entity_type): +def get_render_console(entity_type: Any) -> Any: + """ + Get the render function for a specific entity type. + + Args: + entity_type (Any): The entity type. + + Returns: + Any: The render function. + """ if entity_type is Vulnerability: def __render__(self, console: Console, detailed_output: bool, rich_kwargs): @@ -330,12 +455,12 @@ def __render__(self, console: Console, detailed_output: bool, rich_kwargs): console.print( Padding( - f"->{pre} Vuln ID [vuln_id]{self.vulnerability_id}[/vuln_id]: {severity_detail if severity_detail else ''}", + f"->{pre} Vuln ID [vuln_id]{self.vulnerability_id}[/vuln_id]: {severity_detail if severity_detail else ''}", (0, 0, 0, 2) ), **rich_kwargs) console.print( Padding( - f"{self.advisory[:advisory_length]}{'...' if len(self.advisory) > advisory_length else ''}", + f"{self.advisory[:advisory_length]}{'...' if len(self.advisory) > advisory_length else ''}", (0, 0, 0, 5) ), **rich_kwargs) @@ -347,7 +472,17 @@ def __render__(self, console: Console, detailed_output: bool, rich_kwargs): return __render__ -def render_scan_html(report: ReportModel, obj) -> str: +def render_scan_html(report: ReportModel, obj: Any) -> str: + """ + Render the scan report to HTML. + + Args: + report (ReportModel): The scan report model. + obj (Any): The object containing additional settings. + + Returns: + str: The rendered HTML report. + """ from safety.scan.command import ScannableEcosystems project = report.projects[0] if any(report.projects) else None @@ -376,30 +511,40 @@ def render_scan_html(report: ReportModel, obj) -> str: ignored_packages += len(file.results.ignored_vulns) # TODO: Get this information for the report model (?) - summary = {"scanned_packages": scanned_packages, - "affected_packages": affected_packages, + summary = {"scanned_packages": scanned_packages, + "affected_packages": affected_packages, "remediations_recommended": remediations_recommended, "ignored_vulnerabilities": ignored_vulnerabilities, "vulnerabilities": vulnerabilities} - + vulnerabilities = [] - - + + # TODO: This should be based on the configs per command ecosystems = [(f"{ecosystem.name.title()}", [file_type.human_name(plural=True) for file_type in ecosystem.file_types]) for ecosystem in [Ecosystem(member.value) for member in list(ScannableEcosystems)]] - + settings ={"audit_and_monitor": True, "platform_url": SAFETY_PLATFORM_URL, "ecosystems": ecosystems} - template_context = {"report": report, "summary": summary, "announcements": [], - "project": project, + template_context = {"report": report, "summary": summary, "announcements": [], + "project": project, "platform_enabled": obj.platform_enabled, "settings": settings, "vulns_per_file": vulns_per_file, "remed_per_file": remed_per_file} - + return parse_html(kwargs=template_context, template="scan/index.html") -def generate_spdx_creation_info(*, spdx_version: str, project_identifier: str) -> Any: +def generate_spdx_creation_info(spdx_version: str, project_identifier: str) -> Any: + """ + Generate SPDX creation information. + + Args: + spdx_version (str): The SPDX version. + project_identifier (str): The project identifier. + + Returns: + Any: The SPDX creation information. + """ from spdx_tools.spdx.model import ( Actor, ActorType, @@ -439,7 +584,17 @@ def generate_spdx_creation_info(*, spdx_version: str, project_identifier: str) - return creation_info -def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]): +def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]) -> Any: + """ + Create an external package reference for SPDX. + + Args: + package (PythonDependency): The package dependency. + version (Optional[str]): The package version. + + Returns: + Any: The external package reference. + """ from spdx_tools.spdx.model import ( ExternalPackageRef, ExternalPackageRefCategory, @@ -455,11 +610,20 @@ def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]): def create_packages(dependencies: List[PythonDependency]) -> List[Any]: + """ + Create a list of SPDX packages. + + Args: + dependencies (List[PythonDependency]): List of Python dependencies. + + Returns: + List[Any]: List of SPDX packages. + """ from spdx_tools.spdx.model.spdx_no_assertion import SpdxNoAssertion from spdx_tools.spdx.model import ( Package, - ) + ) doc_pkgs = [] pkgs_added = set([]) @@ -471,7 +635,7 @@ def create_packages(dependencies: List[PythonDependency]) -> List[Any]: if pkg_id in pkgs_added: continue pkg_ref = create_pkg_ext_ref(package=dependency, version=pkg_version) - + pkg = Package( spdx_id=pkg_id, name=f"pip:{dep_name}", @@ -491,6 +655,16 @@ def create_packages(dependencies: List[PythonDependency]) -> List[Any]: def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[Any]: + """ + Create an SPDX document. + + Args: + report (ReportModel): The scan report model. + spdx_version (str): The SPDX version. + + Returns: + Optional[Any]: The SPDX document. + """ from spdx_tools.spdx.model import ( Document, Relationship, @@ -501,13 +675,13 @@ def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[ if not project: return None - + prj_id = project.id - + if not prj_id: parent_name = project.project_path.parent.name prj_id = parent_name if parent_name else str(int(time.time())) - + creation_info = generate_spdx_creation_info(spdx_version=spdx_version, project_identifier=prj_id) depedencies = iter([]) @@ -534,12 +708,23 @@ def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[ return spdx_doc -def render_scan_spdx(report: ReportModel, obj, spdx_version: Optional[str]) -> Optional[Any]: +def render_scan_spdx(report: ReportModel, obj: Any, spdx_version: Optional[str]) -> Optional[Any]: + """ + Render the scan report to SPDX format. + + Args: + report (ReportModel): The scan report model. + obj (Any): The object containing additional settings. + spdx_version (Optional[str]): The SPDX version. + + Returns: + Optional[Any]: The rendered SPDX document in JSON format. + """ from spdx_tools.spdx.writer.write_utils import ( convert, validate_and_deduplicate ) - + # Set to latest supported if a version is not specified if not spdx_version: spdx_version = "2.3" diff --git a/safety/scan/util.py b/safety/scan/util.py index 388b3f98..3fea1a5c 100644 --- a/safety/scan/util.py +++ b/safety/scan/util.py @@ -13,11 +13,20 @@ LOG = logging.getLogger(__name__) class Language(str, Enum): + """ + Enum representing supported programming languages. + """ python = "python" javascript = "javascript" safety_project = "safety_project" def handler(self) -> FileHandler: + """ + Get the appropriate file handler for the language. + + Returns: + FileHandler: The file handler for the language. + """ if self is Language.python: return PythonFileHandler() if self is Language.safety_project: @@ -26,20 +35,35 @@ def handler(self) -> FileHandler: return PythonFileHandler() class Output(Enum): + """ + Enum representing output formats. + """ json = "json" class AuthenticationType(str, Enum): + """ + Enum representing authentication types. + """ token = "token" api_key = "api_key" none = "unauthenticated" def is_allowed_in(self, stage: Stage = Stage.development) -> bool: + """ + Check if the authentication type is allowed in the given stage. + + Args: + stage (Stage): The current stage. + + Returns: + bool: True if the authentication type is allowed, otherwise False. + """ if self is AuthenticationType.none: - return False - + return False + if stage == Stage.development and self is AuthenticationType.api_key: return False - + if (not stage == Stage.development) and self is AuthenticationType.token: return False @@ -47,64 +71,137 @@ def is_allowed_in(self, stage: Stage = Stage.development) -> bool: class GIT: + """ + Class representing Git operations. + """ ORIGIN_CMD: Tuple[str, ...] = ("remote", "get-url", "origin") BRANCH_CMD: Tuple[str, ...] = ("symbolic-ref", "--short", "-q", "HEAD") TAG_CMD: Tuple[str, ...] = ("describe", "--tags", "--exact-match") - DESCRIBE_CMD: Tuple[str, ...] = ("describe", '--match=""', '--always', + DESCRIBE_CMD: Tuple[str, ...] = ("describe", '--match=""', '--always', '--abbrev=40', '--dirty') GIT_CHECK_CMD: Tuple[str, ...] = ("rev-parse", "--is-inside-work-tree") - + def __init__(self, root: Path = Path(".")) -> None: + """ + Initialize the GIT class with the given root directory. + + Args: + root (Path): The root directory for Git operations. + """ self.git = ("git", "-C", root.resolve()) def __run__(self, cmd: Tuple[str, ...], env_var: Optional[str] = None) -> Optional[str]: + """ + Run a Git command. + + Args: + cmd (Tuple[str, ...]): The Git command to run. + env_var (Optional[str]): An optional environment variable to check for the command result. + + Returns: + Optional[str]: The result of the Git command, or None if an error occurred. + """ if env_var and os.environ.get(env_var): return os.environ.get(env_var) try: - return subprocess.run(self.git + cmd, stdout=subprocess.PIPE, + return subprocess.run(self.git + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout.decode('utf-8').strip() except Exception as e: LOG.exception(e) - + return None def origin(self) -> Optional[str]: + """ + Get the Git origin URL. + + Returns: + Optional[str]: The Git origin URL, or None if an error occurred. + """ return self.__run__(self.ORIGIN_CMD, env_var="SAFETY_GIT_ORIGIN") - + def branch(self) -> Optional[str]: + """ + Get the current Git branch. + + Returns: + Optional[str]: The current Git branch, or None if an error occurred. + """ return self.__run__(self.BRANCH_CMD, env_var="SAFETY_GIT_BRANCH") def tag(self) -> Optional[str]: + """ + Get the current Git tag. + + Returns: + Optional[str]: The current Git tag, or None if an error occurred. + """ return self.__run__(self.TAG_CMD, env_var="SAFETY_GIT_TAG") - + def describe(self) -> Optional[str]: + """ + Get the Git describe output. + + Returns: + Optional[str]: The Git describe output, or None if an error occurred. + """ return self.__run__(self.DESCRIBE_CMD) - + def dirty(self, raw_describe: str) -> bool: + """ + Check if the working directory is dirty. + + Args: + raw_describe (str): The raw describe output. + + Returns: + bool: True if the working directory is dirty, otherwise False. + """ if os.environ.get("SAFETY_GIT_DIRTY") in ["0", "1"]: return bool(int(os.environ.get("SAFETY_GIT_DIRTY"))) - + return raw_describe.endswith('-dirty') def commit(self, raw_describe: str) -> Optional[str]: + """ + Get the current Git commit hash. + + Args: + raw_describe (str): The raw describe output. + + Returns: + Optional[str]: The current Git commit hash, or None if an error occurred. + """ if os.environ.get("SAFETY_GIT_COMMIT"): return os.environ.get("SAFETY_GIT_COMMIT") - try: + try: return raw_describe.split("-dirty")[0] except Exception: pass def is_git(self) -> bool: + """ + Check if the current directory is a Git repository. + + Returns: + bool: True if the current directory is a Git repository, otherwise False. + """ result = self.__run__(self.GIT_CHECK_CMD) if result == "true": return True - + return False def build_git_data(self): + """ + Build a GITModel object with Git data. + + Returns: + Optional[GITModel]: The GITModel object with Git data, or None if the directory is not a Git repository. + """ from safety_schemas.models import GITModel if self.is_git(): @@ -114,8 +211,8 @@ def build_git_data(self): if raw_describe: commit = self.commit(raw_describe) dirty = self.dirty(raw_describe) - return GITModel(branch=self.branch(), - tag=self.tag(), commit=commit, dirty=dirty, + return GITModel(branch=self.branch(), + tag=self.tag(), commit=commit, dirty=dirty, origin=self.origin()) - + return None diff --git a/safety/scan/validators.py b/safety/scan/validators.py index 12aa777f..0bfb81ea 100644 --- a/safety/scan/validators.py +++ b/safety/scan/validators.py @@ -8,19 +8,32 @@ from safety.scan.render import print_wait_project_verification, prompt_project_id, prompt_link_project from safety_schemas.models import AuthenticationType, ProjectModel, Stage +from safety.auth.utils import SafetyAuthSession MISSING_SPDX_EXTENSION_MSG = "spdx extra is not installed, please install it with: pip install safety[spdx]" -def raise_if_not_spdx_extension_installed(): +def raise_if_not_spdx_extension_installed() -> None: + """ + Raises an error if the spdx extension is not installed. + """ try: import spdx_tools.spdx except Exception as e: - raise typer.BadParameter(MISSING_SPDX_EXTENSION_MSG) + raise typer.BadParameter(MISSING_SPDX_EXTENSION_MSG) -def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]): +def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]) -> Tuple[Optional[str], Optional[Path]]: + """ + Callback function to handle save_as parameter and validate if spdx extension is installed. + + Args: + save_as (Optional[Tuple[ScanExport, Path]]): The export type and path. + + Returns: + Tuple[Optional[str], Optional[Path]]: The validated export type and path. + """ export_type, export_path = save_as if save_as else (None, None) if ScanExport.is_format(export_type, ScanExport.SPDX): @@ -28,18 +41,32 @@ def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]): return (export_type.value, export_path) if export_type and export_path else (export_type, export_path) -def output_callback(output: ScanOutput): +def output_callback(output: ScanOutput) -> str: + """ + Callback function to handle output parameter and validate if spdx extension is installed. + + Args: + output (ScanOutput): The output format. + Returns: + str: The validated output format. + """ if ScanOutput.is_format(output, ScanExport.SPDX): raise_if_not_spdx_extension_installed() - + return output.value def fail_if_not_allowed_stage(ctx: typer.Context): + """ + Fail the command if the authentication type is not allowed in the current stage. + + Args: + ctx (typer.Context): The context of the Typer command. + """ if ctx.resilient_parsing: return - + stage = ctx.obj.auth.stage auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type() @@ -51,7 +78,17 @@ def fail_if_not_allowed_stage(ctx: typer.Context): f"the '{stage}' stage.") -def save_verified_project(ctx, slug, name, project_path, url_path): +def save_verified_project(ctx: typer.Context, slug: str, name: Optional[str], project_path: Path, url_path: Optional[str]): + """ + Save the verified project information to the context and project info file. + + Args: + ctx (typer.Context): The context of the Typer command. + slug (str): The project slug. + name (Optional[str]): The project name. + project_path (Path): The project path. + url_path (Optional[str]): The project URL path. + """ ctx.obj.project = ProjectModel( id=slug, name=name, @@ -59,14 +96,28 @@ def save_verified_project(ctx, slug, name, project_path, url_path): url_path=url_path ) if ctx.obj.auth.stage is Stage.development: - save_project_info(project=ctx.obj.project, + save_project_info(project=ctx.obj.project, project_path=project_path) -def check_project(console, ctx, session, - unverified_project: UnverifiedProjectModel, - stage, - git_origin, ask_project_id=False): +def check_project(console, ctx: typer.Context, session: SafetyAuthSession, + unverified_project: UnverifiedProjectModel, stage: Stage, + git_origin: Optional[str], ask_project_id: bool = False) -> dict: + """ + Check the project against the session and stage, verifying the project if necessary. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + ask_project_id (bool): Whether to prompt for the project ID. + + Returns: + dict: The result of the project check. + """ stage = ctx.obj.auth.stage source = ctx.obj.telemetry.safety_source if ctx.obj.telemetry else None data = {"scan_stage": stage, "safety_source": source} @@ -91,17 +142,27 @@ def check_project(console, ctx, session, data[PRJ_SLUG_KEY] = unverified_project.id data[PRJ_SLUG_SOURCE_KEY] = "user" - status = print_wait_project_verification(console, data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", + status = print_wait_project_verification(console, data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", (session.check_project, data), on_error_delay=1) return status -def verify_project(console, ctx, session, - unverified_project: UnverifiedProjectModel, - stage, - git_origin): - +def verify_project(console, ctx: typer.Context, session: SafetyAuthSession, + unverified_project: UnverifiedProjectModel, stage: Stage, + git_origin: Optional[str]): + """ + Verify the project, linking it if necessary and saving the verified project information. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + """ + verified_prj = False link_prj = True @@ -122,17 +183,17 @@ def verify_project(console, ctx, session, link_prj = prompt_link_project(prj_name=prj_name, prj_admin_email=prj_admin_email, console=console) - + if not link_prj: continue verified_prj = print_wait_project_verification( - console, unverified_slug, (session.project, + console, unverified_slug, (session.project, {"project_id": unverified_slug}), on_error_delay=1) - + if verified_prj and isinstance(verified_prj, dict) and verified_prj.get("slug", None): - save_verified_project(ctx, verified_prj["slug"], verified_prj.get("name", None), + save_verified_project(ctx, verified_prj["slug"], verified_prj.get("name", None), unverified_project.project_path, verified_prj.get("url", None)) else: verified_prj = False diff --git a/safety/util.py b/safety/util.py index feef3747..420eb13a 100644 --- a/safety/util.py +++ b/safety/util.py @@ -7,7 +7,7 @@ from datetime import datetime from difflib import SequenceMatcher from threading import Lock -from typing import List, Optional +from typing import List, Optional, Dict, Generator, Tuple, Union, Any import click from click import BadParameter @@ -27,17 +27,45 @@ LOG = logging.getLogger(__name__) -def is_a_remote_mirror(mirror): +def is_a_remote_mirror(mirror: str) -> bool: + """ + Check if a mirror URL is remote. + + Args: + mirror (str): The mirror URL. + + Returns: + bool: True if the mirror URL is remote, False otherwise. + """ return mirror.startswith("http://") or mirror.startswith("https://") -def is_supported_by_parser(path): +def is_supported_by_parser(path: str) -> bool: + """ + Check if the file path is supported by the parser. + + Args: + path (str): The file path. + + Returns: + bool: True if the file path is supported, False otherwise. + """ supported_types = (".txt", ".in", ".yml", ".ini", "Pipfile", "Pipfile.lock", "setup.cfg", "poetry.lock") return path.endswith(supported_types) -def parse_requirement(dep, found): +def parse_requirement(dep: Any, found: str) -> SafetyRequirement: + """ + Parse a requirement. + + Args: + dep (Any): The dependency. + found (str): The location where the dependency was found. + + Returns: + SafetyRequirement: The parsed requirement. + """ req = SafetyRequirement(dep) req.found = found @@ -47,7 +75,16 @@ def parse_requirement(dep, found): return req -def find_version(requirements): +def find_version(requirements: List[SafetyRequirement]) -> Optional[str]: + """ + Find the version of a requirement. + + Args: + requirements (List[SafetyRequirement]): The list of requirements. + + Returns: + Optional[str]: The version if found, None otherwise. + """ ver = None if len(requirements) != 1: @@ -61,12 +98,16 @@ def find_version(requirements): return ver -def read_requirements(fh, resolve=True): +def read_requirements(fh: Any, resolve: bool = True) -> Generator[Package, None, None]: """ - Reads requirements from a file like object and (optionally) from referenced files. - :param fh: file like object to read from - :param resolve: boolean. resolves referenced files. - :return: generator + Reads requirements from a file-like object and (optionally) from referenced files. + + Args: + fh (Any): The file-like object to read from. + resolve (bool): Resolves referenced files. + + Returns: + Generator: Yields Package objects. """ is_temp_file = not hasattr(fh, 'name') path = None @@ -111,14 +152,35 @@ def read_requirements(fh, resolve=True): more_info_url=None) -def get_proxy_dict(proxy_protocol, proxy_host, proxy_port): +def get_proxy_dict(proxy_protocol: str, proxy_host: str, proxy_port: int) -> Optional[Dict[str, str]]: + """ + Get the proxy dictionary for requests. + + Args: + proxy_protocol (str): The proxy protocol. + proxy_host (str): The proxy host. + proxy_port (int): The proxy port. + + Returns: + Optional[Dict[str, str]]: The proxy dictionary if all parameters are provided, None otherwise. + """ if proxy_protocol and proxy_host and proxy_port: # Safety only uses https request, so only https dict will be passed to requests return {'https': f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"} return None -def get_license_name_by_id(license_id, db): +def get_license_name_by_id(license_id: int, db: Dict[str, Any]) -> Optional[str]: + """ + Get the license name by its ID. + + Args: + license_id (int): The license ID. + db (Dict[str, Any]): The database containing license information. + + Returns: + Optional[str]: The license name if found, None otherwise. + """ licenses = db.get('licenses', []) for name, id in licenses.items(): if id == license_id: @@ -126,7 +188,13 @@ def get_license_name_by_id(license_id, db): return None -def get_flags_from_context(): +def get_flags_from_context() -> Dict[str, str]: + """ + Get the flags from the current click context. + + Returns: + Dict[str, str]: A dictionary of flags and their corresponding option names. + """ flags = {} context = click.get_current_context(silent=True) @@ -139,7 +207,13 @@ def get_flags_from_context(): return flags -def get_used_options(): +def get_used_options() -> Dict[str, Dict[str, int]]: + """ + Get the used options from the command-line arguments. + + Returns: + Dict[str, Dict[str, int]]: A dictionary of used options and their counts. + """ flags = get_flags_from_context() used_options = {} @@ -156,12 +230,27 @@ def get_used_options(): return used_options -def get_safety_version(): +def get_safety_version() -> str: + """ + Get the version of Safety. + + Returns: + str: The Safety version. + """ from safety import VERSION return VERSION -def get_primary_announcement(announcements): +def get_primary_announcement(announcements: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Get the primary announcement from a list of announcements. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + + Returns: + Optional[Dict[str, Any]]: The primary announcement if found, None otherwise. + """ for announcement in announcements: if announcement.get('type', '').lower() == 'primary_announcement': try: @@ -176,20 +265,50 @@ def get_primary_announcement(announcements): return None -def get_basic_announcements(announcements, include_local: bool = True): +def get_basic_announcements(announcements: List[Dict[str, Any]], include_local: bool = True) -> List[Dict[str, Any]]: + """ + Get the basic announcements from a list of announcements. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + include_local (bool): Whether to include local announcements. + + Returns: + List[Dict[str, Any]]: The list of basic announcements. + """ return [announcement for announcement in announcements if announcement.get('type', '').lower() != 'primary_announcement' and not announcement.get('local', False) or (announcement.get('local', False) and include_local)] -def filter_announcements(announcements, by_type='error'): +def filter_announcements(announcements: List[Dict[str, Any]], by_type: str = 'error') -> List[Dict[str, Any]]: + """ + Filter announcements by type. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + by_type (str): The type of announcements to filter by. + + Returns: + List[Dict[str, Any]]: The filtered announcements. + """ return [announcement for announcement in announcements if announcement.get('type', '').lower() == by_type] -def build_telemetry_data(telemetry = True, - command: Optional[str] = None, +def build_telemetry_data(telemetry: bool = True, + command: Optional[str] = None, subcommand: Optional[str] = None) -> TelemetryModel: + """Build telemetry data for the Safety context. + + Args: + telemetry (bool): Whether telemetry is enabled. + command (Optional[str]): The command. + subcommand (Optional[str]): The subcommand. + + Returns: + TelemetryModel: The telemetry data model. + """ context = SafetyContext() body = { @@ -212,10 +331,15 @@ def build_telemetry_data(telemetry = True, return TelemetryModel(**body) -def build_git_data(): +def build_git_data() -> Dict[str, Any]: + """Build git data for the repository. + + Returns: + Dict[str, str]: The git data. + """ import subprocess - def git_command(commandline): + def git_command(commandline: List[str]) -> str: return subprocess.run(commandline, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout.decode('utf-8').strip() try: @@ -251,7 +375,17 @@ def git_command(commandline): } -def output_exception(exception, exit_code_output=True): +def output_exception(exception: Exception, exit_code_output: bool = True) -> None: + """ + Output an exception message to the console and exit. + + Args: + exception (Exception): The exception to output. + exit_code_output (bool): Whether to output the exit code. + + Exits: + Exits the program with the appropriate exit code. + """ click.secho(str(exception), fg="red", file=sys.stderr) if exit_code_output: @@ -264,7 +398,19 @@ def output_exception(exception, exit_code_output=True): sys.exit(exit_code) def build_remediation_info_url(base_url: str, version: Optional[str], spec: str, - target_version: Optional[str] = ''): + target_version: Optional[str] = '') -> str: + """ + Build the remediation info URL. + + Args: + base_url (str): The base URL. + version (Optional[str]): The current version. + spec (str): The specification. + target_version (Optional[str]): The target version. + + Returns: + str: The remediation info URL. + """ params = {'from': version, 'to': target_version} @@ -277,8 +423,23 @@ def build_remediation_info_url(base_url: str, version: Optional[str], spec: str, return req.url -def get_processed_options(policy_file, ignore, ignore_severity_rules, exit_code, ignore_unpinned_requirements=None, - project=None): +def get_processed_options(policy_file: Dict[str, Any], ignore: Dict[str, Any], ignore_severity_rules: Dict[str, Any], + exit_code: bool, ignore_unpinned_requirements: Optional[bool] = None, + project: Optional[str] = None) -> Tuple[Dict[str, Any], Dict[str, Any], bool, Optional[bool], Optional[str]]: + """ + Get processed options from the policy file. + + Args: + policy_file (Dict[str, Any]): The policy file. + ignore (Dict[str, Any]): The ignore settings. + ignore_severity_rules (Dict[str, Any]): The ignore severity rules. + exit_code (bool): The exit code setting. + ignore_unpinned_requirements (Optional[bool]): The ignore unpinned requirements setting. + project (Optional[str]): The project setting. + + Returns: + Tuple[Dict[str, Any], Dict[str, Any], bool, Optional[bool], Optional[str]]: The processed options. + """ if policy_file: project_config = policy_file.get('project', {}) security = policy_file.get('security', {}) @@ -306,7 +467,17 @@ def get_processed_options(policy_file, ignore, ignore_severity_rules, exit_code, return ignore, ignore_severity_rules, exit_code, ignore_unpinned_requirements, project -def get_fix_options(policy_file, auto_remediation_limit): +def get_fix_options(policy_file: Dict[str, Any], auto_remediation_limit: int) -> int: + """ + Get fix options from the policy file. + + Args: + policy_file (Dict[str, Any]): The policy file. + auto_remediation_limit (int): The auto remediation limit. + + Returns: + int: The auto remediation limit. + """ auto_fix = [] source = click.get_current_context().get_parameter_source("auto_remediation_limit") @@ -323,6 +494,10 @@ def get_fix_options(policy_file, auto_remediation_limit): class MutuallyExclusiveOption(click.Option): + """ + A click option that is mutually exclusive with other options. + """ + def __init__(self, *args, **kwargs): self.mutually_exclusive = set(kwargs.pop('mutually_exclusive', [])) self.with_values = kwargs.pop('with_values', {}) @@ -335,7 +510,18 @@ def __init__(self, *args, **kwargs): ) super(MutuallyExclusiveOption, self).__init__(*args, **kwargs) - def handle_parse_result(self, ctx, opts, args): + def handle_parse_result(self, ctx: click.Context, opts: Dict[str, Any], args: List[str]) -> Tuple[Any, List[str]]: + """ + Handle the parse result for mutually exclusive options. + + Args: + ctx (click.Context): The click context. + opts (Dict[str, Any]): The options dictionary. + args (List[str]): The arguments list. + + Returns: + Tuple[Any, List[str]]: The result and remaining arguments. + """ m_exclusive_used = self.mutually_exclusive.intersection(opts) option_used = m_exclusive_used and self.name in opts @@ -363,6 +549,9 @@ def handle_parse_result(self, ctx, opts, args): class DependentOption(click.Option): + """ + A click option that depends on other options. + """ def __init__(self, *args, **kwargs): self.required_options = set(kwargs.pop('required_options', [])) help = kwargs.get('help', '') @@ -373,7 +562,18 @@ def __init__(self, *args, **kwargs): ) super(DependentOption, self).__init__(*args, **kwargs) - def handle_parse_result(self, ctx, opts, args): + def handle_parse_result(self, ctx: click.Context, opts: Dict[str, Any], args: List[str]) -> Tuple[Any, List[str]]: + """ + Handle the parse result for dependent options. + + Args: + ctx (click.Context): The click context. + opts (Dict[str, Any]): The options dictionary. + args (List[str]): The arguments list. + + Returns: + Tuple[Any, List[str]]: The result and remaining arguments. + """ missing_required_arguments = None if self.name in opts: @@ -395,7 +595,18 @@ def handle_parse_result(self, ctx, opts, args): ) -def transform_ignore(ctx, param, value): +def transform_ignore(ctx: click.Context, param: click.Parameter, value: Tuple[str]) -> Dict[str, Dict[str, Optional[str]]]: + """ + Transform ignore parameters into a dictionary. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (Tuple[str]): The parameter value. + + Returns: + Dict[str, Dict[str, Optional[str]]]: The transformed ignore parameters. + """ ignored_default_dict = {'reason': '', 'expires': None} if isinstance(value, tuple) and any(value): # Following code is required to support the 2 ways of providing 'ignore' @@ -409,7 +620,18 @@ def transform_ignore(ctx, param, value): return {} -def active_color_if_needed(ctx, param, value): +def active_color_if_needed(ctx: click.Context, param: click.Parameter, value: str) -> str: + """ + Activate color if needed based on the context and environment variables. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (str): The parameter value. + + Returns: + str: The parameter value. + """ if value == 'screen': ctx.color = True @@ -426,24 +648,63 @@ def active_color_if_needed(ctx, param, value): return value -def json_alias(ctx, param, value): +def json_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'json' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'json' return value -def html_alias(ctx, param, value): +def html_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'html' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'html' return value -def bare_alias(ctx, param, value): +def bare_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'bare' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'bare' return value -def get_terminal_size(): +def get_terminal_size() -> os.terminal_size: + """ + Get the terminal size. + + Returns: + os.terminal_size: The terminal size. + """ from shutil import get_terminal_size as t_size # get_terminal_size can report 0, 0 if run from pseudo-terminal prior Python 3.11 versions @@ -453,7 +714,16 @@ def get_terminal_size(): return os.terminal_size((columns, lines)) -def clean_project_id(input_string): +def clean_project_id(input_string: str) -> str: + """ + Clean a project ID by removing non-alphanumeric characters and normalizing the string. + + Args: + input_string (str): The input string. + + Returns: + str: The cleaned project ID. + """ input_string = re.sub(r'[^a-zA-Z0-9]+', '-', input_string) input_string = input_string.strip('-') input_string = input_string.lower() @@ -461,7 +731,16 @@ def clean_project_id(input_string): return input_string -def validate_expiration_date(expiration_date): +def validate_expiration_date(expiration_date: str) -> Optional[datetime]: + """ + Validate an expiration date string. + + Args: + expiration_date (str): The expiration date string. + + Returns: + Optional[datetime]: The validated expiration date if valid, None otherwise. + """ d = None if expiration_date: @@ -480,7 +759,7 @@ def validate_expiration_date(expiration_date): class SafetyPolicyFile(click.ParamType): """ - Custom Safety Policy file to hold validations + Custom Safety Policy file to hold validations. """ name = "filename" @@ -489,7 +768,7 @@ class SafetyPolicyFile(click.ParamType): def __init__( self, mode: str = "r", - encoding: str = None, + encoding: Optional[str] = None, errors: str = "strict", pure: bool = os.environ.get('SAFETY_PURE_YAML', 'false').lower() == 'true' ) -> None: @@ -499,12 +778,33 @@ def __init__( self.basic_msg = '\n' + click.style('Unable to load the Safety Policy file "{name}".', fg='red') self.pure = pure - def to_info_dict(self): + def to_info_dict(self) -> Dict[str, Any]: + """ + Convert the object to an info dictionary. + + Returns: + Dict[str, Any]: The info dictionary. + """ info_dict = super().to_info_dict() info_dict.update(mode=self.mode, encoding=self.encoding) return info_dict - def fail_if_unrecognized_keys(self, used_keys, valid_keys, param=None, ctx=None, msg='{hint}', context_hint=''): + def fail_if_unrecognized_keys(self, used_keys: List[str], valid_keys: List[str], param: Optional[click.Parameter] = None, + ctx: Optional[click.Context] = None, msg: str = '{hint}', context_hint: str = '') -> None: + """ + Fail if unrecognized keys are found in the policy file. + + Args: + used_keys (List[str]): The used keys. + valid_keys (List[str]): The valid keys. + param (Optional[click.Parameter]): The click parameter. + ctx (Optional[click.Context]): The click context. + msg (str): The error message template. + context_hint (str): The context hint for the error message. + + Raises: + click.UsageError: If unrecognized keys are found. + """ for keyword in used_keys: if keyword not in valid_keys: match = None @@ -521,31 +821,60 @@ def fail_if_unrecognized_keys(self, used_keys, valid_keys, param=None, ctx=None, self.fail(msg.format(hint=f'{context_hint}"{keyword}" is not a valid keyword.{maybe_msg}'), param, ctx) - def fail_if_wrong_bool_value(self, keyword, value, msg='{hint}'): + def fail_if_wrong_bool_value(self, keyword: str, value: Any, msg: str = '{hint}') -> None: + """ + Fail if a boolean value is invalid. + + Args: + keyword (str): The keyword. + value (Any): The value. + msg (str): The error message template. + + Raises: + click.UsageError: If the boolean value is invalid. + """ if value is not None and not isinstance(value, bool): self.fail(msg.format(hint=f"'{keyword}' value needs to be a boolean. " "You can use True, False, TRUE, FALSE, true or false")) - def convert(self, value, param, ctx): - try: + def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> Any: + """ + Convert the parameter value to a Safety policy file. + + Args: + value (Any): The parameter value. + param (Optional[click.Parameter]): The click parameter. + ctx (Optional[click.Context]): The click context. + + Returns: + Any: The converted policy file. + Raises: + click.UsageError: If the policy file is invalid. + """ + try: + # Check if the value is already a file-like object if hasattr(value, "read") or hasattr(value, "write"): return value + # Prepare the error message template msg = self.basic_msg.format(name=value) + '\n' + click.style('HINT:', fg='yellow') + ' {hint}' + # Open the file stream f, _ = click.types.open_stream( value, self.mode, self.encoding, self.errors, atomic=False ) filename = '' try: + # Read the content of the file raw = f.read() yaml = YAML(typ='safe', pure=self.pure) safety_policy = yaml.load(raw) filename = f.name f.close() except Exception as e: + # Handle YAML parsing errors show_parsed_hint = isinstance(e, MarkedYAMLError) hint = str(e) if show_parsed_hint: @@ -553,6 +882,7 @@ def convert(self, value, param, ctx): self.fail(msg.format(name=value, hint=hint), param, ctx) + # Validate the structure of the safety policy if not safety_policy or not isinstance(safety_policy, dict) or not safety_policy.get('security', None): hint = "you are missing the security root tag" try: @@ -566,33 +896,34 @@ def convert(self, value, param, ctx): self.fail( msg.format(hint=hint), param, ctx) + # Validate 'security' section keys security_config = safety_policy.get('security', {}) security_keys = ['ignore-cvss-severity-below', 'ignore-cvss-unknown-severity', 'ignore-vulnerabilities', 'continue-on-vulnerability-error', 'ignore-unpinned-requirements'] self.fail_if_unrecognized_keys(security_config.keys(), security_keys, param=param, ctx=ctx, msg=msg, context_hint='"security" -> ') + # Validate 'ignore-cvss-severity-below' value ignore_cvss_security_below = security_config.get('ignore-cvss-severity-below', None) - if ignore_cvss_security_below: limit = 0.0 - try: limit = float(ignore_cvss_security_below) except ValueError as e: self.fail(msg.format(hint="'ignore-cvss-severity-below' value needs to be an integer or float.")) - if limit < 0 or limit > 10: self.fail(msg.format(hint="'ignore-cvss-severity-below' needs to be a value between 0 and 10")) + # Validate 'continue-on-vulnerability-error' value continue_on_vulnerability_error = security_config.get('continue-on-vulnerability-error', None) self.fail_if_wrong_bool_value('continue-on-vulnerability-error', continue_on_vulnerability_error, msg) + # Validate 'ignore-cvss-unknown-severity' value ignore_cvss_unknown_severity = security_config.get('ignore-cvss-unknown-severity', None) self.fail_if_wrong_bool_value('ignore-cvss-unknown-severity', ignore_cvss_unknown_severity, msg) + # Validate 'ignore-vulnerabilities' section ignore_vulns = safety_policy.get('security', {}).get('ignore-vulnerabilities', {}) - if ignore_vulns: if not isinstance(ignore_vulns, dict): self.fail(msg.format(hint="Vulnerability IDs under the 'ignore-vulnerabilities' key, need to " @@ -626,7 +957,7 @@ def convert(self, value, param, ctx): f"be a positive integer") ) - # Validate expires + # Validate expires date d = validate_expiration_date(expires) if expires and not d: @@ -644,9 +975,9 @@ def convert(self, value, param, ctx): else: safety_policy['security']['ignore-vulnerabilities'] = {} + # Validate 'fix' section keys fix_config = safety_policy.get('fix', {}) - self.fail_if_unrecognized_keys(fix_config.keys(), ['auto-security-updates-limit'], param=param, ctx=ctx, msg=msg, - context_hint='"fix" -> ') + self.fail_if_unrecognized_keys(fix_config.keys(), ['auto-security-updates-limit'], param=param, ctx=ctx, msg=msg, context_hint='"fix" -> ') auto_remediation_limit = fix_config.get('auto-security-updates-limit', None) if auto_remediation_limit: @@ -658,7 +989,7 @@ def convert(self, value, param, ctx): except BadParameter as expected_e: raise expected_e except Exception as e: - # Don't fail in the default case + # Handle file not found errors gracefully, don't fail in the default case if ctx and isinstance(e, OSError): default = ctx.get_parameter_source source = default("policy_file") if default("policy_file") else default("policy_file_path") @@ -670,14 +1001,19 @@ def convert(self, value, param, ctx): self.fail(f"{problem}\n{hint}", param, ctx) def shell_complete( - self, ctx: "Context", param: "Parameter", incomplete: str + self, ctx: click.Context, param: click.Parameter, incomplete: str ): - """Return a special completion marker that tells the completion + """ + Return a special completion marker that tells the completion system to use the shell to provide file path completions. - :param ctx: Invocation context for this command. - :param param: The parameter that is requesting completion. - :param incomplete: Value being completed. May be empty. + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + incomplete (str): The value being completed. May be empty. + + Returns: + List[click.shell_completion.CompletionItem]: The completion items. .. versionadded:: 8.0 """ @@ -687,12 +1023,15 @@ def shell_complete( class SingletonMeta(type): + """ + A metaclass for singleton classes. + """ - _instances = {} + _instances: Dict[type, Any] = {} - _lock = Lock() + _lock: Lock = Lock() - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: with cls._lock: if cls not in cls._instances: instance = super().__call__(*args, **kwargs) @@ -701,6 +1040,9 @@ def __call__(cls, *args, **kwargs): class SafetyContext(metaclass=SingletonMeta): + """ + A singleton class to hold the Safety context. + """ packages = [] key = False db_mirror = False @@ -725,6 +1067,9 @@ class SafetyContext(metaclass=SingletonMeta): def sync_safety_context(f): + """ + Decorator to sync the Safety context with the function arguments. + """ def new_func(*args, **kwargs): ctx = SafetyContext() @@ -746,12 +1091,16 @@ def new_func(*args, **kwargs): @sync_safety_context -def get_packages_licenses(*, packages=None, licenses_db=None): - """Get the licenses for the specified packages based on their version. +def get_packages_licenses(*, packages: Optional[List[Package]] = None, licenses_db: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """ + Get the licenses for the specified packages based on their version. + + Args: + packages (Optional[List[Package]]): The list of packages. + licenses_db (Optional[Dict[str, Any]]): The licenses database. - :param packages: packages list - :param licenses_db: the licenses db in the raw form. - :return: list of objects with the packages and their respectives licenses. + Returns: + List[Dict[str, Any]]: The list of packages and their licenses. """ SafetyContext().command = 'license' @@ -776,7 +1125,7 @@ def get_packages_licenses(*, packages=None, licenses_db=None): if is_pinned_requirement(req.specifier): pkg.version = next(iter(req.specifier)).version break - + if not pkg.version: continue version_requested = parse_version(pkg.version) @@ -806,7 +1155,19 @@ def get_packages_licenses(*, packages=None, licenses_db=None): return filtered_packages_licenses -def get_requirements_content(files): +def get_requirements_content(files: List[click.File]) -> Dict[str, str]: + """ + Get the content of the requirements files. + + Args: + files (List[click.File]): The list of requirement files. + + Returns: + Dict[str, str]: The content of the requirement files. + + Raises: + InvalidProvidedReportError: If a file cannot be read. + """ requirements_files = {} for f in files: @@ -820,16 +1181,43 @@ def get_requirements_content(files): return requirements_files -def is_ignore_unpinned_mode(version): +def is_ignore_unpinned_mode(version: str) -> bool: + """ + Check if unpinned mode is enabled based on the version. + + Args: + version (str): The version string. + + Returns: + bool: True if unpinned mode is enabled, False otherwise. + """ ignore = SafetyContext().params.get('ignore_unpinned_requirements') return (ignore is None or ignore) and not version -def get_remediations_count(remediations): +def get_remediations_count(remediations: Dict[str, Any]) -> int: + """ + Get the count of remediations. + + Args: + remediations (Dict[str, Any]): The remediations dictionary. + + Returns: + int: The count of remediations. + """ return sum((len(rem.keys()) for pkg, rem in remediations.items())) -def get_hashes(dependency): +def get_hashes(dependency: Any) -> List[Dict[str, str]]: + """ + Get the hashes for a dependency. + + Args: + dependency (Any): The dependency. + + Returns: + List[Dict[str, str]]: The list of hashes. + """ pattern = re.compile(HASH_REGEX_GROUPS) return [{'method': method, 'hash': hsh} for method, hsh in @@ -837,9 +1225,19 @@ def get_hashes(dependency): def pluralize(word: str, count: int = 0) -> str: + """ + Pluralize a word based on the count. + + Args: + word (str): The word to pluralize. + count (int): The count. + + Returns: + str: The pluralized word. + """ if count == 1: return word - + default = {"was": "were", "this": "these", "has": "have"} if word in default: @@ -858,7 +1256,10 @@ def pluralize(word: str, count: int = 0) -> str: return word + "s" -def initializate_config_dirs(): +def initializate_config_dirs() -> None: + """ + Initialize the configuration directories. + """ USER_CONFIG_DIR.mkdir(parents=True, exist_ok=True) try: