From 07fc391f4e6b69de7a8cf7234e1f9becb3085dbe Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Fri, 18 Sep 2020 12:14:42 -0700 Subject: [PATCH 1/4] amundsen-gremlin: first main commit. Add bulk loader, script translator, streams. Tests. Add commit/push hooks for style and to prevent credential leakage. Add github checks and workflow. Use amundsen-gremlin-committers as default reviewers instead of amundsen-committers Signed-off-by: Joshua Hoskins --- .github/CODEOWNERS | 4 +- .gitignore | 3 + .gitmodules | 3 + .hooks/pre-commit | 30 + .hooks/pre-push | 40 + amundsen_gremlin/__init__.py | 2 + amundsen_gremlin/config.py | 49 ++ amundsen_gremlin/gremlin_model.py | 492 ++++++++++++ amundsen_gremlin/gremlin_shared.py | 122 +++ .../neptune_bulk_loader/__init__.py | 2 + amundsen_gremlin/neptune_bulk_loader/api.py | 533 +++++++++++++ .../gremlin_model_converter.py | 721 ++++++++++++++++++ amundsen_gremlin/py.typed | 0 amundsen_gremlin/script_translator.py | 176 +++++ .../test_and_development_shard.py | 82 ++ amundsen_gremlin/utils/__init__.py | 2 + amundsen_gremlin/utils/streams.py | 390 ++++++++++ for_requests/__init__.py | 2 + for_requests/assume_role_aws4auth.py | 57 ++ for_requests/aws4auth_compatible.py | 24 + for_requests/host_header_ssl.py | 54 ++ requirements.txt | 16 + setup.cfg | 48 ++ setup.py | 27 + ssl_override_server_hostname/__init__.py | 2 + ssl_override_server_hostname/ssl_context.py | 21 + tests/__init__.py | 2 + tests/conftest.py | 27 + tests/unit/__init__.py | 2 + tests/unit/neptune_bulk_loader/__init__.py | 2 + tests/unit/neptune_bulk_loader/test_api.py | 80 ++ .../test_gremlin_model_converter.py | 439 +++++++++++ tests/unit/test_gremlin_model.py | 182 +++++ tests/unit/test_gremlin_shared.py | 53 ++ tests/unit/test_script_translator.py | 89 +++ tests/unit/test_test_and_development_shard.py | 89 +++ tests/unit/utils/__init__.py | 2 + tests/unit/utils/test_streams.py | 205 +++++ 38 files changed, 4072 insertions(+), 2 deletions(-) create mode 100644 .gitmodules create mode 100755 .hooks/pre-commit create mode 100755 .hooks/pre-push create mode 100644 amundsen_gremlin/__init__.py create mode 100644 amundsen_gremlin/config.py create mode 100644 amundsen_gremlin/gremlin_model.py create mode 100644 amundsen_gremlin/gremlin_shared.py create mode 100644 amundsen_gremlin/neptune_bulk_loader/__init__.py create mode 100644 amundsen_gremlin/neptune_bulk_loader/api.py create mode 100644 amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py create mode 100644 amundsen_gremlin/py.typed create mode 100644 amundsen_gremlin/script_translator.py create mode 100644 amundsen_gremlin/test_and_development_shard.py create mode 100644 amundsen_gremlin/utils/__init__.py create mode 100644 amundsen_gremlin/utils/streams.py create mode 100644 for_requests/__init__.py create mode 100644 for_requests/assume_role_aws4auth.py create mode 100644 for_requests/aws4auth_compatible.py create mode 100644 for_requests/host_header_ssl.py create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 ssl_override_server_hostname/__init__.py create mode 100644 ssl_override_server_hostname/ssl_context.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/neptune_bulk_loader/__init__.py create mode 100644 tests/unit/neptune_bulk_loader/test_api.py create mode 100644 tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py create mode 100644 tests/unit/test_gremlin_model.py create mode 100644 tests/unit/test_gremlin_shared.py create mode 100644 tests/unit/test_script_translator.py create mode 100644 tests/unit/test_test_and_development_shard.py create mode 100644 tests/unit/utils/__init__.py create mode 100644 tests/unit/utils/test_streams.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index aa80f70..f0c2ea7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,8 +6,8 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, -# @amundsen-io/amundsen-committerswill be requested for +# @amundsen-io/amundsen-gremlin-committers will be requested for # review when someone opens a pull request. -* @amundsen-io/amundsen-committers +* @amundsen-io/amundsen-gremlin-committers *.py @alran @cpu @dsimms @friendtocephalopods @kathawthorne @worldwise001 diff --git a/.gitignore b/.gitignore index b6e4761..d6ac92a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# Vscode project settings +.vscode/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..700c661 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "amazon-neptune-tools"] + path = amazon-neptune-tools + url = https://github.com/awslabs/amazon-neptune-tools.git diff --git a/.hooks/pre-commit b/.hooks/pre-commit new file mode 100755 index 0000000..a966562 --- /dev/null +++ b/.hooks/pre-commit @@ -0,0 +1,30 @@ +#!/bin/sh + +red='\033[0;31m' +green='\033[0;32m' +NC='\033[0m' + +# Sort imports +echo -e "${green}[Isort]: Checking Sorting${NC}" +venv/bin/isort -c +if [ $? -ne 0 ] +then + venv/bin/isort --apply + echo -e "${red}Sorted imports; recommit${NC}" + exit 1 +fi + +# Check Neptune config + +echo -e "${green}[Config]: Checking for secrets${NC}" + +# Grep will return 0 if and only if it finds a changed amazonaws url. This suggests you may accidentally be publically +# committing your config! +git diff HEAD amundsen_gremlin/config.py | grep -q -c 'amazonaws.com' +if [ $? -eq 0 ] +then + echo -e "${red}Did you remember to remove your AWS config? If this is s a false alarm, recommit with --no-verify${NC}" + exit 1 +else + exit 0 +fi \ No newline at end of file diff --git a/.hooks/pre-push b/.hooks/pre-push new file mode 100755 index 0000000..d8e9c51 --- /dev/null +++ b/.hooks/pre-push @@ -0,0 +1,40 @@ +#!/bin/sh + +red='\033[0;31m' +green='\033[0;32m' +NC='\033[0m' + +set -e + +# Get only the files different on this branch +BASE_SHA="$(git merge-base master HEAD)" +FILES="$(git diff --name-only --diff-filter=AMC $BASE_SHA HEAD | grep "\.py$" | tr '\n' ' ')" + +echo -e "${green}[Python Checks][Info]: Checking Python Style, Types${NC}" + +if [ -n "$FILES" ] +then + + echo -e "${green}[Python Checks][Info]: ${FILES}${NC}" + + # Run flake8 + flake8 . + + if [ $? -ne 0 ]; then + echo -e "${red}[Python Style][Error]: Fix the issues and commit again (or commit with --no-verify if you are sure)${NC}" + exit 1 + fi + + # Run mypy + mypy . + + if [ $? -ne 0 ]; then + echo -e "${red}[Python Type Checks][Error]: Fix the issues and commit again (or commit with --no-verify if you are sure)${NC}" + exit 1 + fi + +else + echo -e "${green}[Python Checks][Info]: No files to check${NC}" +fi + +exit 0 diff --git a/amundsen_gremlin/__init__.py b/amundsen_gremlin/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/amundsen_gremlin/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/amundsen_gremlin/config.py b/amundsen_gremlin/config.py new file mode 100644 index 0000000..6d7c631 --- /dev/null +++ b/amundsen_gremlin/config.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Mapping, Optional, Union + + +class Config: + pass + + +NEPTUNE_URLS_BY_USER: Mapping[str, Mapping[str, Any]] = { + "nobody": { + "neptune_endpoint": "nowhere.amazonaws.com", + "neptune_port": 8182, + "uri": "nowhere.amazonaws.com:8182/gremlin" + }, +} + + +def neptune_url_for_development(*, user: Optional[str] = None) -> Optional[Mapping[str, Any]]: + # Hello! If you get here and and your user is not above, ask one of them to borrow theirs. Or add your username + # to development_instance_users in terraform/deployments/development/main.tf and terraform apply + return NEPTUNE_URLS_BY_USER[os.getenv('USER', 'nobody')] + + +class TestGremlinConfig(Config): + NEPTUNE_BULK_LOADER_S3_BUCKET_NAME = 'amundsen-gremlin-development-bulk-loader' + NEPTUNE_URL = 'something.amazonaws.com:8182/gremlin' + # TODO: populate a session here + NEPTUNE_SESSION = None + + +class LocalGremlinConfig(Config): + LOG_LEVEL = 'DEBUG' + NEPTUNE_BULK_LOADER_S3_BUCKET_NAME = 'amundsen-gremlin-development-bulk-loader' + # The appropriate AWS region for your neptune setup + # ex: AWS_REGION_NAME = 'us-west-2' + AWS_REGION_NAME = None + # NB: Session should be shaped like: + # NEPTUNE_SESSION = boto3.session.Session(profile_name='youruserprofilehere', + # region_name=AWS_REGION_NAME) + # Unfortunately this will always blow up without a legit profile name + NEPTUNE_SESSION = None + # NB: NEPTUNE_URL should be shaped like: + # NEPTUNE_URL = neptune_url_for_development() + # Unfortunately, this will blow up if your user is not present + NEPTUNE_URL = None + PROXY_HOST: Union[str, Mapping[str, Any], None] = NEPTUNE_URL diff --git a/amundsen_gremlin/gremlin_model.py b/amundsen_gremlin/gremlin_model.py new file mode 100644 index 0000000..fad19e3 --- /dev/null +++ b/amundsen_gremlin/gremlin_model.py @@ -0,0 +1,492 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from abc import ABC, abstractmethod +from enum import Enum, unique +from functools import lru_cache +from typing import ( + Any, Dict, FrozenSet, Hashable, List, Mapping, NamedTuple, Optional, + Sequence, Set, Tuple, Type, TypeVar +) + +from gremlin_python.process.traversal import Cardinality +from overrides import overrides + +from amundsen_gremlin.test_and_development_shard import get_shard + + +class GremlinTyper(ABC): + @abstractmethod + def is_allowed(self, value: Any) -> None: + pass + + def format(self, value: Any) -> str: + self.is_allowed(value) + # default is to use the built-in format + return str(value) + + +class GremlinBooleanTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, bool), f'expected bool, not {type(value)} {value}' + + +class GremlinByteTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, int) and value >= -(2**7) and value < (2**7), \ + f'expected int in [-2**7, 2**7), not {type(value)} {value}' + + +class GremlinShortTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, int) and value >= -(2**15) and value < (2**15), \ + f'expected int in [-2**15, 2**15), not {type(value)} {value}' + + +class GremlinIntTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, int) and value >= -(2**31) and value < (2**31), \ + f'expected int in [-2**31, 2**31), not {type(value)} {value}' + + +class GremlinLongTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, int) and value >= -(2**63) and value < (2**63), \ + f'expected int in [-2**63, 2**63), not {type(value)} {value}' + + +class GremlinFloatTyper(GremlinTyper): + """ + The Neptune Bulk Loader loads any precision floating point and rounds the mantissa. It'd be nice to avoid the + possible surprise by checking the precision here but feels too strict. + """ + def is_allowed(self, value: Any) -> None: + assert isinstance(value, float), f'expected float, not {type(value)} {value}' + + +class GremlinStringTyper(GremlinTyper): + def is_allowed(self, value: Any) -> None: + assert isinstance(value, str), f'expected str, not {type(value)} {value}' + + +class GremlinDateTyper(GremlinTyper): + @overrides + def is_allowed(self, value: Any) -> None: + assert (isinstance(value, (datetime.datetime, datetime.date)) + and (value.tzinfo is None if isinstance(value, (datetime.datetime)) else True)), \ + f'expected datetime.datetime (without tz) or datetime.date, not {type(value)} {value}' + + @overrides + def format(self, value: Any) -> str: + # datetime.datetime first otherwise isinstance(datetime.date) will catch it + if isinstance(value, datetime.datetime): + # already asserted no tz but double check + assert value.tzinfo is None, f'wat? already checked there was no tzinfo {value}' + return value.isoformat(timespec='seconds') + elif isinstance(value, datetime.date): + return value.isoformat() + else: + raise AssertionError(f'wat? already checked value was datetime or date: {value}') + + +class GremlinType(Enum): + """ + https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-tutorial-format-gremlin.html + Note I really wish the enum *values* had types too, but Guido didn't seem to like that at all. + """ + # 'Bool', 'Boolean' + Boolean = GremlinBooleanTyper() + # 'Byte', 'Short', 'Int', 'Long' + Byte = GremlinByteTyper() + Short = GremlinShortTyper() + Int = GremlinIntTyper() + Long = GremlinLongTyper() + # 'Float', 'Double': the loader loads any precision floating point and rounds the mantissa + Float = GremlinFloatTyper() + Double = GremlinFloatTyper() + # 'String' + String = GremlinStringTyper() + # 'Date' in YYYY-MM-DD, YYYY-MM-DDTHH:mm, YYYY-MM-DDTHH:mm:SS, YYYY-MM-DDTHH:mm:SSZ + Date = GremlinDateTyper() + + +class GremlinCardinality(Enum): + single = Cardinality.single + set = Cardinality.set_ + # list is not supported by Neptune + list = Cardinality.list_ + + def gremlin_python_cardinality(self) -> Cardinality: + return self.value + + +class Property(NamedTuple): + name: str + type: GremlinType + # For edge properties, omit (always single). For vertex properties, we assume single (unlike Neptune's gremlin + # which would assume set). + cardinality: Optional[GremlinCardinality] = None + multi_valued: bool = False + required: bool = False + comment: Optional[str] = None + default: Optional[Any] = None + + def signature(self, default_cardinality: Optional[GremlinCardinality]) -> 'Property': + # this isn't foolproof but works for Property and MagicProperty at least + return type(self)(name=self.name, type=self.type, cardinality=self.cardinality or default_cardinality) + + def format(self, value: Any) -> str: + return self.type.value.format(value) + + def header(self) -> str: + formatted = f'{self.name}:{self.type.name}' + if self.cardinality: + formatted += f'({self.cardinality.name})' + if self.multi_valued: + formatted += '[]' + return formatted + + +class MagicProperty(Property): + @overrides + def header(self) -> str: + return self.name + + +@unique +class MagicProperties(Enum): + """ + when writing out the header for these, they don't get the cardinality (they're single) or multi-valued + (they're single) + """ + ID = MagicProperty(name='~id', type=GremlinType.String, required=True) + LABEL = MagicProperty(name='~label', type=GremlinType.String, required=True) + FROM = MagicProperty(name='~from', type=GremlinType.String, required=True) + TO = MagicProperty(name='~to', type=GremlinType.String, required=True) + + +@unique +class WellKnownProperties(Enum): + # we expect that key is unique for a label + Key = Property(name='key', type=GremlinType.String, required=True) + + Created = Property(name='created', type=GremlinType.Date, required=True) + Expired = Property(name='expired', type=GremlinType.Date) + + TestShard = Property(name='shard', type=GremlinType.String, required=True, comment=''' + Only present in development and testing. Separates different instances sharing a datastore (so is sort of the + opposite of how one might usually use the word shard). + ''') + + +# TODO: move this someplace shared +def _discover_parameters(format_string: str) -> FrozenSet[str]: + """ + use this to discover what the parameters to a format string + """ + parameters: FrozenSet[str] = frozenset() + while True: + try: + format_string.format(**dict((k, '') for k in parameters)) + return parameters + except KeyError as e: + updated = parameters.union(set(e.args)) + assert updated != parameters + parameters = updated + + +V = TypeVar('V') + + +class VertexTypeIdFormats(Enum): + # note, these aren't f-strings, but their formatted values are used + DEFAULT = '{~label}:{key}' + + +class VertexType(NamedTuple): + label: str + properties: Tuple[Property, ...] + id_format: str = 'not used' + defaults: Tuple[Tuple[str, Hashable], ...] = () + + # let's make this simpler and say no positional args (which the NamedTuple constructor would allow), also we can't + # override __new__ + @classmethod + def construct_type(cls: Type["VertexType"], **kwargs: Any) -> "VertexType": + defaults: Dict[str, Hashable] = dict() + properties: Set[Property] = set(kwargs.pop('properties', [])) + properties.update({MagicProperties.LABEL.value, MagicProperties.ID.value, WellKnownProperties.Key.value}) + + # (magically) insinuate the shard identifier into the vertex id format + shard = get_shard() + id_format = kwargs.pop('id_format', VertexTypeIdFormats.DEFAULT.value) + if shard: + properties.update({WellKnownProperties.TestShard.value}) + defaults.update({WellKnownProperties.TestShard.value.name: shard}) + # prepend if it's not in the format already (which would be pretty weird, but anyway) + if '{shard}' not in id_format: + id_format = '{shard}:' + id_format + + parameters = _discover_parameters(id_format) + properties_names = set([p.name for p in properties]) + assert all(p in properties_names for p in parameters), \ + f'id_format: {id_format} has parameters: {parameters} not found in our properties {properties_names}' + + return cls(properties=tuple(properties), defaults=tuple(defaults.items()), id_format=id_format, **kwargs) + + @lru_cache() + def properties_as_map(self) -> Mapping[str, Property]: + mapping = dict([(p.name, p) for p in self.properties]) + assert len(mapping) == len(self.properties), f'are property names not unique? {self.properties}' + return mapping + + def id(self, **entity: Any) -> str: + for name, value in (self.defaults or ()): + if name not in entity: + entity[name] = value + # format them if they're not already. (the isinstance(v, str) feels wrong here tho) + values = dict([(n, (self.properties_as_map()[n].format(v) if v is not None and not isinstance(v, str) else v)) + for n, v in entity.items()]) + values.update({'~label': self.label}) + return self.id_format.format(**values) + + def create(self, **properties: Any) -> Mapping[str, Any]: + if MagicProperties.ID.value in self.properties and MagicProperties.ID.value.name not in properties: + properties[MagicProperties.ID.value.name] = self.id(**properties) + if MagicProperties.LABEL.value in self.properties and MagicProperties.LABEL.value.name not in properties: + properties[MagicProperties.LABEL.value.name] = self.label + for name, value in (self.defaults or ()): + if name not in properties: + properties[name] = value + # remove missing values + for k in [k for k, v in properties.items() if v is None]: + del properties[k] + properties.update([(k, v) for k, v in self.properties_as_map().items() + if v.default is not None and k not in properties]) + property_names = set(self.properties_as_map().keys()) + assert set(properties.keys()).issubset(property_names), \ + f'unexpected properties: properties: {properties}, expected names: {property_names}' + required_property_names = set([k for k, v in self.properties_as_map().items() if v.required]) + assert set(properties.keys()).issuperset(required_property_names), \ + f'expected required properties: properties: {properties}, expected names: {required_property_names}' + return properties + + +class EdgeTypeIdFormats(Enum): + # note, these aren't f-strings, but their formatted values are used + DEFAULT = '{~label}:{~from}->{~to}' + EXPIRABLE = '{~label}:{created}:{~from}->{~to}' + + +class EdgeType(NamedTuple): + label: str + properties: Tuple[Property, ...] + # TODO: fill these out + from_labels: Tuple[str, ...] = () + to_labels: Tuple[str, ...] = () + id_format: str = 'not used' + + # let's make this simpler and say no positional args (which the NamedTuple constructor would allow), also we can't + # override __new__ + @classmethod + def construct_type(cls: Type["EdgeType"], *, expirable: bool = True, **kwargs: Any) -> "EdgeType": + properties: Set[Property] = set(kwargs.pop('properties', [])) + properties.update({MagicProperties.LABEL.value, MagicProperties.ID.value, MagicProperties.FROM.value, + MagicProperties.TO.value, WellKnownProperties.Created.value}) + + # NB: Some edge types may not make sense to soft-expire + if expirable: + properties.update({WellKnownProperties.Expired.value}) + id_format = kwargs.pop('id_format', EdgeTypeIdFormats.EXPIRABLE.value) + else: + id_format = kwargs.pop('id_format', EdgeTypeIdFormats.DEFAULT.value) + + parameters = _discover_parameters(id_format) + properties_names = set([p.name for p in properties]) + assert all(p in properties_names for p in parameters), \ + f'id_format: {id_format} has parameters: {parameters} not found in our properties {properties_names}' + + return cls(properties=tuple(properties), id_format=id_format, **kwargs) + + @lru_cache() + def properties_as_map(self) -> Mapping[str, Property]: + mapping = dict([(p.name, p) for p in self.properties]) + assert len(mapping) == len(self.properties), f'are property names not unique? {self.properties}' + return mapping + + def id(self, **entity: Any) -> str: + # format them if they're not already. (the isinstance(v, str) feels wrong here tho) + values = dict([(n, (self.properties_as_map()[n].format(v) if v is not None and not isinstance(v, str) else v)) + for n, v in entity.items()]) + values.update({'~label': self.label}) + return self.id_format.format(**values) + + def create(self, **properties: Any) -> Mapping[str, Any]: + properties = dict(properties) + if MagicProperties.ID.value in self.properties and MagicProperties.ID.value.name not in properties: + properties[MagicProperties.ID.value.name] = self.id(**properties) + if MagicProperties.LABEL.value in self.properties and MagicProperties.LABEL.value.name not in properties: + properties[MagicProperties.LABEL.value.name] = self.label + # remove missing values + for k in [k for k, v in properties.items() if v is None]: + del properties[k] + properties.update([(k, v) for k, v in self.properties_as_map().items() + if v.default is not None and k not in properties]) + property_names = set(self.properties_as_map().keys()) + assert set(properties.keys()).issubset(property_names), \ + f'unexpected properties: properties: {properties}, expected names: {property_names}' + required_property_names = set([k for k, v in self.properties_as_map().items() if v.required]) + assert set(properties.keys()).issuperset(required_property_names), \ + f'expected required properties: properties: {properties}, expected names: {required_property_names}' + return properties + + +class VertexTypes(Enum): + """ + In general, you will need to reload all your data: 1. change label, 2. if you change the type of a property, + 3. change the effective id_format + """ + @classmethod + @lru_cache() + def by_label(cls) -> Mapping[str, "VertexTypes"]: + constants: List[VertexTypes] = list(cls) + mapping = dict([(c.value.label, c) for c in constants]) + assert len(mapping) == len(constants), f'are label names not unique? {constants}' + return mapping + + Application = VertexType.construct_type( + label='Application', + properties=[ + # there's a kind property we don't care about so much, so are ignoring and assuming all Application with + # the same id (but different kind) have the same identity + Property(name='id', type=GremlinType.String, required=True), + Property(name='name', type=GremlinType.String), + Property(name='description', type=GremlinType.String), + # except: we get different application_url per kind so keep those (but if a kind's url changes, we'll keep + # the old one around so the model isn't perfect) + Property(name='application_url', type=GremlinType.String, cardinality=GremlinCardinality.set)]) + Column = VertexType.construct_type( + label='Column', + properties=[ + Property(name='name', type=GremlinType.String, required=True), + Property(name='sort_order', type=GremlinType.Int), + Property(name='col_type', type=GremlinType.String)]) + Cluster = VertexType.construct_type( + label='Cluster', + properties=[Property(name='name', type=GremlinType.String, required=True)]) + Database = VertexType.construct_type( + label='Database', + properties=[ + Property(name='name', type=GremlinType.String, required=True)]) + Description = VertexType.construct_type( + label='Description', + properties=[ + Property(name='description', type=GremlinType.String, required=True), + Property(name='source', type=GremlinType.String, required=True, comment='effectively an enum')]) + Schema = VertexType.construct_type( + label='Schema', + properties=[Property(name='name', type=GremlinType.String, required=True)]) + Source = VertexType.construct_type( + label='Source', + properties=[]) + Stat = VertexType.construct_type( + label='Stat', + properties=[ + Property(name='stat_val', type=GremlinType.String), + Property(name='stat_type', type=GremlinType.String, comment='effectively an enum'), + Property(name='start_epoch', type=GremlinType.Date), + Property(name='end_epoch', type=GremlinType.Date)]) + Table = VertexType.construct_type( + label='Table', + properties=[ + Property(name='name', type=GremlinType.String, required=True), + Property(name='is_view', type=GremlinType.Boolean), + Property(name='display_name', type=GremlinType.String)]) + Tag = VertexType.construct_type( + label='Tag', + properties=[ + Property(name='tag_name', type=GremlinType.String, required=True), + Property(name='tag_type', type=GremlinType.String, required=True, default='default', + comment='effectively an enum, usually default')]) + Timestamp = VertexType.construct_type( + label='Timestamp', + properties=[]) + Updatedtimestamp = VertexType.construct_type( + label='Updatedtimestamp', + properties=[ + Property(name='latest_timestamp', type=GremlinType.Date, required=True)]) + User = VertexType.construct_type( + label='User', + properties=[ + Property(name='user_id', type=GremlinType.String, required=True), + Property(name='email', type=GremlinType.String), + Property(name='full_name', type=GremlinType.String), + Property(name='first_name', type=GremlinType.String), + Property(name='last_name', type=GremlinType.String), + Property(name='display_name', type=GremlinType.String), + Property(name='team_name', type=GremlinType.String), + Property(name='employee_type', type=GremlinType.String, comment='this is effectively an enum'), + Property(name='is_active', type=GremlinType.Boolean), + Property(name='profile_url', type=GremlinType.String), + Property(name='role_name', type=GremlinType.String), + Property(name='slack_id', type=GremlinType.String), + Property(name='github_username', type=GremlinType.String), + Property(name='manager_fullname', type=GremlinType.String), + Property(name='manager_email', type=GremlinType.String), + Property(name='manager_id', type=GremlinType.String, + comment='the key/user_id of another User who is the manager for this User'), + Property(name='is_robot', type=GremlinType.Boolean) + ]) + Watermark = VertexType.construct_type( + label='Watermark', + properties=[]) + + +class EdgeTypes(Enum): + """ + In general, you will need to reload all your data: 1. change label, 2. if you change the type of a property, + 3. change the effective id_format (e.g. change required) + """ + @classmethod + @lru_cache() + def by_label(cls) -> Mapping[str, "EdgeTypes"]: + constants: List[EdgeTypes] = list(cls) + mapping = dict([(c.value.label, c) for c in constants]) + assert len(mapping) == len(constants), f'are label names not unique? {constants}' + return mapping + + @classmethod + @lru_cache() + def expirable(cls: Type["EdgeTypes"]) -> Sequence["EdgeTypes"]: + return tuple(t for t in cls + if WellKnownProperties.Expired.value.name in t.value.properties_as_map()) + + Admin = EdgeType.construct_type(label='ADMIN') + BelongToTable = EdgeType.construct_type(label='BELONG_TO_TABLE') + Cluster = EdgeType.construct_type(label='CLUSTER') + Column = EdgeType.construct_type(label='COLUMN') + Database = EdgeType.construct_type(label='DATABASE') + Description = EdgeType.construct_type(label='DESCRIPTION') + Grant = EdgeType.construct_type(label='GRANT') + Follow = EdgeType.construct_type(label='FOLLOW') + Generates = EdgeType.construct_type(label='GENERATES') + LastUpdatedAt = EdgeType.construct_type(label='LAST_UPDATED_AT') + Member = EdgeType.construct_type(label='MEMBER') + ManagedBy = EdgeType.construct_type(label='MANAGED_BY') + Owner = EdgeType.construct_type(label='OWNER') + Read = EdgeType.construct_type( + label='READ', + # need to do something safer with that date (so it doesn't end up as datetime ever) + id_format='{~label}:{date}:{~from}->{~to}', + properties=[ + Property(name='date', type=GremlinType.Date, required=True), + Property(name='read_count', type=GremlinType.Long, required=True)]) + ReadWrite = EdgeType.construct_type(label='READ_WRITE') + ReadOnly = EdgeType.construct_type(label='READ_ONLY') + RequiresAccessTo = EdgeType.construct_type(label='REQUIRES_ACCESS_TO') + Schema = EdgeType.construct_type(label='SCHEMA') + Source = EdgeType.construct_type(label='SOURCE') + Stat = EdgeType.construct_type(label='STAT') + Table = EdgeType.construct_type(label='TABLE') + Tag = EdgeType.construct_type(label='TAG') diff --git a/amundsen_gremlin/gremlin_shared.py b/amundsen_gremlin/gremlin_shared.py new file mode 100644 index 0000000..912c96c --- /dev/null +++ b/amundsen_gremlin/gremlin_shared.py @@ -0,0 +1,122 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, no_type_check, overload + +from gremlin_python.process.graph_traversal import GraphTraversal +from gremlin_python.process.traversal import Bytecode, Traversal + + +@no_type_check +def rsubstringstartingwith(sub: str, s: str) -> Optional[str]: + """ + >>> rsubstringstartingwith('://', 'database://foo') + 'foo' + >>> rsubstringstartingwith('://', 'database://foo://bar') + 'bar' + >>> rsubstringstartingwith('://', 'foo') + None + """ + try: + return s[s.rindex(sub) + len(sub):] + except ValueError: + return None + + +def make_database_uri(*, database_name: str) -> str: + return f'database://{database_name}' + + +def get_database_name_from_uri(*, database_uri: str) -> str: + if not database_uri.startswith('database://'): + raise RuntimeError(f'database_uri is malformed! {database_uri}') + database_name = rsubstringstartingwith('://', database_uri) + if database_name is None: + raise RuntimeError(f'database_uri is malformed! {database_uri}') + return database_name + + +@overload +def make_cluster_uri(*, database_uri: str, cluster_name: str) -> str: ... + + +@overload +def make_cluster_uri(*, database_name: str, cluster_name: str) -> str: ... + + +def make_cluster_uri(*, cluster_name: str, database_uri: Optional[str] = None, + database_name: Optional[str] = None) -> str: + if database_name is None: + assert database_uri is not None + database_name = get_database_name_from_uri(database_uri=database_uri) + assert database_name is not None + return f'{database_name}://{cluster_name}' + + +@overload +def make_schema_uri(*, cluster_uri: str, schema_name: str) -> str: ... + + +@overload +def make_schema_uri(*, database_name: str, cluster_name: str, schema_name: str) -> str: ... + + +def make_schema_uri(*, schema_name: str, cluster_uri: Optional[str] = None, database_name: Optional[str] = None, + cluster_name: Optional[str] = None) -> str: + if cluster_uri is None: + assert cluster_name is not None and database_name is not None + cluster_uri = make_cluster_uri(cluster_name=cluster_name, database_name=database_name) + assert cluster_uri is not None + return f'{cluster_uri}.{schema_name}' + + +@overload +def make_table_uri(*, schema_uri: str, table_name: str) -> str: ... + + +@overload +def make_table_uri(*, database_name: str, cluster_name: str, schema_name: str, table_name: str) -> str: ... + + +def make_table_uri(*, table_name: str, schema_uri: Optional[str] = None, database_name: Optional[str] = None, + cluster_name: Optional[str] = None, schema_name: Optional[str] = None) -> str: + if schema_uri is None: + assert database_name is not None and cluster_name is not None and schema_name is not None + schema_uri = make_schema_uri(schema_name=schema_name, cluster_name=cluster_name, database_name=database_name) + assert schema_uri is not None + return f'{schema_uri}/{table_name}' + + +def make_message_uri(*, name: str, package: str) -> str: + return f'message/{package}.{name}' + + +def make_shard_uri(*, table_uri: str, shard_name: str) -> str: + return f'{table_uri}/shards/{shard_name}' + + +def make_description_uri(*, subject_uri: str, source: str) -> str: + return f'{subject_uri}/{source}/_description' + + +def make_column_uri(*, table_uri: str, column_name: str) -> str: + return f'{table_uri}/{column_name}' + + +def make_column_statistic_uri(*, column_uri: str, statistic_type: str) -> str: + return f'{column_uri}/stat/{statistic_type}' + + +def append_traversal(g: Traversal, *traversals: Optional[Traversal]) -> GraphTraversal: + """ + copy the traversal, and append the traversals to it. (It's a little magic, but common-ish in the gremlin world + forums.) + """ + bytecode = Bytecode(bytecode=g.bytecode) + for t in [t for t in traversals if t is not None]: + assert t.graph is None, f'traversal has a graph source! should be an anonymous traversal: {t}' + for source_name, *source_args in t.bytecode.source_instructions: + bytecode.add_source(source_name, *source_args) + for step_name, *step_args in t.bytecode.step_instructions: + bytecode.add_step(step_name, *step_args) + return GraphTraversal(graph=g.graph, traversal_strategies=g.traversal_strategies, bytecode=bytecode) diff --git a/amundsen_gremlin/neptune_bulk_loader/__init__.py b/amundsen_gremlin/neptune_bulk_loader/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/amundsen_gremlin/neptune_bulk_loader/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/amundsen_gremlin/neptune_bulk_loader/api.py b/amundsen_gremlin/neptune_bulk_loader/api.py new file mode 100644 index 0000000..0cdf3ca --- /dev/null +++ b/amundsen_gremlin/neptune_bulk_loader/api.py @@ -0,0 +1,533 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import datetime +import json +import logging +import time +from collections import defaultdict +from enum import Enum, auto +from io import BytesIO, StringIO +from typing import ( + IO, Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, + Sequence, Set, Tuple, Type, TypeVar, Union, cast +) +from urllib.parse import SplitResult, urlencode, urlsplit, urlunsplit + +import boto3 +import requests +from boto3.s3.transfer import TransferConfig +from flask import Config +from gremlin_python.driver.driver_remote_connection import ( + DriverRemoteConnection +) +from gremlin_python.process.anonymous_traversal import traversal +from gremlin_python.process.graph_traversal import GraphTraversalSource +from neptune_python_utils.endpoints import Endpoints, RequestParameters +from requests_aws4auth import AWS4Auth +from tornado import httpclient +from typing_extensions import TypedDict # is in typing in 3.8 + +from amundsen_gremlin.gremlin_model import ( + EdgeType, GremlinCardinality, Property, VertexType +) +from amundsen_gremlin.test_and_development_shard import get_shard +from for_requests.assume_role_aws4auth import AssumeRoleAWS4Auth +from for_requests.aws4auth_compatible import to_aws4_request_compatible_host +from for_requests.host_header_ssl import HostHeaderSSLAdapter +from ssl_override_server_hostname.ssl_context import ( + OverrideServerHostnameSSLContext +) + +LOGGER = logging.getLogger(__name__) + +GraphEntity = Mapping[str, Any] +FormattedEntity = Mapping[str, str] +GraphEntityType = Union[VertexType, EdgeType] +PropertiesMap = Mapping[str, Property] +GraphEntities = Mapping[GraphEntityType, List[GraphEntity]] +FormattedEntities = Sequence[Mapping[str, str]] + + +def get_neptune_graph_traversal_source_factory_from_config(config: Config) -> Callable[[], GraphTraversalSource]: + session = config.get('NEPTUNE_SESSION') + assert session is not None + + neptune_url = config.get('NEPTUNE_URL') + assert neptune_url is not None + + return get_neptune_graph_traversal_source_factory(neptune_url=neptune_url, session=session) + + +def get_neptune_graph_traversal_source_factory(*, neptune_url: Union[str, Mapping[str, Any]], + session: boto3.session.Session) -> Callable[[], GraphTraversalSource]: + + endpoints: Endpoints + override_uri: Optional[str] + if isinstance(neptune_url, str): + uri = urlsplit(neptune_url) + assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \ + f'expected Neptune URL not {neptune_url}' + endpoints = Endpoints(neptune_endpoint=uri.hostname, neptune_port=uri.port, + region_name=session.region_name, credentials=session.get_credentials()) + override_uri = None + elif isinstance(neptune_url, Mapping): + endpoints = Endpoints(neptune_endpoint=neptune_url['neptune_endpoint'], + neptune_port=neptune_url['neptune_port'], region_name=session.region_name, + credentials=session.get_credentials()) + override_uri = neptune_url['uri'] + assert override_uri is None or isinstance(override_uri, str) + else: + raise AssertionError(f'what is NEPTUNE_URL? {neptune_url}') + + def create_graph_traversal_source(**kwargs: Any) -> GraphTraversalSource: + assert all(e not in kwargs for e in ('url', 'traversal_source')), \ + f'do not pass traversal_source or url in {kwargs}' + prepared_request = override_prepared_request_parameters( + endpoints.gremlin_endpoint().prepare_request(), override_uri=override_uri) + kwargs['traversal_source'] = 'g' + remote_connection = DriverRemoteConnection(url=prepared_request, **kwargs) + return traversal().withRemote(remote_connection) + return create_graph_traversal_source + + +def override_prepared_request_parameters( + request_parameters: RequestParameters, *, override_uri: Optional[Union[str, SplitResult]] = None, + method: Optional[str] = None, data: Optional[str] = None) -> httpclient.HTTPRequest: + """ + use like: + endpoints = Endpoints(neptune_endpoint=host_name, neptune_port=port_number, + region_name=session.region_name, credentials=session.get_credentials()) + override_prepared_request(endpoints.gremlin_endpoint().prepare_request(), override_uri=host_to_actually_connect_to) + + but note if you are not GETing (or have a payload), perpare_request doesn't *actually* generate sufficient headers + (despite the fact that it accepts a method) + """ + http_request_param: Dict[str, Any] = dict(url=request_parameters.uri, headers=request_parameters.headers) + if method is not None: + http_request_param['method'] = method + if data is not None: + http_request_param['body'] = data + if override_uri is not None: + # we override the URI slightly (because the instance thinks it's a different host than we're connecting to) + if isinstance(override_uri, str): + override_uri = urlsplit(override_uri) + assert isinstance(override_uri, SplitResult) + uri = urlsplit(request_parameters.uri) + http_request_param['headers'] = dict(request_parameters.headers) + http_request_param['headers']['Host'] = uri.netloc + http_request_param['ssl_options'] = OverrideServerHostnameSSLContext(server_hostname=uri.hostname) + http_request_param['url'] = urlunsplit( + (uri.scheme, override_uri.netloc, uri.path, uri.query, uri.fragment)) + return httpclient.HTTPRequest(**http_request_param) + + +def _urlsplit_if_not_already(uri: Union[str, SplitResult]) -> SplitResult: + if isinstance(uri, str): + return urlsplit(uri) + elif isinstance(uri, SplitResult): + return uri + raise AssertionError(f'what is uri? {uri}') + + +def request_with_override(*, uri: Union[str, SplitResult], override_uri: Optional[Union[str, SplitResult]] = None, + method: str = 'GET', headers: Dict[str, str] = {}, **kwargs: Any) -> Any: + # why not use endpoints? Despite the fact that it accepts a method and payload, it doesn't *actually* generate + # sufficient headers so we'll use requests for these since we can + if isinstance(uri, str): + uri = urlsplit(uri) + elif isinstance(uri, SplitResult): + pass + else: + raise AssertionError(f'what is uri? {uri}') + + # don't always need this, but it doesn't hurt + if 'Host' not in headers: + headers = dict(headers) + headers['Host'] = to_aws4_request_compatible_host(uri) + s = requests.Session() + if override_uri: + override_uri = _urlsplit_if_not_already(override_uri) + uri = urlunsplit((uri.scheme, override_uri.netloc, uri.path, uri.query, uri.fragment)) + s.mount('https://', HostHeaderSSLAdapter()) + else: + uri = urlunsplit(uri) + return s.request(method=method, url=uri, headers=headers, **kwargs) + + +def response_as_json(response: Any) -> Mapping[str, Any]: + return json.loads(response.content.decode('utf-8')) + + +class NeptuneBulkLoaderLoadStatusOverallStatus(TypedDict): + fullUri: str + runNumber: int + retryNumber: int + status: str # Literal['LOAD_FAILED', 'LOAD_COMPLETED', ...] + totalTimeSpent: int + startTime: int + totalRecords: int + totalDuplicates: int + parsingErrors: int + datatypeMismatchErrors: int + insertErrors: int + + +class NeptuneBulkLoaderLoadStatusErrorLogEntry(TypedDict): + errorCode: str + errorMessage: str + fileName: str + recordNum: int + + +class NeptuneBulkLoaderLoadStatusErrors(TypedDict): + startIndex: int + endIndex: int + loadId: str + # depends on the errors_per_page and errors_page + errorLogs: List[NeptuneBulkLoaderLoadStatusErrorLogEntry] + + +class NeptuneBulkLoaderLoadStatusPayload(TypedDict): + # those string keys are like the status enums + feedCount: List[Dict[str, int]] + overallStatus: NeptuneBulkLoaderLoadStatusOverallStatus + # optional, only if errors is true in the request + errors: NeptuneBulkLoaderLoadStatusErrors + + +class NeptuneBulkLoaderLoadStatus(TypedDict): + """ + see https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-status.html + """ + status: str + payload: NeptuneBulkLoaderLoadStatusPayload + + +class BulkLoaderParallelism(Enum): + """ + Literal might be better for this in 3.8? + """ + LOW = auto() + MEDIUM = auto() + HIGH = auto() + OVERSUBSCRIBE = auto() + + +class BulkLoaderFormat(Enum): + """ + See https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-tutorial-format.html + """ + # see https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-tutorial-format-gremlin.html + CSV = 'csv' + # see https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-tutorial-format-rdf.html + N_TRIPLES = 'ntriples' + N_QUADS = 'nquads' + RDF_XML = 'rdfxml' + TURTLE = 'turtle' # no, this is totaly a coincidence + + +class NeptuneBulkLoaderApi: + def __init__(self, *, session: boto3.session.Session, endpoint_uri: Union[str, SplitResult], + override_uri: Optional[Union[str, SplitResult]] = None, + iam_role_name: str = 'NeptuneLoadFromS3', s3_bucket_name: str) -> None: + self.session = session + self.endpoint_uri = _urlsplit_if_not_already(endpoint_uri) + assert self.endpoint_uri.path == '/gremlin' and self.endpoint_uri.scheme in ('ws', 'wss') and \ + not self.endpoint_uri.query, f'expected gremlin uri: {endpoint_uri}' + self.override_uri = _urlsplit_if_not_already(override_uri) if override_uri is not None else None + account_id = self.session.client('sts').get_caller_identity()['Account'] + self.iam_role_arn = f'arn:aws:iam::{account_id}:role/NeptuneLoadFromS3' + self.s3_bucket_name = s3_bucket_name + # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3.html#using-the-transfer-manager + self.s3_transfer_config = TransferConfig( + multipart_threshold=100 * (2 ** 20), # 100MB + max_concurrency=5) + + @classmethod + def create_from_config(cls: Type["NeptuneBulkLoaderApi"], config: Mapping[Any, Any]) -> "NeptuneBulkLoaderApi": + neptune_url = config.get('NEPTUNE_URL') + endpoint_uri: Optional[str] + override_uri: Optional[str] + if isinstance(neptune_url, str): + endpoint_uri = neptune_url + override_uri = None + elif isinstance(neptune_url, Mapping): + endpoint_uri = f"wss://{neptune_url['neptune_endpoint']}:{neptune_url['neptune_port']}/gremlin" + override_uri = neptune_url['uri'] + else: + raise AssertionError(f'expected NEPTUNE_URL to be a str or dict: {neptune_url}') + s3_bucket_name = config.get('NEPTUNE_BULK_LOADER_S3_BUCKET_NAME') + assert s3_bucket_name is not None and isinstance(s3_bucket_name, str) + return cls(endpoint_uri=endpoint_uri, override_uri=override_uri, session=config.get('NEPTUNE_SESSION'), + s3_bucket_name=s3_bucket_name) + + def _get_aws4auth(self, service_name: str) -> AWS4Auth: + return AssumeRoleAWS4Auth(self.session.get_credentials(), self.session.region_name, service_name) + + def upload(self, *, f: IO[bytes], s3_object_key: str) -> None: + """ + e.g. + + with BytesIO(vertex_csv) as f: + api.upload(f=f, s3_object_key=f'{object_prefix}/vertex.csv') + """ + s3_client = self.session.client('s3') + return s3_client.upload_fileobj(f, self.s3_bucket_name, s3_object_key, Config=self.s3_transfer_config) + + def load(self, *, s3_object_key: str, dependencies: List[str] = [], + parallelism: BulkLoaderParallelism = BulkLoaderParallelism.HIGH, failOnError: bool = False, + updateSingleCardinalityProperties: bool = True, queueRequest: bool = True, **kwargs: Any) -> Any: + uri = urlunsplit(('https' if self.endpoint_uri.scheme == 'wss' else 'http', self.endpoint_uri.netloc, 'loader', + self.endpoint_uri.query, self.endpoint_uri.fragment)) + response = request_with_override( + method='POST', uri=uri, override_uri=self.override_uri, auth=self._get_aws4auth('neptune-db'), + proxies=dict(http=None, https=None), + data=dict(source=f's3://{self.s3_bucket_name}/{s3_object_key}', + format=BulkLoaderFormat.CSV.value, + iamRoleArn=self.iam_role_arn, + region=self.session.region_name, + failOnError=failOnError, + parallelism=parallelism.name, + updateSingleCardinalityProperties=updateSingleCardinalityProperties, + queueRequest=queueRequest, + dependencies=dependencies), + **kwargs) + return response_as_json(response) + + def load_status(self, *, load_id: str = '', errors: bool = False, errors_per_page: int = 10, errors_page: int = 1, + **kwargs: Any) -> NeptuneBulkLoaderLoadStatus: + """ + See https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-status.html + """ + query_parameters = dict() + if errors: + query_parameters.update(dict(errors=errors, errorsPerPage=errors_per_page, page=errors_page)) + uri = urlunsplit(('https' if self.endpoint_uri.scheme == 'wss' else 'http', self.endpoint_uri.netloc, + f'loader/{load_id}', urlencode(query_parameters), self.endpoint_uri.fragment)) + response = request_with_override(uri=uri, override_uri=self.override_uri, auth=self._get_aws4auth('neptune-db'), + proxies=dict(http=None, https=None)) + return cast(NeptuneBulkLoaderLoadStatus, response_as_json(response)) + + def bulk_load_entities( + self, *, entities: Mapping[GraphEntityType, Mapping[str, GraphEntity]], object_prefix: Optional[str] = None, + polling_period: int = 10, raise_if_failed: bool = False) -> Mapping[str, Mapping[str, Any]]: + """ + :param entities: The entities being bulk loaded. They will be partitioned at least by vertex vs edge, but + possibly by conflicting property type (though the latter is unusual), and written to files in S3, then loaded + by Neptune. + :param object_prefix: (optional) The string is treated like a format string, and 'now' and 'shard' (if + get_shard() is truthy) are the available parameters. Defaults to '{now}/{shard}' or '{now}'. + :param polling_period: (optional) defaults to 10 (seconds). The period at which the status will be polled. + :param raise_if_failed: (optional) defaults to False. If True, will raise if any of the loads failed, otherwise + log a warning and return the status. True would be useful for testing or other situations where you would + always expect the load to succeed. + :return: + """ + format_args = dict( + now=datetime.datetime.now().isoformat(timespec='milliseconds').replace(':', '-').replace('.', '-')) + shard = get_shard() + if shard: + format_args.update(shard=shard) + if not object_prefix: + object_prefix = '{now}/{shard}' if 'shard' in format_args else '{now}' + object_prefix = object_prefix.format(**format_args) + + assert isinstance(object_prefix, str) and all(c not in object_prefix for c in ':'), \ + f'object_prefix is going to break S3 {object_prefix}' + + vertexes, edges = group_by_class(entities) + + # TODO: write these to tmp? stream them in? + vertex_csvs: List[bytes] = [] + for types in partition_properties(vertexes.keys()): + with StringIO() as w: + write_entities_as_csv(w, dict((t, vertexes[t]) for t in types)) + vertex_csvs.append(w.getvalue().encode('utf-8')) + + edge_csvs: List[bytes] = [] + for types in partition_properties(edges.keys()): + with StringIO() as w: + write_entities_as_csv(w, dict((t, edges[t]) for t in types)) + edge_csvs.append(w.getvalue().encode('utf-8')) + + csvs: List[Tuple[str, bytes]] = [(f'{object_prefix}/vertex{i}.csv', v) for i, v in enumerate(vertex_csvs)] + [ + (f'{object_prefix}/edge{i}.csv', v) for i, v in enumerate(edge_csvs)] + + todo: List[str] = [] + for s3_object_key, v in csvs: + # upload to s3 + with BytesIO(v) as r: + self.upload(f=r, s3_object_key=s3_object_key) + + # now poke Neptune and tell it to load that file + # TODO: dependencies? endpoint doesn't seem to like the way we pass these + response = self.load(s3_object_key=s3_object_key) + + # TODO: retry? + assert 'payload' in response and 'loadId' in response['payload'], \ + f'failed to submit load for vertex.csv: {response}' + todo.append(response['payload']['loadId']) + + status_by_load_id: Dict[str, Mapping[str, Any]] = dict() + + while todo: + status_by_load_id.update( + [(id, self.load_status(load_id=id, errors=True, errors_per_page=30)['payload']) + for id in todo]) + todo = [load_id for load_id, overall_status in status_by_load_id.items() + if overall_status['overallStatus']['status'] not in ('LOAD_COMPLETED', 'LOAD_FAILED')] + time.sleep(polling_period) + + # TODO: timeout and parse errors + assert not todo + failed = dict([(load_id, overall_status) for load_id, overall_status in status_by_load_id.items() + if overall_status['overallStatus']['status'] != 'LOAD_COMPLETED']) + if failed: + LOGGER.warning(f'some loads failed: {failed.keys()}: bulk_loader_details={failed}') + if raise_if_failed: + raise AssertionError(f'some loads failed: {failed.keys()}') + + return status_by_load_id + + +def group_by_class(entities: Mapping[GraphEntityType, Mapping[str, GraphEntity]]) -> \ + Tuple[Mapping[GraphEntityType, Iterable[GraphEntity]], Mapping[GraphEntityType, Iterable[GraphEntity]]]: + vertex_types: Mapping[GraphEntityType, List[GraphEntity]] = defaultdict(list) + edge_types: Mapping[GraphEntityType, List[GraphEntity]] = defaultdict(list) + for t, es in entities.items(): + # there was a slicker (10 lines shorter) map based implementation in pasta.py but it defeated the type checker + if isinstance(t, VertexType): + assert not isinstance(t, EdgeType) + vertex_types[t].extend(es.values()) + elif isinstance(t, EdgeType): + assert not isinstance(t, VertexType) + edge_types[t].extend(es.values()) + else: + raise AssertionError(f'expected type {t} to be a VertexType or an EdgeType not a {type(t)}') + return vertex_types, edge_types + + +def partition_properties(types: Collection[GraphEntityType]) -> Iterable[Iterable[GraphEntityType]]: + default_cardinality: Optional[GremlinCardinality] + if all(isinstance(t, VertexType) for t in types): + default_cardinality = GremlinCardinality.single + elif all(isinstance(t, EdgeType) for t in types): + default_cardinality = None + else: + raise AssertionError(f'some are not VertexType or EdgeType? {types}') + + partitioned: List[Collection[GraphEntityType]] = list() + + def _try(_types: Collection[GraphEntityType]) -> None: + by_name: Dict[str, Set[Property]] = defaultdict(set) + by_signature: Dict[Property, Set[GraphEntityType]] = defaultdict(set) + for t in _types: + for p in t.properties: + signature = p.signature(default_cardinality) + by_name[signature.name].add(signature) + by_signature[signature].add(t) + + overlapping_properties = [(k, v) for k, v in by_name.items() if len(v) != 1] + if not overlapping_properties: + partitioned.append(tuple(_types)) + return + + # this could be smarter if there are a lot of overlaps, or if you'd like to make the parititions equally sized + ignored, signatures = overlapping_properties[0] + overlapping_types = sorted([by_signature[signature] for signature in signatures], key=len) + assert all(a.isdisjoint(b) for i, a in enumerate(overlapping_types) for b in overlapping_types[i + 1:]), \ + f'expected to not overlap: {overlapping_types}' + + all_overlapping_types: Set[GraphEntityType] = set().union(*overlapping_types) # type: ignore + other_types = set(_types).difference(all_overlapping_types) + assert all(other_types.isdisjoint(e) for e in overlapping_types), \ + f'expected to not overlap: {other_types} and {overlapping_types}' + overlapping_types[0] = overlapping_types[0].union(other_types) + + # let's make sure we're making progress + assert sum(len(e) for e in overlapping_types) == len(_types) and all(len(e) > 0 for e in overlapping_types) + for e in overlapping_types: + _try(e) + + _try(types) + + # did we cover them all? + assert sum(len(e) for e in partitioned) == len(types) + # do they overlap? + assert all(set(a).isdisjoint(set(b)) for i, a in enumerate(partitioned) for b in partitioned[i + 1:]) + return partitioned + + +GET = TypeVar('GET', bound=GraphEntityType) + + +def write_entities_as_csv(file: IO[str], entities: Mapping[GET, Iterable[GraphEntity]]) -> None: + # TODO: also explodes if there is incompatible overlap in names, which we could avoid around by partitioning + properties = merge_properties(entities.keys()) + formatted: List[Mapping[str, str]] = [format_entity(t, e) for t, es in entities.items() for e in es] + write_csv(file, properties, formatted) + + +def merge_properties(types: Iterable[GraphEntityType]) -> PropertiesMap: + default_cardinality: Optional[GremlinCardinality] + if all(isinstance(t, VertexType) for t in types): + default_cardinality = GremlinCardinality.single + elif all(isinstance(t, EdgeType) for t in types): + default_cardinality = None + else: + raise AssertionError(f'some are not VertexType or EdgeType? {types}') + + by_name: Dict[str, Set[Property]] = defaultdict(set) + for t in types: + for p in t.properties: + signature = p.signature(default_cardinality) + by_name[signature.name].add(signature) + + overlapping_types = [(k, v) for k, v in by_name.items() if len(v) != 1] + assert not overlapping_types, f'some Property have incompatible signatures: {overlapping_types}' + return dict([(k, v) for k, (v,) in by_name.items()]) + + +def format_entity(entity_type: GraphEntityType, entity: GraphEntity) -> FormattedEntity: + properties = entity_type.properties_as_map() + # the use would naturally explode if property name isn't in the type, but this is better + assert set(entity.keys()).issubset(set(properties.keys())), \ + f'some properties in the entity are not in the entity type? entity: {entity}, type: {entity_type}' + assert set(entity.keys()).issuperset(set([k for k, p in properties.items() if p.required])), \ + f'some required properties in the entity are not present? entity: {entity}, type: {entity_type}' + return dict([(n, properties[n].format(v)) for n, v in entity.items()]) + + +# these seem to match the format that neptune-export produces (which doesn't use csv.writer) +csv_kwargs = dict(dialect='excel', delimiter=',', quotechar='"', doublequote=True) + + +def write_csv(file: IO[str], properties: PropertiesMap, entities: FormattedEntities) -> None: + """ + entities is the already formatted values + """ + # eliminate the not-present properties + property_names_set: Set[str] = set() + for e in entities: + property_names_set.update(k for k, v in e.items() if v is not None) + assert property_names_set.issubset(properties.keys()), f'wat? entities have property names not in the Properties?' + property_names: List[str] = sorted(list(property_names_set)) + + # only really shows up in testing + if not property_names: + return + + # no, it's not a context manager + w = csv.writer(file, **csv_kwargs) + + # don't writeheader, instead write header from properties + w.writerow([properties[n].header() for n in property_names]) + + for e in entities: + w.writerow([e.get(n, '') for n in property_names]) + + +def new_entities() -> Dict[GraphEntityType, List[GraphEntity]]: + return defaultdict(list) diff --git a/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py new file mode 100644 index 0000000..116b0c8 --- /dev/null +++ b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py @@ -0,0 +1,721 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import logging +from collections import defaultdict +from functools import lru_cache +from typing import ( + Any, Callable, FrozenSet, Iterable, List, Mapping, MutableMapping, + NamedTuple, NewType, Optional, Sequence, Tuple, Union, cast +) + +from amundsen_common.models.table import Application, Column, Table +from amundsen_common.models.user import User +from gremlin_python.process.graph_traversal import GraphTraversalSource, __ +from gremlin_python.process.traversal import Column as MapColumn +from gremlin_python.process.traversal import T + +from amundsen_gremlin.gremlin_model import ( + EdgeType, EdgeTypes, GremlinCardinality, MagicProperties, Property, + VertexType, VertexTypes, WellKnownProperties +) +from amundsen_gremlin.gremlin_shared import ( # noqa: F401 + append_traversal, make_cluster_uri, make_column_statistic_uri, + make_column_uri, make_database_uri, make_description_uri, make_schema_uri, + make_table_uri +) +from amundsen_gremlin.utils.streams import chunk + +LOGGER = logging.getLogger(__name__) + +EXISTING_KEY = FrozenSet[Tuple[str, str]] +EXISTING = NewType('EXISTING', + Mapping[Union[VertexType, EdgeType], MutableMapping[EXISTING_KEY, Mapping[str, Any]]]) +ENTITIES = NewType('ENTITIES', Mapping[Union[VertexType, EdgeType], MutableMapping[str, Mapping[str, Any]]]) + + +def new_entities() -> ENTITIES: + return cast(ENTITIES, defaultdict(dict)) + + +def new_existing() -> EXISTING: + return cast(EXISTING, defaultdict(dict)) + + +def _get_existing_key_from_entity(_entity: Mapping[str, Any]) -> EXISTING_KEY: + """ + Used for testing. + """ + _entity = dict(_entity) + label: str = _entity.pop(MagicProperties.LABEL.value.name) + assert isinstance(label, str) # appease the types + return _get_existing_key(_type=label, **_entity) + + +def _get_existing_key(_type: Union[VertexType, EdgeType, VertexTypes, EdgeTypes, str], **_entity: Any) -> EXISTING_KEY: + """ + Maybe this should be a part of EdgeType and VertexType. But, this function certainly shouldn't be used + away from EXISTING (or testing) + """ + if isinstance(_type, str): + vertex_type = VertexTypes.by_label().get(_type) + edge_type = EdgeTypes.by_label().get(_type) + assert bool(vertex_type) != bool(edge_type), \ + f'expected exactly one of VertexTypes or EdgeTypes to match {_type}' + vertex_or_edge_type = vertex_type or edge_type + assert vertex_or_edge_type is not None # appease mypy + _type = vertex_or_edge_type + assert isinstance(_type, (VertexTypes, EdgeTypes)) + if isinstance(_type, (VertexTypes, EdgeTypes)): + _type = _type.value + assert isinstance(_type, (VertexType, EdgeType)) + key_properties = _get_key_properties(_type) + # this eliding (the created and label fields) feels icky, but makes sense. + if isinstance(_type, EdgeType): + key_properties = key_properties.difference({WellKnownProperties.Created.value}) + if isinstance(_type, VertexType): + key_properties = key_properties.difference({WellKnownProperties.TestShard.value}) + key_properties = key_properties.difference({MagicProperties.LABEL.value}) + assert WellKnownProperties.Created.value not in key_properties + assert MagicProperties.LABEL.value not in key_properties + return frozenset([(p.name, p.format(_entity.get(p.name))) for p in key_properties]) + + +@lru_cache(maxsize=len(list(VertexTypes)) + len(list(EdgeTypes)) + 100) +def _get_key_properties(_type: Union[VertexType, EdgeType]) -> FrozenSet[Property]: + assert isinstance(_type, (EdgeType, VertexType)) + return frozenset([_type.properties_as_map()[n] for n in _discover_parameters(_type.id_format)]) + + +def _discover_parameters(format_string: str) -> FrozenSet[str]: + """ + use this to discover what the parameters to a format string + """ + parameters: FrozenSet[str] = frozenset() + while True: + try: + format_string.format(**dict((k, '') for k in parameters)) + return parameters + except KeyError as e: + updated = parameters.union(set(e.args)) + assert updated != parameters + parameters = updated + + +def date_string_to_date(a_date: str) -> datetime.date: + return datetime.datetime.strptime(a_date, '%Y-%m-%d').date() + + +class TableUris(NamedTuple): + database: str + cluster: str + schema: str + table: str + + @staticmethod + def get(*, database: str, cluster: str, schema: str, table: str) -> "TableUris": + database_uri = make_database_uri(database_name=database) + cluster_uri = make_cluster_uri(database_uri=database_uri, cluster_name=cluster) + schema_uri = make_schema_uri(cluster_uri=cluster_uri, schema_name=schema) + table_uri = make_table_uri(schema_uri=schema_uri, table_name=table) + return TableUris(database=database_uri, cluster=cluster_uri, schema=schema_uri, table=table_uri) + + +HISTORICAL_APP_PREFIX = 'app-' +ENVIRONMENT_APP_SUFFIXES = frozenset(['-development', '-devel', '-staging', '-stage', '-production', '-prod']) + + +def possible_application_names_application_key(app_key: str) -> Iterable[str]: + # get both the app- and not + app_keys = [app_key] + if app_key.startswith(HISTORICAL_APP_PREFIX): + app_keys.append(app_key[len(HISTORICAL_APP_PREFIX):]) + else: + app_keys.append(f'{HISTORICAL_APP_PREFIX}{app_key}') + + for suffix in ENVIRONMENT_APP_SUFFIXES: + if app_key.endswith(suffix): + without = [_[:-len(suffix)] for _ in app_keys] + app_keys.extend(without) + break + + return tuple(app_keys) + + +def possible_existing_keys_for_application_key(*app_keys: str) -> FrozenSet[EXISTING_KEY]: + return frozenset([_get_existing_key(VertexTypes.Application, key=key) + for app_key in app_keys for key in possible_application_names_application_key(app_key)]) + + +def possible_vertex_ids_for_application_key(*app_keys: str) -> FrozenSet[str]: + return frozenset([ + VertexTypes.Application.value.id(**dict(key)) for key in possible_existing_keys_for_application_key(*app_keys)]) + + +def ensure_edge_type(edge_type: Union[str, EdgeTypes, EdgeType]) -> EdgeType: + if isinstance(edge_type, str): + edge_type = EdgeTypes.by_label()[edge_type].value + if isinstance(edge_type, EdgeTypes): + edge_type = edge_type.value + assert isinstance(edge_type, EdgeType) + return edge_type + + +def ensure_vertex_type(vertex_type: Union[str, VertexTypes, VertexType]) -> VertexType: + if isinstance(vertex_type, str): + vertex_type = VertexTypes.by_label()[vertex_type].value + if isinstance(vertex_type, VertexTypes): + vertex_type = vertex_type.value + assert isinstance(vertex_type, VertexType) + return vertex_type + + +class _FetchExisting: + @classmethod + def _fake_into_existing_edges_for_testing(cls, _existing: EXISTING, _type: Union[EdgeType, EdgeTypes], _from: str, + _to: str, **entity: Any) -> Mapping[str, Any]: + _type = ensure_edge_type(_type) + _entity = _type.create(**entity, **{ + MagicProperties.LABEL.value.name: _type.label, + MagicProperties.FROM.value.name: _from, + MagicProperties.TO.value.name: _to, + }) + _key = _get_existing_key(_type=_type, **_entity) + assert _key not in _existing[_type] + _existing[_type][_key] = _entity + return _entity + + @classmethod + def _fake_into_existing_vertexes_for_testing(cls, _existing: EXISTING, _type: Union[VertexType, VertexTypes], + **entity: Any) -> Mapping[str, Any]: + _type = ensure_vertex_type(_type) + _entity = _type.create(** entity, **{ + MagicProperties.LABEL.value.name: _type.label, + }) + _key = _get_existing_key(_type=_type, **_entity) + assert _key not in _existing[_type] + _existing[_type][_key] = _entity + return _entity + + @classmethod # noqa: C901 + def _honor_cardinality_once(cls, _property: Property, value: Any) -> Any: + # use the types to figure out if we should take the element instead + if _property.cardinality == GremlinCardinality.single or _property.cardinality is None: + # is this the most general type? + if isinstance(value, Sequence): + assert len(value) <= 1, f'single cardinality property has more than one value! {value}' + value = value[0] if value else None + if value is not None: + _property.type.value.is_allowed(value) + return value + elif _property.cardinality == GremlinCardinality.list: + # is this the most general type? + if value is None: + value = () + elif isinstance(value, Iterable) and not isinstance(value, tuple): + value = tuple(value) + for e in value: + _property.type.value.is_allowed(e) + return value + elif _property.cardinality == GremlinCardinality.set: + # is this the most general type? + if value is None: + value = frozenset() + elif isinstance(value, Iterable) and not isinstance(value, FrozenSet): + value = frozenset(value) + for e in value: + _property.type.value.is_allowed(e) + return value + raise AssertionError('never') + + @classmethod + def _honor_cardinality(cls, _type: Union[VertexType, EdgeType], **entity: Any) -> Mapping[str, Any]: + _properties = _type.properties_as_map() + result = dict() + for k, v in entity.items(): + if not _properties.get(k): + LOGGER.error(f'Trying to honor cardinality for property {k} which isnt allowed for {_type.label}') + continue + result[k] = cls._honor_cardinality_once(_properties[k], v) + return result + + @classmethod + def _into_existing(cls, value_maps: Sequence[Union[Mapping[Any, Any], Sequence[Any]]], existing: EXISTING) -> None: + """ + value_map for an edge should be the result of .union(__.outV().id(), __.valueMap(True), __.inV().id()).fold() + value_map for a vertex should be the result of valueMap(True) + """ + assert all(isinstance(e, (Mapping, Sequence)) for e in value_maps) + edge_value_maps = [e for e in value_maps if isinstance(e, Sequence)] + vertex_value_maps = [e for e in value_maps if isinstance(e, Mapping)] + assert len(value_maps) == len(edge_value_maps) + len(vertex_value_maps) + + for _from, entity, _to in edge_value_maps: + entity = dict(entity) + _type = EdgeTypes.by_label()[entity.pop(T.label)].value + _id = entity.pop(T.id) + # clear out the other special values. eventually we'll be able to ask for just the id and label, but that's + # not supported in Neptune (you can only do valueMap(True)) + for v in iter(T): + entity.pop(v, None) + _entity = _type.create(**entity, **{ + MagicProperties.LABEL.value.name: _type.label, + MagicProperties.ID.value.name: _id, + MagicProperties.FROM.value.name: _from, + MagicProperties.TO.value.name: _to, + }) + _key = _get_existing_key(_type=_type, **_entity) + # should we expect only one? things like the CLUSTER, and SCHEMA will duplicate + if _key in existing[_type]: + if existing[_type][_key] != _entity: + LOGGER.info(f'we already have a type: {_type.label}, id={_id} that is different: ' + f'{existing[_type][_key]} != {_entity}') + else: + # should the magic properties go in here too? It might be nicer to not, but is convenient + existing[_type][_key] = _entity + + for entity in vertex_value_maps: + entity = dict(entity) + _type = VertexTypes.by_label()[entity.pop(T.label)].value + _id = entity.pop(T.id) + # clear out the other special values. eventually we'll be able to ask for just the id and label, but that's + # not supported in Neptune (you can only do valueMap(True)) + for v in iter(T): + entity.pop(v, None) + _entity = _type.create(**cls._honor_cardinality(_type, **entity), **{ + MagicProperties.LABEL.value.name: _type.label, + MagicProperties.ID.value.name: _id, + }) + _key = _get_existing_key(_type=_type, **_entity) + # should we expect only one? things like the CLUSTER, and SCHEMA will duplicate + if _key in existing[_type]: + if existing[_type][_key] != _entity: + LOGGER.error(f'we already have a type: {_type.label}, id={_id} that is different: ' + f'{existing[_type][_key]} != {_entity}') + else: + # should the magic properties go in here too? It might be nicer to not, but is convenient + existing[_type][_key] = _entity + + @classmethod + def table_entities(cls, *, _g: GraphTraversalSource, table_data: List[Table], existing: EXISTING) -> None: + + all_tables_ids = list(set([ + VertexTypes.Table.value.id(key=TableUris.get( + database=t.database, cluster=t.cluster, schema=t.schema, table=t.name).table) + for t in table_data])) + + all_owner_ids = list(set([VertexTypes.User.value.id(key=key) + for key in [t.table_writer.id for t in table_data if t.table_writer is not None]])) + all_application_ids = list(set(list(possible_vertex_ids_for_application_key( + *[t.table_writer.id for t in table_data if t.table_writer is not None])))) + + # chunk these since 100,000s seems to choke + for tables_ids in chunk(all_tables_ids, 1000): + LOGGER.info(f'fetching for tables: {tables_ids}') + # fetch database -> cluster -> schema -> table links + g = _g.V(tuple(tables_ids)).as_('tables') + g = g.coalesce(__.inE(EdgeTypes.Table.value.label).dedup().fold()).as_(EdgeTypes.Table.name) + g = g.coalesce(__.unfold().outV().hasLabel(VertexTypes.Schema.value.label). + inE(EdgeTypes.Schema.value.label).dedup(). + fold()).as_(EdgeTypes.Schema.name) + g = g.coalesce(__.unfold().outV().hasLabel(VertexTypes.Cluster.value.label). + inE(EdgeTypes.Cluster.value.label).dedup(). + fold()).as_(EdgeTypes.Cluster.name) + + # fetch table <- links + for t in (EdgeTypes.BelongToTable, EdgeTypes.Generates, EdgeTypes.Tag): + g = g.coalesce( + __.select('tables').inE(t.value.label).fold()).as_(t.name) + + # fetch table -> column et al links + for t in (EdgeTypes.Column, EdgeTypes.Description, EdgeTypes.LastUpdatedAt, + EdgeTypes.Source, EdgeTypes.Stat): + g = g.coalesce( + __.select('tables').outE(t.value.label).fold()).as_(t.name) + + # TODO: add owners, watermarks, last timestamp existing, source + aliases = set([t.name for t in ( + EdgeTypes.Table, EdgeTypes.Schema, EdgeTypes.Cluster, EdgeTypes.BelongToTable, EdgeTypes.Generates, + EdgeTypes.Tag, EdgeTypes.Column, EdgeTypes.Description, EdgeTypes.LastUpdatedAt, + EdgeTypes.Source, EdgeTypes.Stat)]) + g = g.select(*aliases).unfold().select(MapColumn.values).unfold() + g = g.local(__.union(__.outV().id(), __.valueMap(True), __.inV().id()).fold()) + cls._into_existing(g.toList(), existing) + + cls._column_entities(_g=_g, tables_ids=tables_ids, existing=existing) + + # fetch Application, User + for ids in chunk(list(set(all_application_ids + all_owner_ids)), 5000): + LOGGER.info(f'fetching for application/owners: {ids}') + g = _g.V(ids).valueMap(True) + cls._into_existing(g.toList(), existing) + + @classmethod + def _column_entities(cls, *, _g: GraphTraversalSource, tables_ids: Iterable[str], existing: EXISTING) -> None: + # fetch database -> cluster -> schema -> table links + g = _g.V(tuple(tables_ids)) + g = g.outE(EdgeTypes.Column.value.label) + g = g.inV().hasLabel(VertexTypes.Column.value.label).as_('columns') + + # fetch column -> links (no Stat) + for t in [EdgeTypes.Description]: + g = g.coalesce(__.select('columns').outE(t.value.label).fold()).as_(t.name) + + g = g.select(EdgeTypes.Description.name).unfold() + g = g.local(__.union(__.outV().id(), __.valueMap(True), __.inV().id()).fold()) + cls._into_existing(g.toList(), existing) + + @classmethod + def expire_connections_for_other(cls, *, _g: GraphTraversalSource, vertex_type: VertexType, keys: FrozenSet[str], + existing: EXISTING) -> None: + # V().has(label, 'key', P.without(keys)) is more intuitive but doesn't scale, so instead just find all those + g = _g.V().hasLabel(vertex_type.label).where(__.bothE()) + g = g.values(WellKnownProperties.Key.value.name) + all_to_expire_keys = set(g.toList()).difference(keys) + + # TODO: when any vertex ids that need something besides key + all_to_expire = set(vertex_type.id(key=key) for key in all_to_expire_keys) + + for to_expire in chunk(all_to_expire, 1000): + g = _g.V(tuple(to_expire)).bothE() + g = g.local(__.union(__.outV().id(), __.valueMap(True), __.inV().id()).fold()) + cls._into_existing(g.toList(), existing) + + +class _GetGraph: + @classmethod + def expire_previously_existing(cls, *, edge_types: Sequence[Union[EdgeTypes, EdgeType]], entities: ENTITIES, + existing: EXISTING) -> None: + _edge_types = [e.value if isinstance(e, EdgeTypes) else e for e in edge_types] + assert all(isinstance(e, EdgeType) for e in _edge_types), \ + f'expected all EdgeTypes or EdgeType: {edge_types}' + + for edge_type in _edge_types: + for entity in existing[edge_type].values(): + + entity_id = entity[MagicProperties.ID.value.name] + if entity_id in entities[edge_type]: + continue + + del entities[edge_type][entity_id] + + @classmethod + def _create(cls, _type: Union[VertexTypes, VertexType, EdgeTypes, EdgeType], _entities: ENTITIES, + _existing: EXISTING, **_kwargs: Any) -> Mapping[str, Any]: + if isinstance(_type, (VertexTypes, EdgeTypes)): + _type = _type.value + assert isinstance(_type, (VertexType, EdgeType)) + + # Let's prefer the new properties unless it's part of the the id properties (e.g. Created) + _existing_key = _get_existing_key(_type, **_kwargs) + if _existing_key in _existing[_type]: + names = frozenset(p.name for p in _get_key_properties(_type)) + _kwargs.update((k, v) for k, v in _existing[_type][_existing_key].items() if k in names) + # need to do this after that update, otherwise we'll miss out on crucial properties when generating ~id + _entity = _type.create(**_kwargs) + else: + _entity = _type.create(**_kwargs) + # also put this in _existing. Say, we're creating a Column or Table and a subsequence Description expects + # to find it. (TODO: This isn't perfect, it will miss tables_by_app, and neighbors_by_capability) + _existing[_type][_existing_key] = _entity + + _id = _entity.get(MagicProperties.ID.value.name, None) + if _id in _entities[_type]: + # it'd be nice to assert _id not in _entities[_type], but we generate duplicates (e.g. Database, Cluster, + # Schema, and their links) so let's at least ensure we're not going to be surprised with a different result + # TODO: reenable this after we figure out why these conflict + # assert _entities[_type][_id] == _entity, \ + if _entities[_type][_id] != _entity: + LOGGER.info(f'we already have a type: {_type.label}, id={_id} that is different: ' + f'{_entities[_type][_id]} != {_entity}') + else: + _entities[_type][_id] = _entity + return _entities[_type][_id] + + @classmethod + def table_metric(cls, table: Table) -> int: + """ + :returns a number like the number of vertexes that would be added due to this table + """ + return sum((2, 1 if table.description is not None else 0, + len(table.programmatic_descriptions or ()), len(table.programmatic_descriptions or ()), + len(table.tags or ()), sum(map(cls._column_metric, table.columns)))) + + @classmethod + def table_entities(cls, *, table_data: List[Table], entities: ENTITIES, existing: EXISTING, # noqa: C901 + created_at: datetime.datetime) -> None: + """ + existing: must cover exactly the set of data. (previously existing edges will be expired herein, and possibly + otherwise duplicate edges will be created) + """ + + for table in table_data: + uris = TableUris.get(database=table.database, cluster=table.cluster, schema=table.schema, table=table.name) + + database = cls._create( + VertexTypes.Database, entities, existing, name=table.database, key=uris.database) + + cluster = cls._create(VertexTypes.Cluster, entities, existing, name=table.cluster, key=uris.cluster) + cls._create(EdgeTypes.Cluster, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: database[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: cluster[MagicProperties.ID.value.name]}) + + schema = cls._create(VertexTypes.Schema, entities, existing, name=table.schema, key=uris.schema) + cls._create(EdgeTypes.Schema, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: cluster[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: schema[MagicProperties.ID.value.name]}) + + table_vertex = cls._create(VertexTypes.Table, entities, existing, name=table.name, key=uris.table, + is_view=table.is_view) + cls._create(EdgeTypes.Table, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: schema[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: table_vertex[MagicProperties.ID.value.name]}) + + if table.table_writer: + cls._application_entities(app_key=table.table_writer.id, table=table_vertex, entities=entities, + existing=existing, created_at=created_at) + + if table.description is not None: + cls._description_entities( + subject_uri=table_vertex['key'], to_vertex_id=table_vertex[MagicProperties.ID.value.name], + source='user', entities=entities, existing=existing, created_at=created_at, + description=table.description) + + for description in table.programmatic_descriptions: + cls._description_entities( + subject_uri=table_vertex['key'], to_vertex_id=table_vertex[MagicProperties.ID.value.name], + source=description.source, entities=entities, existing=existing, created_at=created_at, + description=description.text) + # TODO: need to call expire source != 'user' description links after + + # create tags + for tag in table.tags: + vertex = cls._create(VertexTypes.Tag, entities, existing, key=tag.tag_name, **vars(tag)) + cls._create(EdgeTypes.Tag, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: vertex[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: table_vertex[MagicProperties.ID.value.name]}) + # since users can tag these, we shouldn't expire any of them (unlike Description where source + # distinguishes) + + # update timestamp + vertex = cls._create(VertexTypes.Updatedtimestamp, entities, existing, key=table_vertex['key'], + latest_timestamp=created_at) + cls._create(EdgeTypes.LastUpdatedAt, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: table_vertex[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: vertex[MagicProperties.ID.value.name]}) + + cls._column_entities(table_vertex=table_vertex, column_data=table.columns, entities=entities, + existing=existing, created_at=created_at) + + @classmethod + def _application_entities(cls, *, app_key: str, table: Mapping[str, Mapping[str, Any]], entities: ENTITIES, + existing: EXISTING, created_at: datetime.datetime) -> None: + # use existing to find what Application really exists, which is a bit different than how it's used for edges + actual_keys = dict([ + (VertexTypes.Application.value.id(**dict(v)), v) + for v in possible_existing_keys_for_application_key(app_key)]) + actual_keys = dict([(k, v) for k, v in actual_keys.items() if v in existing[VertexTypes.Application.value]]) + if actual_keys: + vertex_id = list(actual_keys.items())[0][0] + cls._create(EdgeTypes.Generates, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: vertex_id, + MagicProperties.TO.value.name: table[MagicProperties.ID.value.name]}) + return + + # if app isn't found, the owner may be a user + actual_keys = dict([(VertexTypes.User.value.id(key=app_key), _get_existing_key(VertexTypes.User, key=app_key))]) + actual_keys = dict([(k, v) for k, v in actual_keys.items() if v in existing[VertexTypes.User.value]]) + if actual_keys: + vertex_id = list(actual_keys.items())[0][0] + LOGGER.debug(f'{app_key} is not a real app but it was marked as owner: {table["key"]}') + cls._create(EdgeTypes.Owner, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: table[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: vertex_id}) + return + + LOGGER.info(f'{app_key} is not a real Application, nor can we find a User to be an Owner for {table["key"]}') + + @classmethod + def _description_entities(cls, *, description: str, source: str, subject_uri: str, + to_vertex_id: str, entities: ENTITIES, existing: EXISTING, + created_at: datetime.datetime) -> None: + vertex = cls._create(VertexTypes.Description, entities, existing, + key=make_description_uri(subject_uri=subject_uri, source=source), + description=description, source=source) + cls._create(EdgeTypes.Description, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: to_vertex_id, + MagicProperties.TO.value.name: vertex[MagicProperties.ID.value.name]}) + + @classmethod + def _column_metric(cls, column: Column) -> int: + """ + :returns a number like the number of vertexes that would be added due to this column + """ + return sum((1, 1 if column.description is not None else 0, len(column.stats or ()))) + + @classmethod + def _column_entities(cls, *, table_vertex: Mapping[str, str], column_data: Sequence[Column], entities: ENTITIES, + existing: EXISTING, created_at: datetime.datetime) -> None: + + for column in column_data: + column_vertex = cls._create(VertexTypes.Column, entities, existing, name=column.name, + key=make_column_uri(table_uri=table_vertex['key'], column_name=column.name), + col_type=column.col_type, sort_order=column.sort_order) + cls._create(EdgeTypes.Column.value, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: table_vertex[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: column_vertex[MagicProperties.ID.value.name]}) + + # Add the description if present + if column.description is not None: + cls._description_entities( + subject_uri=column_vertex['key'], to_vertex_id=column_vertex[MagicProperties.ID.value.name], + source='user', entities=entities, existing=existing, created_at=created_at, + description=column.description) + + # Add stats if present + if column.stats: + for stat in column.stats: + vertex = cls._create( + VertexTypes.Stat, entities, existing, + key=make_column_statistic_uri(column_uri=column_vertex['key'], statistic_type=stat.stat_type), + # stat.stat_val is a str, but some callers seem to put ints in there + stat_val=(None if stat.stat_val is None else str(stat.stat_val)), + **dict([(k, v) for k, v in vars(stat).items() if k != 'stat_val'])) + cls._create(EdgeTypes.Stat, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: column_vertex[MagicProperties.ID.value.name], + MagicProperties.TO.value.name: vertex[MagicProperties.ID.value.name]}) + + @classmethod + def user_entities(cls, *, user_data: List[User], entities: ENTITIES, existing: EXISTING, + created_at: datetime.datetime) -> None: + for user in user_data: + # TODO: handle this properly + cls._create(VertexTypes.User, entities, existing, key=user.user_id, + **dict([(k, v) for k, v in vars(user).items() if k != 'other_key_values'])) + + @classmethod + def app_entities(cls, *, app_data: List[Application], entities: ENTITIES, existing: EXISTING, + created_at: datetime.datetime) -> None: + for app in app_data: + cls._create(VertexTypes.Application, entities, existing, key=app.id, + **dict((k, v) for k, v in vars(app).items())) + + @classmethod + def _expire_other_edges( + cls, *, edge_type: Union[EdgeTypes, EdgeType], vertex_id: str, to_or_from_vertex: MagicProperties, + entities: ENTITIES, existing: EXISTING, created_at: datetime.datetime) -> None: + """ + Use this in lieu of expire_previously_existing. + + :param edge_type: + :param vertex_id: + :param to_or_from_vertex: + :param entities: + :param existing: + :param created_at: + :return: + """ + assert to_or_from_vertex in (MagicProperties.FROM, MagicProperties.TO), \ + f'only FROM or TO allowed for {to_or_from_vertex}' + edge_type = ensure_edge_type(edge_type) + # edges of that type.... + edges = tuple(e for e in existing.get(edge_type, {}).values() + # to/from the vertex + if e[to_or_from_vertex.value.name] == vertex_id + # edges that aren't recreated + and e[MagicProperties.ID.value.name] not in entities.get(edge_type, {})) + # expire those: + for entity in edges: + del entities[edge_type][entity[MagicProperties.ID.value.name]] + + +class GetGraph: + def __init__(self, *, g: GraphTraversalSource, created_at: Optional[datetime.datetime] = None) -> None: + self.g = g + self.created_at = datetime.datetime.now() if created_at is None else created_at + self.existing = new_existing() + self.entities = new_entities() + self._expire_previously_existing_callables: List[Callable[[], None]] = list() + + @staticmethod + def table_metric(table: Table) -> int: + return _GetGraph.table_metric(table) + + def add_table_entities(self, table_data: List[Table]) -> "GetGraph": + _FetchExisting.table_entities(table_data=table_data, _g=self.g, existing=self.existing) + _GetGraph.table_entities( + table_data=table_data, entities=self.entities, existing=self.existing, created_at=self.created_at) + self._expire_previously_existing_callables.append(self._expire_previously_existing_table_entities) + return self + + def _expire_previously_existing_table_entities(self) -> None: + _GetGraph.expire_previously_existing( + edge_types=(EdgeTypes.Column, EdgeTypes.Generates, EdgeTypes.Owner), + entities=self.entities, existing=self.existing) + + def add_user_entities(self, user_data: List[User]) -> "GetGraph": + _GetGraph.user_entities( + user_data=user_data, entities=self.entities, existing=self.existing, created_at=self.created_at) + self._expire_previously_existing_callables.append(self._expire_previously_existing_user_entities) + return self + + def _expire_previously_existing_user_entities(self) -> None: + pass + + def add_app_entities(self, app_data: List[Application]) -> "GetGraph": + _GetGraph.app_entities( + app_data=app_data, entities=self.entities, existing=self.existing, created_at=self.created_at) + self._expire_previously_existing_callables.append(self._expire_previously_existing_app_entities) + return self + + def _expire_previously_existing_app_entities(self) -> None: + pass + + def complete(self) -> ENTITIES: + for c in self._expire_previously_existing_callables: + c() + entities = self.entities + del self.entities + del self.existing + return entities + + @classmethod + def default_created_at(cls, created_at: Optional[datetime.datetime]) -> datetime.datetime: + return datetime.datetime.now() if created_at is None else created_at + + @classmethod + def table_entities(cls, *, table_data: List[Table], g: GraphTraversalSource, + created_at: Optional[datetime.datetime] = None) -> ENTITIES: + return GetGraph(g=g, created_at=created_at).add_table_entities(table_data).complete() + + @classmethod + def user_entities(cls, *, user_data: List[User], g: GraphTraversalSource, + created_at: Optional[datetime.datetime] = None) -> ENTITIES: + return GetGraph(g=g, created_at=created_at).add_user_entities(user_data).complete() + + @classmethod + def app_entities(cls, *, app_data: List[Application], g: GraphTraversalSource, + created_at: Optional[datetime.datetime] = None) -> ENTITIES: + return GetGraph(g=g, created_at=created_at).add_app_entities(app_data).complete() + + @classmethod + def expire_connections_for_other( + cls, *, vertex_type: Union[VertexTypes, VertexType], keys: Iterable[str], g: GraphTraversalSource, + created_at: Optional[datetime.datetime] = None) -> ENTITIES: + """ + There's no builder style for this since the expiration implementation is presumptive. + """ + if created_at is None: + created_at = datetime.datetime.now() + assert created_at is not None + if not isinstance(keys, frozenset): + keys = frozenset(keys) + assert isinstance(keys, frozenset) + vertex_type = ensure_vertex_type(vertex_type) + existing = new_existing() + entities = new_entities() + _FetchExisting.expire_connections_for_other(vertex_type=vertex_type, keys=keys, existing=existing, _g=g) + _GetGraph.expire_previously_existing(edge_types=tuple(t for t in EdgeTypes), entities=entities, + existing=existing) + return entities diff --git a/amundsen_gremlin/py.typed b/amundsen_gremlin/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/amundsen_gremlin/script_translator.py b/amundsen_gremlin/script_translator.py new file mode 100644 index 0000000..ee9286f --- /dev/null +++ b/amundsen_gremlin/script_translator.py @@ -0,0 +1,176 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +largely lifted from tinkerpop/gremlin-python/src/main/java/org/apache/tinkerpop/gremlin/python/jsr223/PythonTranslator.java +and tinkerpop/gremlin-groovy/src/main/java/org/apache/tinkerpop/gremlin/groovy/jsr223/GroovyTranslator.java +all credit to its author, all blame leave here. +""" # noqa: E501 + +import datetime +from abc import ABCMeta, abstractmethod +from itertools import starmap +from typing import ( + Any, Dict, Iterable, Iterator, List, Mapping, Sequence, Set, Union +) + +from gremlin_python.driver.remote_connection import RemoteStrategy +from gremlin_python.process.traversal import ( + Barrier, Binding, Bytecode, Cardinality, Column, Direction, + GraphSONVersion, GryoVersion, Operator, Order, P, Pick, Pop, Scope, T, + Traversal +) +from gremlin_python.structure.graph import ( + Edge, Element, Vertex, VertexProperty +) +from overrides import overrides + + +class ScriptTranslator(metaclass=ABCMeta): + @classmethod + def translateB(cls, traversal_source: str, bytecode: Bytecode) -> str: + return cls._internal_translate(traversal_source, bytecode) + + @classmethod + def translateT(cls, traversal: Traversal) -> str: + return cls._internal_translate(cls.get_traversal_source_name(traversal), traversal.bytecode) + + @classmethod + def _internal_translate(cls, traversal_source: str, thing: Union[Traversal, Bytecode]) -> str: + """ + Translates bytecode into a gremlin-groovy script. + """ + if isinstance(thing, Traversal): + bytecode = thing.bytecode + else: + bytecode = thing + + assert isinstance(bytecode, Bytecode), f'object is not supported!: {type(thing)} {thing}' + return f'''{traversal_source}.{'.'.join(starmap(cls._translate_instruction, bytecode.step_instructions))}''' + + @classmethod + def _translate_instruction(cls, step_name: str, *step_args: Any) -> str: + assert isinstance(step_name, str), f'step_name is not a string? {step_name}' + return f'''{step_name}({','.join(map(cls._convert_to_string, step_args))})''' + + @classmethod # noqa: C901 + def _convert_to_string(cls, thing: Any) -> str: + if thing is None: + return "null" + + if isinstance(thing, bool): + # TODO: this is java/groovy specific + # also, did you know that isinstance(True, int) == True? so do this ahead of the int/float branch below + return repr(thing).lower() + + if isinstance(thing, (int, float)): + # TODO: do we need the f, L, d suffixes? + return repr(thing) + + if isinstance(thing, str): + return cls._escape_java_style(thing) + + if isinstance(thing, (Dict, Mapping)): + return f'''[{','.join(f'({cls._convert_to_string(k)}):({cls._convert_to_string(v)})' for k, v in thing.items())}]''' # noqa: E501 + + if isinstance(thing, Set): + return f'''[{','.join(cls._convert_to_string(i) for i in thing)}] as Set''' + + if isinstance(thing, (List, Sequence)): + return f'''[{','.join(cls._convert_to_string(i) for i in thing)}]''' + + if isinstance(thing, Binding): + binding: Binding = thing + return cls._convert_to_string(binding.value) + + if isinstance(thing, Bytecode): + return cls._internal_translate("__", thing) + + if isinstance(thing, Traversal): + return cls._internal_translate("__", thing) + + if isinstance(thing, Element): + if isinstance(thing, (Edge, Vertex, VertexProperty)): + # returns like f'v[{thing.id}]' which seems right + return repr(thing) + + raise AssertionError(f'thing is not supported!: {thing}') + + if isinstance(thing, P): + p: P = thing + # TODO: if isinstance(p, ConnectiveP) + return f'{cls._qualify(type(p))}{p.operator}({cls._convert_to_string(p.value)})' + + if isinstance(thing, (Barrier, Cardinality, Column, Direction, GraphSONVersion, GryoVersion, Order, Pick, Pop, + Scope, Operator, T)): + return f'{cls._qualify(type(thing))}{thing.name}' + + if isinstance(thing, (datetime.datetime, datetime.date)): + return cls._date_to_string(thing) + + # TODO: Class, UUID?, Lambda, TraversalStrategyProxy, TraversalStrategy + + raise AssertionError(f'thing is not supported!: {thing}') + + @classmethod + @abstractmethod + def _date_to_string(cls, thing: Union[datetime.datetime, datetime.date]) -> str: + raise RuntimeError('Not implemented') + + @classmethod + def _qualify(cls, t: type) -> str: + # TODO: for Neptune we would like to not qualify most things (e.g. except T, Order, Scope, all of which accept + # the unqualified so there's no point) + return '' + + @classmethod + def get_traversal_source_name(cls, t: Traversal) -> str: + if t.traversal_strategies is not None: + if t.traversal_strategies.traversal_strategies is not None: + if isinstance(t.traversal_strategies.traversal_strategies[0], RemoteStrategy): + return t.traversal_strategies.traversal_strategies[0].remote_connection.traversal_source + # TODO: more of these + raise AssertionError(f'no idea what to do with {t}') + + CHAR_MAPPINGS = dict([(ord(v), f'\\{c}') for v, c in zip('\b\n\t\f\r', 'bntfr')] + + [(ord(s), f'\\{s}') for s in '\'"\\']) + + @classmethod + def _escape_java_style_chars(cls, chars: Iterable[int]) -> Iterator[str]: + for ch in chars: + # handle unicode + if ch in cls.CHAR_MAPPINGS: + yield cls.CHAR_MAPPINGS[ch] + elif ch < 0x7f and ch >= 32: + yield chr(ch) + else: + # handle unicode and control characters + yield f'''\\u{hex(ch)[2:].rjust(4, '0')}''' + + @classmethod + def _escape_java_style(cls, value: str) -> str: + return f'''"{''.join(cls._escape_java_style_chars(map(ord, value)))}"''' + + +class ScriptTranslatorTargetJanusgraph(ScriptTranslator): + @classmethod + @overrides + def _date_to_string(cls, thing: Union[datetime.datetime, datetime.date]) -> str: + if isinstance(thing, datetime.datetime): + # datetime.datetime.isoformat(timespec='auto') does something a little surprising (that goes back into + # antiquity). If milliseconds/microseconds == 0, then it OMITS them (as if yyyy-MM-dd'T'HH:mm:ss), which + # is usually fine. Except here where we're passing it into java.text.SimpleDateFormat, which is strict. + # so timespec='microseconds' to get those every time. + return f'''new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSS").parse("{thing.isoformat(timespec='microseconds')}")''' # noqa: E501 + elif isinstance(thing, datetime.date): + # so timespec is not a thing for datetime.date though, so use the date only format produced there. + return f'''new java.text.SimpleDateFormat("yyyy-MM-dd").parse("{thing.isoformat()}")''' + else: + raise AssertionError(f'thing is not supported!: {thing}') + + +class ScriptTranslatorTargetNeptune(ScriptTranslator): + @classmethod + @overrides + def _date_to_string(cls, thing: Union[datetime.datetime, datetime.date]) -> str: + return f'datetime("{thing.isoformat()}")' diff --git a/amundsen_gremlin/test_and_development_shard.py b/amundsen_gremlin/test_and_development_shard.py new file mode 100644 index 0000000..b0ede79 --- /dev/null +++ b/amundsen_gremlin/test_and_development_shard.py @@ -0,0 +1,82 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +from threading import Lock +from typing import Optional + +from gremlin_python.process.graph_traversal import GraphTraversalSource + +# use _get_shard to retrieve this +_shard: Optional[str] +_shard_lock = Lock() +_shard_used = False + + +def _shard_default() -> Optional[str]: + if os.environ.get('CI'): + # TODO: support CI-specific env variables in config? + # BUILD_PART_ID: identifies a part-build (doesn't change when you click rebuild, but also not shared across + # builds) + build_part_id = os.environ.get('BUILD_PART_ID') + # TBD: can we easily shard on github? + if build_part_id is None: + return 'OneShardToRuleThemAll' + assert build_part_id, f'Expected BUILD_PART_ID environment variable to be set' + + # e.g. gw0 if -n or main if -n0 + xdist_worker = os.environ.get('PYTEST_XDIST_WORKER') + + if xdist_worker: + return f'{build_part_id}_{xdist_worker}' + else: + return build_part_id + elif os.environ.get('DATACENTER', 'local') == 'local': + # this replaces the NEPTUNE_URLS_BY_USER et al in Development. + user = os.environ.get('USER') + assert user is not None, f'Expected USER environment variable to be set' + + # e.g. gw0 if -n or main if -n0 + xdist_worker = os.environ.get('PYTEST_XDIST_WORKER') + + if xdist_worker: + return f'{user}_{xdist_worker}' + else: + return user + else: + return None + + +def shard_set_explicitly(shard: Optional[str]) -> None: + global _shard + with _shard_lock: + assert not _shard_used, 'can only shard_set_explicitly if it has not been used yet. (sorry)' + _shard = shard + + +def get_shard() -> Optional[str]: + global _shard, _shard_used + # lock free path first + if _shard_used: + return _shard + with _shard_lock: + _shard_used = True + return _shard + + +def _reset_for_testing_only() -> None: + global _shard, _shard_used, get_shard + _shard = _shard_default() + _shard_used = False + + +# or just the once here +_reset_for_testing_only() + + +def delete_graph_for_shard_only(g: GraphTraversalSource) -> None: + shard = get_shard() + assert shard, f'expected shard to exist! Surely you are only using this in development or test?' + # TODO: do something better than not using WellKnownProperties.TestShard here (since that makes a circular + # dependency) + g.V().has('shard', shard).drop().iterate() diff --git a/amundsen_gremlin/utils/__init__.py b/amundsen_gremlin/utils/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/amundsen_gremlin/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/amundsen_gremlin/utils/streams.py b/amundsen_gremlin/utils/streams.py new file mode 100644 index 0000000..1b42fd1 --- /dev/null +++ b/amundsen_gremlin/utils/streams.py @@ -0,0 +1,390 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import threading +from typing import ( + Any, AsyncIterator, Callable, Collection, Iterable, Iterator, List, + Optional, Tuple, TypeVar, Union +) + +from typing_extensions import Final, final + +LOGGER = logging.getLogger(__name__) + + +V = TypeVar('V') +R = TypeVar('R') + + +def one(ignored: Any) -> int: + return 1 + + +class PeekingIterator(Iterator[V]): + """ + Like Iterator, but with peek(), peek_default(), and take_peeked() + """ + def __init__(self, iterable: Iterable[V]): + self.it: Final[Iterator[V]] = iterable if isinstance(iterable, Iterator) else iter(iterable) + self.has_peeked_value = False + self.peeked_value: Optional[V] = None + # RLock could make sense, but it would be just weird for the same thread to try to peek from same blocking + # iterator + self.lock: Final[threading.Lock] = threading.Lock() + + @final + # @overrides Iterator but @overrides doesn't like + def __next__(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopIteration if there is no more values + """ + with self.lock: + value: V + if self.has_peeked_value: + value = self.peeked_value # type: ignore + self.peeked_value = None + self.has_peeked_value = False + else: + value = next(self.it) + assert not self.has_peeked_value + return value + + @final + def peek(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopIteration if there is no more values + """ + with self.lock: + if not self.has_peeked_value: + self.peeked_value = next(self.it) + self.has_peeked_value = True + assert self.has_peeked_value + return self.peeked_value # type: ignore + + @final + def peek_default(self, default: Optional[V]) -> Optional[V]: + """ + :return: the previously peeked value or the next, or default if no more values + """ + try: + return self.peek() + except StopIteration: + return default + + @final + def take_peeked(self, value: V) -> None: + with self.lock: + assert self.has_peeked_value, f'expected to find a peaked value' + assert self.peeked_value is value, f'expected the peaked value to be the same' + self.peeked_value = None + self.has_peeked_value = False + + @final + def has_more(self) -> bool: + try: + self.peek() + return True + except StopIteration: + return False + + +class PeekingAsyncIterator(AsyncIterator[V]): + """ + Like AsyncIterator, but with peek(), peek_default(), and take_peeked() + """ + def __init__(self, iterable: AsyncIterator[V]): + self.it: Final[AsyncIterator[V]] = iterable + self.has_peeked_value = False + self.peeked_value: Optional[V] = None + # RLock could make sense, but it would be just weird for the same thread to try to peek from same blocking + # iterator + self.lock: Final[threading.Lock] = threading.Lock() + + @final + # @overrides AsyncIterator but @overrides doesn't like + async def __anext__(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopAsyncIteration if there is no more values + """ + with self.lock: + value: V + if self.has_peeked_value: + value = self.peeked_value # type: ignore + self.peeked_value = None + self.has_peeked_value = False + else: + value = await self.__anext__() + assert not self.has_peeked_value + return value + + @final + async def peek(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopAsyncIteration if there is no more values + """ + with self.lock: + if not self.has_peeked_value: + self.peeked_value = await self.it.__anext__() + self.has_peeked_value = True + assert self.has_peeked_value + return self.peeked_value # type: ignore + + @final + async def peek_default(self, default: Optional[V]) -> Optional[V]: + """ + :return: the previously peeked value or the next, or default if no more values + """ + try: + return await self.peek() + except StopAsyncIteration: + return default + + @final + def take_peeked(self, value: V) -> None: + with self.lock: + assert self.has_peeked_value, f'expected to find a peaked value' + assert self.peeked_value is value, f'expected the peaked value to be the same' + self.peeked_value = None + self.has_peeked_value = False + + @final + async def has_more(self) -> bool: + try: + await self.peek() + return True + except StopAsyncIteration: + return False + + +def one_chunk(*, it: PeekingIterator[V], n: int, metric: Callable[[V], int]) -> Tuple[Iterable[V], bool]: + """ + :param it: stream of values as a PeekingIterator (or regular iterable if you are only going to take the first chunk + and don't care about the peeked value being consumed) + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk + """ + items: List[V] = [] + items_metric: int = 0 + try: + while True: + item = it.peek() + item_metric = metric(item) + # negative would be insane, let's say positive + assert item_metric > 0, \ + f'expected metric to be positive! item_metric={item_metric}, metric={metric}, item={item}' + if not items and item_metric > n: + # should we assert instead? it's probably a surprise to the caller too, and might fail for whatever + # limit they were trying to avoid, but let's give them a shot at least. + LOGGER.error(f"expected a single item's metric to be less than the chunk limit! {item_metric} > {n}, " + f"but returning to make progress") + items.append(item) + it.take_peeked(item) + items_metric += item_metric + break + elif items_metric + item_metric <= n: + items.append(item) + it.take_peeked(item) + items_metric += item_metric + if items_metric >= n: + # we're full + break + # else keep accumulating + else: + assert items_metric + item_metric > n + # we're full + break + # don't catch exception, let that be a concern for callers + except StopIteration: + pass + + has_more = it.has_more() + return tuple(items), has_more + + +def chunk(it: Union[Iterable[V], PeekingIterator[V]], n: int, metric: Callable[[V], int] = one + ) -> Iterable[Iterable[V]]: + """ + :param it: stream of values as a PeekingIterator (or regular iterable if you are only going to take the first chunk + and don't care about the peeked value being consumed) + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the Iterable (generator) of chunks + """ + if not isinstance(it, PeekingIterator): + it = PeekingIterator(it) + assert isinstance(it, PeekingIterator) + has_more: bool = True + while has_more: + items, has_more = one_chunk(it=it, n=n, metric=metric) + if items or has_more: + yield items + + +async def async_one_chunk( + it: PeekingAsyncIterator[V], n: int, metric: Callable[[V], int] = one) -> Tuple[Iterable[V], bool]: + """ + :param it: stream of values as a PeekingAsyncIterator + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk and if there are more items + """ + items: List[V] = [] + items_metric: int = 0 + if not isinstance(it, PeekingAsyncIterator): + it = PeekingAsyncIterator(it) + assert isinstance(it, PeekingAsyncIterator) + try: + while True: + item = await it.peek() + item_metric = metric(item) + # negative would be insane, let's say positive + assert item_metric > 0, \ + f'expected metric to be positive! item_metric={item_metric}, metric={metric}, item={item}' + if not items and item_metric > n: + # should we assert instead? it's probably a surprise to the caller too, and might fail for whatever + # limit they were trying to avoid, but let's give them a shot at least. + LOGGER.error(f"expected a single item's metric to be less than the chunk limit! {item_metric} > {n}, " + f"but returning to make progress") + items.append(item) + it.take_peeked(item) + items_metric += item_metric + break + elif items_metric + item_metric <= n: + items.append(item) + it.take_peeked(item) + items_metric += item_metric + if items_metric >= n: + # we're full + break + # else keep accumulating + else: + assert items_metric + item_metric > n + # we're full + break + # don't catch exception, let that be a concern for callers + except StopAsyncIteration: + pass + + has_more = await it.has_more() + return tuple(items), has_more + + +async def async_chunk(*, it: Union[AsyncIterator[V], PeekingAsyncIterator[V]], n: int, metric: Callable[[V], int] + ) -> AsyncIterator[Iterable[V]]: + """ + :param it: stream of values as a PeekingAsyncIterator + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk and if there are more items + """ + if not isinstance(it, PeekingAsyncIterator): + it = PeekingAsyncIterator(it) + assert isinstance(it, PeekingAsyncIterator) + has_more: bool = True + while has_more: + items, has_more = await async_one_chunk(it=it, n=n, metric=metric) + if items or has_more: + yield items + + +def reduce_in_chunks(*, stream: Iterable[V], n: int, initial: R, + consumer: Callable[[Iterable[V], R], R], metric: Callable[[V], int] = one) -> R: + """ + :param stream: stream of values + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :param initial: the initial state + :param consumer: the callable to handle the chunk + :returns the final state + """ + if n > 0: + it = PeekingIterator(stream) + state = initial + for items in chunk(it=it, n=n, metric=metric): + state = consumer(items, state) + return state + else: + return consumer(stream, initial) + + +async def async_reduce_in_chunks(*, stream: AsyncIterator[V], n: int, metric: Callable[[V], int], initial: R, + consumer: Callable[[Iterable[V], R], R]) -> R: + """ + :param stream: + :param n: if n is 0, process whole stream as one chunk + :param metric: the callable that returns positive metric for a value + :param initial: the initial state + :param consumer: the callable to handle the chunk + :returns the final state + """ + if n > 0: + it = PeekingAsyncIterator(stream) + state = initial + async for items in async_chunk(it=it, n=n, metric=metric): + state = consumer(items, state) + return state + else: + return consumer(tuple([_ async for _ in stream]), initial) + + +def consume_in_chunks(*, stream: Iterable[V], n: int, consumer: Callable[[Iterable[V]], None], + metric: Callable[[V], int] = one) -> int: + """ + :param stream: + :param n: consume stream until n is reached if n is 0, process whole stream as one chunk + :param metric: the callable that returns positive metric for a value + :param consumer: the callable to handle the chunk + :return: + """ + _actual_state: int = 0 + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state += len(things) + consumer(things) + reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return _actual_state + + +async def async_consume_in_chunks(*, stream: AsyncIterator[V], n: int, consumer: Callable[[Iterable[V]], None], + metric: Callable[[V], int] = one) -> int: + _actual_state: int = 0 + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state += len(things) + consumer(things) + await async_reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return _actual_state + + +def consume_in_chunks_with_state(*, stream: Iterable[V], n: int, consumer: Callable[[Iterable[V]], None], + state: Callable[[V], R], metric: Callable[[V], int] = one) -> Iterable[R]: + _actual_state: List[R] = list() + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state.extend(map(state, things)) + consumer(things) + + reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return tuple(_actual_state) + + +def _assure_collection(iterable: Iterable[V]) -> Collection[V]: + if isinstance(iterable, Collection): + return iterable + else: + return tuple(iterable) diff --git a/for_requests/__init__.py b/for_requests/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/for_requests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/for_requests/assume_role_aws4auth.py b/for_requests/assume_role_aws4auth.py new file mode 100644 index 0000000..94a056f --- /dev/null +++ b/for_requests/assume_role_aws4auth.py @@ -0,0 +1,57 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +From https://gist.github.com/zen4ever/5103a4091de28f2f53be2ab8de2ae905 +""" +from typing import Any + +import botocore +from requests_aws4auth import AWS4Auth + + +class AssumeRoleAWS4Auth(AWS4Auth): + """ + Subclass of AWS4Auth which accepts botocore credentials as its first argument + Which allows us to handle assumed role sessions transparently + """ + def __init__(self, credentials: botocore.credentials.Credentials, region: str, service: str, **kwargs: Any): + self.credentials = credentials + + frozen_credentials = self.get_credentials() + + super(AssumeRoleAWS4Auth, self).__init__( + frozen_credentials.access_key, + frozen_credentials.secret_key, + region, + service, + session_token=frozen_credentials.token, + **kwargs + ) + + def get_credentials(self) -> botocore.credentials.Credentials: + if hasattr(self.credentials, 'get_frozen_credentials'): + return self.credentials.get_frozen_credentials() + return self.credentials + + def __call__(self, req: Any) -> Any: + if hasattr(self.credentials, 'refresh_needed') and self.credentials.refresh_needed(): + frozen_credentials = self.get_credentials() + + self.access_id = frozen_credentials.access_key + self.session_token = frozen_credentials.token + self.regenerate_signing_key(secret_key=frozen_credentials.secret_key) + return super(AssumeRoleAWS4Auth, self).__call__(req) + + def handle_date_mismatch(self, req: Any) -> None: + req_datetime = self.get_request_date(req) + new_key_date = req_datetime.strftime('%Y%m%d') + + frozen_credentials = self.get_credentials() + + self.access_id = frozen_credentials.access_key + self.session_token = frozen_credentials.token + self.regenerate_signing_key( + date=new_key_date, + secret_key=frozen_credentials.secret_key + ) diff --git a/for_requests/aws4auth_compatible.py b/for_requests/aws4auth_compatible.py new file mode 100644 index 0000000..816806b --- /dev/null +++ b/for_requests/aws4auth_compatible.py @@ -0,0 +1,24 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union +from urllib.parse import SplitResult, urlsplit + + +def to_aws4_request_compatible_host(url: Union[str, SplitResult]) -> str: + """ + Why do this? well, requests-aws4auth quietly pretends a Host header exists by parsing the request URL. Why? In + the python stack, requests defers adding of the host header to a lower layer library (http.client) which comes + after the auth objects like AWS4Auth get run ...so it guesses. However, its guess is just the host part, which + works if you're using https and port 443 or http and port 80, but not so much if you're using https and port + 8182 for example. + """ + if isinstance(url, str): + result = urlsplit(url) + elif isinstance(url, SplitResult): + result = url + # we have to canonicalize the URL as the server would (so omit if https and port 443 or http and port 80) + if (result.scheme == 'https' and result.port == 443) or (result.scheme == 'http' and result.port == 80): + return result.netloc.split(':')[0] + else: + return result.netloc diff --git a/for_requests/host_header_ssl.py b/for_requests/host_header_ssl.py new file mode 100644 index 0000000..1a2a625 --- /dev/null +++ b/for_requests/host_header_ssl.py @@ -0,0 +1,54 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +like requests_toolbelt.adapters.host_header_ssl, but with our fix + + https://github.com/requests/toolbelt/pull/289 + +..for: + + https://github.com/requests/toolbelt/issues/288 +""" + +from typing import Any + +import requests + + +class HostHeaderSSLAdapter(requests.adapters.HTTPAdapter): + """ + A HTTPS Adapter for Python Requests that sets the hostname for certificate + verification based on the Host header. + + This allows requesting the IP address directly via HTTPS without getting + a "hostname doesn't match" exception. + + Example usage: + + >>> s.mount('https://', HostHeaderSSLAdapter()) + >>> s.get("https://93.184.216.34", headers={"Host": "example.org"}) + + """ + + def send(self, request: requests.PreparedRequest, *args: Any, **kwargs: Any) -> requests.Response: + # HTTP headers are case-insensitive (RFC 7230) + host_header = None + for header in request.headers: + if header.lower() == "host": + host_header = request.headers[header] + break + + connection_pool_kwargs = self.poolmanager.connection_pool_kw + + if host_header: + # host header can include port, but we should not include it in the + # assert_hostname + host_header = host_header.split(':')[0] + + connection_pool_kwargs["assert_hostname"] = host_header + elif "assert_hostname" in connection_pool_kwargs: + # an assert_hostname from a previous request may have been left + connection_pool_kwargs.pop("assert_hostname", None) + + return super(HostHeaderSSLAdapter, self).send(request, *args, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4efb812 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +amundsen-common==0.5.1 +boto3==1.9.205 +flake8==3.7.8 +gremlinpython>=3.4.3 +isort==4.3.21 +marshmallow==2.20.5 +marshmallow-annotations==2.4.0 +mypy==0.761 +overrides==2.8.0 +pytest==5.0.1 +pytest-cov==2.5.1 +pytest-mock==1.1 +PyYAML==5.1.2 +pytz +requests==2.23.0 +requests-aws4auth==0.9 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..fd4fe15 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,48 @@ +[flake8] +format = pylint +exclude = .svc,CVS,.bzr,.hg,.git,__pycache__,venv,build/*,amazon-neptune-tools +max-complexity = 11 +max-line-length = 120 +# TODO: wrap these lines +per-file-ignores=tests/unit/neptune_bulk_loader/*.py:E501 + +# flake8-tidy-imports rules +banned-modules = + dateutil.parser = Use `ciso8601` instead + flask.ext.restful = Use `flask_restful` + flask.ext.script = Use `flask_script` + flask_restful.reqparse = Use `marshmallow` for request/response validation + haversine = Use `from fast_distance import haversine` + py.test = Use `pytest` + python-s3file = Use `boto` + +[pycodestyle] +max-line-length = 120 + +[isort] +multi_line_output=5 +skip=venv,amazon-neptune-tools + +[tool:pytest] +# NB: coverage is 80 with RT tests +addopts = --cov=amundsen_gremlin --cov-fail-under=60 --cov-report=term-missing:skip-covered --cov-report=xml --cov-report=html -vvv + +[coverage:run] +branch = True + +[coverage:xml] +output = build/coverage.xml + +[coverage:html] +directory = build/coverage_html + +[mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +ignore_missing_imports = True +no_implicit_optional = true + +[mypy-tests.*] +disallow_untyped_defs = false diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2fba70f --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from setuptools import find_packages, setup + +# gross, but git submodule is the simplest way to install these (there's no egg), so add them to our structure +neptune_python_utils_package_names = find_packages('amazon-neptune-tools/neptune-python-utils') +neptune_python_utils_package_directories = dict((name, f'amazon-neptune-tools/neptune-python-utils/{name}') + for name in neptune_python_utils_package_names) + +setup( + name='amundsen-gremlin', + version='0.0.1', + description='Gremlin code library for Amundsen', + long_description=open('README.md').read(), + url='https://github.com/amundsen-io/amundsengremlin', + maintainer='Linux Foundation', + maintainer_email='amundsen-dev@lyft.com', + packages=find_packages(exclude=['tests*']) + neptune_python_utils_package_names, + package_dir=neptune_python_utils_package_directories, + zip_safe=False, + dependency_links=[], + include_package_data=True, + install_requires=[], + python_requires=">=3.7", + package_data={'amundsen_gremlin': ['py.typed']}, +) diff --git a/ssl_override_server_hostname/__init__.py b/ssl_override_server_hostname/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/ssl_override_server_hostname/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/ssl_override_server_hostname/ssl_context.py b/ssl_override_server_hostname/ssl_context.py new file mode 100644 index 0000000..4711e1a --- /dev/null +++ b/ssl_override_server_hostname/ssl_context.py @@ -0,0 +1,21 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Credit to https://github.com/dwfreed in https://github.com/requests/toolbelt/issues/159 +""" +import ssl +from typing import Any + + +class OverrideServerHostnameSSLContext(ssl.SSLContext): + def __init__(self, *args: Any, server_hostname: str, **kwargs: Any) -> None: + super(OverrideServerHostnameSSLContext, self).__init__(*args, **kwargs) + self.override_server_hostname = server_hostname + + def change_server_hostname(self, server_hostname: str) -> None: + self.override_server_hostname = server_hostname + + def wrap_socket(self, *args: Any, **kwargs: Any) -> Any: + kwargs['server_hostname'] = self.override_server_hostname + return super(OverrideServerHostnameSSLContext, self).wrap_socket(*args, **kwargs) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d36d23b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +# This file configures the roundtrip pytest option and skips roundtrip tests without it + + +def pytest_addoption(parser): + parser.addoption( + "--roundtrip", action="store_true", default=False, help="Run roundtrip tests. These tests are slow and require \ + a configured neptune instance." + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "roundtrip: mark test as roundtrip") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--roundtrip"): + # --roundtrip given in cli: do not skip roundtrip tests + return + skip_roundtrip = pytest.mark.skip(reason="need --roundtrip option to run") + for item in items: + if "roundtrip" in item.keywords: + item.add_marker(skip_roundtrip) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/neptune_bulk_loader/__init__.py b/tests/unit/neptune_bulk_loader/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/tests/unit/neptune_bulk_loader/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/neptune_bulk_loader/test_api.py b/tests/unit/neptune_bulk_loader/test_api.py new file mode 100644 index 0000000..c2f9c40 --- /dev/null +++ b/tests/unit/neptune_bulk_loader/test_api.py @@ -0,0 +1,80 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import unittest + +import pytest +from flask import Flask + +from amundsen_gremlin.config import LocalGremlinConfig +from amundsen_gremlin.gremlin_model import ( + EdgeTypes, MagicProperties, VertexTypes +) +from amundsen_gremlin.neptune_bulk_loader.api import ( + NeptuneBulkLoaderApi, + get_neptune_graph_traversal_source_factory_from_config +) +from amundsen_gremlin.neptune_bulk_loader.gremlin_model_converter import ( + _GetGraph, new_entities, new_existing +) +from amundsen_gremlin.test_and_development_shard import ( + delete_graph_for_shard_only +) + + +@pytest.mark.roundtrip +class TestBulkLoader(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + self.app = Flask(__name__) + self.app_context = self.app.app_context() + self.app.config.from_object(LocalGremlinConfig()) + self.app_context.push() + self.bulk_loader = NeptuneBulkLoaderApi.create_from_config(self.app.config) + self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory_from_config(self.app.config) + self._drop_almost_everything() + + def _drop_almost_everything(self) -> None: + delete_graph_for_shard_only(self.neptune_graph_traversal_source_factory()) + + def tearDown(self) -> None: + delete_graph_for_shard_only(self.neptune_graph_traversal_source_factory()) + self.app_context.pop() + + def test_failed_load_logs(self) -> None: + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50) + entities = new_entities() + existing = new_existing() + _GetGraph._create(VertexTypes.Database, entities, existing, key='foo', name='foo') + _GetGraph._create(EdgeTypes.Cluster, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: VertexTypes.Database.value.id(key='foo'), + MagicProperties.TO.value.name: VertexTypes.Database.value.id(key='bar'), + }) + with self.assertLogs(logger='amundsen_gremlin.neptune_bulk_loader.api', level='WARNING') as logs: + status = self.bulk_loader.bulk_load_entities(entities=entities) + failed = [load_id for load_id, overall_status in status.items() + if overall_status['overallStatus']['status'] != 'LOAD_COMPLETED'] + self.assertEqual(2, len(status), f'expected 2 status = {status}') + self.assertEqual(1, len(failed), f'expected 1 failed in status = {status}') + self.assertEqual(1, len(logs.output), f'expected 1 output: {logs.output}') + self.assertTrue(all(line.startswith('WARNING:amundsen_gremlin.neptune_bulk_loader.api:some loads failed:') + for line in logs.output), + f'expected output to start with some loads failed: {logs.output}') + + def test_failed_load_raises(self) -> None: + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50) + entities = new_entities() + existing = new_existing() + _GetGraph._create(VertexTypes.Database, entities, existing, key='foo', name='foo') + _GetGraph._create(EdgeTypes.Cluster, entities, existing, created=created_at, **{ + MagicProperties.FROM.value.name: VertexTypes.Database.value.id(key='foo'), + MagicProperties.TO.value.name: VertexTypes.Database.value.id(key='bar'), + }) + with self.assertLogs(logger='amundsen_gremlin.neptune_bulk_loader.api', level='WARNING') as logs: + with self.assertRaisesRegex(AssertionError, 'some loads failed'): + self.bulk_loader.bulk_load_entities(entities=entities, raise_if_failed=True) + self.assertEqual(1, len(logs.output), f'expected 1 output: {logs.output}') + self.assertTrue(all(line.startswith('WARNING:amundsen_gremlin.neptune_bulk_loader.api:some loads failed:') + for line in logs.output), + f'expected output to start with some loads failed: {logs.output}') diff --git a/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py b/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py new file mode 100644 index 0000000..10d3a01 --- /dev/null +++ b/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py @@ -0,0 +1,439 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import unittest +from operator import attrgetter +from typing import ( + Any, Callable, Dict, Hashable, Iterable, Mapping, Optional, Tuple, TypeVar, + Union +) +from unittest import mock + +import pytest +from amundsen_common.models.table import ( + Application, Column, ProgrammaticDescription, Table, Tag +) +from amundsen_common.models.user import User +from amundsen_common.tests.fixtures import Fixtures +from flask import Flask + +from amundsen_gremlin.config import LocalGremlinConfig +from amundsen_gremlin.gremlin_model import ( + EdgeType, EdgeTypes, MagicProperties, VertexType, VertexTypes +) +from amundsen_gremlin.neptune_bulk_loader.api import ( + GraphEntity, GraphEntityType, NeptuneBulkLoaderApi, + get_neptune_graph_traversal_source_factory_from_config +) +from amundsen_gremlin.neptune_bulk_loader.gremlin_model_converter import ( + ENTITIES, EXISTING, GetGraph, _FetchExisting, + possible_application_names_application_key +) +from amundsen_gremlin.test_and_development_shard import ( + delete_graph_for_shard_only, get_shard +) + +# TODO: add fetch test existing tests + + +def _create_one_expected(_type: Union[VertexType, EdgeType], **properties: Any) -> Tuple[str, Mapping[str, Any]]: + if isinstance(_type, EdgeType): + for property in (MagicProperties.FROM.value, MagicProperties.TO.value): + value = properties.get(property.name) + # as a convenience construct the id from the properties (since it will usually have all kinds of test + # shards) + if isinstance(value, Mapping) and MagicProperties.LABEL.value.name in value: + id = VertexTypes.by_label()[value[MagicProperties.LABEL.value.name]].value.id(**value) + properties[property.name] = id + property.type.value.is_allowed(properties.get(property.name)) + + entity = _type.create(**properties) + return entity[MagicProperties.ID.value.name], entity + + +def _create_expected(expected: Mapping[Union[VertexType, EdgeType], Iterable[Mapping[str, Any]]]) -> ENTITIES: + return {_type: dict(_create_one_expected(_type, **properties) for properties in entities) # type: ignore + for _type, entities in expected.items()} + + +@mock.patch('amundsen_gremlin.neptune_bulk_loader.gremlin_model_converter._FetchExisting') +class TestGetGraph(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + + def test_table_entities(self, fetch_existing: Any) -> None: + table1 = Table( + database='Snowflake', cluster='production', schema='esikmo', name='igloo', + description='''it's cool''', + programmatic_descriptions=[ProgrammaticDescription(text='super cool', source='other')], + table_writer=Application(id='eskimo'), + columns=[ + Column(name='block1', col_type='ice', sort_order=1, description='won'), + Column(name='block2', col_type='ice', sort_order=2), + ], + tags=[Tag(tag_name='Kewl', tag_type='default')], + ) + table2 = Table(database=table1.database, cluster=table1.cluster, schema=table1.schema, name='floes', columns=[], + table_writer=table1.table_writer) + table_data = [table1, table2] + + def side_effect(*args: Any, existing: EXISTING, **kwargs: Any) -> None: + _FetchExisting._fake_into_existing_vertexes_for_testing( + _existing=existing, _type=VertexTypes.Application, id='eskimo', key='eskimo') + + fetch_existing.table_entities.side_effect = side_effect + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + expected = _create_expected({ + VertexTypes.Database.value: [ + {'key': 'database://Snowflake', 'name': 'Snowflake'}], + EdgeTypes.Cluster.value: [ + {'created': created_at, '~from': {'~label': 'Database', 'key': 'database://Snowflake'}, + '~to': {'~label': 'Cluster', 'key': 'Snowflake://production'}}], + VertexTypes.Cluster.value: [ + {'key': 'Snowflake://production', 'name': 'production'}], + EdgeTypes.Schema.value: [ + {'created': created_at, '~from': {'~label': 'Cluster', 'key': 'Snowflake://production'}, + '~to': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}}], + VertexTypes.Schema.value: [ + {'key': 'Snowflake://production.esikmo', 'name': 'esikmo'}], + EdgeTypes.Table.value: [ + {'created': created_at, '~from': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}}, + {'created': created_at, '~from': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/floes'}}], + VertexTypes.Table.value: [ + {'is_view': False, 'key': 'Snowflake://production.esikmo/igloo', 'name': 'igloo'}, + {'is_view': False, 'key': 'Snowflake://production.esikmo/floes', 'name': 'floes'}], + EdgeTypes.Tag.value: [ + {'created': created_at, '~from': {'~label': 'Tag', 'key': 'Kewl'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}}], + VertexTypes.Tag.value: [ + {'key': 'Kewl', 'tag_name': 'Kewl', 'tag_type': 'default'}], + EdgeTypes.LastUpdatedAt.value: [ + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/igloo'}}, + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/floes'}, + '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/floes'}}], + VertexTypes.Updatedtimestamp.value: [ + {'key': 'Snowflake://production.esikmo/igloo', 'latest_timestamp': created_at}, + {'key': 'Snowflake://production.esikmo/floes', 'latest_timestamp': created_at}], + EdgeTypes.Column.value: [ + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Column', 'key': 'Snowflake://production.esikmo/igloo/block1'}}, + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Column', 'key': 'Snowflake://production.esikmo/igloo/block2'}}], + VertexTypes.Column.value: [ + {'col_type': 'ice', 'key': 'Snowflake://production.esikmo/igloo/block1', + 'name': 'block1', 'sort_order': 1}, + {'col_type': 'ice', 'key': 'Snowflake://production.esikmo/igloo/block2', + 'name': 'block2', 'sort_order': 2}], + EdgeTypes.Description.value: [ + {'created': created_at, '~from': f'{get_shard()}:Column:Snowflake://production.esikmo/igloo/block1', + '~to': f'{get_shard()}:Description:Snowflake://production.esikmo/igloo/block1/user/_description'}, + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Description', 'key': 'Snowflake://production.esikmo/igloo/other/_description'}}, + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Description', 'key': 'Snowflake://production.esikmo/igloo/user/_description'}}], + VertexTypes.Description.value: [ + {'description': 'won', 'key': 'Snowflake://production.esikmo/igloo/block1/user/_description', + 'source': 'user'}, + {'description': 'super ' 'cool', 'key': 'Snowflake://production.esikmo/igloo/other/_description', + 'source': 'other'}, + {'description': "it's " 'cool', 'key': 'Snowflake://production.esikmo/igloo/user/_description', + 'source': 'user'}], + EdgeTypes.Generates.value: [ + {'created': created_at, '~from': {'~label': 'Application', 'key': 'eskimo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}}, + {'created': created_at, '~from': {'~label': 'Application', 'key': 'eskimo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/floes'}}], + }) + actual = GetGraph.table_entities( + table_data=table_data, created_at=created_at, g=None) + # make the diff a little better + self.assertDictEqual(_transform_dict(expected, transform_key=attrgetter('label')), + _transform_dict(actual, transform_key=attrgetter('label'))) + + def test_table_entities_app_prefix(self, fetch_existing: Any) -> None: + table_data = [Table(database='Snowflake', cluster='production', schema='esikmo', name='igloo', columns=[], + table_writer=Application(id='eskimo'))] + + def side_effect(*args: Any, existing: EXISTING, **kwargs: Any) -> None: + _FetchExisting._fake_into_existing_vertexes_for_testing( + _existing=existing, _type=VertexTypes.Application, id='app-eskimo', key='app-eskimo') + + fetch_existing.table_entities.side_effect = side_effect + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + expected = _create_expected({ + EdgeTypes.Generates.value: [ + {'created': created_at, '~from': {'~label': 'Application', 'key': 'app-eskimo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}}, + ], + }) + actual = GetGraph.table_entities(table_data=table_data, created_at=created_at, g=None) + # make the diff a little better, and only look at the expected ones + self.assertDictEqual(dict((k.label, v) for k, v in expected.items()), + dict((k.label, actual[k]) for k, v in expected.items())) + + def test_table_entities_app_owner(self, fetch_existing: Any) -> None: + table_data = [Table( + database='Snowflake', cluster='production', schema='esikmo', name='igloo', columns=[], + table_writer=Application(id='eskimo'), + )] + + def side_effect(*args: Any, existing: EXISTING, **kwargs: Any) -> None: + _FetchExisting._fake_into_existing_vertexes_for_testing( + _existing=existing, _type=VertexTypes.User, user_id='eskimo', key='eskimo') + + fetch_existing.table_entities.side_effect = side_effect + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + expected = _create_expected({ + EdgeTypes.Owner.value: [ + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'User', 'key': 'eskimo'}}], + }) + actual = GetGraph.table_entities( + table_data=table_data, created_at=created_at, g=None) + # make the diff a little better, and only look at the expected ones + self.assertDictEqual(dict((k.label, v) for k, v in expected.items()), + dict((k.label, actual[k]) for k, v in expected.items())) + + def test_user_entities(self, fetch_existing: Any) -> None: + user_data = [ + User(email='test1@test.com', user_id='test1', first_name="first_name", last_name="last_name", + full_name="full_name", display_name="display_name", is_active=True, + github_username="github_username", team_name="team_name", slack_id="slack_id", + employee_type="employee_type", manager_fullname="manager_fullname", manager_email="manager_email", + manager_id="manager_id", role_name="role_name", profile_url="profile_url")] + expected = _create_expected({ + VertexTypes.User.value: [ + {'display_name': 'display_name', 'email': 'test1@test.com', 'employee_type': 'employee_type', + 'first_name': 'first_name', 'full_name': 'full_name', 'github_username': 'github_username', + 'is_active': True, 'key': 'test1', 'last_name': 'last_name', 'manager_email': 'manager_email', + 'manager_fullname': 'manager_fullname', 'manager_id': 'manager_id', 'profile_url': 'profile_url', + 'role_name': 'role_name', 'slack_id': 'slack_id', 'team_name': 'team_name', 'user_id': 'test1'}, + ], + }) + + actual = GetGraph.user_entities(user_data=user_data, g=None) + # make the diff a little better + self.assertDictEqual(dict((k.label, v) for k, v in expected.items()), + dict((k.label, v) for k, v in actual.items())) + + def test_app_entities(self, fetch_existing: Any) -> None: + app_data = [Application(application_url="wais://", description="description", id="college", name="essay")] + expected = _create_expected({ + VertexTypes.Application.value: [ + {'application_url': 'wais://', 'description': 'description', 'id': 'college', 'key': 'college', + 'name': 'essay'}, + ], + }) + actual = GetGraph.app_entities(app_data=app_data, g=None) + # make the diff a little better + self.assertDictEqual(dict((k.label, v) for k, v in expected.items()), + dict((k.label, v) for k, v in actual.items())) + + def test_duplicates_ok(self, fetch_existing: Any) -> None: + table_data = [ + Table(database='Snowflake', cluster='production', schema='esikmo', name='igloo', columns=[]), + Table(database='Snowflake', cluster='production', schema='esikmo', name='electric-bugaloo', columns=[])] + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + expected = _create_expected({ + # only one of these, from here: + VertexTypes.Database.value: [ + {'key': 'database://Snowflake', 'name': 'Snowflake'}, + ], + EdgeTypes.Cluster.value: [ + {'created': created_at, '~from': {'~label': 'Database', 'key': 'database://Snowflake'}, + '~to': {'~label': 'Cluster', 'key': 'Snowflake://production'}}, + ], + VertexTypes.Cluster.value: [ + {'key': 'Snowflake://production', 'name': 'production'}, + ], + EdgeTypes.Schema.value: [ + {'created': created_at, '~from': {'~label': 'Cluster', 'key': 'Snowflake://production'}, + '~to': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}}, + ], + VertexTypes.Schema.value: [ + {'key': 'Snowflake://production.esikmo', 'name': 'esikmo'}, + ], + # ...to here. + EdgeTypes.Table.value: [ + {'created': created_at, '~from': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/electric-bugaloo'}}, + {'created': created_at, '~from': {'~label': 'Schema', 'key': 'Snowflake://production.esikmo'}, + '~to': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}}, + ], + VertexTypes.Table.value: [ + {'is_view': False, 'key': 'Snowflake://production.esikmo/electric-bugaloo', 'name': 'electric-bugaloo'}, + {'is_view': False, 'key': 'Snowflake://production.esikmo/igloo', 'name': 'igloo'}, + ], + EdgeTypes.LastUpdatedAt.value: [ + {'created': created_at, + '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/electric-bugaloo'}, + '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/electric-bugaloo'}}, + {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/igloo'}, + '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/igloo'}}, + ], + VertexTypes.Updatedtimestamp.value: [ + {'key': 'Snowflake://production.esikmo/electric-bugaloo', 'latest_timestamp': created_at}, + {'key': 'Snowflake://production.esikmo/igloo', 'latest_timestamp': created_at}, + ], + }) + + actual = GetGraph.table_entities(table_data=table_data, created_at=created_at, g=None) + # make the diff a little better + self.assertDictEqual(dict((k.label, v) for k, v in expected.items()), + dict((k.label, v) for k, v in actual.items())) + + def test_duplicates_explode(self, fetch_existing: Any) -> None: + user_data = [Fixtures.next_user(user_id='u'), Fixtures.next_user(user_id='u'), Fixtures.next_user(user_id='u')] + # with self.assertRaisesRegex(AssertionError, 'we already have a .*id=User:u that is different: '): + with self.assertLogs('amundsen_gremlin.neptune_bulk_loader.gremlin_model_converter', level='INFO') as cm: + GetGraph.user_entities(user_data=user_data, g=None) + self.assertTrue( + len(cm.output) == 2 + and all(f'we already have a type: User, id={get_shard()}:User:u that is different' in line for line in cm.output), + f'expected message in {cm.output}') + + +@pytest.mark.roundtrip +class TestGetGraphRoundTrip(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + self.app = Flask(__name__) + self.app_context = self.app.app_context() + self.app.config.from_object(LocalGremlinConfig()) + self.app_context.push() + self.bulk_loader = NeptuneBulkLoaderApi.create_from_config(self.app.config) + self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory_from_config(self.app.config) + self._drop_almost_everything() + + def _drop_almost_everything(self) -> None: + delete_graph_for_shard_only(self.neptune_graph_traversal_source_factory()) + + def tearDown(self) -> None: + delete_graph_for_shard_only(self.neptune_graph_traversal_source_factory()) + self.app_context.pop() + + def _bulk_load_entities_successfully( + self, *, entities: Mapping[GraphEntityType, Mapping[str, GraphEntity]], **kwargs: Any) -> None: + self.bulk_loader.bulk_load_entities( + entities=entities, raise_if_failed=True, object_prefix=f'{{now}}/{get_shard()}', **kwargs) + + def test_table_entities(self) -> None: + app_data = [Application(id='eskimo')] + table_data = [Table( + database='Snowflake', cluster='production', schema='esikmo', name='igloo', + description='''it's cool''', + programmatic_descriptions=[ProgrammaticDescription(text='super cool', source='other')], + table_writer=Application(id='eskimo'), + columns=[ + Column(name='block1', col_type='ice', sort_order=1, description='won'), + Column(name='block2', col_type='ice', sort_order=2), + ], + tags=[Tag(tag_name='Kewl', tag_type='default')], + )] + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50) # omit the precision here, we're roundtripping + entities1 = GetGraph(created_at=created_at, g=self.neptune_graph_traversal_source_factory()).\ + add_app_entities(app_data).add_table_entities(table_data).complete() + self._bulk_load_entities_successfully(entities=entities1) + + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + datetime.timedelta(seconds=10, days=1) + entities2 = GetGraph(created_at=created_at, g=self.neptune_graph_traversal_source_factory()).\ + add_app_entities(app_data).add_table_entities(table_data).complete() + + # Updatedtimestamp vertex should change + self.assertDictEqual(dict((k.label, v) for k, v in entities1.items() if k.label != 'Updatedtimestamp'), + dict((k.label, v) for k, v in entities2.items() if k.label != 'Updatedtimestamp')) + +# TODO: redo this test + def test_expire_others(self) -> None: + pass + + +class TestGetGraphMisc(unittest.TestCase): + def test_possible_application_names_application_key(self) -> None: + self.assertSetEqual({'app-foo', 'foo'}, set(possible_application_names_application_key('foo'))) + self.assertSetEqual({'app-foo', 'foo'}, set(possible_application_names_application_key('app-foo'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-devel', 'foo-devel'}, + set(possible_application_names_application_key('foo-devel'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-development', 'foo-development'}, + set(possible_application_names_application_key('foo-development'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-stage', 'foo-stage'}, + set(possible_application_names_application_key('foo-stage'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-staging', 'foo-staging'}, + set(possible_application_names_application_key('foo-staging'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-prod', 'foo-prod'}, + set(possible_application_names_application_key('foo-prod'))) + self.assertSetEqual({'app-foo', 'foo', 'app-foo-production', 'foo-production'}, + set(possible_application_names_application_key('foo-production'))) + + +K = TypeVar('K', bound=Hashable) +V = TypeVar('V') +K2 = TypeVar('K2', bound=Hashable) +V2 = TypeVar('V2') + + +def _transform_dict( # noqa: C901 + mapping: Mapping[K, V], *, if_key: Optional[Callable[[K], bool]] = None, + if_value: Optional[Callable[[V], bool]] = None, if_item: Optional[Callable[[K, V], bool]] = None, + transform_key: Optional[Callable[[K], K2]] = None, transform_value: Optional[Callable[[V], V2]] = None, + transform_item: Optional[Callable[[K, V], Tuple[K2, V2]]] = None) -> Union[Dict[K, V], Dict[K2, V], Dict[K, V2], Dict[K2, V2]]: + assert len([c for c in (if_key, if_value, if_item) if c is not None]) <= 1, \ + f'expected exactly at most one of if_key, if_value, or if_item' + assert len([c for c in (transform_key, transform_value, transform_item) if c is not None]) <= 1, \ + f'expected exactly at most one of transform_key, transform_value, or transform_item' + + # I couldn't make mypy like the overloading + if transform_key is not None: + if if_key is not None: + return dict([(transform_key(k), v) for k, v in mapping.items() if if_key(k)]) + elif if_value is not None: + return dict([(transform_key(k), v) for k, v in mapping.items() if if_value(v)]) + elif if_item is not None: + return dict([(transform_key(k), v) for k, v in mapping.items() if if_item(k, v)]) + else: + return dict([(transform_key(k), v) for k, v in mapping.items()]) + elif transform_value is not None: + if if_key is not None: + return dict([(k, transform_value(v)) for k, v in mapping.items() if if_key(k)]) + elif if_value is not None: + return dict([(k, transform_value(v)) for k, v in mapping.items() if if_value(v)]) + elif if_item is not None: + return dict([(k, transform_value(v)) for k, v in mapping.items() if if_item(k, v)]) + else: + return dict([(k, transform_value(v)) for k, v in mapping.items()]) + elif transform_item is not None: + if if_key is not None: + return dict([transform_item(k, v) for k, v in mapping.items() if if_key(k)]) + elif if_value is not None: + return dict([transform_item(k, v) for k, v in mapping.items() if if_value(v)]) + elif if_item is not None: + return dict([transform_item(k, v) for k, v in mapping.items() if if_item(k, v)]) + else: + return dict([transform_item(k, v) for k, v in mapping.items()]) + else: + if if_key is not None: + return dict([(k, v) for k, v in mapping.items() if if_key(k)]) + elif if_value is not None: + return dict([(k, v) for k, v in mapping.items() if if_value(v)]) + elif if_item is not None: + return dict([(k, v) for k, v in mapping.items() if if_item(k, v)]) + else: + return dict([(k, v) for k, v in mapping.items()]) + + +VERTEX_OR_EDGE_TYPE = TypeVar('VERTEX_OR_EDGE_TYPE', bound=Union[VertexType, EdgeType]) + + +def _label_not_in(*labels: str) -> Callable[[VERTEX_OR_EDGE_TYPE], bool]: + def not_in(_type: VERTEX_OR_EDGE_TYPE) -> bool: + return _type is not None and _type.label not in labels + return not_in diff --git a/tests/unit/test_gremlin_model.py b/tests/unit/test_gremlin_model.py new file mode 100644 index 0000000..ea5524f --- /dev/null +++ b/tests/unit/test_gremlin_model.py @@ -0,0 +1,182 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import unittest + +import pytz +from gremlin_python.process.traversal import Cardinality + +from amundsen_gremlin.gremlin_model import ( + EdgeType, EdgeTypes, GremlinCardinality, GremlinType, MagicProperties, + Property, VertexType, VertexTypes, WellKnownProperties +) +from amundsen_gremlin.test_and_development_shard import get_shard + + +class TestGremlinEnums(unittest.TestCase): + def test_enum_unique_labels(self) -> None: + self.assertIsInstance(VertexTypes.by_label(), dict) + self.assertIsInstance(EdgeTypes.by_label(), dict) + + def test_cardinality(self) -> None: + self.assertEqual(Cardinality.set_, GremlinCardinality.set.gremlin_python_cardinality()) + + +class TestGremlinTyper(unittest.TestCase): + def test_boolean_type(self) -> None: + self.assertEqual('True', GremlinType.Boolean.value.format(True)) + self.assertEqual('False', GremlinType.Boolean.value.format(False)) + with self.assertRaisesRegex(AssertionError, 'expected bool'): + GremlinType.Boolean.value.is_allowed('hi') + with self.assertRaisesRegex(AssertionError, 'expected bool'): + GremlinType.Boolean.value.is_allowed('True') + + def test_byte_type(self) -> None: + a_byte = 2**7 - 1 + self.assertEqual('127', GremlinType.Byte.value.format(a_byte)) + GremlinType.Byte.value.is_allowed(a_byte) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]7, 2[*][*]7[)]'): + GremlinType.Byte.value.is_allowed('hi') + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]7, 2[*][*]7[)]'): + GremlinType.Byte.value.is_allowed(2**7) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]7, 2[*][*]7[)]'): + GremlinType.Byte.value.is_allowed(-(2**7+1)) + + def test_short_type(self) -> None: + a_short = 2**7 + 1 + self.assertEqual('129', GremlinType.Short.value.format(a_short)) + GremlinType.Short.value.is_allowed(a_short) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]15, 2[*][*]15[)]'): + GremlinType.Short.value.is_allowed('hi') + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]15, 2[*][*]15[)]'): + GremlinType.Short.value.is_allowed(2**15) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]15, 2[*][*]15[)]'): + GremlinType.Short.value.is_allowed(-(2**15+1)) + + def test_int_type(self) -> None: + a_int = 2**15 + 1 + self.assertEqual('32769', GremlinType.Int.value.format(a_int)) + GremlinType.Int.value.is_allowed(a_int) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]31, 2[*][*]31[)]'): + GremlinType.Int.value.is_allowed('hi') + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]31, 2[*][*]31[)]'): + GremlinType.Int.value.is_allowed(2**31) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]31, 2[*][*]31[)]'): + GremlinType.Int.value.is_allowed(-(2**31+1)) + + def test_long_type(self) -> None: + a_long = 2**31 + 1 + self.assertEqual('2147483649', GremlinType.Long.value.format(a_long)) + GremlinType.Long.value.is_allowed(a_long) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]63, 2[*][*]63[)]'): + GremlinType.Long.value.is_allowed('hi') + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]63, 2[*][*]63[)]'): + GremlinType.Long.value.is_allowed(2**63) + with self.assertRaisesRegex(AssertionError, 'expected int in [[]-2[*][*]63, 2[*][*]63[)]'): + GremlinType.Long.value.is_allowed(-(2**63+1)) + + def test_float_type(self) -> None: + a_float = float(4/3) + self.assertEqual('1.3333333333333333', GremlinType.Float.value.format(a_float)) + with self.assertRaisesRegex(AssertionError, 'expected float'): + GremlinType.Float.value.format('hi') + with self.assertRaisesRegex(AssertionError, 'expected float.'): + GremlinType.Float.value.is_allowed('hi') + GremlinType.Float.value.is_allowed(a_float) + + def test_string_type(self) -> None: + a_str = 'hi' + self.assertEqual('hi', GremlinType.String.value.format(a_str)) + with self.assertRaisesRegex(AssertionError, 'expected str'): + GremlinType.String.value.format(10) + with self.assertRaisesRegex(AssertionError, 'expected str'): + GremlinType.String.value.is_allowed(10) + GremlinType.String.value.is_allowed(a_str) + + def test_date_type(self) -> None: + a_datetime = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + self.assertEqual('2020-05-27T10:50:50', GremlinType.Date.value.format(a_datetime)) + self.assertEqual('2020-05-27', GremlinType.Date.value.format(a_datetime.date())) + with self.assertRaisesRegex(AssertionError, 'wat?'): + GremlinType.Date.value.format('2020-05-27') + with self.assertRaisesRegex(AssertionError, 'expected datetime.'): + GremlinType.Date.value.is_allowed('2020-05-27') + with self.assertRaisesRegex(AssertionError, 'expected datetime.'): + GremlinType.Date.value.is_allowed(a_datetime.astimezone(pytz.utc)) + GremlinType.Date.value.is_allowed(a_datetime) + GremlinType.Date.value.is_allowed(a_datetime.date()) + + +class TestVertexType(unittest.TestCase): + def test_as_map(self) -> None: + self.assertIsInstance(VertexTypes.Column.value.properties_as_map(), dict) + + def test_create(self) -> None: + actual = VertexTypes.Column.value.create(key='column_key', name='name', col_type=None) + self.assertSetEqual(set(actual.keys()), + {MagicProperties.LABEL.value.name, MagicProperties.ID.value.name, + WellKnownProperties.TestShard.value.name, 'key', 'name'}) + self.assertEqual(actual.get(MagicProperties.LABEL.value.name), 'Column') + self.assertEqual(actual.get(MagicProperties.ID.value.name), f'{get_shard()}:Column:column_key') + self.assertEqual(actual.get('key'), 'column_key') + + def test_create_type_explodes_if_id_format(self) -> None: + with self.assertRaisesRegex(AssertionError, 'id_format: {shard}:{foo}:bar has parameters:'): + VertexType.construct_type(id_format='{foo}:bar') + + +class TestEdgeType(unittest.TestCase): + def test_as_map(self) -> None: + self.assertIsInstance(EdgeTypes.Column.value.properties_as_map(), dict) + + def test_create(self) -> None: + created_at = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + actual = EdgeTypes.Column.value.create(created=created_at, expired=None, **{ + MagicProperties.FROM.value.name: VertexTypes.Column.value.id(key='column_key'), + MagicProperties.TO.value.name: VertexTypes.Table.value.id(key='table_key'), + }) + self.assertSetEqual( + set(actual.keys()), + set([e.value.name for e in (MagicProperties.LABEL, MagicProperties.ID, MagicProperties.FROM, + MagicProperties.TO, WellKnownProperties.Created)])) + self.assertEqual(actual.get(MagicProperties.LABEL.value.name), 'COLUMN') + self.assertEqual(actual.get(MagicProperties.ID.value.name), + f'COLUMN:2020-05-27T10:50:50:{get_shard()}:Column:column_key->{get_shard()}:Table:table_key') + + def test_create_type_explodes_if_id_format(self) -> None: + with self.assertRaisesRegex(AssertionError, 'id_format: {foo}:bar has parameters:'): + EdgeType.construct_type(id_format='{foo}:bar') + + +class TestProperty(unittest.TestCase): + def test_signature(self) -> None: + expected = Property(name='foo', type=GremlinType.String, cardinality=GremlinCardinality.list) + actual = Property(name='foo', type=GremlinType.String, comment='a bar').signature(GremlinCardinality.list) + self.assertEqual(expected, actual) + + def test_format(self) -> None: + a_datetime = datetime.datetime(2020, 5, 27, 10, 50, 50, 924185) + a_property = Property(name='date', type=GremlinType.Date) + self.assertEqual('2020-05-27T10:50:50', a_property.format(a_datetime)) + + def test_header(self) -> None: + self.assertEqual('foo:Date', Property(name='foo', type=GremlinType.Date).header()) + self.assertEqual('foo:Date(single)', + Property(name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.single).header()) + self.assertEqual('foo:Date(set)', + Property(name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.set).header()) + self.assertEqual('foo:Date(list)', + Property(name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.list).header()) + self.assertEqual('foo:Date(single)[]', Property( + name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.single, multi_valued=True).header()) + self.assertEqual('foo:Date(set)[]', Property( + name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.set, multi_valued=True).header()) + self.assertEqual('foo:Date(list)[]', Property( + name='foo', type=GremlinType.Date, cardinality=GremlinCardinality.list, multi_valued=True).header()) + + def test_magic_header(self) -> None: + self.assertEqual('~label', MagicProperties.LABEL.value.header()) + self.assertEqual('~id', MagicProperties.ID.value.header()) + self.assertEqual('~from', MagicProperties.FROM.value.header()) + self.assertEqual('~to', MagicProperties.TO.value.header()) diff --git a/tests/unit/test_gremlin_shared.py b/tests/unit/test_gremlin_shared.py new file mode 100644 index 0000000..99947c6 --- /dev/null +++ b/tests/unit/test_gremlin_shared.py @@ -0,0 +1,53 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from gremlin_python.process.graph_traversal import __ + +from amundsen_gremlin.gremlin_shared import ( + append_traversal, get_database_name_from_uri, make_cluster_uri, + make_column_uri, make_database_uri, make_description_uri, make_schema_uri, + make_table_uri, rsubstringstartingwith +) + + +class TestGremlinShared(unittest.TestCase): + def test_make_database_uri_et_al(self) -> None: + self.assertEqual('database://BigQuery', make_database_uri(database_name='BigQuery')) + self.assertEqual('BigQuery://neverland', + make_cluster_uri(cluster_name='neverland', database_uri='database://BigQuery')) + self.assertEqual('BigQuery://neverland', + make_cluster_uri(cluster_name='neverland', database_name='BigQuery')) + self.assertEqual('BigQuery://neverland.production', + make_schema_uri(schema_name='production', cluster_uri='BigQuery://neverland')) + self.assertEqual('BigQuery://neverland.production', + make_schema_uri(schema_name='production', cluster_name='neverland', database_name='BigQuery')) + self.assertEqual('BigQuery://neverland.production/lost_boys', + make_table_uri(table_name='lost_boys', schema_uri='BigQuery://neverland.production')) + self.assertEqual('BigQuery://neverland.production/lost_boys', + make_table_uri(table_name='lost_boys', schema_name='production', cluster_name='neverland', + database_name='BigQuery')) + self.assertEqual('BigQuery://neverland.production/lost_boys/peter_pan', + make_column_uri(column_name='peter_pan', + table_uri='BigQuery://neverland.production/lost_boys')) + self.assertEqual('BigQuery://neverland.production/lost_boys/user/_description', + make_description_uri(source='user', + subject_uri='BigQuery://neverland.production/lost_boys')) + self.assertEqual('BigQuery://neverland.production/lost_boys/peter_pan/user/_description', + make_description_uri( + source='user', + subject_uri='BigQuery://neverland.production/lost_boys/peter_pan')) + + def test_get_database_name_from_uri_exceptional(self) -> None: + self.assertEqual(None, rsubstringstartingwith('://', 'foo')) + self.assertEqual('foo', rsubstringstartingwith('database://', 'database://foo')) + with self.assertRaisesRegex(RuntimeError, 'database_uri is malformed! foo'): + get_database_name_from_uri(database_uri='foo') + + def test_append_traversal(self) -> None: + g = __.V().hasLabel('Foo') + w = __.where(__.inE().outV().hasLabel('Bar')) + actual = append_traversal(g, w) + expected = __.V().hasLabel('Foo').where(__.inE().outV().hasLabel('Bar')) + self.assertEqual(actual, expected) diff --git a/tests/unit/test_script_translator.py b/tests/unit/test_script_translator.py new file mode 100644 index 0000000..d2ffadd --- /dev/null +++ b/tests/unit/test_script_translator.py @@ -0,0 +1,89 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import unittest + +from gremlin_python.process.graph_traversal import __, addV, unfold +from gremlin_python.process.traversal import Cardinality + +from amundsen_gremlin.script_translator import ( + ScriptTranslator, ScriptTranslatorTargetJanusgraph, + ScriptTranslatorTargetNeptune +) + + +class ScriptTranslatorTest(unittest.TestCase): + def test_upsert(self) -> None: + g = __.V().has('User', 'key', 'jack').fold().coalesce( + unfold(), + addV('User').property(Cardinality.single, 'key', 'jack')). \ + coalesce(__.has('email', 'jack@example.com'), + __.property(Cardinality.single, 'email', 'jack@example.com')). \ + coalesce(__.has('url', 'https://twitter.com/jack'), + __.property(Cardinality.single, 'url', 'https://twitter.com/jack')) + actual = ScriptTranslator.translateB('g', g) + self.assertEqual(actual, '''g.V().has("User","key","jack").fold().coalesce(__.unfold(),__.addV("User").property(single,"key","jack")).coalesce(__.has("email","jack@example.com"),__.property(single,"email","jack@example.com")).coalesce(__.has("url","https://twitter.com/jack"),__.property(single,"url","https://twitter.com/jack"))''') # noqa: E501 + + def test_string_null(self) -> None: + self.assertEqual(ScriptTranslator._convert_to_string(None), 'null') + + def test_string_bool(self) -> None: + self.assertEqual(ScriptTranslator._convert_to_string(True), 'true') + self.assertEqual(ScriptTranslator._convert_to_string(False), 'false') + + def test_string_char(self) -> None: + # the printables minus the double quote + for c in tuple(' !#$%&()*+,-.0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]^_`abcdefghijklmnopqrstuvwxyz{|}~'): + actual = ScriptTranslator._convert_to_string(c) + self.assertEqual(actual, f'"{c}"') + + def test_string_escaping_char(self) -> None: + for c in '\'\\"': + actual = ScriptTranslator._convert_to_string(c) + self.assertEqual(actual, f'"\\{c}"') + + def test_string_escaping_control_or_unicode(self) -> None: + for input, escaped in zip('\x00\x01\x02\x03\x04\x05\x06\x07\b\t\n\x0b\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16' + + '\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89' + + '\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95', + ('\\u0000', '\\u0001', '\\u0002', '\\u0003', '\\u0004', '\\u0005', '\\u0006', + '\\u0007', '\\b', '\\t', '\\n', '\\u000b', '\\f', '\\r', '\\u000e', '\\u000f', + '\\u0010', '\\u0011', '\\u0012', '\\u0013', '\\u0014', '\\u0015', '\\u0016', + '\\u0017', '\\u0018', '\\u0019', '\\u001a', '\\u001b', '\\u001c', '\\u001d', + '\\u001e', '\\u001f', '\\u007f', '\\u0080', '\\u0081', '\\u0082', '\\u0083', + '\\u0084', '\\u0085', '\\u0086', '\\u0087', '\\u0088', '\\u0089', '\\u008a', + '\\u008b', '\\u008c', '\\u008d', '\\u008e', '\\u008f', '\\u0090', '\\u0091', + '\\u0092', '\\u0093', '\\u0094', '\\u0095')): + actual = ScriptTranslator._convert_to_string(input) + self.assertEqual(actual, f'"{escaped}"') + + def test_string_datetime_zero_millis_janusgraph(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.datetime(2010, 8, 31, 19, 55, 10)) + actual = ScriptTranslatorTargetJanusgraph.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSS").parse("2010-08-31T19:55:10.000000"))''') # noqa: E501 + + def test_string_datetime_some_millis_janusgraph(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.datetime(2010, 8, 31, 19, 55, 10, 123)) + actual = ScriptTranslatorTargetJanusgraph.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSS").parse("2010-08-31T19:55:10.000123"))''') # noqa: E501 + + def test_string_date_janusgraph(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.date(2010, 8, 31)) + actual = ScriptTranslatorTargetJanusgraph.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",new java.text.SimpleDateFormat("yyyy-MM-dd").parse("2010-08-31"))''') # noqa: E501 + + def test_string_datetime_zero_millis_neptune(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.datetime(2010, 8, 31, 19, 55, 10)) + actual = ScriptTranslatorTargetNeptune.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",datetime("2010-08-31T19:55:10"))''') # noqa: E501 + + def test_string_datetime_some_millis_neptune(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.datetime(2010, 8, 31, 19, 55, 10, 123)) + actual = ScriptTranslatorTargetNeptune.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",datetime("2010-08-31T19:55:10.000123"))''') # noqa: E501 + + def test_string_date_neptune(self) -> None: + g = __.property(Cardinality.single, 'created', datetime.date(2010, 8, 31)) + actual = ScriptTranslatorTargetNeptune.translateB('g', g) + self.assertEqual(actual, '''g.property(single,"created",datetime("2010-08-31"))''') # noqa: E501 diff --git a/tests/unit/test_test_and_development_shard.py b/tests/unit/test_test_and_development_shard.py new file mode 100644 index 0000000..88ba254 --- /dev/null +++ b/tests/unit/test_test_and_development_shard.py @@ -0,0 +1,89 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest +from unittest import mock + +from amundsen_common.tests.fixtures import Fixtures + +from amundsen_gremlin.test_and_development_shard import ( + _reset_for_testing_only, _shard_default, get_shard, shard_set_explicitly +) + + +class TestTestShard(unittest.TestCase): + def setUp(self) -> None: + _reset_for_testing_only() + + def tearDown(self) -> None: + _reset_for_testing_only() + + def test_set_shard_works(self) -> None: + expected = Fixtures.next_string() + shard_set_explicitly(expected) + actual = get_shard() + self.assertEqual(expected, actual) + + def test_set_shard_explodes(self) -> None: + expected = get_shard() + with self.assertRaisesRegex(AssertionError, 'can only shard_set_explicitly if it has not been used yet.'): + shard_set_explicitly('x') + actual = get_shard() + self.assertEqual(expected, actual) + + def test_shard_default_ci(self) -> None: + with mock.patch.dict(os.environ): + os.environ['CI'] = 'x' + os.environ['BUILD_PART_ID'] = '12345' + os.environ.pop('DATACENTER', None) + os.environ['USER'] = 'jack' + actual = _shard_default() + self.assertEqual('12345', actual) + + def test_shard_default_local(self) -> None: + with mock.patch.dict(os.environ): + os.environ.pop('CI', None) + os.environ.pop('BUILD_PART_ID', None) + os.environ.pop('DATACENTER', None) + os.environ['USER'] = 'jack' + actual = _shard_default() + self.assertEqual('jack', actual) + + def test_shard_default_also_local(self) -> None: + with mock.patch.dict(os.environ): + os.environ.pop('CI', None) + os.environ.pop('BUILD_PART_ID', None) + os.environ.pop('DATACENTER', 'local') + os.environ['USER'] = 'jack' + actual = _shard_default() + self.assertEqual('jack', actual) + + def test_shard_default_local_explodes(self) -> None: + with mock.patch.dict(os.environ): + os.environ.pop('CI', None) + os.environ.pop('BUILD_PART_ID', None) + os.environ.pop('DATACENTER', None) + os.environ.pop('USER', None) + with self.assertRaisesRegex(AssertionError, 'Expected USER environment variable to be set'): + _shard_default() + + def test_shard_default_environment_production(self) -> None: + with mock.patch.dict(os.environ): + os.environ.pop('CI', None) + os.environ.pop('BUILD_PART_ID', None) + os.environ['DATACENTER'] = 'x' + os.environ['USER'] = 'jack' + os.environ['ENVIRONMENT'] = 'production' + actual = _shard_default() + self.assertIsNone(actual) + + def test_shard_default_environment_staging(self) -> None: + with mock.patch.dict(os.environ): + os.environ.pop('CI', None) + os.environ.pop('BUILD_PART_ID', None) + os.environ['DATACENTER'] = 'x' + os.environ['USER'] = 'jack' + os.environ['ENVIRONMENT'] = 'staging' + actual = _shard_default() + self.assertIsNone(actual) diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/tests/unit/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/utils/test_streams.py b/tests/unit/utils/test_streams.py new file mode 100644 index 0000000..e41f221 --- /dev/null +++ b/tests/unit/utils/test_streams.py @@ -0,0 +1,205 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import logging +import sys +import unittest +from typing import AsyncIterator, Iterable +from unittest.mock import Mock, call + +import pytest + +from amundsen_gremlin.utils.streams import ( + PeekingIterator, _assure_collection, async_consume_in_chunks, + consume_in_chunks, consume_in_chunks_with_state, one_chunk, + reduce_in_chunks +) + + +class TestConsumer(unittest.TestCase): + def test_consume_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + parent = Mock() + parent.values = values + parent.consumer = consumer + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + count = consume_in_chunks(stream=stream(), n=2, consumer=consumer) + self.assertEqual(count, 5) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1)), + call.values(), call.values(), call.consumer((2, 3)), call.consumer((4,))], + parent.mock_calls) + + def test_consume_in_chunks_with_exception(self) -> None: + consumer = Mock() + + def stream() -> Iterable[int]: + yield from range(10) + raise KeyError('hi') + + with self.assertRaisesRegex(KeyError, 'hi'): + consume_in_chunks(stream=stream(), n=4, consumer=consumer) + self.assertSequenceEqual([call.consumer((0, 1, 2, 3)), call.consumer((4, 5, 6, 7))], consumer.mock_calls) + + def test_consume_in_chunks_with_state(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + consumer.side_effect = list(range(1, 4)) + state = Mock() + state.side_effect = lambda x: x * 10 + parent = Mock() + parent.values = values + parent.consumer = consumer + parent.state = state + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + result = consume_in_chunks_with_state(stream=stream(), n=2, consumer=consumer, state=state) + self.assertSequenceEqual(tuple(result), (0, 10, 20, 30, 40)) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.state(0), call.state(1), + call.consumer((0, 1)), call.values(), call.values(), call.state(2), call.state(3), + call.consumer((2, 3)), call.state(4), call.consumer((4,))], + parent.mock_calls) + + def test_consume_in_chunks_no_batch(self) -> None: + consumer = Mock() + count = consume_in_chunks(stream=range(100000000), n=-1, consumer=consumer) + self.assertEqual(100000000, count) + consumer.assert_called_once() + + def test_reduce_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + consumer.side_effect = list(range(1, 4)) + parent = Mock() + parent.values = values + parent.consumer = consumer + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + result = reduce_in_chunks(stream=stream(), n=2, initial=0, consumer=consumer) + self.assertEqual(result, 3) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1), 0), + call.values(), call.values(), call.consumer((2, 3), 1), call.consumer((4,), 2)], + parent.mock_calls) + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") + def test_async_consume_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + parent = Mock() + parent.values = values + parent.consumer = consumer + + async def stream() -> AsyncIterator[int]: + for i in range(5): + yield values() + + count = asyncio.run(async_consume_in_chunks(stream=stream(), n=2, consumer=consumer)) + self.assertEqual(5, count, 'count') + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1)), + call.values(), call.values(), call.consumer((2, 3)), call.consumer((4,))], + parent.mock_calls) + + def test_one_chunk_logging(self) -> None: + it = PeekingIterator(range(1, 4)) + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([1], tuple(actual)) + self.assertTrue(has_more) + + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([2], tuple(actual)) + self.assertTrue(has_more) + + with self.assertLogs(logger='amundsen_gremlin.utils.streams', level=logging.ERROR): + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([3], tuple(actual)) + self.assertFalse(has_more) + + def test_assure_collection(self) -> None: + actual = _assure_collection(iter(range(2))) + self.assertIsInstance(actual, tuple) + self.assertEqual((0, 1), actual) + actual = _assure_collection(list(range(2))) + self.assertIsInstance(actual, list) + self.assertEqual([0, 1], actual) + actual = _assure_collection(set(range(2))) + self.assertIsInstance(actual, set) + self.assertEqual({0, 1}, actual) + actual = _assure_collection(frozenset(range(2))) + self.assertIsInstance(actual, frozenset) + self.assertEqual(frozenset({0, 1}), actual) + + +class TestPeekingIterator(unittest.TestCase): + # TODO: it'd be good to test the locking + def test_no_peek(self) -> None: + it = PeekingIterator(range(3)) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + self.assertEqual(2, next(it)) + with self.assertRaises(StopIteration): + next(it) + + def test_peek_is_next(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + self.assertTrue(it.has_more()) + self.assertEqual(0, next(it)) + self.assertTrue(it.has_more()) + self.assertEqual(1, next(it)) + self.assertFalse(it.has_more()) + with self.assertRaises(StopIteration): + next(it) + + def test_peek_repeats(self) -> None: + it = PeekingIterator(range(2)) + for _ in range(100): + self.assertEqual(0, it.peek()) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + + def test_peek_after_exhaustion(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + with self.assertRaises(StopIteration): + next(it) + with self.assertRaises(StopIteration): + it.peek() + self.assertEqual(999, it.peek_default(999)) + + def test_take_peeked(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + it.take_peeked(0) + self.assertEqual(1, next(it)) + with self.assertRaises(StopIteration): + next(it) + + def test_take_peeked_wrong_value(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + with self.assertRaisesRegex(AssertionError, 'expected the peaked value to be the same'): + it.take_peeked(1) + it.take_peeked(0) + self.assertEqual(1, next(it)) + +# TODO: test PeekingAsyncIterator directly From 55cd69d9d1587b7684caa7d3c763b6ae820019f9 Mon Sep 17 00:00:00 2001 From: friendtocephalopods <52580251+friendtocephalopods@users.noreply.github.com> Date: Tue, 22 Sep 2020 14:50:11 -0700 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Daniel McCarney Signed-off-by: Joshua Hoskins --- .hooks/pre-commit | 4 ++-- amundsen_gremlin/gremlin_model.py | 4 ++-- amundsen_gremlin/neptune_bulk_loader/api.py | 2 +- .../neptune_bulk_loader/gremlin_model_converter.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.hooks/pre-commit b/.hooks/pre-commit index a966562..b7b1ea0 100755 --- a/.hooks/pre-commit +++ b/.hooks/pre-commit @@ -23,8 +23,8 @@ echo -e "${green}[Config]: Checking for secrets${NC}" git diff HEAD amundsen_gremlin/config.py | grep -q -c 'amazonaws.com' if [ $? -eq 0 ] then - echo -e "${red}Did you remember to remove your AWS config? If this is s a false alarm, recommit with --no-verify${NC}" + echo -e "${red}Did you remember to remove your AWS config? If this is a false alarm, recommit with --no-verify${NC}" exit 1 else exit 0 -fi \ No newline at end of file +fi diff --git a/amundsen_gremlin/gremlin_model.py b/amundsen_gremlin/gremlin_model.py index fad19e3..23cd33a 100644 --- a/amundsen_gremlin/gremlin_model.py +++ b/amundsen_gremlin/gremlin_model.py @@ -344,7 +344,7 @@ def create(self, **properties: Any) -> Mapping[str, Any]: class VertexTypes(Enum): """ - In general, you will need to reload all your data: 1. change label, 2. if you change the type of a property, + In general, you will need to reload all your data if you: 1. change label, 2. if you change the type of a property, 3. change the effective id_format """ @classmethod @@ -445,7 +445,7 @@ def by_label(cls) -> Mapping[str, "VertexTypes"]: class EdgeTypes(Enum): """ - In general, you will need to reload all your data: 1. change label, 2. if you change the type of a property, + In general, you will need to reload all your data if you: 1. change label, 2. if you change the type of a property, 3. change the effective id_format (e.g. change required) """ @classmethod diff --git a/amundsen_gremlin/neptune_bulk_loader/api.py b/amundsen_gremlin/neptune_bulk_loader/api.py index 0cdf3ca..fdef834 100644 --- a/amundsen_gremlin/neptune_bulk_loader/api.py +++ b/amundsen_gremlin/neptune_bulk_loader/api.py @@ -101,7 +101,7 @@ def override_prepared_request_parameters( region_name=session.region_name, credentials=session.get_credentials()) override_prepared_request(endpoints.gremlin_endpoint().prepare_request(), override_uri=host_to_actually_connect_to) - but note if you are not GETing (or have a payload), perpare_request doesn't *actually* generate sufficient headers + but note if you are not GETing (or have a payload), prepare_request doesn't *actually* generate sufficient headers (despite the fact that it accepts a method) """ http_request_param: Dict[str, Any] = dict(url=request_parameters.uri, headers=request_parameters.headers) diff --git a/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py index 116b0c8..552f7d3 100644 --- a/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py +++ b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py @@ -90,7 +90,7 @@ def _get_key_properties(_type: Union[VertexType, EdgeType]) -> FrozenSet[Propert def _discover_parameters(format_string: str) -> FrozenSet[str]: """ - use this to discover what the parameters to a format string + use this to discover what the parameters to a format string are """ parameters: FrozenSet[str] = frozenset() while True: From a17ed5207f6449369a24d99b3e42340175dbb529 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Tue, 22 Sep 2020 14:56:44 -0700 Subject: [PATCH 3/4] feedback 2 Signed-off-by: Joshua Hoskins --- .gitignore | 2 +- .hooks/pre-commit | 6 ++---- .hooks/pre-push | 8 +++++--- .vscode/settings.json | 5 ----- README.md | 16 ++++++++++++++++ amundsen_gremlin/config.py | 1 + amundsen_gremlin/gremlin_model.py | 5 +---- amundsen_gremlin/neptune_bulk_loader/api.py | 7 ++----- 8 files changed, 28 insertions(+), 22 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index d6ac92a..29b2738 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,4 @@ dmypy.json .pyre/ # Vscode project settings -.vscode/ \ No newline at end of file +.vscode/ diff --git a/.hooks/pre-commit b/.hooks/pre-commit index b7b1ea0..e279cc4 100755 --- a/.hooks/pre-commit +++ b/.hooks/pre-commit @@ -6,10 +6,10 @@ NC='\033[0m' # Sort imports echo -e "${green}[Isort]: Checking Sorting${NC}" -venv/bin/isort -c +isort -c if [ $? -ne 0 ] then - venv/bin/isort --apply + isort --apply echo -e "${red}Sorted imports; recommit${NC}" exit 1 fi @@ -25,6 +25,4 @@ if [ $? -eq 0 ] then echo -e "${red}Did you remember to remove your AWS config? If this is a false alarm, recommit with --no-verify${NC}" exit 1 -else - exit 0 fi diff --git a/.hooks/pre-push b/.hooks/pre-push index d8e9c51..3739a76 100755 --- a/.hooks/pre-push +++ b/.hooks/pre-push @@ -4,7 +4,7 @@ red='\033[0;31m' green='\033[0;32m' NC='\033[0m' -set -e +set -e -o pipefail # Get only the files different on this branch BASE_SHA="$(git merge-base master HEAD)" @@ -20,7 +20,8 @@ then # Run flake8 flake8 . - if [ $? -ne 0 ]; then + if [ $? -ne 0 ] + then echo -e "${red}[Python Style][Error]: Fix the issues and commit again (or commit with --no-verify if you are sure)${NC}" exit 1 fi @@ -28,7 +29,8 @@ then # Run mypy mypy . - if [ $? -ne 0 ]; then + if [ $? -ne 0 ] + then echo -e "${red}[Python Type Checks][Error]: Fix the issues and commit again (or commit with --no-verify if you are sure)${NC}" exit 1 fi diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9160909..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "python.pythonPath": "venv/bin/python", - "python.linting.mypyEnabled": true, - "python.linting.flake8Enabled": true -} diff --git a/README.md b/README.md index 35f2c15..7584d0a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,18 @@ # amundsengremlin Amundsen Gremlin + +## Instructions to configure venv +Virtual environments for python are convenient for avoiding dependency conflicts. +The `venv` module built into python3 is recommended for ease of use, but any managed virtual environment will do. +If you'd like to set up venv in this repo: +```bash +$ venv_path=[path_for_virtual_environment] +$ python3 -m venv $venv_path +$ source $venv_path/bin/activate +$ pip install -r requirements.txt +``` + +If something goes wrong, you can always: +```bash +$ rm -rf $venv_path +``` \ No newline at end of file diff --git a/amundsen_gremlin/config.py b/amundsen_gremlin/config.py index 6d7c631..cce9c3e 100644 --- a/amundsen_gremlin/config.py +++ b/amundsen_gremlin/config.py @@ -21,6 +21,7 @@ class Config: def neptune_url_for_development(*, user: Optional[str] = None) -> Optional[Mapping[str, Any]]: # Hello! If you get here and and your user is not above, ask one of them to borrow theirs. Or add your username # to development_instance_users in terraform/deployments/development/main.tf and terraform apply + # TODO: add terraform files. One stopgap is to manually set up neptune instances for each dev user. return NEPTUNE_URLS_BY_USER[os.getenv('USER', 'nobody')] diff --git a/amundsen_gremlin/gremlin_model.py b/amundsen_gremlin/gremlin_model.py index 23cd33a..0001332 100644 --- a/amundsen_gremlin/gremlin_model.py +++ b/amundsen_gremlin/gremlin_model.py @@ -183,7 +183,7 @@ class WellKnownProperties(Enum): # TODO: move this someplace shared def _discover_parameters(format_string: str) -> FrozenSet[str]: """ - use this to discover what the parameters to a format string + use this to discover what the parameters are to a format string (e.g. what parameters we need for a vertex id) """ parameters: FrozenSet[str] = frozenset() while True: @@ -468,11 +468,9 @@ def expirable(cls: Type["EdgeTypes"]) -> Sequence["EdgeTypes"]: Column = EdgeType.construct_type(label='COLUMN') Database = EdgeType.construct_type(label='DATABASE') Description = EdgeType.construct_type(label='DESCRIPTION') - Grant = EdgeType.construct_type(label='GRANT') Follow = EdgeType.construct_type(label='FOLLOW') Generates = EdgeType.construct_type(label='GENERATES') LastUpdatedAt = EdgeType.construct_type(label='LAST_UPDATED_AT') - Member = EdgeType.construct_type(label='MEMBER') ManagedBy = EdgeType.construct_type(label='MANAGED_BY') Owner = EdgeType.construct_type(label='OWNER') Read = EdgeType.construct_type( @@ -484,7 +482,6 @@ def expirable(cls: Type["EdgeTypes"]) -> Sequence["EdgeTypes"]: Property(name='read_count', type=GremlinType.Long, required=True)]) ReadWrite = EdgeType.construct_type(label='READ_WRITE') ReadOnly = EdgeType.construct_type(label='READ_ONLY') - RequiresAccessTo = EdgeType.construct_type(label='REQUIRES_ACCESS_TO') Schema = EdgeType.construct_type(label='SCHEMA') Source = EdgeType.construct_type(label='SOURCE') Stat = EdgeType.construct_type(label='STAT') diff --git a/amundsen_gremlin/neptune_bulk_loader/api.py b/amundsen_gremlin/neptune_bulk_loader/api.py index fdef834..c5425a3 100644 --- a/amundsen_gremlin/neptune_bulk_loader/api.py +++ b/amundsen_gremlin/neptune_bulk_loader/api.py @@ -206,9 +206,7 @@ class NeptuneBulkLoaderLoadStatus(TypedDict): class BulkLoaderParallelism(Enum): - """ - Literal might be better for this in 3.8? - """ + # TODO: Literal might be better for this in 3.8? LOW = auto() MEDIUM = auto() HIGH = auto() @@ -225,7 +223,7 @@ class BulkLoaderFormat(Enum): N_TRIPLES = 'ntriples' N_QUADS = 'nquads' RDF_XML = 'rdfxml' - TURTLE = 'turtle' # no, this is totaly a coincidence + TURTLE = 'turtle' class NeptuneBulkLoaderApi: @@ -397,7 +395,6 @@ def group_by_class(entities: Mapping[GraphEntityType, Mapping[str, GraphEntity]] vertex_types: Mapping[GraphEntityType, List[GraphEntity]] = defaultdict(list) edge_types: Mapping[GraphEntityType, List[GraphEntity]] = defaultdict(list) for t, es in entities.items(): - # there was a slicker (10 lines shorter) map based implementation in pasta.py but it defeated the type checker if isinstance(t, VertexType): assert not isinstance(t, EdgeType) vertex_types[t].extend(es.values()) From fee660d5dcc674fb0cff15a76b3c7310b927f894 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 23 Sep 2020 09:27:56 -0700 Subject: [PATCH 4/4] add global timestamp Signed-off-by: Joshua Hoskins --- .../neptune_bulk_loader/gremlin_model_converter.py | 4 ++++ .../unit/neptune_bulk_loader/test_gremlin_model_converter.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py index 552f7d3..759e02d 100644 --- a/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py +++ b/amundsen_gremlin/neptune_bulk_loader/gremlin_model_converter.py @@ -499,6 +499,10 @@ def table_entities(cls, *, table_data: List[Table], entities: ENTITIES, existing # distinguishes) # update timestamp + # Amundsen global timestamp + cls._create(VertexTypes.Updatedtimestamp, entities, existing, key='amundsen_updated_timestamp', + latest_timestamp=created_at) + # Table-specific timestamp vertex = cls._create(VertexTypes.Updatedtimestamp, entities, existing, key=table_vertex['key'], latest_timestamp=created_at) cls._create(EdgeTypes.LastUpdatedAt, entities, existing, created=created_at, **{ diff --git a/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py b/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py index 10d3a01..c3af3c3 100644 --- a/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py +++ b/tests/unit/neptune_bulk_loader/test_gremlin_model_converter.py @@ -117,6 +117,7 @@ def side_effect(*args: Any, existing: EXISTING, **kwargs: Any) -> None: {'created': created_at, '~from': {'~label': 'Table', 'key': 'Snowflake://production.esikmo/floes'}, '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/floes'}}], VertexTypes.Updatedtimestamp.value: [ + {'key': 'amundsen_updated_timestamp', 'latest_timestamp': created_at}, {'key': 'Snowflake://production.esikmo/igloo', 'latest_timestamp': created_at}, {'key': 'Snowflake://production.esikmo/floes', 'latest_timestamp': created_at}], EdgeTypes.Column.value: [ @@ -280,6 +281,7 @@ def test_duplicates_ok(self, fetch_existing: Any) -> None: '~to': {'~label': 'Updatedtimestamp', 'key': 'Snowflake://production.esikmo/igloo'}}, ], VertexTypes.Updatedtimestamp.value: [ + {'key': 'amundsen_updated_timestamp', 'latest_timestamp': created_at}, {'key': 'Snowflake://production.esikmo/electric-bugaloo', 'latest_timestamp': created_at}, {'key': 'Snowflake://production.esikmo/igloo', 'latest_timestamp': created_at}, ],