From 6e64f472968d6e91e28a20fd5fa23e47c3e773fd Mon Sep 17 00:00:00 2001 From: tmadlener Date: Tue, 23 Jan 2024 11:43:44 +0100 Subject: [PATCH] [format] Apply black formatting to all python sources --- .flake8 | 2 +- doc/conf.py | 16 +- python/podio/__init__.py | 18 +- python/podio/base_reader.py | 140 +-- python/podio/base_writer.py | 73 +- python/podio/frame.py | 615 +++++------ python/podio/frame_iterator.py | 94 +- python/podio/reading.py | 113 +- python/podio/root_io.py | 115 +- python/podio/sio_io.py | 66 +- python/podio/test_Frame.py | 309 +++--- python/podio/test_Reader.py | 192 ++-- python/podio/test_ReaderRoot.py | 24 +- python/podio/test_ReaderSio.py | 30 +- python/podio/test_utils.py | 22 +- python/podio_class_generator.py | 263 +++-- python/podio_gen/cpp_generator.py | 988 ++++++++++-------- python/podio_gen/generator_base.py | 487 ++++----- python/podio_gen/generator_utils.py | 528 +++++----- python/podio_gen/julia_generator.py | 267 ++--- python/podio_gen/podio_config_reader.py | 987 +++++++++-------- .../test_ClassDefinitionValidator.py | 928 +++++++++------- python/podio_gen/test_DataModelJSONEncoder.py | 108 +- python/podio_gen/test_MemberParser.py | 591 ++++++----- python/podio_schema_evolution.py | 725 +++++++------ tests/write_frame.py | 82 +- tools/podio-dump | 296 +++--- tools/podio-ttree-to-rntuple | 27 +- tools/podio-vis | 203 ++-- 29 files changed, 4485 insertions(+), 3824 deletions(-) diff --git a/.flake8 b/.flake8 index 705c2dc40..58db8675a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -max-line-length = 88 +max-line-length = 99 extend-ignore = E203 per-file-ignores = diff --git a/doc/conf.py b/doc/conf.py index c123ccd4b..97710f354 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -49,7 +49,7 @@ "myst_parser", "breathe", "sphinx_copybutton", - ] +] source_suffix = {".rst": "restructuredtext", ".md": "markdown"} @@ -89,7 +89,7 @@ "collapse_navigation": False, "navigation_depth": 4, "prev_next_buttons_location": None, # no navigation buttons - } +} # -- Doxygen integration with Breathe ----------------------------------------- @@ -110,9 +110,7 @@ os.makedirs("_build/cpp", exist_ok=True) -subprocess.check_call( - ["doxygen", "Doxyfile"], stdout=subprocess.PIPE, cwd=doc_dir, env=env - ) +subprocess.check_call(["doxygen", "Doxyfile"], stdout=subprocess.PIPE, cwd=doc_dir, env=env) cpp_api_index_target = doc_dir / "cpp_api/api.rst" @@ -122,10 +120,10 @@ stdout=subprocess.PIPE, cwd=doc_dir, env=env, - ) +) if not cpp_api_index_target.exists(): - shutil.copyfile(doc_dir / "cpp_api.rst", cpp_api_index_target) + shutil.copyfile(doc_dir / "cpp_api.rst", cpp_api_index_target) print("Done with c++ API doc generation") @@ -148,7 +146,7 @@ "../python", "../*/*test_*.py", # exclude tests "../python/podio_version.py", # exclude convenience module - ] - ) + ] +) print("Done with python API doc generation") diff --git a/python/podio/__init__.py b/python/podio/__init__.py index 4c68b7a58..893758510 100644 --- a/python/podio/__init__.py +++ b/python/podio/__init__.py @@ -4,21 +4,21 @@ # Try to load podio, this is equivalent to trying to load libpodio.so and will # error if libpodio.so is not found but work if it's found try: - from ROOT import podio # noqa: F401 + from ROOT import podio # noqa: F401 except ImportError: - print("Unable to load podio, make sure that libpodio.so is in LD_LIBRARY_PATH") - raise + print("Unable to load podio, make sure that libpodio.so is in LD_LIBRARY_PATH") + raise from .frame import Frame from . import root_io, reading try: - # We try to import the sio bindings which may fail if ROOT is not able to - # load the dictionary. In this case they have most likely not been built and - # we just move on - from . import sio_io + # We try to import the sio bindings which may fail if ROOT is not able to + # load the dictionary. In this case they have most likely not been built and + # we just move on + from . import sio_io except ImportError: - pass + pass __all__ = [ "__version__", @@ -26,4 +26,4 @@ "root_io", "sio_io", "reading", - ] +] diff --git a/python/podio/base_reader.py b/python/podio/base_reader.py index 88d3acc3e..45f0ae45a 100644 --- a/python/podio/base_reader.py +++ b/python/podio/base_reader.py @@ -7,76 +7,76 @@ class BaseReaderMixin: - """Mixin class the defines the base interface of the readers. + """Mixin class the defines the base interface of the readers. - The backend specific readers inherit from here and have to initialize the - following members: - - _reader: The actual reader that is able to read frames - """ - - def __init__(self): - """Initialize common members. - - In inheriting classes this needs to be called **after** the _reader has been - setup. - """ - self._categories = tuple(s.data() for s in self._reader.getAvailableCategories()) - if hasattr(self, '_is_legacy'): - self._is_legacy = getattr(self, '_is_legacy') - else: - self._is_legacy = False # by default assume we are not legacy - - @property - def categories(self): - """Get the available categories from this reader. - - Returns: - tuple(str): The names of the available categories from this reader - """ - return self._categories - - def get(self, category): - """Get an iterator with access functionality for a given category. - - Args: - category (str): The name of the desired category - - Returns: - FrameCategoryIterator: The iterator granting access to all Frames of the - desired category - """ - return FrameCategoryIterator(self._reader, category) - - @property - def is_legacy(self): - """Whether this is a legacy file reader or not. - - Returns: - bool: True if this is a legacy file reader - """ - return self._is_legacy - - @property - def datamodel_definitions(self): - """Get the available datamodel definitions from this reader. - - Returns: - tuple(str): The names of the available datamodel definitions + The backend specific readers inherit from here and have to initialize the + following members: + - _reader: The actual reader that is able to read frames """ - if self._is_legacy: - return () - return tuple(n.c_str() for n in self._reader.getAvailableDatamodels()) - - def get_datamodel_definition(self, edm_name): - """Get the datamodel definition as JSON string. - Args: - str: The name of the datamodel - - Returns: - str: The complete model definition in JSON format. Use, e.g. json.loads - to convert it into a python dictionary. - """ - if self._is_legacy: - return "" - return self._reader.getDatamodelDefinition(edm_name).data() + def __init__(self): + """Initialize common members. + + In inheriting classes this needs to be called **after** the _reader has been + setup. + """ + self._categories = tuple(s.data() for s in self._reader.getAvailableCategories()) + if hasattr(self, "_is_legacy"): + self._is_legacy = getattr(self, "_is_legacy") + else: + self._is_legacy = False # by default assume we are not legacy + + @property + def categories(self): + """Get the available categories from this reader. + + Returns: + tuple(str): The names of the available categories from this reader + """ + return self._categories + + def get(self, category): + """Get an iterator with access functionality for a given category. + + Args: + category (str): The name of the desired category + + Returns: + FrameCategoryIterator: The iterator granting access to all Frames of the + desired category + """ + return FrameCategoryIterator(self._reader, category) + + @property + def is_legacy(self): + """Whether this is a legacy file reader or not. + + Returns: + bool: True if this is a legacy file reader + """ + return self._is_legacy + + @property + def datamodel_definitions(self): + """Get the available datamodel definitions from this reader. + + Returns: + tuple(str): The names of the available datamodel definitions + """ + if self._is_legacy: + return () + return tuple(n.c_str() for n in self._reader.getAvailableDatamodels()) + + def get_datamodel_definition(self, edm_name): + """Get the datamodel definition as JSON string. + + Args: + str: The name of the datamodel + + Returns: + str: The complete model definition in JSON format. Use, e.g. json.loads + to convert it into a python dictionary. + """ + if self._is_legacy: + return "" + return self._reader.getDatamodelDefinition(edm_name).data() diff --git a/python/podio/base_writer.py b/python/podio/base_writer.py index 116066907..7b49ca4b5 100644 --- a/python/podio/base_writer.py +++ b/python/podio/base_writer.py @@ -6,22 +6,23 @@ class AllWriters: - """Class to manage all writers in the program - so that they can be properly finished at the end of the program - """ - writers = [] + """Class to manage all writers in the program + so that they can be properly finished at the end of the program + """ + + writers = [] - def add(self, writer): - """Add a writer to the list of managed writers""" - self.writers.append(writer) + def add(self, writer): + """Add a writer to the list of managed writers""" + self.writers.append(writer) - def finish(self): - """Finish all managed writers""" - for writer in self.writers: - try: - writer._writer.finish() # pylint: disable=protected-access - except AttributeError: - pass + def finish(self): + """Finish all managed writers""" + for writer in self.writers: + try: + writer._writer.finish() # pylint: disable=protected-access + except AttributeError: + pass _all_writers = AllWriters() @@ -29,27 +30,29 @@ def finish(self): class BaseWriterMixin: - """Mixin class that defines the base interface of the writers. - - The backend specific writers inherit from here and have to initialize the - following members: + """Mixin class that defines the base interface of the writers. - - _writer: The actual writer that is able to write frames - """ + The backend specific writers inherit from here and have to initialize the + following members: - def __init__(self): - """Initialize the writer""" - _all_writers.add(self) - - def write_frame(self, frame, category, collections=None): - """Write the given frame under the passed category, optionally limiting the - collections that are written. - - Args: - frame (podio.frame.Frame): The Frame to write - category (str): The category name - collections (optional, default=None): The subset of collections to - write. If None, all collections are written + - _writer: The actual writer that is able to write frames """ - # pylint: disable-next=protected-access - self._writer.writeFrame(frame._frame, category, collections or frame.getAvailableCollections()) + + def __init__(self): + """Initialize the writer""" + _all_writers.add(self) + + def write_frame(self, frame, category, collections=None): + """Write the given frame under the passed category, optionally limiting the + collections that are written. + + Args: + frame (podio.frame.Frame): The Frame to write + category (str): The category name + collections (optional, default=None): The subset of collections to + write. If None, all collections are written + """ + # pylint: disable=protected-access + self._writer.writeFrame( + frame._frame, category, collections or frame.getAvailableCollections() + ) diff --git a/python/podio/frame.py b/python/podio/frame.py index 3886fe28a..094611c82 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -10,332 +10,343 @@ # # We check whether we can actually load the header to not break python bindings # in environments with *ancient* podio versions -if ROOT.gInterpreter.LoadFile('podio/Frame.h') == 0: # noqa: E402 - from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position - _FRAME_HEADER_AVAILABLE = True +if ROOT.gInterpreter.LoadFile("podio/Frame.h") == 0: # noqa: E402 + from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position + + _FRAME_HEADER_AVAILABLE = True else: - _FRAME_HEADER_AVAILABLE = False + _FRAME_HEADER_AVAILABLE = False def _determine_supported_parameter_types(): - """Determine the supported types for the parameters. - - Returns: - tuple(tuple(str, str)): the tuple with the string representation of all - c++ and their corresponding python types that are supported - """ - types_tuple = podio.SupportedGenericDataTypes() - n_types = cppyy.gbl.std.tuple_size[podio.SupportedGenericDataTypes].value - - # Get the python types with the help of cppyy and the STL - py_types = [type(cppyy.gbl.std.get[i](types_tuple)).__name__ for i in range(n_types)] + """Determine the supported types for the parameters. - def _determine_cpp_type(idx_and_type): - """Determine the actual c++ type from the python type name. - - Mainly maps 'str' to 'std::string', and also determines whether a python - 'float' is actually a 'double' or a 'float' in c++. The latter is necessary - since python only has float (corresponding to double in c++) and we - need the exact c++ type + Returns: + tuple(tuple(str, str)): the tuple with the string representation of all + c++ and their corresponding python types that are supported """ - idx, typename = idx_and_type - if typename == 'float': - cpp_type = cppyy.gbl.std.tuple_element[idx, podio.SupportedGenericDataTypes].type - if cppyy.typeid(cpp_type).name() == 'd': - return 'double' - return 'float' - if typename == 'str': - return 'std::string' - return typename - - cpp_types = list(map(_determine_cpp_type, enumerate(py_types))) - return tuple(zip(cpp_types, py_types)) + types_tuple = podio.SupportedGenericDataTypes() + n_types = cppyy.gbl.std.tuple_size[podio.SupportedGenericDataTypes].value + + # Get the python types with the help of cppyy and the STL + py_types = [type(cppyy.gbl.std.get[i](types_tuple)).__name__ for i in range(n_types)] + + def _determine_cpp_type(idx_and_type): + """Determine the actual c++ type from the python type name. + + Mainly maps 'str' to 'std::string', and also determines whether a python + 'float' is actually a 'double' or a 'float' in c++. The latter is necessary + since python only has float (corresponding to double in c++) and we + need the exact c++ type + """ + idx, typename = idx_and_type + if typename == "float": + cpp_type = cppyy.gbl.std.tuple_element[idx, podio.SupportedGenericDataTypes].type + if cppyy.typeid(cpp_type).name() == "d": + return "double" + return "float" + if typename == "str": + return "std::string" + return typename + + cpp_types = list(map(_determine_cpp_type, enumerate(py_types))) + return tuple(zip(cpp_types, py_types)) if _FRAME_HEADER_AVAILABLE: - SUPPORTED_PARAMETER_TYPES = _determine_supported_parameter_types() + SUPPORTED_PARAMETER_TYPES = _determine_supported_parameter_types() def _get_cpp_types(type_str): - """Get all possible c++ types from the passed py_type string.""" - types = list(filter(lambda t: type_str in t, SUPPORTED_PARAMETER_TYPES)) - if not types: - raise ValueError(f'{type_str} cannot be mapped to a valid parameter type') + """Get all possible c++ types from the passed py_type string.""" + types = list(filter(lambda t: type_str in t, SUPPORTED_PARAMETER_TYPES)) + if not types: + raise ValueError(f"{type_str} cannot be mapped to a valid parameter type") - return types + return types def _get_cpp_vector_types(type_str): - """Get the possible std::vector from the passed py_type string.""" - # Gather a list of all types that match the type_str (c++ or python) - types = _get_cpp_types(type_str) - return [f'std::vector<{t}>' for t in map(lambda x: x[0], types)] + """Get the possible std::vector from the passed py_type string.""" + # Gather a list of all types that match the type_str (c++ or python) + types = _get_cpp_types(type_str) + return [f"std::vector<{t}>" for t in map(lambda x: x[0], types)] def _is_collection_base(thing): - """Check whether the passed thing is a podio::CollectionBase - - Args: - thing (any): any object - - Returns: - bool: True if thing is a base of podio::CollectionBase, False otherwise - """ - # Make sure to only instantiate the template with things that cppyy - # understands - if "cppyy" in repr(thing): - return cppyy.gbl.std.is_base_of[cppyy.gbl.podio.CollectionBase, type(thing)].value - return False - - -class Frame: - """Frame class that serves as a container of collection and meta data.""" - - # cppyy implicitly converts empty collections to False in boolean contexts. To - # distinguish between empty and non-existant collection create a nullptr here - # with the correct type that we can compare against - _coll_nullptr = cppyy.bind_object(cppyy.nullptr, 'podio::CollectionBase') - - def __init__(self, data=None): - """Create a Frame. + """Check whether the passed thing is a podio::CollectionBase Args: - data (FrameData, optional): Almost arbitrary FrameData, e.g. from file - """ - # Explicitly check for None here, to not return empty Frames on nullptr data - if data is not None: - self._frame = podio.Frame(data) - else: - self._frame = podio.Frame() - - self._param_key_types = self._get_param_keys_types() - - def getAvailableCollections(self): - """Get the currently available collection (names) from this Frame. + thing (any): any object Returns: - tuple(str): The names of the available collections from this Frame. + bool: True if thing is a base of podio::CollectionBase, False otherwise """ - return tuple(str(s) for s in self._frame.getAvailableCollections()) + # Make sure to only instantiate the template with things that cppyy + # understands + if "cppyy" in repr(thing): + return cppyy.gbl.std.is_base_of[cppyy.gbl.podio.CollectionBase, type(thing)].value + return False - @property - def collections(self): - """Get the currently available collection (names) from this Frame. - Returns: - tuple(str): The names of the available collections from this Frame. - """ - warnings.warn('WARNING: collections is deprecated, use getAvailableCollections()' - ' (like in C++) instead', FutureWarning) - return self.getAvailableCollections() - - def get(self, name): - """Get a collection from the Frame by name. - - Args: - name (str): The name of the desired collection - - Returns: - collection (podio.CollectionBase): The collection stored in the Frame - - Raises: - KeyError: If the collection with the name is not available - """ - collection = self._frame.get(name) - if collection == self._coll_nullptr: - raise KeyError(f"Collection '{name}' is not available") - return collection - - def put(self, collection, name): - """Put the collection into the frame - - The passed collection is "moved" into the Frame, i.e. it cannot be used any - longer after a call to this function. This also means that only objects that - were in the collection at the time of calling this function will be - available afterwards. - - Args: - collection (podio.CollectionBase): The collection to put into the Frame - name (str): The name of the collection - - Returns: - podio.CollectionBase: The reference to the collection that has been put - into the Frame. NOTE: That mutating this collection is not allowed. - - Raises: - ValueError: If collection is not actually a podio.CollectionBase - """ - if not _is_collection_base(collection): - raise ValueError("Can only put podio collections into a Frame") - return self._frame.put(cppyy.gbl.std.move(collection), name) - - @property - def parameters(self): - """Get the currently available parameter names from this Frame. - - Returns: - tuple (str): The names of the available parameters from this Frame. - """ - return tuple(self._param_key_types.keys()) - - def get_parameter(self, name, as_type=None): - """Get the parameter stored under the given name. - - Args: - name (str): The name of the parameter - as_type (str, optional): Type specifier to disambiguate between - parameters with the same name but different types. If there is only - one parameter with a given name, this argument is ignored - - Returns: - int, float, str or list of those: The value of the stored parameter - - Raises: - KeyError: If no parameter is stored under the given name - ValueError: If there are multiple parameters with the same name, but - multiple types and no type specifier to disambiguate between them - has been passed. - """ - def _get_param_value(par_type, name): - par_value = self._frame.getParameter[par_type](name) - if len(par_value) == 1: - return par_value[0] - return list(par_value) - - # This access already raises the KeyError if there is no such parameter - par_type = self._param_key_types[name] - # Exactly one parameter, nothing more to do here - if len(par_type) == 1: - return _get_param_value(par_type[0], name) - - if as_type is None: - raise ValueError(f'{name} parameter has {len(par_type)} different types available, ' - 'but no as_type argument to disambiguate') - - # Get all possible c++ vector types and see if we can unambiguously map them - # to the available types for this parameter - vec_types = _get_cpp_vector_types(as_type) - vec_types = [t for t in vec_types if t in par_type] - if len(vec_types) == 0: - raise ValueError(f'{name} parameter is not available as type {as_type}') - - if len(vec_types) > 1: - raise ValueError(f'{name} parameter cannot be unambiguously mapped to a c++ type with ' - f'{as_type=}. Consider passing in the c++ type instead of the python type') - - return _get_param_value(vec_types[0], name) - - def put_parameter(self, key, value, as_type=None): - """Put a parameter into the Frame. - - Puts a parameter into the Frame after doing some (incomplete) type checks. - If a list is passed the parameter type is determined from looking at the - first element of the list only. Additionally, since python doesn't - differentiate between floats and doubles, floats will always be stored as - doubles by default, use the as_type argument to change this if necessary. - - Args: - key (str): The name of the parameter - value (int, float, str or list of these): The parameter value - as_type (str, optional): Explicitly specify the type that should be used - to put the parameter into the Frame. Python types (e.g. "str") will - be converted to c++ types. This will override any automatic type - deduction that happens otherwise. Note that this will be taken at - pretty much face-value and there are only limited checks for this. - - Raises: - ValueError: If a non-supported parameter type is passed - """ - # For lists we determine the c++ vector type and use that to call the - # correct template overload explicitly - if isinstance(value, (list, tuple)): - type_name = as_type or type(value[0]).__name__ - vec_types = _get_cpp_vector_types(type_name) - if len(vec_types) == 0: - raise ValueError(f"Cannot put a parameter of type {type_name} into a Frame") - - par_type = vec_types[0] - if isinstance(value[0], float): - # Always store floats as doubles from the python side - par_type = par_type.replace("float", "double") - - self._frame.putParameter[par_type](key, value) - else: - if as_type is not None: - cpp_types = _get_cpp_types(as_type) - if len(cpp_types) == 0: - raise ValueError(f"Cannot put a parameter of type {as_type} into a Frame") - self._frame.putParameter[cpp_types[0]](key, value) - - # If we have a single integer, a std::string overload kicks in with higher - # priority than the template for some reason. So we explicitly select the - # correct template here - elif isinstance(value, int): - self._frame.putParameter["int"](key, value) - else: - self._frame.putParameter(key, value) - - self._param_key_types = self._get_param_keys_types() # refresh the cache - - def get_parameters(self): - """Get the complete podio::GenericParameters object stored in this Frame. - - NOTE: This is mainly intended for dumping things, for actually obtaining - parameters please use get_parameter - - Returns: - podio.GenericParameters: The stored generic parameters - """ - # Going via the not entirely intended way here - return self._frame.getParameters() - - def get_param_info(self, name): - """Get the parameter type information stored under the given name. - - Args: - name (str): The parameter name - - Returns: - dict (str: int): The c++-type(s) of the stored parameter and the number of - parameters - - Raise: - KeyError: If no parameter is stored under the given name - """ - # This raises the KeyError if the name is not present - par_types = [t.replace('std::vector<', '').replace('>', '') for t in self._param_key_types[name]] - # Assume that we have one parameter and update the dictionary below in case - # there are more - par_infos = {t: 1 for t in par_types} - for par_type in par_types: - par_value = self.get_parameter(name, as_type=par_type) - if isinstance(par_value, list): - par_infos[par_type] = len(par_value) - - return par_infos - - def _get_param_keys_types(self): - """Initialize the param keys dict for easier lookup of the available parameters. - - Returns: - dict: A dictionary mapping each key to the corresponding c++ type - """ - params = self._frame.getParameters() - keys_dict = {} - for par_type, _ in SUPPORTED_PARAMETER_TYPES: - keys = params.getKeys[par_type]() - for key in keys: - # Make sure to convert to a python string here to not have a dangling - # reference here for the key. - key = str(key) - # In order to support the use case of having the same key for multiple - # types create a list of available types for the key, so that we can - # disambiguate later. Storing a vector here, and check later how - # many elements there actually are to decide whether to return a single - # value or a list - if key not in keys_dict: - keys_dict[key] = [f'std::vector<{par_type}>'] +class Frame: + """Frame class that serves as a container of collection and meta data.""" + + # cppyy implicitly converts empty collections to False in boolean contexts. To + # distinguish between empty and non-existant collection create a nullptr here + # with the correct type that we can compare against + _coll_nullptr = cppyy.bind_object(cppyy.nullptr, "podio::CollectionBase") + + def __init__(self, data=None): + """Create a Frame. + + Args: + data (FrameData, optional): Almost arbitrary FrameData, e.g. from file + """ + # Explicitly check for None here, to not return empty Frames on nullptr data + if data is not None: + self._frame = podio.Frame(data) else: - keys_dict[key].append(f'std::vector<{par_type}>') - - return keys_dict + self._frame = podio.Frame() + + self._param_key_types = self._get_param_keys_types() + + def getAvailableCollections(self): + """Get the currently available collection (names) from this Frame. + + Returns: + tuple(str): The names of the available collections from this Frame. + """ + return tuple(str(s) for s in self._frame.getAvailableCollections()) + + @property + def collections(self): + """Get the currently available collection (names) from this Frame. + + Returns: + tuple(str): The names of the available collections from this Frame. + """ + warnings.warn( + "WARNING: collections is deprecated, use getAvailableCollections()" + " (like in C++) instead", + FutureWarning, + ) + return self.getAvailableCollections() + + def get(self, name): + """Get a collection from the Frame by name. + + Args: + name (str): The name of the desired collection + + Returns: + collection (podio.CollectionBase): The collection stored in the Frame + + Raises: + KeyError: If the collection with the name is not available + """ + collection = self._frame.get(name) + if collection == self._coll_nullptr: + raise KeyError(f"Collection '{name}' is not available") + return collection + + def put(self, collection, name): + """Put the collection into the frame + + The passed collection is "moved" into the Frame, i.e. it cannot be used any + longer after a call to this function. This also means that only objects that + were in the collection at the time of calling this function will be + available afterwards. + + Args: + collection (podio.CollectionBase): The collection to put into the Frame + name (str): The name of the collection + + Returns: + podio.CollectionBase: The reference to the collection that has been put + into the Frame. NOTE: That mutating this collection is not allowed. + + Raises: + ValueError: If collection is not actually a podio.CollectionBase + """ + if not _is_collection_base(collection): + raise ValueError("Can only put podio collections into a Frame") + return self._frame.put(cppyy.gbl.std.move(collection), name) + + @property + def parameters(self): + """Get the currently available parameter names from this Frame. + + Returns: + tuple (str): The names of the available parameters from this Frame. + """ + return tuple(self._param_key_types.keys()) + + def get_parameter(self, name, as_type=None): + """Get the parameter stored under the given name. + + Args: + name (str): The name of the parameter + as_type (str, optional): Type specifier to disambiguate between + parameters with the same name but different types. If there is only + one parameter with a given name, this argument is ignored + + Returns: + int, float, str or list of those: The value of the stored parameter + + Raises: + KeyError: If no parameter is stored under the given name + ValueError: If there are multiple parameters with the same name, but + multiple types and no type specifier to disambiguate between them + has been passed. + """ + + def _get_param_value(par_type, name): + par_value = self._frame.getParameter[par_type](name) + if len(par_value) == 1: + return par_value[0] + return list(par_value) + + # This access already raises the KeyError if there is no such parameter + par_type = self._param_key_types[name] + # Exactly one parameter, nothing more to do here + if len(par_type) == 1: + return _get_param_value(par_type[0], name) + + if as_type is None: + raise ValueError( + f"{name} parameter has {len(par_type)} different types available, " + "but no as_type argument to disambiguate" + ) + + # Get all possible c++ vector types and see if we can unambiguously map them + # to the available types for this parameter + vec_types = _get_cpp_vector_types(as_type) + vec_types = [t for t in vec_types if t in par_type] + if len(vec_types) == 0: + raise ValueError(f"{name} parameter is not available as type {as_type}") + + if len(vec_types) > 1: + raise ValueError( + f"{name} parameter cannot be unambiguously mapped to a c++ type with " + f"{as_type=}. Consider passing in the c++ type instead of the python type" + ) + + return _get_param_value(vec_types[0], name) + + def put_parameter(self, key, value, as_type=None): + """Put a parameter into the Frame. + + Puts a parameter into the Frame after doing some (incomplete) type checks. + If a list is passed the parameter type is determined from looking at the + first element of the list only. Additionally, since python doesn't + differentiate between floats and doubles, floats will always be stored as + doubles by default, use the as_type argument to change this if necessary. + + Args: + key (str): The name of the parameter + value (int, float, str or list of these): The parameter value + as_type (str, optional): Explicitly specify the type that should be used + to put the parameter into the Frame. Python types (e.g. "str") will + be converted to c++ types. This will override any automatic type + deduction that happens otherwise. Note that this will be taken at + pretty much face-value and there are only limited checks for this. + + Raises: + ValueError: If a non-supported parameter type is passed + """ + # For lists we determine the c++ vector type and use that to call the + # correct template overload explicitly + if isinstance(value, (list, tuple)): + type_name = as_type or type(value[0]).__name__ + vec_types = _get_cpp_vector_types(type_name) + if len(vec_types) == 0: + raise ValueError(f"Cannot put a parameter of type {type_name} into a Frame") + + par_type = vec_types[0] + if isinstance(value[0], float): + # Always store floats as doubles from the python side + par_type = par_type.replace("float", "double") + + self._frame.putParameter[par_type](key, value) + else: + if as_type is not None: + cpp_types = _get_cpp_types(as_type) + if len(cpp_types) == 0: + raise ValueError(f"Cannot put a parameter of type {as_type} into a Frame") + self._frame.putParameter[cpp_types[0]](key, value) + + # If we have a single integer, a std::string overload kicks in with higher + # priority than the template for some reason. So we explicitly select the + # correct template here + elif isinstance(value, int): + self._frame.putParameter["int"](key, value) + else: + self._frame.putParameter(key, value) + + self._param_key_types = self._get_param_keys_types() # refresh the cache + + def get_parameters(self): + """Get the complete podio::GenericParameters object stored in this Frame. + + NOTE: This is mainly intended for dumping things, for actually obtaining + parameters please use get_parameter + + Returns: + podio.GenericParameters: The stored generic parameters + """ + # Going via the not entirely intended way here + return self._frame.getParameters() + + def get_param_info(self, name): + """Get the parameter type information stored under the given name. + + Args: + name (str): The parameter name + + Returns: + dict (str: int): The c++-type(s) of the stored parameter and the number of + parameters + + Raise: + KeyError: If no parameter is stored under the given name + """ + # This raises the KeyError if the name is not present + par_types = [ + t.replace("std::vector<", "").replace(">", "") for t in self._param_key_types[name] + ] + # Assume that we have one parameter and update the dictionary below in case + # there are more + par_infos = {t: 1 for t in par_types} + for par_type in par_types: + par_value = self.get_parameter(name, as_type=par_type) + if isinstance(par_value, list): + par_infos[par_type] = len(par_value) + + return par_infos + + def _get_param_keys_types(self): + """Initialize the param keys dict for easier lookup of the available parameters. + + Returns: + dict: A dictionary mapping each key to the corresponding c++ type + """ + params = self._frame.getParameters() + keys_dict = {} + for par_type, _ in SUPPORTED_PARAMETER_TYPES: + keys = params.getKeys[par_type]() + for key in keys: + # Make sure to convert to a python string here to not have a dangling + # reference here for the key. + key = str(key) + # In order to support the use case of having the same key for multiple + # types create a list of available types for the key, so that we can + # disambiguate later. Storing a vector here, and check later how + # many elements there actually are to decide whether to return a single + # value or a list + if key not in keys_dict: + keys_dict[key] = [f"std::vector<{par_type}>"] + else: + keys_dict[key].append(f"std::vector<{par_type}>") + + return keys_dict diff --git a/python/podio/frame_iterator.py b/python/podio/frame_iterator.py index d82ad2d70..1fa97d828 100644 --- a/python/podio/frame_iterator.py +++ b/python/podio/frame_iterator.py @@ -7,59 +7,63 @@ class FrameCategoryIterator: - """Iterator for iterating over all Frames of a given category available from a - reader as well as accessing specific entries - """ + """Iterator for iterating over all Frames of a given category available from a + reader as well as accessing specific entries + """ - def __init__(self, reader, category): - """Construct the iterator from the reader and the category. + def __init__(self, reader, category): + """Construct the iterator from the reader and the category. - Args: - reader (Reader): Any podio reader offering access to Frames - category (str): The category name of the Frames to be iterated over - """ - self._reader = reader - self._category = category + Args: + reader (Reader): Any podio reader offering access to Frames + category (str): The category name of the Frames to be iterated over + """ + self._reader = reader + self._category = category - def __iter__(self): - """The trivial implementation for the iterator protocol.""" - return self + def __iter__(self): + """The trivial implementation for the iterator protocol.""" + return self - def __next__(self): - """Get the next available Frame or stop.""" - frame_data = self._reader.readNextEntry(self._category) - if frame_data: - return Frame(std.move(frame_data)) + def __next__(self): + """Get the next available Frame or stop.""" + frame_data = self._reader.readNextEntry(self._category) + if frame_data: + return Frame(std.move(frame_data)) - raise StopIteration + raise StopIteration - def __len__(self): - """Get the number of available Frames for the passed category.""" - return self._reader.getEntries(self._category) + def __len__(self): + """Get the number of available Frames for the passed category.""" + return self._reader.getEntries(self._category) - def __getitem__(self, entry): - """Get a specific entry. + def __getitem__(self, entry): + """Get a specific entry. - Args: - entry (int): The entry to access - """ - # Handle python negative indexing to start from the end - if entry < 0: - entry = self._reader.getEntries(self._category) + entry + Args: + entry (int): The entry to access + """ + # Handle python negative indexing to start from the end + if entry < 0: + entry = self._reader.getEntries(self._category) + entry + + if entry < 0: + # If we are below 0 now, we do not have enough entries to serve the request + raise IndexError - if entry < 0: - # If we are below 0 now, we do not have enough entries to serve the request - raise IndexError + try: + frame_data = self._reader.readEntry(self._category, entry) + except std.bad_function_call: + print( + "Error: Unable to read an entry of the input file. This can " + "happen when the ROOT model dictionaries are not in " + "LD_LIBRARY_PATH. Make sure that LD_LIBRARY_PATH points to the " + "library folder of the installation of podio and also to the" + "library folder with your data model\n" + ) + raise - try: - frame_data = self._reader.readEntry(self._category, entry) - except std.bad_function_call: - print('Error: Unable to read an entry of the input file. This can happen when the ' - 'ROOT model dictionaries are not in LD_LIBRARY_PATH. Make sure that LD_LIBRARY_PATH ' - 'points to the library folder of the installation of podio and also to the library ' - 'folder with your data model\n') - raise - if frame_data: - return Frame(std.move(frame_data)) + if frame_data: + return Frame(std.move(frame_data)) - raise IndexError + raise IndexError diff --git a/python/podio/reading.py b/python/podio/reading.py index 357e37151..0ae43facf 100644 --- a/python/podio/reading.py +++ b/python/podio/reading.py @@ -4,73 +4,78 @@ from ROOT import TFile from podio import root_io + try: - from podio import sio_io + from podio import sio_io - def _is_frame_sio_file(filename): - """Peek into the sio file to determine whether this is a legacy file or not.""" - with open(filename, 'rb') as sio_file: - first_line = str(sio_file.readline()) - # The SIO Frame writer writes a podio_header_info at the beginning of the - # file - return first_line.find('podio_header_info') > 0 + def _is_frame_sio_file(filename): + """Peek into the sio file to determine whether this is a legacy file or not.""" + with open(filename, "rb") as sio_file: + first_line = str(sio_file.readline()) + # The SIO Frame writer writes a podio_header_info at the beginning of the + # file + return first_line.find("podio_header_info") > 0 except ImportError: - def _is_frame_sio_file(filename): - """Stub raising a ValueError""" - raise ValueError('podio has not been built with SIO support, ' - 'which is necessary to read this file, ' - 'or there is a version mismatch') + + def _is_frame_sio_file(filename): + """Stub raising a ValueError""" + raise ValueError( + "podio has not been built with SIO support, " + "which is necessary to read this file, " + "or there is a version mismatch" + ) class RootFileFormat: - """Enum to specify the ROOT file format""" - TTREE = 0 # Non-legacy TTree based file - RNTUPLE = 1 # RNTuple based file - LEGACY = 2 # Legacy TTree based file + """Enum to specify the ROOT file format""" + + TTREE = 0 # Non-legacy TTree based file + RNTUPLE = 1 # RNTuple based file + LEGACY = 2 # Legacy TTree based file def _determine_root_format(filename): - """Peek into the root file to determine which flavor we have at hand.""" - file = TFile.Open(filename) + """Peek into the root file to determine which flavor we have at hand.""" + file = TFile.Open(filename) - metadata = file.Get("podio_metadata") - if not metadata: - return RootFileFormat.LEGACY + metadata = file.Get("podio_metadata") + if not metadata: + return RootFileFormat.LEGACY - md_class = metadata.IsA().GetName() - if "TTree" in md_class: - return RootFileFormat.TTREE + md_class = metadata.IsA().GetName() + if "TTree" in md_class: + return RootFileFormat.TTREE - return RootFileFormat.RNTUPLE + return RootFileFormat.RNTUPLE def get_reader(filename): - """Get an appropriate reader for the passed file. - - Args: - filename (str): The input file - - Returns: - root_io.[Legacy]Reader, sio_io.[Legacy]Reader: an initialized reader that - is able to process the input file. - - Raises: - ValueError: If the file cannot be recognized, or if podio has not been - built with the necessary backend I/O support - """ - if filename.endswith('.sio'): - if _is_frame_sio_file(filename): - return sio_io.Reader(filename) - return sio_io.LegacyReader(filename) - - if filename.endswith('.root'): - root_flavor = _determine_root_format(filename) - if root_flavor == RootFileFormat.TTREE: - return root_io.Reader(filename) - if root_flavor == RootFileFormat.RNTUPLE: - return root_io.RNTupleReader(filename) - if root_flavor == RootFileFormat.LEGACY: - return root_io.LegacyReader(filename) - - raise ValueError('file must end on .root or .sio') + """Get an appropriate reader for the passed file. + + Args: + filename (str): The input file + + Returns: + root_io.[Legacy]Reader, sio_io.[Legacy]Reader: an initialized reader that + is able to process the input file. + + Raises: + ValueError: If the file cannot be recognized, or if podio has not been + built with the necessary backend I/O support + """ + if filename.endswith(".sio"): + if _is_frame_sio_file(filename): + return sio_io.Reader(filename) + return sio_io.LegacyReader(filename) + + if filename.endswith(".root"): + root_flavor = _determine_root_format(filename) + if root_flavor == RootFileFormat.TTREE: + return root_io.Reader(filename) + if root_flavor == RootFileFormat.RNTUPLE: + return root_io.RNTupleReader(filename) + if root_flavor == RootFileFormat.LEGACY: + return root_io.LegacyReader(filename) + + raise ValueError("file must end on .root or .sio") diff --git a/python/podio/root_io.py b/python/podio/root_io.py index 0fcc75c83..51416427d 100644 --- a/python/podio/root_io.py +++ b/python/podio/root_io.py @@ -2,91 +2,94 @@ """Python module for reading root files containing podio Frames""" from ROOT import gSystem -gSystem.Load('libpodioRootIO') # noqa: E402 + +gSystem.Load("libpodioRootIO") # noqa: E402 from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position -from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position -from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position +from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position # noqa: E402 +from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position # noqa: E402 class Reader(BaseReaderMixin): - """Reader class for reading podio root files.""" + """Reader class for reading podio root files.""" - def __init__(self, filenames): - """Create a reader that reads from the passed file(s). + def __init__(self, filenames): + """Create a reader that reads from the passed file(s). - Args: - filenames (str or list[str]): file(s) to open and read data from - """ - if isinstance(filenames, str): - filenames = (filenames,) + Args: + filenames (str or list[str]): file(s) to open and read data from + """ + if isinstance(filenames, str): + filenames = (filenames,) - self._reader = podio.ROOTFrameReader() - self._reader.openFiles(filenames) + self._reader = podio.ROOTFrameReader() + self._reader.openFiles(filenames) - super().__init__() + super().__init__() class RNTupleReader(BaseReaderMixin): - """Reader class for reading podio RNTuple root files.""" + """Reader class for reading podio RNTuple root files.""" - def __init__(self, filenames): - """Create an RNTuple reader that reads from the passed file(s). + def __init__(self, filenames): + """Create an RNTuple reader that reads from the passed file(s). - Args: - filenames (str or list[str]): file(s) to open and read data from - """ - if isinstance(filenames, str): - filenames = (filenames,) + Args: + filenames (str or list[str]): file(s) to open and read data from + """ + if isinstance(filenames, str): + filenames = (filenames,) - self._reader = podio.RNTupleReader() - self._reader.openFiles(filenames) + self._reader = podio.RNTupleReader() + self._reader.openFiles(filenames) - super().__init__() + super().__init__() class LegacyReader(BaseReaderMixin): - """Reader class for reading legacy podio root files. + """Reader class for reading legacy podio root files. - This reader can be used to read files that have not yet been written using - Frame based I/O into Frames for a more seamless transition. - """ + This reader can be used to read files that have not yet been written using + Frame based I/O into Frames for a more seamless transition. + """ - def __init__(self, filenames): - """Create a reader that reads from the passed file(s). + def __init__(self, filenames): + """Create a reader that reads from the passed file(s). - Args: - filenames (str or list[str]): file(s) to open and read data from - """ - if isinstance(filenames, str): - filenames = (filenames,) + Args: + filenames (str or list[str]): file(s) to open and read data from + """ + if isinstance(filenames, str): + filenames = (filenames,) - self._reader = podio.ROOTLegacyReader() - self._reader.openFiles(filenames) - self._is_legacy = True + self._reader = podio.ROOTLegacyReader() + self._reader.openFiles(filenames) + self._is_legacy = True - super().__init__() + super().__init__() class Writer(BaseWriterMixin): - """Writer class for writing podio root files""" - def __init__(self, filename): - """Create a writer for writing files + """Writer class for writing podio root files""" - Args: - filename (str): The name of the output file - """ - self._writer = podio.ROOTFrameWriter(filename) - super().__init__() + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.ROOTFrameWriter(filename) + super().__init__() class RNTupleWriter(BaseWriterMixin): - """Writer class for writing podio root files""" - def __init__(self, filename): - """Create a writer for writing files + """Writer class for writing podio root files""" - Args: - filename (str): The name of the output file - """ - self._writer = podio.RNTupleWriter(filename) - super().__init__() + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.RNTupleWriter(filename) + super().__init__() diff --git a/python/podio/sio_io.py b/python/podio/sio_io.py index c0caf63ae..f876b0475 100644 --- a/python/podio/sio_io.py +++ b/python/podio/sio_io.py @@ -2,10 +2,11 @@ """Python module for reading sio files containing podio Frames""" from ROOT import gSystem + if gSystem.DynamicPathName("libpodioSioIO.so", True): - gSystem.Load('libpodioSioIO') # noqa: 402 + gSystem.Load("libpodioSioIO") # noqa: 402 else: - raise ImportError('Error when importing libpodioSioIO') + raise ImportError("Error when importing libpodioSioIO") from ROOT import podio # noqa: 402 # pylint: disable=wrong-import-position from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position @@ -13,47 +14,48 @@ class Reader(BaseReaderMixin): - """Reader class for reading podio SIO files.""" + """Reader class for reading podio SIO files.""" - def __init__(self, filename): - """Create a reader that reads from the passed file. + def __init__(self, filename): + """Create a reader that reads from the passed file. - Args: - filename (str): File to open and read data from - """ - self._reader = podio.SIOFrameReader() - self._reader.openFile(filename) + Args: + filename (str): File to open and read data from + """ + self._reader = podio.SIOFrameReader() + self._reader.openFile(filename) - super().__init__() + super().__init__() class LegacyReader(BaseReaderMixin): - """Reader class for reading legacy podio sio files. + """Reader class for reading legacy podio sio files. - This reader can be used to read files that have not yet been written using the - Frame based I/O into Frames for a more seamless transition. - """ + This reader can be used to read files that have not yet been written using the + Frame based I/O into Frames for a more seamless transition. + """ - def __init__(self, filename): - """Create a reader that reads from the passed file. + def __init__(self, filename): + """Create a reader that reads from the passed file. - Args: - filename (str): File to open and read data from - """ - self._reader = podio.SIOLegacyReader() - self._reader.openFile(filename) - self._is_legacy = True + Args: + filename (str): File to open and read data from + """ + self._reader = podio.SIOLegacyReader() + self._reader.openFile(filename) + self._is_legacy = True - super().__init__() + super().__init__() class Writer(BaseWriterMixin): - """Writer class for writing podio root files""" - def __init__(self, filename): - """Create a writer for writing files + """Writer class for writing podio root files""" - Args: - filename (str): The name of the output file - """ - self._writer = podio.SIOFrameWriter(filename) - super().__init__() + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.SIOFrameWriter(filename) + super().__init__() diff --git a/python/podio/test_Frame.py b/python/podio/test_Frame.py index e9a3c1709..08b98baea 100644 --- a/python/podio/test_Frame.py +++ b/python/podio/test_Frame.py @@ -7,160 +7,191 @@ from ROOT import ExampleHitCollection from podio.frame import Frame + # using root_io as that should always be present regardless of which backends are built from podio.root_io import Reader # The expected collections in each frame EXPECTED_COLL_NAMES = { - 'arrays', 'WithVectorMember', 'info', 'fixedWidthInts', 'mcparticles', - 'moreMCs', 'mcParticleRefs', 'hits', 'hitRefs', 'clusters', 'refs', 'refs2', - 'OneRelation', 'userInts', 'userDoubles', 'WithNamespaceMember', - 'WithNamespaceRelation', 'WithNamespaceRelationCopy', - 'emptyCollection', 'emptySubsetColl' - } + "arrays", + "WithVectorMember", + "info", + "fixedWidthInts", + "mcparticles", + "moreMCs", + "mcParticleRefs", + "hits", + "hitRefs", + "clusters", + "refs", + "refs2", + "OneRelation", + "userInts", + "userDoubles", + "WithNamespaceMember", + "WithNamespaceRelation", + "WithNamespaceRelationCopy", + "emptyCollection", + "emptySubsetColl", +} # The expected collections from the extension (only present in the other_events category) EXPECTED_EXTENSION_COLL_NAMES = { - "extension_Contained", "extension_ExternalComponent", "extension_ExternalRelation", - "VectorMemberSubsetColl" - } + "extension_Contained", + "extension_ExternalComponent", + "extension_ExternalRelation", + "VectorMemberSubsetColl", +} # The expected parameter names in each frame -EXPECTED_PARAM_NAMES = {'anInt', 'UserEventWeight', 'UserEventName', 'SomeVectorData', 'SomeValue'} +EXPECTED_PARAM_NAMES = { + "anInt", + "UserEventWeight", + "UserEventName", + "SomeVectorData", + "SomeValue", +} class FrameTest(unittest.TestCase): - """General unittests for for python bindings of the Frame""" - def test_frame_invalid_access(self): - """Check that the advertised exceptions are raised on invalid access.""" - # Create an empty Frame here - frame = Frame() - with self.assertRaises(KeyError): - _ = frame.get('NonExistantCollection') - - with self.assertRaises(KeyError): - _ = frame.get_parameter('NonExistantParameter') - - with self.assertRaises(ValueError): - collection = [1, 2, 4] - _ = frame.put(collection, "invalid_collection_type") - - def test_frame_put_collection(self): - """Check that putting a collection works as expected""" - frame = Frame() - self.assertEqual(frame.getAvailableCollections(), tuple()) - - hits = ExampleHitCollection() - hits.create() - hits2 = frame.put(hits, "hits_from_python") - self.assertEqual(frame.getAvailableCollections(), tuple(["hits_from_python"])) - # The original collection is gone at this point, and ideally just leaves an - # empty shell - self.assertEqual(len(hits), 0) - # On the other hand the return value of put has the original content - self.assertEqual(len(hits2), 1) - - def test_frame_put_parameters(self): - """Check that putting a parameter works as expected""" - frame = Frame() - self.assertEqual(frame.parameters, tuple()) - - frame.put_parameter("a_string_param", "a string") - self.assertEqual(frame.parameters, tuple(["a_string_param"])) - self.assertEqual(frame.get_parameter("a_string_param"), "a string") - - frame.put_parameter("float_param", 3.14) - self.assertEqual(frame.get_parameter("float_param"), 3.14) - - frame.put_parameter("int", 42) - self.assertEqual(frame.get_parameter("int"), 42) - - frame.put_parameter("string_vec", ["a", "b", "cd"]) - str_vec = frame.get_parameter("string_vec") - self.assertEqual(len(str_vec), 3) - self.assertEqual(str_vec, ["a", "b", "cd"]) - - frame.put_parameter("more_ints", [1, 2345]) - int_vec = frame.get_parameter("more_ints") - self.assertEqual(len(int_vec), 2) - self.assertEqual(int_vec, [1, 2345]) - - frame.put_parameter("float_vec", [1.23, 4.56, 7.89]) - vec = frame.get_parameter("float_vec", as_type="double") - self.assertEqual(len(vec), 3) - self.assertEqual(vec, [1.23, 4.56, 7.89]) - - frame.put_parameter("real_float_vec", [1.23, 4.56, 7.89], as_type="float") - f_vec = frame.get_parameter("real_float_vec", as_type="float") - self.assertEqual(len(f_vec), 3) - self.assertEqual(vec, [1.23, 4.56, 7.89]) - - frame.put_parameter("float_as_float", 3.14, as_type="float") - self.assertAlmostEqual(frame.get_parameter("float_as_float"), 3.14, places=5) + """General unittests for for python bindings of the Frame""" + + def test_frame_invalid_access(self): + """Check that the advertised exceptions are raised on invalid access.""" + # Create an empty Frame here + frame = Frame() + with self.assertRaises(KeyError): + _ = frame.get("NonExistantCollection") + + with self.assertRaises(KeyError): + _ = frame.get_parameter("NonExistantParameter") + + with self.assertRaises(ValueError): + collection = [1, 2, 4] + _ = frame.put(collection, "invalid_collection_type") + + def test_frame_put_collection(self): + """Check that putting a collection works as expected""" + frame = Frame() + self.assertEqual(frame.getAvailableCollections(), tuple()) + + hits = ExampleHitCollection() + hits.create() + hits2 = frame.put(hits, "hits_from_python") + self.assertEqual(frame.getAvailableCollections(), tuple(["hits_from_python"])) + # The original collection is gone at this point, and ideally just leaves an + # empty shell + self.assertEqual(len(hits), 0) + # On the other hand the return value of put has the original content + self.assertEqual(len(hits2), 1) + + def test_frame_put_parameters(self): + """Check that putting a parameter works as expected""" + frame = Frame() + self.assertEqual(frame.parameters, tuple()) + + frame.put_parameter("a_string_param", "a string") + self.assertEqual(frame.parameters, tuple(["a_string_param"])) + self.assertEqual(frame.get_parameter("a_string_param"), "a string") + + frame.put_parameter("float_param", 3.14) + self.assertEqual(frame.get_parameter("float_param"), 3.14) + + frame.put_parameter("int", 42) + self.assertEqual(frame.get_parameter("int"), 42) + + frame.put_parameter("string_vec", ["a", "b", "cd"]) + str_vec = frame.get_parameter("string_vec") + self.assertEqual(len(str_vec), 3) + self.assertEqual(str_vec, ["a", "b", "cd"]) + + frame.put_parameter("more_ints", [1, 2345]) + int_vec = frame.get_parameter("more_ints") + self.assertEqual(len(int_vec), 2) + self.assertEqual(int_vec, [1, 2345]) + + frame.put_parameter("float_vec", [1.23, 4.56, 7.89]) + vec = frame.get_parameter("float_vec", as_type="double") + self.assertEqual(len(vec), 3) + self.assertEqual(vec, [1.23, 4.56, 7.89]) + + frame.put_parameter("real_float_vec", [1.23, 4.56, 7.89], as_type="float") + f_vec = frame.get_parameter("real_float_vec", as_type="float") + self.assertEqual(len(f_vec), 3) + self.assertEqual(vec, [1.23, 4.56, 7.89]) + + frame.put_parameter("float_as_float", 3.14, as_type="float") + self.assertAlmostEqual(frame.get_parameter("float_as_float"), 3.14, places=5) class FrameReadTest(unittest.TestCase): - """Unit tests for the Frame python bindings for Frames read from file. - - NOTE: The assumption is that the Frame has been written by tests/write_frame.h - """ - def setUp(self): - """Open the file and read in the first frame internally. + """Unit tests for the Frame python bindings for Frames read from file. - Reading only one event/Frame of each category here as looping and other - basic checks are already handled by the Reader tests + NOTE: The assumption is that the Frame has been written by tests/write_frame.h """ - reader = Reader('root_io/example_frame.root') - self.event = reader.get('events')[0] - self.other_event = reader.get('other_events')[7] - - def test_frame_collections(self): - """Check that all expected collections are available.""" - self.assertEqual(set(self.event.getAvailableCollections()), EXPECTED_COLL_NAMES) - self.assertEqual(set(self.other_event.getAvailableCollections()), - EXPECTED_COLL_NAMES.union(EXPECTED_EXTENSION_COLL_NAMES)) - - # Not going over all collections here, as that should all be covered by the - # c++ test cases; Simply picking a few and doing some basic tests - mc_particles = self.event.get('mcparticles') - self.assertEqual(mc_particles.getValueTypeName().data(), 'ExampleMC') - self.assertEqual(len(mc_particles), 10) - self.assertEqual(len(mc_particles[0].daughters()), 4) - - mc_particle_refs = self.event.get('mcParticleRefs') - self.assertTrue(mc_particle_refs.isSubsetCollection()) - self.assertEqual(len(mc_particle_refs), 10) - - fixed_w_ints = self.event.get('fixedWidthInts') - self.assertEqual(len(fixed_w_ints), 3) - # Python has no concept of fixed width integers... - max_vals = fixed_w_ints[0] - self.assertEqual(max_vals.fixedInteger64(), 2**63 - 1) - self.assertEqual(max_vals.fixedU64(), 2**64 - 1) - - def test_frame_parameters(self): - """Check that all expected parameters are available.""" - self.assertEqual(set(self.event.parameters), EXPECTED_PARAM_NAMES) - self.assertEqual(set(self.other_event.parameters), EXPECTED_PARAM_NAMES) - - self.assertEqual(self.event.get_parameter('anInt'), 42) - self.assertEqual(self.other_event.get_parameter('anInt'), 42 + 107) - - self.assertEqual(self.event.get_parameter('UserEventWeight'), 0) - self.assertEqual(self.other_event.get_parameter('UserEventWeight'), 100. * 107) - - self.assertEqual(self.event.get_parameter('UserEventName'), ' event_number_0') - self.assertEqual(self.other_event.get_parameter('UserEventName'), ' event_number_107') - - with self.assertRaises(ValueError): - # Parameter name is available with multiple types - _ = self.event.get_parameter('SomeVectorData') - - with self.assertRaises(ValueError): - # Parameter not available as float (only int and string) - _ = self.event.get_parameter('SomeValue', as_type='float') - - self.assertEqual(self.event.get_parameter('SomeVectorData', as_type='int'), [1, 2, 3, 4]) - self.assertEqual(self.event.get_parameter('SomeVectorData', as_type='str'), ["just", "some", "strings"]) - # as_type='float' will also retrieve double values (if the name is unambiguous) - self.assertEqual(self.event.get_parameter('SomeVectorData', as_type='float'), [0.0, 0.0]) + + def setUp(self): + """Open the file and read in the first frame internally. + + Reading only one event/Frame of each category here as looping and other + basic checks are already handled by the Reader tests + """ + reader = Reader("root_io/example_frame.root") + self.event = reader.get("events")[0] + self.other_event = reader.get("other_events")[7] + + def test_frame_collections(self): + """Check that all expected collections are available.""" + self.assertEqual(set(self.event.getAvailableCollections()), EXPECTED_COLL_NAMES) + self.assertEqual( + set(self.other_event.getAvailableCollections()), + EXPECTED_COLL_NAMES.union(EXPECTED_EXTENSION_COLL_NAMES), + ) + + # Not going over all collections here, as that should all be covered by the + # c++ test cases; Simply picking a few and doing some basic tests + mc_particles = self.event.get("mcparticles") + self.assertEqual(mc_particles.getValueTypeName().data(), "ExampleMC") + self.assertEqual(len(mc_particles), 10) + self.assertEqual(len(mc_particles[0].daughters()), 4) + + mc_particle_refs = self.event.get("mcParticleRefs") + self.assertTrue(mc_particle_refs.isSubsetCollection()) + self.assertEqual(len(mc_particle_refs), 10) + + fixed_w_ints = self.event.get("fixedWidthInts") + self.assertEqual(len(fixed_w_ints), 3) + # Python has no concept of fixed width integers... + max_vals = fixed_w_ints[0] + self.assertEqual(max_vals.fixedInteger64(), 2**63 - 1) + self.assertEqual(max_vals.fixedU64(), 2**64 - 1) + + def test_frame_parameters(self): + """Check that all expected parameters are available.""" + self.assertEqual(set(self.event.parameters), EXPECTED_PARAM_NAMES) + self.assertEqual(set(self.other_event.parameters), EXPECTED_PARAM_NAMES) + + self.assertEqual(self.event.get_parameter("anInt"), 42) + self.assertEqual(self.other_event.get_parameter("anInt"), 42 + 107) + + self.assertEqual(self.event.get_parameter("UserEventWeight"), 0) + self.assertEqual(self.other_event.get_parameter("UserEventWeight"), 100.0 * 107) + + self.assertEqual(self.event.get_parameter("UserEventName"), " event_number_0") + self.assertEqual(self.other_event.get_parameter("UserEventName"), " event_number_107") + + with self.assertRaises(ValueError): + # Parameter name is available with multiple types + _ = self.event.get_parameter("SomeVectorData") + + with self.assertRaises(ValueError): + # Parameter not available as float (only int and string) + _ = self.event.get_parameter("SomeValue", as_type="float") + + self.assertEqual(self.event.get_parameter("SomeVectorData", as_type="int"), [1, 2, 3, 4]) + self.assertEqual( + self.event.get_parameter("SomeVectorData", as_type="str"), + ["just", "some", "strings"], + ) + # as_type='float' will also retrieve double values (if the name is unambiguous) + self.assertEqual(self.event.get_parameter("SomeVectorData", as_type="float"), [0.0, 0.0]) diff --git a/python/podio/test_Reader.py b/python/podio/test_Reader.py index eca8552c8..5671244cc 100644 --- a/python/podio/test_Reader.py +++ b/python/podio/test_Reader.py @@ -3,101 +3,103 @@ class ReaderTestCaseMixin: - """Common unittests for readers. - - Inheriting actual test cases have to inhert from this and unittest.TestCase. - All test cases assume that the files are produced with the tests/write_frame.h - functionaltiy. The following members have to be setup and initialized by the - inheriting test cases: - - reader: a podio reader - """ - def test_categories(self): - """Make sure that the categories are as expected""" - reader_cats = self.reader.categories - self.assertEqual(len(reader_cats), 2) - - for cat in ('events', 'other_events'): - self.assertTrue(cat in reader_cats) - - def test_frame_iterator_valid_category(self): - """Check that the returned iterators returned by Reader.get behave as expected.""" - # NOTE: very basic iterator tests only, content tests are done elsewhere - frames = self.reader.get('other_events') - self.assertEqual(len(frames), 10) - - i = 0 - for frame in self.reader.get('events'): - # Rudimentary check here only to see whether we got the right frame - self.assertEqual(frame.get_parameter('UserEventName'), f' event_number_{i}') - i += 1 - self.assertEqual(i, 10) - - # Out of bound access should not work - with self.assertRaises(IndexError): - _ = frames[10] - with self.assertRaises(IndexError): - _ = frames[-11] - - # Again only rudimentary checks - frame = frames[7] - self.assertEqual(frame.get_parameter('UserEventName'), ' event_number_107') - # Valid negative indexing - frame = frames[-2] - self.assertEqual(frame.get_parameter('UserEventName'), ' event_number_108') - # jumping back again also works - frame = frames[3] - self.assertEqual(frame.get_parameter('UserEventName'), ' event_number_103') - - # Looping starts from where we left, i.e. here we have 6 frames left - i = 0 - for _ in frames: - i += 1 - self.assertEqual(i, 6) - - def test_frame_iterator_invalid_category(self): - """Make sure non existant Frames are handled gracefully""" - non_existant = self.reader.get('non-existant') - self.assertEqual(len(non_existant), 0) - - # Indexed access should obviously not work - with self.assertRaises(IndexError): - _ = non_existant[0] - - # Loops should never be entered - i = 0 - for _ in non_existant: - i += 1 - self.assertEqual(i, 0) + """Common unittests for readers. + + Inheriting actual test cases have to inhert from this and unittest.TestCase. + All test cases assume that the files are produced with the tests/write_frame.h + functionaltiy. The following members have to be setup and initialized by the + inheriting test cases: + - reader: a podio reader + """ + + def test_categories(self): + """Make sure that the categories are as expected""" + reader_cats = self.reader.categories + self.assertEqual(len(reader_cats), 2) + + for cat in ("events", "other_events"): + self.assertTrue(cat in reader_cats) + + def test_frame_iterator_valid_category(self): + """Check that the returned iterators returned by Reader.get behave as expected.""" + # NOTE: very basic iterator tests only, content tests are done elsewhere + frames = self.reader.get("other_events") + self.assertEqual(len(frames), 10) + + i = 0 + for frame in self.reader.get("events"): + # Rudimentary check here only to see whether we got the right frame + self.assertEqual(frame.get_parameter("UserEventName"), f" event_number_{i}") + i += 1 + self.assertEqual(i, 10) + + # Out of bound access should not work + with self.assertRaises(IndexError): + _ = frames[10] + with self.assertRaises(IndexError): + _ = frames[-11] + + # Again only rudimentary checks + frame = frames[7] + self.assertEqual(frame.get_parameter("UserEventName"), " event_number_107") + # Valid negative indexing + frame = frames[-2] + self.assertEqual(frame.get_parameter("UserEventName"), " event_number_108") + # jumping back again also works + frame = frames[3] + self.assertEqual(frame.get_parameter("UserEventName"), " event_number_103") + + # Looping starts from where we left, i.e. here we have 6 frames left + i = 0 + for _ in frames: + i += 1 + self.assertEqual(i, 6) + + def test_frame_iterator_invalid_category(self): + """Make sure non existant Frames are handled gracefully""" + non_existant = self.reader.get("non-existant") + self.assertEqual(len(non_existant), 0) + + # Indexed access should obviously not work + with self.assertRaises(IndexError): + _ = non_existant[0] + + # Loops should never be entered + i = 0 + for _ in non_existant: + i += 1 + self.assertEqual(i, 0) class LegacyReaderTestCaseMixin: - """Common test cases for the legacy readers python bindings. - - These tests assume that input files are produced with the write_test.h header - and that inheriting test cases inherit from unittes.TestCase as well. - Additionally they have to have an initialized reader as a member. - - NOTE: Since the legacy readers also use the BaseReaderMixin, many of the - invalid access test cases are already covered by the ReaderTestCaseMixin and - here we simply focus on the slightly different happy paths - """ - def test_categories(self): - """Make sure the legacy reader returns only one category""" - cats = self.reader.categories - self.assertEqual(("events",), cats) - - def test_frame_iterator(self): - """Make sure the FrameIterator works.""" - frames = self.reader.get('events') - self.assertEqual(len(frames), 2000) - - for i, frame in enumerate(frames): - # Rudimentary check here only to see whether we got the right frame - self.assertEqual(frame.get_parameter('UserEventName'), f' event_number_{i}') - # Only check a few Frames here - if i > 10: - break - - # Index based access - frame = frames[123] - self.assertEqual(frame.get_parameter('UserEventName'), ' event_number_123') + """Common test cases for the legacy readers python bindings. + + These tests assume that input files are produced with the write_test.h header + and that inheriting test cases inherit from unittes.TestCase as well. + Additionally they have to have an initialized reader as a member. + + NOTE: Since the legacy readers also use the BaseReaderMixin, many of the + invalid access test cases are already covered by the ReaderTestCaseMixin and + here we simply focus on the slightly different happy paths + """ + + def test_categories(self): + """Make sure the legacy reader returns only one category""" + cats = self.reader.categories + self.assertEqual(("events",), cats) + + def test_frame_iterator(self): + """Make sure the FrameIterator works.""" + frames = self.reader.get("events") + self.assertEqual(len(frames), 2000) + + for i, frame in enumerate(frames): + # Rudimentary check here only to see whether we got the right frame + self.assertEqual(frame.get_parameter("UserEventName"), f" event_number_{i}") + # Only check a few Frames here + if i > 10: + break + + # Index based access + frame = frames[123] + self.assertEqual(frame.get_parameter("UserEventName"), " event_number_123") diff --git a/python/podio/test_ReaderRoot.py b/python/podio/test_ReaderRoot.py index eeeb2256f..bfa8b0b7e 100644 --- a/python/podio/test_ReaderRoot.py +++ b/python/podio/test_ReaderRoot.py @@ -3,21 +3,27 @@ import unittest -from test_Reader import ReaderTestCaseMixin, LegacyReaderTestCaseMixin # pylint: disable=import-error +# pylint: disable-next=import-error +from test_Reader import ( + ReaderTestCaseMixin, + LegacyReaderTestCaseMixin, +) from podio.test_utils import get_legacy_input from podio.root_io import Reader, LegacyReader class RootReaderTestCase(ReaderTestCaseMixin, unittest.TestCase): - """Test cases for root input files""" - def setUp(self): - """Setup the corresponding reader""" - self.reader = Reader('root_io/example_frame.root') + """Test cases for root input files""" + + def setUp(self): + """Setup the corresponding reader""" + self.reader = Reader("root_io/example_frame.root") class RootLegacyReaderTestCase(LegacyReaderTestCaseMixin, unittest.TestCase): - """Test cases for the legacy root input files and reader.""" - def setUp(self): - """Setup a reader, reading from the example files""" - self.reader = LegacyReader(get_legacy_input("v00-16-06-example.root")) + """Test cases for the legacy root input files and reader.""" + + def setUp(self): + """Setup a reader, reading from the example files""" + self.reader = LegacyReader(get_legacy_input("v00-16-06-example.root")) diff --git a/python/podio/test_ReaderSio.py b/python/podio/test_ReaderSio.py index 5db31a0da..c879dad00 100644 --- a/python/podio/test_ReaderSio.py +++ b/python/podio/test_ReaderSio.py @@ -3,23 +3,31 @@ import unittest -from test_Reader import ReaderTestCaseMixin, LegacyReaderTestCaseMixin # pylint: disable=import-error +# pylint: disable-next=import-error +from test_Reader import ( + ReaderTestCaseMixin, + LegacyReaderTestCaseMixin, +) from podio.test_utils import SKIP_SIO_TESTS, get_legacy_input @unittest.skipIf(SKIP_SIO_TESTS, "no SIO support") class SioReaderTestCase(ReaderTestCaseMixin, unittest.TestCase): - """Test cases for root input files""" - def setUp(self): - """Setup the corresponding reader""" - from podio.sio_io import Reader # pylint: disable=import-outside-toplevel - self.reader = Reader('sio_io/example_frame.sio') + """Test cases for root input files""" + + def setUp(self): + """Setup the corresponding reader""" + from podio.sio_io import Reader # pylint: disable=import-outside-toplevel + + self.reader = Reader("sio_io/example_frame.sio") @unittest.skipIf(SKIP_SIO_TESTS, "no SIO support") class SIOLegacyReaderTestCase(LegacyReaderTestCaseMixin, unittest.TestCase): - """Test cases for the legacy root input files and reader.""" - def setUp(self): - """Setup a reader, reading from the example files""" - from podio.sio_io import LegacyReader # pylint: disable=import-outside-toplevel - self.reader = LegacyReader(get_legacy_input("v00-16-06-example.sio")) + """Test cases for the legacy root input files and reader.""" + + def setUp(self): + """Setup a reader, reading from the example files""" + from podio.sio_io import LegacyReader # pylint: disable=import-outside-toplevel + + self.reader = LegacyReader(get_legacy_input("v00-16-06-example.sio")) diff --git a/python/podio/test_utils.py b/python/podio/test_utils.py index c69de90a6..cc3f329c1 100644 --- a/python/podio/test_utils.py +++ b/python/podio/test_utils.py @@ -7,15 +7,15 @@ def get_legacy_input(filename): - """Try to get a legacy input file by name from the ExternalData that is - fetched by CMake. + """Try to get a legacy input file by name from the ExternalData that is + fetched by CMake. - Returns either the absolute path to the actual file or an empty string. - """ - try: - datafile = os.path.join(os.environ["PODIO_BUILD_BASE"], "tests", "input_files", filename) - if os.path.isfile(datafile): - return os.path.abspath(datafile) - except KeyError: - pass - return "" + Returns either the absolute path to the actual file or an empty string. + """ + try: + datafile = os.path.join(os.environ["PODIO_BUILD_BASE"], "tests", "input_files", filename) + if os.path.isfile(datafile): + return os.path.abspath(datafile) + except KeyError: + pass + return "" diff --git a/python/podio_class_generator.py b/python/podio_class_generator.py index c47a3d822..9254ec155 100755 --- a/python/podio_class_generator.py +++ b/python/podio_class_generator.py @@ -13,113 +13,176 @@ def has_clang_format(): - """Check if clang format is available""" - try: - # This one can raise if -fallback-style is not found - out = subprocess.check_output(["clang-format", "-style=file", "-fallback-style=llvm", "--help"], - stderr=subprocess.STDOUT) - # This one doesn't raise - out = subprocess.check_output('echo | clang-format -style=file ', stderr=subprocess.STDOUT, shell=True) - if b'.clang-format' in out: - return False - return True - except FileNotFoundError: - print("ERROR: Cannot find clang-format executable") - print(" Please make sure it is in the PATH.") - return False - except subprocess.CalledProcessError: - print('ERROR: At least one argument was not recognized by clang-format') - print(' Most likely the version you are using is old') - return False + """Check if clang format is available""" + try: + # This one can raise if -fallback-style is not found + out = subprocess.check_output( + ["clang-format", "-style=file", "-fallback-style=llvm", "--help"], + stderr=subprocess.STDOUT, + ) + # This one doesn't raise + out = subprocess.check_output( + "echo | clang-format -style=file ", stderr=subprocess.STDOUT, shell=True + ) + if b".clang-format" in out: + return False + return True + except FileNotFoundError: + print("ERROR: Cannot find clang-format executable") + print(" Please make sure it is in the PATH.") + return False + except subprocess.CalledProcessError: + print("ERROR: At least one argument was not recognized by clang-format") + print(" Most likely the version you are using is old") + return False def clang_format_file(content, name): - """Formatter function to run clang-format on generate c++ files""" - if name.endswith(".jl"): - return content + """Formatter function to run clang-format on generate c++ files""" + if name.endswith(".jl"): + return content - clang_format = ["clang-format", "-style=file", "-fallback-style=llvm"] - with subprocess.Popen(clang_format, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as cfproc: - return cfproc.communicate(input=content.encode())[0].decode() + clang_format = ["clang-format", "-style=file", "-fallback-style=llvm"] + with subprocess.Popen(clang_format, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as cfproc: + return cfproc.communicate(input=content.encode())[0].decode() def read_upstream_edm(name_path): - """Read an upstream EDM yaml definition file to make the types that are defined - in that available to the current EDM""" - if name_path is None: - return None - - try: - name, path = name_path.split(':') - except ValueError as err: - raise argparse.ArgumentTypeError('upstream-edm argument needs to be the upstream package ' - 'name and the upstream edm yaml file separated by a colon') from err - - if not os.path.isfile(path): - raise argparse.ArgumentTypeError(f'{path} needs to be an EDM yaml file') - - try: - return PodioConfigReader.read(path, name) - except DefinitionError as err: - raise argparse.ArgumentTypeError(f'{path} does not contain a valid datamodel definition') from err + """Read an upstream EDM yaml definition file to make the types that are defined + in that available to the current EDM""" + if name_path is None: + return None + + try: + name, path = name_path.split(":") + except ValueError as err: + raise argparse.ArgumentTypeError( + "upstream-edm argument needs to be the upstream package " + "name and the upstream edm yaml file separated by a colon" + ) from err + + if not os.path.isfile(path): + raise argparse.ArgumentTypeError(f"{path} needs to be an EDM yaml file") + + try: + return PodioConfigReader.read(path, name) + except DefinitionError as err: + raise argparse.ArgumentTypeError( + f"{path} does not contain a valid datamodel definition" + ) from err if __name__ == "__main__": - import argparse - # pylint: disable=invalid-name # before 2.5.0 pylint is too strict with the naming here - parser = argparse.ArgumentParser(description='Given a description yaml file this script generates ' - 'the necessary c++ or julia files in the target directory') - - parser.add_argument('description', help='yaml file describing the datamodel') - parser.add_argument('targetdir', help='Target directory where the generated data classes will be put. ' - 'Header files will be put under //*.h. ' - 'Source files will be put under /src/*.cc. ' - 'Julia files will be put under //*.jl.') - parser.add_argument('packagename', help='Name of the package.') - parser.add_argument('iohandlers', choices=['ROOT', 'SIO'], nargs='*', - help='The IO backend specific code that should be generated', - default="ROOT") - parser.add_argument('-l', '--lang', choices=['cpp', 'julia'], default='cpp', - help='Specify the programming language (default: cpp)') - parser.add_argument('-q', '--quiet', dest='verbose', action='store_false', default=True, - help='Don\'t write a report to screen') - parser.add_argument('-d', '--dryrun', action='store_true', default=False, - help='Do not actually write datamodel files') - parser.add_argument('-c', '--clangformat', action='store_true', default=False, - help='Apply clang-format when generating code (with -style=file)') - parser.add_argument('--upstream-edm', - help='Make datatypes of this upstream EDM available to the current' - ' EDM. Format is \':\'. ' - 'Note that only the code for the current EDM will be generated', - default=None, type=read_upstream_edm) - parser.add_argument('--old-description', - help='Provide schema evolution relative to the old yaml file.', - default=None, action='store') - parser.add_argument('-e', '--evolution_file', help='yaml file clarifying schema evolutions', - default=None, action='store') - - args = parser.parse_args() - - install_path = args.targetdir - project = args.packagename - - for sub_dir in ('src', project): - directory = os.path.join(install_path, sub_dir) - if not os.path.exists(directory): - os.makedirs(directory) - - if args.lang == "julia": - gen = JuliaClassGenerator(args.description, args.targetdir, args.packagename, - verbose=args.verbose, dryrun=args.dryrun, - upstream_edm=args.upstream_edm) - if args.lang == "cpp": - gen = CPPClassGenerator(args.description, args.targetdir, args.packagename, args.iohandlers, - verbose=args.verbose, dryrun=args.dryrun, upstream_edm=args.upstream_edm, - old_description=args.old_description, evolution_file=args.evolution_file) - - if args.clangformat and has_clang_format(): - gen.formatter_func = clang_format_file - - gen.process() - - # pylint: enable=invalid-name + import argparse + + # pylint: disable=invalid-name # before 2.5.0 pylint is too strict with the naming here + parser = argparse.ArgumentParser( + description="Given a description yaml file this script generates " + "the necessary c++ or julia files in the target directory" + ) + + parser.add_argument("description", help="yaml file describing the datamodel") + parser.add_argument( + "targetdir", + help="Target directory where the generated data classes will be put. " + "Header files will be put under //*.h. " + "Source files will be put under /src/*.cc. " + "Julia files will be put under //*.jl.", + ) + parser.add_argument("packagename", help="Name of the package.") + parser.add_argument( + "iohandlers", + choices=["ROOT", "SIO"], + nargs="*", + help="The IO backend specific code that should be generated", + default="ROOT", + ) + parser.add_argument( + "-l", + "--lang", + choices=["cpp", "julia"], + default="cpp", + help="Specify the programming language (default: cpp)", + ) + parser.add_argument( + "-q", + "--quiet", + dest="verbose", + action="store_false", + default=True, + help="Don't write a report to screen", + ) + parser.add_argument( + "-d", + "--dryrun", + action="store_true", + default=False, + help="Do not actually write datamodel files", + ) + parser.add_argument( + "-c", + "--clangformat", + action="store_true", + default=False, + help="Apply clang-format when generating code (with -style=file)", + ) + parser.add_argument( + "--upstream-edm", + help="Make datatypes of this upstream EDM available to the current" + " EDM. Format is ':'. " + "Note that only the code for the current EDM will be generated", + default=None, + type=read_upstream_edm, + ) + parser.add_argument( + "--old-description", + help="Provide schema evolution relative to the old yaml file.", + default=None, + action="store", + ) + parser.add_argument( + "-e", + "--evolution_file", + help="yaml file clarifying schema evolutions", + default=None, + action="store", + ) + + args = parser.parse_args() + + install_path = args.targetdir + project = args.packagename + + for sub_dir in ("src", project): + directory = os.path.join(install_path, sub_dir) + if not os.path.exists(directory): + os.makedirs(directory) + + if args.lang == "julia": + gen = JuliaClassGenerator( + args.description, + args.targetdir, + args.packagename, + verbose=args.verbose, + dryrun=args.dryrun, + upstream_edm=args.upstream_edm, + ) + if args.lang == "cpp": + gen = CPPClassGenerator( + args.description, + args.targetdir, + args.packagename, + args.iohandlers, + verbose=args.verbose, + dryrun=args.dryrun, + upstream_edm=args.upstream_edm, + old_description=args.old_description, + evolution_file=args.evolution_file, + ) + + if args.clangformat and has_clang_format(): + gen.formatter_func = clang_format_file + + gen.process() + + # pylint: enable=invalid-name diff --git a/python/podio_gen/cpp_generator.py b/python/podio_gen/cpp_generator.py index 82542b151..be762cd59 100644 --- a/python/podio_gen/cpp_generator.py +++ b/python/podio_gen/cpp_generator.py @@ -22,470 +22,544 @@ def replace_component_in_paths(oldname, newname, paths): - """Replace component name by another one in existing paths""" - # strip the namespace - shortoldname = oldname.split("::")[-1] - shortnewname = newname.split("::")[-1] - # and do the replace in place - for index, thePath in enumerate(paths): - if shortoldname in thePath: - newPath = thePath.replace(shortoldname, shortnewname) - paths[index] = newPath + """Replace component name by another one in existing paths""" + # strip the namespace + shortoldname = oldname.split("::")[-1] + shortnewname = newname.split("::")[-1] + # and do the replace in place + for index, thePath in enumerate(paths): + if shortoldname in thePath: + newPath = thePath.replace(shortoldname, shortnewname) + paths[index] = newPath class IncludeFrom(IntEnum): - """Enum to signify if an include is needed and from where it should come""" - NOWHERE = 0 # No include needed - INTERNAL = 1 # include from within the datamodel - EXTERNAL = 2 # include from an upstream datamodel + """Enum to signify if an include is needed and from where it should come""" + + NOWHERE = 0 # No include needed + INTERNAL = 1 # include from within the datamodel + EXTERNAL = 2 # include from an upstream datamodel class CPPClassGenerator(ClassGeneratorBaseMixin): - """The c++ class / code generator for podio""" - def __init__(self, yamlfile, install_dir, package_name, io_handlers, verbose, dryrun, upstream_edm, - old_description, evolution_file): - super().__init__(yamlfile, install_dir, package_name, verbose, dryrun, upstream_edm) - self.io_handlers = io_handlers - - # schema evolution specific code - self.old_yamlfile = old_description - self.evolution_file = evolution_file - self.old_schema_version = None - self.old_schema_version_int = None - self.old_datamodel = None - self.old_datamodels_components = set() - self.old_datamodels_datatypes = set() - self.root_schema_dict = {} # containing the root relevant schema evolution per datatype - # information to update the selection.xml - self.root_schema_component_names = set() - self.root_schema_datatype_names = set() - self.root_schema_iorules = set() - - def pre_process(self): - """The necessary specific pre-processing for cpp code generation""" - self._pre_process_schema_evolution() - return {} - - def post_process(self, _): - """Do the cpp specific post processing""" - self._write_edm_def_file() - if "ROOT" in self.io_handlers: - self._prepare_iorules() - self._create_selection_xml() - self._write_cmake_lists_file() - - def do_process_component(self, name, component): - """Handle everything cpp specific after the common processing of a component""" - includes = set() - includes.update(*(m.includes for m in component['Members'])) - for member in component['Members']: - if not (member.is_builtin or member.is_builtin_array): - includes.add(self._build_include(member)) - - includes.update(component.get("ExtraCode", {}).get("includes", "").split('\n')) - - component['includes'] = self._sort_includes(includes) - - self._fill_templates('Component', component) - # Add potentially older schema for schema evolution - # based on ROOT capabilities for now - if name in self.root_schema_dict: - schema_evolutions = self.root_schema_dict[name] - component = deepcopy(component) - for schema_evolution in schema_evolutions: - if isinstance(schema_evolution, RenamedMember): - for member in component['Members']: - if member.name == schema_evolution.member_name_new: - member.name = schema_evolution.member_name_old - component['class'] = DataType(name + self.old_schema_version) - else: - raise NotImplementedError - self._fill_templates('Component', component) - self.root_schema_component_names.add(name + self.old_schema_version) - - return component - - def do_process_datatype(self, name, datatype): - """Do the cpp specific processing of a datatype""" - datatype["includes_data"] = self._get_member_includes(datatype["Members"]) - self._preprocess_for_class(datatype) - self._preprocess_for_obj(datatype) - self._preprocess_for_collection(datatype) - - # ROOT schema evolution preparation - # Compute and prepare the potential schema evolution parts - schema_evolution_datatype = deepcopy(datatype) - needs_schema_evolution = False - for member in schema_evolution_datatype['Members']: - if member.is_array: - if member.array_type in self.root_schema_dict: - needs_schema_evolution = True - replace_component_in_paths(member.array_type, member.array_type + self.old_schema_version, - schema_evolution_datatype['includes_data']) - member.full_type = member.full_type.replace(member.array_type, member.array_type + self.old_schema_version) - member.array_type = member.array_type + self.old_schema_version - - else: - if member.full_type in self.root_schema_dict: - needs_schema_evolution = True - # prepare the ROOT I/O rule - replace_component_in_paths(member.full_type, member.full_type + self.old_schema_version, - schema_evolution_datatype['includes_data']) - member.full_type = member.full_type + self.old_schema_version - member.bare_type = member.bare_type + self.old_schema_version - - if needs_schema_evolution: - print(f" Preparing explicit schema evolution for {name}") - schema_evolution_datatype['class'].bare_type = schema_evolution_datatype['class'].bare_type + self.old_schema_version # noqa - schema_evolution_datatype["old_schema_version"] = self.old_schema_version_int - self._fill_templates('Data', schema_evolution_datatype) - self.root_schema_datatype_names.add(name + self.old_schema_version) - self._fill_templates('Collection', datatype, schema_evolution_datatype) - else: - self._fill_templates('Collection', datatype) - - self._fill_templates('Data', datatype) - self._fill_templates('Object', datatype) - self._fill_templates('MutableObject', datatype) - self._fill_templates('Obj', datatype) - self._fill_templates('Collection', datatype) - self._fill_templates('CollectionData', datatype) - - if 'SIO' in self.io_handlers: - self._fill_templates('SIOBlock', datatype) - - return datatype - - def do_process_interface(self, _, interface): - """Process an interface definition and generate the necesary code""" - interface["include_types"] = [ - self._build_include(t) for t in interface["Types"] - ] - - self._fill_templates("Interface", interface) - return interface - - def print_report(self): - """Print a summary report about the generated code""" - if not self.verbose: - return - nclasses = 5 * len(self.datamodel.datatypes) + len(self.datamodel.components) - text = REPORT_TEXT.format(yamlfile=self.yamlfile, - nclasses=nclasses, - installdir=self.install_dir) - - for summaryline in text.splitlines(): - print(summaryline) - print() - - def _preprocess_for_class(self, datatype): - """Do the preprocessing that is necessary for the classes and Mutable classes""" - includes = set(datatype['includes_data']) - fwd_declarations = {} - includes_cc = set() - - for member in datatype["Members"]: - if self.expose_pod_members and not member.is_builtin and not member.is_array: - member.sub_members = self.datamodel.components[member.full_type]['Members'] - - for relation in datatype['OneToOneRelations']: - if self._is_interface(relation.full_type): - relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"] - if self._needs_include(relation.full_type): - if relation.namespace not in fwd_declarations: - fwd_declarations[relation.namespace] = [] - fwd_declarations[relation.namespace].append(relation.bare_type) - fwd_declarations[relation.namespace].append('Mutable' + relation.bare_type) - includes_cc.add(self._build_include(relation)) - - if datatype['VectorMembers'] or datatype['OneToManyRelations']: - includes.add('#include ') - includes.add('#include "podio/RelationRange.h"') - - for relation in datatype['OneToManyRelations']: - if self._is_interface(relation.full_type): - relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"] - if self._needs_include(relation.full_type): - includes.add(self._build_include(relation)) - - for vectormember in datatype['VectorMembers']: - if vectormember.full_type in self.datamodel.components: - includes.add(self._build_include(vectormember)) - - includes.update(datatype.get('ExtraCode', {}).get('includes', '').split('\n')) - # TODO: in principle only the mutable classes would need these includes! # pylint: disable=fixme - includes.update(datatype.get('MutableExtraCode', {}).get('includes', '').split('\n')) - - # When we have a relation to the same type we have the header that we are - # just generating in the includes. This would lead to a circular include, so - # remove "ourselves" again from the necessary includes - try: - includes.remove(self._build_include_for_class(datatype['class'].bare_type, IncludeFrom.INTERNAL)) - except KeyError: - pass - - datatype['includes'] = self._sort_includes(includes) - datatype['includes_cc'] = self._sort_includes(includes_cc) - datatype['forward_declarations'] = fwd_declarations - - def _preprocess_for_obj(self, datatype): - """Do the preprocessing that is necessary for the Obj classes""" - fwd_declarations = defaultdict(list) - includes, includes_cc = set(), set() - for relation in datatype['OneToOneRelations']: - if relation.full_type != datatype['class'].full_type: - fwd_declarations[relation.namespace].append(relation.bare_type) - includes_cc.add(self._build_include(relation)) - - if datatype['VectorMembers'] or datatype['OneToManyRelations']: - includes.add('#include ') - - for relation in datatype['VectorMembers'] + datatype['OneToManyRelations']: - if not relation.is_builtin: - if relation.full_type == datatype['class'].full_type: - includes_cc.add(self._build_include(datatype['class'])) - else: - includes.add(self._build_include(relation)) - - datatype['forward_declarations_obj'] = fwd_declarations - datatype['includes_obj'] = self._sort_includes(includes) - datatype['includes_cc_obj'] = self._sort_includes(includes_cc) - non_trivial_type = datatype['VectorMembers'] or datatype['OneToManyRelations'] or datatype['OneToOneRelations'] - datatype['is_trivial_type'] = not non_trivial_type - - def _preprocess_for_collection(self, datatype): - """Do the necessary preprocessing for the collection""" - includes_cc, includes = set(), set() - - for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']: - if datatype['class'].bare_type != relation.bare_type: - include_from = self._needs_include(relation.full_type) - if self._is_interface(relation.full_type): - includes_cc.add(self._build_include_for_class(relation.bare_type, include_from)) - for int_type in relation.interface_types: - includes_cc.add(self._build_include_for_class(int_type.bare_type + 'Collection', include_from)) + """The c++ class / code generator for podio""" + + def __init__( + self, + yamlfile, + install_dir, + package_name, + io_handlers, + verbose, + dryrun, + upstream_edm, + old_description, + evolution_file, + ): + super().__init__(yamlfile, install_dir, package_name, verbose, dryrun, upstream_edm) + self.io_handlers = io_handlers + + # schema evolution specific code + self.old_yamlfile = old_description + self.evolution_file = evolution_file + self.old_schema_version = None + self.old_schema_version_int = None + self.old_datamodel = None + self.old_datamodels_components = set() + self.old_datamodels_datatypes = set() + self.root_schema_dict = {} # containing the root relevant schema evolution per datatype + # information to update the selection.xml + self.root_schema_component_names = set() + self.root_schema_datatype_names = set() + self.root_schema_iorules = set() + + def pre_process(self): + """The necessary specific pre-processing for cpp code generation""" + self._pre_process_schema_evolution() + return {} + + def post_process(self, _): + """Do the cpp specific post processing""" + self._write_edm_def_file() + if "ROOT" in self.io_handlers: + self._prepare_iorules() + self._create_selection_xml() + self._write_cmake_lists_file() + + def do_process_component(self, name, component): + """Handle everything cpp specific after the common processing of a component""" + includes = set() + includes.update(*(m.includes for m in component["Members"])) + for member in component["Members"]: + if not (member.is_builtin or member.is_builtin_array): + includes.add(self._build_include(member)) + + includes.update(component.get("ExtraCode", {}).get("includes", "").split("\n")) + + component["includes"] = self._sort_includes(includes) + + self._fill_templates("Component", component) + # Add potentially older schema for schema evolution + # based on ROOT capabilities for now + if name in self.root_schema_dict: + schema_evolutions = self.root_schema_dict[name] + component = deepcopy(component) + for schema_evolution in schema_evolutions: + if isinstance(schema_evolution, RenamedMember): + for member in component["Members"]: + if member.name == schema_evolution.member_name_new: + member.name = schema_evolution.member_name_old + component["class"] = DataType(name + self.old_schema_version) + else: + raise NotImplementedError + self._fill_templates("Component", component) + self.root_schema_component_names.add(name + self.old_schema_version) + + return component + + def do_process_datatype(self, name, datatype): + """Do the cpp specific processing of a datatype""" + datatype["includes_data"] = self._get_member_includes(datatype["Members"]) + self._preprocess_for_class(datatype) + self._preprocess_for_obj(datatype) + self._preprocess_for_collection(datatype) + + # ROOT schema evolution preparation + # Compute and prepare the potential schema evolution parts + schema_evolution_datatype = deepcopy(datatype) + needs_schema_evolution = False + for member in schema_evolution_datatype["Members"]: + if member.is_array: + if member.array_type in self.root_schema_dict: + needs_schema_evolution = True + replace_component_in_paths( + member.array_type, + member.array_type + self.old_schema_version, + schema_evolution_datatype["includes_data"], + ) + member.full_type = member.full_type.replace( + member.array_type, member.array_type + self.old_schema_version + ) + member.array_type = member.array_type + self.old_schema_version + + else: + if member.full_type in self.root_schema_dict: + needs_schema_evolution = True + # prepare the ROOT I/O rule + replace_component_in_paths( + member.full_type, + member.full_type + self.old_schema_version, + schema_evolution_datatype["includes_data"], + ) + member.full_type = member.full_type + self.old_schema_version + member.bare_type = member.bare_type + self.old_schema_version + + if needs_schema_evolution: + print(f" Preparing explicit schema evolution for {name}") + schema_evolution_datatype["class"].bare_type = ( + schema_evolution_datatype["class"].bare_type + self.old_schema_version + ) # noqa + schema_evolution_datatype["old_schema_version"] = self.old_schema_version_int + self._fill_templates("Data", schema_evolution_datatype) + self.root_schema_datatype_names.add(name + self.old_schema_version) + self._fill_templates("Collection", datatype, schema_evolution_datatype) else: - includes_cc.add(self._build_include_for_class(relation.bare_type + 'Collection', include_from)) - includes.add(self._build_include_for_class(relation.bare_type, include_from)) - - if datatype['VectorMembers']: - includes_cc.add('#include ') - - datatype['includes_coll_cc'] = self._sort_includes(includes_cc) - datatype['includes_coll_data'] = self._sort_includes(includes) - - # the ostream operator needs a bit of help from the python side in the form - # of some pre processing but also in the form of formatting, both are done - # here. - # TODO: also handle array members properly. These are currently simply # pylint: disable=fixme - # ignored - header_contents = [] - for member in datatype['Members']: - header = {'name': member.name} - if member.full_type in self.datamodel.components: - comps = [c.name for c in self.datamodel.components[member.full_type]['Members']] - header['components'] = comps - header_contents.append(header) - - def ostream_collection_header(member_header, col_width=12): - """Custom filter for the jinja2 templates to handle the ostream header that is - printed for the collections. Need this custom filter because it is easier - to implement the content dependent width in python than in jinja2. - """ - if not isinstance(member_header, Mapping): - # Assume that we have a string and format it according to the width - return f'{{:>{col_width}}}'.format(member_header) - - components = member_header.get('components', None) - name = member_header['name'] - if components is None: - return f'{{:>{col_width}}}'.format(name) - - n_comps = len(components) - comp_str = f'[ {", ".join(components)}]' - return f'{{:>{col_width * n_comps}}}'.format(name + ' ' + comp_str) - - datatype['ostream_collection_settings'] = { - 'header_contents': header_contents + self._fill_templates("Collection", datatype) + + self._fill_templates("Data", datatype) + self._fill_templates("Object", datatype) + self._fill_templates("MutableObject", datatype) + self._fill_templates("Obj", datatype) + self._fill_templates("Collection", datatype) + self._fill_templates("CollectionData", datatype) + + if "SIO" in self.io_handlers: + self._fill_templates("SIOBlock", datatype) + + return datatype + + def do_process_interface(self, _, interface): + """Process an interface definition and generate the necesary code""" + interface["include_types"] = [self._build_include(t) for t in interface["Types"]] + + self._fill_templates("Interface", interface) + return interface + + def print_report(self): + """Print a summary report about the generated code""" + if not self.verbose: + return + nclasses = 5 * len(self.datamodel.datatypes) + len(self.datamodel.components) + text = REPORT_TEXT.format( + yamlfile=self.yamlfile, nclasses=nclasses, installdir=self.install_dir + ) + + for summaryline in text.splitlines(): + print(summaryline) + print() + + def _preprocess_for_class(self, datatype): + """Do the preprocessing that is necessary for the classes and Mutable classes""" + includes = set(datatype["includes_data"]) + fwd_declarations = {} + includes_cc = set() + + for member in datatype["Members"]: + if self.expose_pod_members and not member.is_builtin and not member.is_array: + member.sub_members = self.datamodel.components[member.full_type]["Members"] + + for relation in datatype["OneToOneRelations"]: + if self._is_interface(relation.full_type): + relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"] + if self._needs_include(relation.full_type): + if relation.namespace not in fwd_declarations: + fwd_declarations[relation.namespace] = [] + fwd_declarations[relation.namespace].append(relation.bare_type) + fwd_declarations[relation.namespace].append("Mutable" + relation.bare_type) + includes_cc.add(self._build_include(relation)) + + if datatype["VectorMembers"] or datatype["OneToManyRelations"]: + includes.add("#include ") + includes.add('#include "podio/RelationRange.h"') + + for relation in datatype["OneToManyRelations"]: + if self._is_interface(relation.full_type): + relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"] + if self._needs_include(relation.full_type): + includes.add(self._build_include(relation)) + + for vectormember in datatype["VectorMembers"]: + if vectormember.full_type in self.datamodel.components: + includes.add(self._build_include(vectormember)) + + includes.update(datatype.get("ExtraCode", {}).get("includes", "").split("\n")) + # TODO: in principle only the mutable classes need these includes! # pylint: disable=fixme + includes.update(datatype.get("MutableExtraCode", {}).get("includes", "").split("\n")) + + # When we have a relation to the same type we have the header that we are + # just generating in the includes. This would lead to a circular include, so + # remove "ourselves" again from the necessary includes + try: + includes.remove( + self._build_include_for_class(datatype["class"].bare_type, IncludeFrom.INTERNAL) + ) + except KeyError: + pass + + datatype["includes"] = self._sort_includes(includes) + datatype["includes_cc"] = self._sort_includes(includes_cc) + datatype["forward_declarations"] = fwd_declarations + + def _preprocess_for_obj(self, datatype): + """Do the preprocessing that is necessary for the Obj classes""" + fwd_declarations = defaultdict(list) + includes, includes_cc = set(), set() + for relation in datatype["OneToOneRelations"]: + if relation.full_type != datatype["class"].full_type: + fwd_declarations[relation.namespace].append(relation.bare_type) + includes_cc.add(self._build_include(relation)) + + if datatype["VectorMembers"] or datatype["OneToManyRelations"]: + includes.add("#include ") + + for relation in datatype["VectorMembers"] + datatype["OneToManyRelations"]: + if not relation.is_builtin: + if relation.full_type == datatype["class"].full_type: + includes_cc.add(self._build_include(datatype["class"])) + else: + includes.add(self._build_include(relation)) + + datatype["forward_declarations_obj"] = fwd_declarations + datatype["includes_obj"] = self._sort_includes(includes) + datatype["includes_cc_obj"] = self._sort_includes(includes_cc) + non_trivial_type = ( + datatype["VectorMembers"] + or datatype["OneToManyRelations"] + or datatype["OneToOneRelations"] + ) + datatype["is_trivial_type"] = not non_trivial_type + + def _preprocess_for_collection(self, datatype): + """Do the necessary preprocessing for the collection""" + includes_cc, includes = set(), set() + + for relation in datatype["OneToManyRelations"] + datatype["OneToOneRelations"]: + if datatype["class"].bare_type != relation.bare_type: + include_from = self._needs_include(relation.full_type) + if self._is_interface(relation.full_type): + includes_cc.add( + self._build_include_for_class(relation.bare_type, include_from) + ) + for int_type in relation.interface_types: + includes_cc.add( + self._build_include_for_class( + int_type.bare_type + "Collection", include_from + ) + ) + else: + includes_cc.add( + self._build_include_for_class( + relation.bare_type + "Collection", include_from + ) + ) + includes.add(self._build_include_for_class(relation.bare_type, include_from)) + + if datatype["VectorMembers"]: + includes_cc.add("#include ") + + datatype["includes_coll_cc"] = self._sort_includes(includes_cc) + datatype["includes_coll_data"] = self._sort_includes(includes) + + # the ostream operator needs a bit of help from the python side in the form + # of some pre processing but also in the form of formatting, both are done + # here. + # TODO: handle array members properly. These are currently ignored # pylint: disable=fixme + header_contents = [] + for member in datatype["Members"]: + header = {"name": member.name} + if member.full_type in self.datamodel.components: + comps = [c.name for c in self.datamodel.components[member.full_type]["Members"]] + header["components"] = comps + header_contents.append(header) + + def ostream_collection_header(member_header, col_width=12): + """Custom filter for the jinja2 templates to handle the ostream header that is + printed for the collections. Need this custom filter because it is easier + to implement the content dependent width in python than in jinja2. + """ + if not isinstance(member_header, Mapping): + # Assume that we have a string and format it according to the width + return f"{{:>{col_width}}}".format(member_header) + + components = member_header.get("components", None) + name = member_header["name"] + if components is None: + return f"{{:>{col_width}}}".format(name) + + n_comps = len(components) + comp_str = f'[ {", ".join(components)}]' + return f"{{:>{col_width * n_comps}}}".format(name + " " + comp_str) + + datatype["ostream_collection_settings"] = {"header_contents": header_contents} + # Register the custom filter for it to become available in the templates + self.env.filters["ostream_collection_header"] = ostream_collection_header + + def _pre_process_schema_evolution(self): + """Process the schema evolution""" + # have to make all necessary comparisons + # which are the ones that changed? + # have to extend the selection xml file + if self.old_yamlfile: + comparator = DataModelComparator( + self.yamlfile, self.old_yamlfile, evolution_file=self.evolution_file + ) + comparator.read() + comparator.compare() + self.old_schema_version = f"v{comparator.datamodel_old.schema_version}" + self.old_schema_version_int = comparator.datamodel_old.schema_version + # some sanity checks + if len(comparator.errors) > 0: + print( + f"The given datamodels '{self.yamlfile}' and '{self.old_yamlfile}' \ +have unresolvable schema evolution incompatibilities:" + ) + for error in comparator.errors: + print(error) + sys.exit(-1) + if len(comparator.warnings) > 0: + print( + f"The given datamodels '{self.yamlfile}' and '{self.old_yamlfile}' \ +have resolvable schema evolution incompatibilities:" + ) + for warning in comparator.warnings: + print(warning) + sys.exit(-1) + + # now go through all the io_handlers and see what we have to do + if "ROOT" in self.io_handlers: + for item in root_filter(comparator.schema_changes): + # add whatever is relevant to our ROOT schema evolution + self.root_schema_dict.setdefault(item.klassname, []).append(item) + + def _prepare_iorules(self): + """Prepare the IORules to be put in the Reflex dictionary""" + for type_name, schema_changes in self.root_schema_dict.items(): + for schema_change in schema_changes: + if isinstance(schema_change, RenamedMember): + # find out the type of the renamed member + component = self.datamodel.components[type_name] + for member in component["Members"]: + if member.name == schema_change.member_name_new: + member_type = member.full_type + + iorule = RootIoRule() + iorule.sourceClass = type_name + iorule.targetClass = type_name + iorule.version = self.old_schema_version.lstrip("v") + iorule.source = f"{member_type} {schema_change.member_name_old}" + iorule.target = schema_change.member_name_new + iorule.code = f"{iorule.target} = onfile.{schema_change.member_name_old};" + self.root_schema_iorules.add(iorule) + else: + raise NotImplementedError( + f"Schema evolution for {schema_change} not yet implemented." + ) + + def _write_cmake_lists_file(self): + """Write the names of all generated header and src files into cmake lists""" + header_files = (f for f in self.generated_files if f.endswith(".h")) + src_files = (f for f in self.generated_files if f.endswith(".cc")) + xml_files = (f for f in self.generated_files if f.endswith(".xml")) + + def _write_list(name, target_folder, files, comment): + """Write all files into a cmake variable using the target_folder as path to the + file""" + list_cont = [] + + list_cont.append(f"# {comment}") + list_cont.append(f"SET({name}") + for full_file in files: + fname = os.path.basename(full_file) + list_cont.append(f" {os.path.join(target_folder, fname)}") + + list_cont.append(")") + + return "\n".join(list_cont) + + full_contents = ["#-- AUTOMATICALLY GENERATED FILE - DO NOT EDIT -- \n"] + full_contents.append( + _write_list( + "headers", + r"${ARG_OUTPUT_FOLDER}/${datamodel}", + header_files, + "Generated header files", + ) + ) + + full_contents.append( + _write_list( + "sources", + r"${ARG_OUTPUT_FOLDER}/src", + src_files, + "Generated source files", + ) + ) + + full_contents.append( + _write_list( + "selection_xml", + r"${ARG_OUTPUT_FOLDER}/src", + xml_files, + "Generated xml files", + ) + ) + + write_file_if_changed( + f"{self.install_dir}/podio_generated_files.cmake", + "\n".join(full_contents), + self.any_changes, + ) + + def _write_edm_def_file(self): + """Write the edm definition to a compile time string""" + model_encoder = DataModelJSONEncoder() + data = { + "package_name": self.package_name, + "edm_definition": model_encoder.encode(self.datamodel), + "incfolder": self.incfolder, + "schema_version": self.datamodel.schema_version, + "datatypes": self.datamodel.datatypes, } - # Register the custom filter for it to become available in the templates - self.env.filters['ostream_collection_header'] = ostream_collection_header - - def _pre_process_schema_evolution(self): - """Process the schema evolution""" - # have to make all necessary comparisons - # which are the ones that changed? - # have to extend the selection xml file - if self.old_yamlfile: - comparator = DataModelComparator(self.yamlfile, self.old_yamlfile, - evolution_file=self.evolution_file) - comparator.read() - comparator.compare() - self.old_schema_version = f"v{comparator.datamodel_old.schema_version}" - self.old_schema_version_int = comparator.datamodel_old.schema_version - # some sanity checks - if len(comparator.errors) > 0: - print(f"The given datamodels '{self.yamlfile}' and '{self.old_yamlfile}' \ -have unresolvable schema evolution incompatibilities:") - for error in comparator.errors: - print(error) - sys.exit(-1) - if len(comparator.warnings) > 0: - print(f"The given datamodels '{self.yamlfile}' and '{self.old_yamlfile}' \ -have resolvable schema evolution incompatibilities:") - for warning in comparator.warnings: - print(warning) - sys.exit(-1) - - # now go through all the io_handlers and see what we have to do - if 'ROOT' in self.io_handlers: - for item in root_filter(comparator.schema_changes): - # add whatever is relevant to our ROOT schema evolution - self.root_schema_dict.setdefault(item.klassname, []).append(item) - - def _prepare_iorules(self): - """Prepare the IORules to be put in the Reflex dictionary""" - for type_name, schema_changes in self.root_schema_dict.items(): - for schema_change in schema_changes: - if isinstance(schema_change, RenamedMember): - # find out the type of the renamed member - component = self.datamodel.components[type_name] - for member in component["Members"]: - if member.name == schema_change.member_name_new: - member_type = member.full_type - - iorule = RootIoRule() - iorule.sourceClass = type_name - iorule.targetClass = type_name - iorule.version = self.old_schema_version.lstrip("v") - iorule.source = f'{member_type} {schema_change.member_name_old}' - iorule.target = schema_change.member_name_new - iorule.code = f'{iorule.target} = onfile.{schema_change.member_name_old};' - self.root_schema_iorules.add(iorule) - else: - raise NotImplementedError(f"Schema evolution for {schema_change} not yet implemented.") - - def _write_cmake_lists_file(self): - """Write the names of all generated header and src files into cmake lists""" - header_files = (f for f in self.generated_files if f.endswith('.h')) - src_files = (f for f in self.generated_files if f.endswith('.cc')) - xml_files = (f for f in self.generated_files if f.endswith('.xml')) - - def _write_list(name, target_folder, files, comment): - """Write all files into a cmake variable using the target_folder as path to the - file""" - list_cont = [] - - list_cont.append(f'# {comment}') - list_cont.append(f'SET({name}') - for full_file in files: - fname = os.path.basename(full_file) - list_cont.append(f' {os.path.join(target_folder, fname)}') - - list_cont.append(')') - - return '\n'.join(list_cont) - - full_contents = ['#-- AUTOMATICALLY GENERATED FILE - DO NOT EDIT -- \n'] - full_contents.append(_write_list('headers', r'${ARG_OUTPUT_FOLDER}/${datamodel}', - header_files, 'Generated header files')) - - full_contents.append(_write_list('sources', r'${ARG_OUTPUT_FOLDER}/src', - src_files, 'Generated source files')) - - full_contents.append(_write_list('selection_xml', r'${ARG_OUTPUT_FOLDER}/src', - xml_files, 'Generated xml files')) - - write_file_if_changed(f'{self.install_dir}/podio_generated_files.cmake', - '\n'.join(full_contents), - self.any_changes) - - def _write_edm_def_file(self): - """Write the edm definition to a compile time string""" - model_encoder = DataModelJSONEncoder() - data = { - 'package_name': self.package_name, - 'edm_definition': model_encoder.encode(self.datamodel), - 'incfolder': self.incfolder, - 'schema_version': self.datamodel.schema_version, - 'datatypes': self.datamodel.datatypes, + + def quoted_sv(string): + return f'"{string}"sv' + + self.env.filters["quoted_sv"] = quoted_sv + + self._write_file( + "DatamodelDefinition.h", + self._eval_template("DatamodelDefinition.h.jinja2", data), + ) + + def _create_selection_xml(self): + """Create the selection xml that is necessary for ROOT I/O""" + data = { + "version": self.datamodel.schema_version, + "components": [DataType(c) for c in self.datamodel.components], + "datatypes": [DataType(d) for d in self.datamodel.datatypes], + "old_schema_components": [ + DataType(d) + for d in self.root_schema_datatype_names | self.root_schema_component_names + ], # noqa + "iorules": self.root_schema_iorules, } - def quoted_sv(string): - return f"\"{string}\"sv" - - self.env.filters["quoted_sv"] = quoted_sv - - self._write_file('DatamodelDefinition.h', - self._eval_template('DatamodelDefinition.h.jinja2', data)) - - def _create_selection_xml(self): - """Create the selection xml that is necessary for ROOT I/O""" - data = {'version': self.datamodel.schema_version, - 'components': [DataType(c) for c in self.datamodel.components], - 'datatypes': [DataType(d) for d in self.datamodel.datatypes], - 'old_schema_components': [DataType(d) for d in - self.root_schema_datatype_names | self.root_schema_component_names], # noqa - 'iorules': self.root_schema_iorules} - - self._write_file('selection.xml', self._eval_template('selection.xml.jinja2', data)) - - def _get_member_includes(self, members): - """Process all members and gather the necessary includes""" - includes = set() - includes.update(*(m.includes for m in members)) - for member in members: - if member.is_array and not member.is_builtin_array: - include_from = IncludeFrom.INTERNAL - if self.upstream_edm and member.array_type in self.upstream_edm.components: - include_from = IncludeFrom.EXTERNAL - includes.add(self._build_include_for_class(member.array_bare_type, include_from)) - - includes.add(self._build_include(member)) - - return self._sort_includes(includes) - - def _needs_include(self, classname) -> IncludeFrom: - """Check whether the member needs an include from within the datamodel""" - if classname in self.datamodel.components or \ - classname in self.datamodel.datatypes or \ - classname in self.datamodel.interfaces: - return IncludeFrom.INTERNAL - - if self.upstream_edm: - if classname in self.upstream_edm.components or \ - classname in self.upstream_edm.datatypes or \ - classname in self.upstream_edm.interfaces: - return IncludeFrom.EXTERNAL - - return IncludeFrom.NOWHERE - - def _build_include(self, member): - """Return the include statment for the passed member.""" - return self._build_include_for_class(member.bare_type, self._needs_include(member.full_type)) - - def _build_include_for_class(self, classname, include_from: IncludeFrom) -> str: - """Return the include statement for the passed classname""" - if include_from == IncludeFrom.INTERNAL: - return f'#include "{self.datamodel.options["includeSubfolder"]}{classname}.h"' - if include_from == IncludeFrom.EXTERNAL: - return f'#include "{self.upstream_edm.options["includeSubfolder"]}{classname}.h"' - - # The empty string is filtered by _sort_includes (plus it doesn't hurt in - # the generated code) - return '' - - def _sort_includes(self, includes): - """Sort the includes in order to try to have the std includes at the bottom""" - package_includes = sorted(i for i in includes if self.package_name in i) - podio_includes = sorted(i for i in includes if 'podio' in i) - stl_includes = sorted(i for i in includes if '<' in i and '>' in i) - - upstream_includes = [] - if self.upstream_edm: - upstream_includes = sorted(i for i in includes if self.upstream_edm.options['includeSubfolder'] in i) - - # Are ther includes that fulfill more than one of the above conditions? Are - # there includes that fulfill none? - - return package_includes + upstream_includes + podio_includes + stl_includes + self._write_file("selection.xml", self._eval_template("selection.xml.jinja2", data)) + + def _get_member_includes(self, members): + """Process all members and gather the necessary includes""" + includes = set() + includes.update(*(m.includes for m in members)) + for member in members: + if member.is_array and not member.is_builtin_array: + include_from = IncludeFrom.INTERNAL + if self.upstream_edm and member.array_type in self.upstream_edm.components: + include_from = IncludeFrom.EXTERNAL + includes.add(self._build_include_for_class(member.array_bare_type, include_from)) + + includes.add(self._build_include(member)) + + return self._sort_includes(includes) + + def _needs_include(self, classname) -> IncludeFrom: + """Check whether the member needs an include from within the datamodel""" + if ( + classname in self.datamodel.components + or classname in self.datamodel.datatypes + or classname in self.datamodel.interfaces + ): + return IncludeFrom.INTERNAL + + if self.upstream_edm: + if ( + classname in self.upstream_edm.components + or classname in self.upstream_edm.datatypes + or classname in self.upstream_edm.interfaces + ): + return IncludeFrom.EXTERNAL + + return IncludeFrom.NOWHERE + + def _build_include(self, member): + """Return the include statment for the passed member.""" + return self._build_include_for_class( + member.bare_type, self._needs_include(member.full_type) + ) + + def _build_include_for_class(self, classname, include_from: IncludeFrom) -> str: + """Return the include statement for the passed classname""" + if include_from == IncludeFrom.INTERNAL: + return f'#include "{self.datamodel.options["includeSubfolder"]}{classname}.h"' + if include_from == IncludeFrom.EXTERNAL: + return f'#include "{self.upstream_edm.options["includeSubfolder"]}{classname}.h"' + + # The empty string is filtered by _sort_includes (plus it doesn't hurt in + # the generated code) + return "" + + def _sort_includes(self, includes): + """Sort the includes in order to try to have the std includes at the bottom""" + package_includes = sorted(i for i in includes if self.package_name in i) + podio_includes = sorted(i for i in includes if "podio" in i) + stl_includes = sorted(i for i in includes if "<" in i and ">" in i) + + upstream_includes = [] + if self.upstream_edm: + upstream_includes = sorted( + i for i in includes if self.upstream_edm.options["includeSubfolder"] in i + ) + + # Are ther includes that fulfill more than one of the above conditions? Are + # there includes that fulfill none? + + return package_includes + upstream_includes + podio_includes + stl_includes diff --git a/python/podio_gen/generator_base.py b/python/podio_gen/generator_base.py index 4b798fd1a..b0101c80c 100644 --- a/python/podio_gen/generator_base.py +++ b/python/podio_gen/generator_base.py @@ -18,248 +18,255 @@ def write_file_if_changed(filename, content, force_write=False): - """Write the file contents only if it has changed or if the file does not exist - yet. Return whether the file has been written or not""" - try: - with open(filename, 'r', encoding='utf-8') as infile: - existing_content = infile.read() - changed = existing_content != content - except FileNotFoundError: - changed = True + """Write the file contents only if it has changed or if the file does not exist + yet. Return whether the file has been written or not""" + try: + with open(filename, "r", encoding="utf-8") as infile: + existing_content = infile.read() + changed = existing_content != content + except FileNotFoundError: + changed = True - if changed or force_write: - with open(filename, 'w', encoding='utf-8') as outfile: - outfile.write(content) - return True + if changed or force_write: + with open(filename, "w", encoding="utf-8") as outfile: + outfile.write(content) + return True - return False + return False class ClassGeneratorBaseMixin: - """Base class for code generation providing common functionality and - orchestration - - The base class takes care of initializing the common state that is necessary - for code generation for the different languages. It reads and valiadates the - datamodel and sets up the jinja2 environment. Furthermore it provides the - functionality for filling templates and it also does the loop over all the - components and datatypes in the datamodel offering hooks (see below) to - augment the common processing with language specifics. - - The following members are initialized and accessible from inheriting classes - - - yamlfile (the path to the yamlfile) - - install_dir (top level directory into which the code should be generated) - - package_name (the name of the package) - - verbose (whether to print some information about the code gen process) - - dryrun (whether to actually generate the datamodel or to only run the - processing without filling the contents) - - upstream_edm (an optional upstream datamodel) - - datamodel (the current datamodel read from the yamlfile) - - get_syntax (whether to use get syntax or not) - - incfolder (whether to create an includeSubfolder or not) - - expose_pod_members (whether or not to expose the pod members) - - formatter_func (an optional formatting function that is called after the - jinja template evaluation but before writing the contents to disk) - - generated_files (a list of files that have been generated) - - any_changes (a boolean indicating whether the current run of the code - generation led to any changes in the generated code wrt the one that is - already present in the output directory) - - Inheriting classes need to implement the following (potentially empty) methods: - - pre_process() -> dict: does some global pre-processing for the datamodel - before any of the components or datatypes are - processed. Needs to return a (potentially) empty - dictionary - - do_process_component(name: str, component: dict) -> dict: do some language - specific processing for a component populating the - component dictionary further. When called only the - "class" key will be populated. Return a dictionary or - None. If None, this will not be put into the "components" - list. This function also has to to take care of filling - the necessary templates! - - do_process_datatype(name: str, datatype: dict): do some language specific - processing for a datatype populating the datatype - dictionary further. When called only the "class" key will - be populated. Return a dictionary or None. If None, this - will not be put into the "datatypes" list. This function - also has to take care of filling the necessary templates! - - do_process_interface(name: str, interface: dict): do some language specific - processing for an interface type, populating the - interface dictionary further. When called only the - "class" key will be populated. Return a dictionary or - None. If None, this will not be put into the "interfaces" - list. This function also has to take care of filling the - necessary templates! - - post_process(datamodel: dict): do some global post processing for which all - components and datatypes need to have been processed already. - Gets called with the dictionary that has been created in - pre_proces and filled during the processing. The process - components and datatypes are accessible via the "components", - "datatypes" and "interfaces" keys respectively. - - print_report(): prints a report summarizing what has been generated - - """ - def __init__(self, yamlfile, install_dir, package_name, verbose, dryrun, upstream_edm): - self.yamlfile = yamlfile - self.install_dir = install_dir - self.package_name = package_name - self.verbose = verbose - self.dryrun = dryrun - self.upstream_edm = upstream_edm - - try: - self.datamodel = PodioConfigReader.read(yamlfile, package_name, upstream_edm) - except DefinitionError as err: - print(f"Error while generating the datamodel: {err}") - sys.exit(1) - - self.env = jinja2.Environment(loader=jinja2.FileSystemLoader(TEMPLATE_DIR), - keep_trailing_newline=True, - lstrip_blocks=True, - trim_blocks=True) - - self.get_syntax = self.datamodel.options["getSyntax"] - self.incfolder = self.datamodel.options['includeSubfolder'] - self.expose_pod_members = self.datamodel.options["exposePODMembers"] - self.upstream_edm = upstream_edm - - self.formatter_func = None - self.generated_files = [] - self.any_changes = False - - def process(self): - """Run the actual generation""" - datamodel = self.pre_process() - - datamodel['components'] = [] - datamodel['datatypes'] = [] - datamodel['interfaces'] = [] - - for name, component in self.datamodel.components.items(): - comp = self._process_component(name, component) - if comp is not None: - datamodel["components"].append(comp) - - for name, datatype in self.datamodel.datatypes.items(): - datat = self._process_datatype(name, datatype) - if datat is not None: - datamodel["datatypes"].append(datat) - - for name, interface in self.datamodel.interfaces.items(): - interf = self._process_interface(name, interface) - if interf is not None: - datamodel["interfaces"].append(interf) - - self.post_process(datamodel) - if self.verbose: - self.print_report() - - def _process_component(self, name, component): - """Process a single component into a dictionary that can be used in jinja2 - templates and return that""" - # Make a copy here and add the preprocessing steps to that such that the - # original definition can be left untouched - component = deepcopy(component) - component['class'] = DataType(name) - - return self.do_process_component(name, component) - - def _process_datatype(self, name, datatype): - """Process a single datatype into a dictionary that can be used in jinja2 - templates and return that""" - datatype = deepcopy(datatype) - datatype["class"] = DataType(name) - - return self.do_process_datatype(name, datatype) - - def _process_interface(self, name, interface): - """Process a single interface definition into a dictionary that can be used - in jinja2 templates and return that""" - interface = deepcopy(interface) - interface["class"] = DataType(name) - - return self.do_process_interface(name, interface) - - @staticmethod - def _get_filenames_templates(template_base, name): - """Get the list of output filenames and corresponding template names""" - # depending on which category is passed different naming conventions apply - # for the generated files. Additionally not all categories need source files. - # Listing the special cases here - def get_fn_format(tmpl): - """Get a format string for the filename""" - prefix = {'MutableObject': 'Mutable'} - postfix = {'Data': 'Data', - 'Obj': 'Obj', - 'SIOBlock': 'SIOBlock', - 'Collection': 'Collection', - 'CollectionData': 'CollectionData', - 'MutableStruct': 'Struct' - } - - return f'{prefix.get(tmpl, "")}{{name}}{postfix.get(tmpl, "")}.{{end}}' - - endings = { - 'Data': ('h',), - 'PrintInfo': ('h',), - 'Interface': ('h',), - 'MutableStruct': ('jl',), - 'ParentModule': ('jl',), - }.get(template_base, ('h', 'cc')) - - fn_templates = [] - for ending in endings: - template_name = f'{template_base}.{ending}.jinja2' - filename = get_fn_format(template_base).format(name=name, end=ending) - fn_templates.append((filename, template_name)) - - return fn_templates - - def _eval_template(self, template, data, old_schema_data=None): - """Fill the specified template""" - # merge the info of data and the old schema into a single dict - if old_schema_data: - data['OneToOneRelations_old'] = old_schema_data['OneToOneRelations'] - data['OneToManyRelations_old'] = old_schema_data['OneToManyRelations'] - data['VectorMembers_old'] = old_schema_data['VectorMembers'] - - return self.env.get_template(template).render(data) - - def _write_file(self, name, content): - """Write the content to file. Dispatch to the correct directory depending on - whether it is a header or a .cc file.""" - if name.endswith("h") or name.endswith("jl"): - fullname = os.path.join(self.install_dir, self.package_name, name) - else: - fullname = os.path.join(self.install_dir, "src", name) - if not self.dryrun: - self.generated_files.append(fullname) - if self.formatter_func is not None: - content = self.formatter_func(content, fullname) - - changed = write_file_if_changed(fullname, content) - self.any_changes = changed or self.any_changes - - def _fill_templates(self, template_base, data, old_schema_data=None): - """Fill the template and write the results to file""" - # Update the passed data with some global things that are the same for all - # files - data['package_name'] = self.package_name - data['use_get_syntax'] = self.get_syntax - data['incfolder'] = self.incfolder - for filename, template in self._get_filenames_templates(template_base, data['class'].bare_type): - self._write_file(filename, self._eval_template(template, data, old_schema_data)) - - def _is_interface(self, classname): - """Check whether this is an interface type or a regular datatype""" - all_interfaces = self.datamodel.interfaces - if self.upstream_edm: - all_interfaces = list(self.datamodel.interfaces) + list(self.upstream_edm.interfaces) - return classname in all_interfaces + """Base class for code generation providing common functionality and + orchestration + + The base class takes care of initializing the common state that is necessary + for code generation for the different languages. It reads and valiadates the + datamodel and sets up the jinja2 environment. Furthermore it provides the + functionality for filling templates and it also does the loop over all the + components and datatypes in the datamodel offering hooks (see below) to + augment the common processing with language specifics. + + The following members are initialized and accessible from inheriting classes + + - yamlfile (the path to the yamlfile) + - install_dir (top level directory into which the code should be generated) + - package_name (the name of the package) + - verbose (whether to print some information about the code gen process) + - dryrun (whether to actually generate the datamodel or to only run the + processing without filling the contents) + - upstream_edm (an optional upstream datamodel) + - datamodel (the current datamodel read from the yamlfile) + - get_syntax (whether to use get syntax or not) + - incfolder (whether to create an includeSubfolder or not) + - expose_pod_members (whether or not to expose the pod members) + - formatter_func (an optional formatting function that is called after the + jinja template evaluation but before writing the contents to disk) + - generated_files (a list of files that have been generated) + - any_changes (a boolean indicating whether the current run of the code + generation led to any changes in the generated code wrt the one that is + already present in the output directory) + + Inheriting classes need to implement the following (potentially empty) methods: + + pre_process() -> dict: does some global pre-processing for the datamodel + before any of the components or datatypes are + processed. Needs to return a (potentially) empty + dictionary + + do_process_component(name: str, component: dict) -> dict: do some language + specific processing for a component populating the + component dictionary further. When called only the + "class" key will be populated. Return a dictionary or + None. If None, this will not be put into the "components" + list. This function also has to to take care of filling + the necessary templates! + + do_process_datatype(name: str, datatype: dict): do some language specific + processing for a datatype populating the datatype + dictionary further. When called only the "class" key will + be populated. Return a dictionary or None. If None, this + will not be put into the "datatypes" list. This function + also has to take care of filling the necessary templates! + + do_process_interface(name: str, interface: dict): do some language specific + processing for an interface type, populating the + interface dictionary further. When called only the + "class" key will be populated. Return a dictionary or + None. If None, this will not be put into the "interfaces" + list. This function also has to take care of filling the + necessary templates! + + post_process(datamodel: dict): do some global post processing for which all + components and datatypes need to have been processed already. + Gets called with the dictionary that has been created in + pre_proces and filled during the processing. The process + components and datatypes are accessible via the "components", + "datatypes" and "interfaces" keys respectively. + + print_report(): prints a report summarizing what has been generated + + """ + + def __init__(self, yamlfile, install_dir, package_name, verbose, dryrun, upstream_edm): + self.yamlfile = yamlfile + self.install_dir = install_dir + self.package_name = package_name + self.verbose = verbose + self.dryrun = dryrun + self.upstream_edm = upstream_edm + + try: + self.datamodel = PodioConfigReader.read(yamlfile, package_name, upstream_edm) + except DefinitionError as err: + print(f"Error while generating the datamodel: {err}") + sys.exit(1) + + self.env = jinja2.Environment( + loader=jinja2.FileSystemLoader(TEMPLATE_DIR), + keep_trailing_newline=True, + lstrip_blocks=True, + trim_blocks=True, + ) + + self.get_syntax = self.datamodel.options["getSyntax"] + self.incfolder = self.datamodel.options["includeSubfolder"] + self.expose_pod_members = self.datamodel.options["exposePODMembers"] + self.upstream_edm = upstream_edm + + self.formatter_func = None + self.generated_files = [] + self.any_changes = False + + def process(self): + """Run the actual generation""" + datamodel = self.pre_process() + + datamodel["components"] = [] + datamodel["datatypes"] = [] + datamodel["interfaces"] = [] + + for name, component in self.datamodel.components.items(): + comp = self._process_component(name, component) + if comp is not None: + datamodel["components"].append(comp) + + for name, datatype in self.datamodel.datatypes.items(): + datat = self._process_datatype(name, datatype) + if datat is not None: + datamodel["datatypes"].append(datat) + + for name, interface in self.datamodel.interfaces.items(): + interf = self._process_interface(name, interface) + if interf is not None: + datamodel["interfaces"].append(interf) + + self.post_process(datamodel) + if self.verbose: + self.print_report() + + def _process_component(self, name, component): + """Process a single component into a dictionary that can be used in jinja2 + templates and return that""" + # Make a copy here and add the preprocessing steps to that such that the + # original definition can be left untouched + component = deepcopy(component) + component["class"] = DataType(name) + + return self.do_process_component(name, component) + + def _process_datatype(self, name, datatype): + """Process a single datatype into a dictionary that can be used in jinja2 + templates and return that""" + datatype = deepcopy(datatype) + datatype["class"] = DataType(name) + + return self.do_process_datatype(name, datatype) + + def _process_interface(self, name, interface): + """Process a single interface definition into a dictionary that can be used + in jinja2 templates and return that""" + interface = deepcopy(interface) + interface["class"] = DataType(name) + + return self.do_process_interface(name, interface) + + @staticmethod + def _get_filenames_templates(template_base, name): + """Get the list of output filenames and corresponding template names""" + + # depending on which category is passed different naming conventions apply + # for the generated files. Additionally not all categories need source files. + # Listing the special cases here + def get_fn_format(tmpl): + """Get a format string for the filename""" + prefix = {"MutableObject": "Mutable"} + postfix = { + "Data": "Data", + "Obj": "Obj", + "SIOBlock": "SIOBlock", + "Collection": "Collection", + "CollectionData": "CollectionData", + "MutableStruct": "Struct", + } + + return f'{prefix.get(tmpl, "")}{{name}}{postfix.get(tmpl, "")}.{{end}}' + + endings = { + "Data": ("h",), + "PrintInfo": ("h",), + "Interface": ("h",), + "MutableStruct": ("jl",), + "ParentModule": ("jl",), + }.get(template_base, ("h", "cc")) + + fn_templates = [] + for ending in endings: + template_name = f"{template_base}.{ending}.jinja2" + filename = get_fn_format(template_base).format(name=name, end=ending) + fn_templates.append((filename, template_name)) + + return fn_templates + + def _eval_template(self, template, data, old_schema_data=None): + """Fill the specified template""" + # merge the info of data and the old schema into a single dict + if old_schema_data: + data["OneToOneRelations_old"] = old_schema_data["OneToOneRelations"] + data["OneToManyRelations_old"] = old_schema_data["OneToManyRelations"] + data["VectorMembers_old"] = old_schema_data["VectorMembers"] + + return self.env.get_template(template).render(data) + + def _write_file(self, name, content): + """Write the content to file. Dispatch to the correct directory depending on + whether it is a header or a .cc file.""" + if name.endswith("h") or name.endswith("jl"): + fullname = os.path.join(self.install_dir, self.package_name, name) + else: + fullname = os.path.join(self.install_dir, "src", name) + if not self.dryrun: + self.generated_files.append(fullname) + if self.formatter_func is not None: + content = self.formatter_func(content, fullname) + + changed = write_file_if_changed(fullname, content) + self.any_changes = changed or self.any_changes + + def _fill_templates(self, template_base, data, old_schema_data=None): + """Fill the template and write the results to file""" + # Update the passed data with some global things that are the same for all + # files + data["package_name"] = self.package_name + data["use_get_syntax"] = self.get_syntax + data["incfolder"] = self.incfolder + for filename, template in self._get_filenames_templates( + template_base, data["class"].bare_type + ): + self._write_file(filename, self._eval_template(template, data, old_schema_data)) + + def _is_interface(self, classname): + """Check whether this is an interface type or a regular datatype""" + all_interfaces = self.datamodel.interfaces + if self.upstream_edm: + all_interfaces = list(self.datamodel.interfaces) + list(self.upstream_edm.interfaces) + return classname in all_interfaces diff --git a/python/podio_gen/generator_utils.py b/python/podio_gen/generator_utils.py index a2d8b680c..d73afb810 100644 --- a/python/podio_gen/generator_utils.py +++ b/python/podio_gen/generator_utils.py @@ -8,288 +8,330 @@ def _get_namespace_class(full_type): - """Get the namespace and the unqualified classname from the full type. Raise a - DefinitionError if a nested namespace is found""" - cnameparts = full_type.split('::') - if len(cnameparts) > 2: - raise DefinitionError(f"'{full_type}' is a type with a nested namespace. not supported, yet.") - if len(cnameparts) == 2: - # If in std namespace, consider that to be part of the type instead and only - # split namespace if that is not the case - if cnameparts[0] != 'std': - return cnameparts - - return "", full_type + """Get the namespace and the unqualified classname from the full type. Raise a + DefinitionError if a nested namespace is found""" + cnameparts = full_type.split("::") + if len(cnameparts) > 2: + raise DefinitionError( + f"'{full_type}' is a type with a nested namespace. not supported, yet." + ) + if len(cnameparts) == 2: + # If in std namespace, consider that to be part of the type instead and only + # split namespace if that is not the case + if cnameparts[0] != "std": + return cnameparts + + return "", full_type def _prefix_name(name, prefix): - """Prefix the name and capitalize the first letter if the prefix is not empty""" - if prefix: - return prefix + name[0].upper() + name[1:] - return name + """Prefix the name and capitalize the first letter if the prefix is not empty""" + if prefix: + return prefix + name[0].upper() + name[1:] + return name def get_julia_type(cpp_type, is_array=False, array_type=None, array_size=None): - """Parse the given c++ type to a Julia type""" - builtin_types_map = {"int": "Int32", "float": "Float32", "double": "Float64", - "bool": "Bool", "long": "Int64", "unsigned int": "UInt32", - "unsigned long": "UInt64", "char": "Char", "short": "Int16", - "long long": "Int64", "unsigned long long": "UInt64"} - # check for cpp_type=None as cpp_type can be None in case of array members - if cpp_type and cpp_type.startswith("::"): - cpp_type = cpp_type[2:] - if cpp_type in builtin_types_map: - return builtin_types_map[cpp_type] - - if not is_array: - if cpp_type.startswith('std::'): - cpp_type = cpp_type[5:] - if cpp_type in ALLOWED_FIXED_WIDTH_TYPES: - regex_string = re.split("(u|)int(8|16|32|64)_t", cpp_type) - cpp_type = regex_string[1].upper() + "Int" + regex_string[2] - return cpp_type - - else: - array_type = get_julia_type(array_type) - if '::' in array_type: - array_type = array_type.split('::')[1] - if array_type not in builtin_types_map.values(): - array_type = array_type + 'Struct' - return f"MVector{{{array_size}, {array_type}}}" - - return cpp_type + """Parse the given c++ type to a Julia type""" + builtin_types_map = { + "int": "Int32", + "float": "Float32", + "double": "Float64", + "bool": "Bool", + "long": "Int64", + "unsigned int": "UInt32", + "unsigned long": "UInt64", + "char": "Char", + "short": "Int16", + "long long": "Int64", + "unsigned long long": "UInt64", + } + # check for cpp_type=None as cpp_type can be None in case of array members + if cpp_type and cpp_type.startswith("::"): + cpp_type = cpp_type[2:] + if cpp_type in builtin_types_map: + return builtin_types_map[cpp_type] + + if not is_array: + if cpp_type.startswith("std::"): + cpp_type = cpp_type[5:] + if cpp_type in ALLOWED_FIXED_WIDTH_TYPES: + regex_string = re.split("(u|)int(8|16|32|64)_t", cpp_type) + cpp_type = regex_string[1].upper() + "Int" + regex_string[2] + return cpp_type + + else: + array_type = get_julia_type(array_type) + if "::" in array_type: + array_type = array_type.split("::")[1] + if array_type not in builtin_types_map.values(): + array_type = array_type + "Struct" + return f"MVector{{{array_size}, {array_type}}}" + + return cpp_type class DefinitionError(Exception): - """Exception raised by the ClassDefinitionValidator for invalid definitions. - Mainly here to distinguish it from plain exceptions that are otherwise raised. - In this way this makes it possible to selectively catch exceptions related to - the datamodel definition without also catching all the rest which might point - to another problem - """ + """Exception raised by the ClassDefinitionValidator for invalid definitions. + Mainly here to distinguish it from plain exceptions that are otherwise raised. + In this way this makes it possible to selectively catch exceptions related to + the datamodel definition without also catching all the rest which might point + to another problem + """ # Types considered to be builtin -BUILTIN_TYPES = ["int", "long", "float", "double", - "unsigned int", "unsigned", "unsigned long", - "char", "short", "bool", "long long", - "unsigned long long"] +BUILTIN_TYPES = [ + "int", + "long", + "float", + "double", + "unsigned int", + "unsigned", + "unsigned long", + "char", + "short", + "bool", + "long long", + "unsigned long long", +] # Fixed width types defined in . Omitting int8_t and uint8_t since they # are often only aliases for signed char and unsigned char, which tends to break # expectations towards the behavior of integer types. Also omitting the _fastN_t # and leastN_t since those are probably already covered by the usual integer # types. -ALLOWED_FIXED_WIDTH_TYPES = ["int16_t", "int32_t", "int64_t", - "uint16_t", "uint32_t", "uint64_t"] +ALLOWED_FIXED_WIDTH_TYPES = [ + "int16_t", + "int32_t", + "int64_t", + "uint16_t", + "uint32_t", + "uint64_t", +] # All fixed width integer types that may be defined in -ALL_FIXED_WIDTH_TYPES_RGX = re.compile(r'u?int(_(fast|least))?(8|16|32|64)_t') +ALL_FIXED_WIDTH_TYPES_RGX = re.compile(r"u?int(_(fast|least))?(8|16|32|64)_t") def _is_fixed_width_type(type_name): - """Check if the passed type is an fixed width type and that it is allowed""" - # Remove the potentially present std:: namespace - if type_name.startswith('std::'): - type_name = type_name[5:] + """Check if the passed type is an fixed width type and that it is allowed""" + # Remove the potentially present std:: namespace + if type_name.startswith("std::"): + type_name = type_name[5:] - if ALL_FIXED_WIDTH_TYPES_RGX.match(type_name): - if type_name not in ALLOWED_FIXED_WIDTH_TYPES: - raise DefinitionError(f"{type_name} is a fixed width integer type that is not allowed") - return True + if ALL_FIXED_WIDTH_TYPES_RGX.match(type_name): + if type_name not in ALLOWED_FIXED_WIDTH_TYPES: + raise DefinitionError(f"{type_name} is a fixed width integer type that is not allowed") + return True - return False + return False class DataType: - """Simple class to hold information about a datatype or component that is - defined in the datamodel.""" - def __init__(self, klass): - self.full_type = klass - self.namespace, self.bare_type = _get_namespace_class(self.full_type) - - def __str__(self): - if self.namespace: - scoped_type = f'::{self.namespace}::{self.bare_type}' - else: - scoped_type = self.full_type - - return scoped_type - - def _to_json(self): - """Return a string representation that can be parsed again""" - return self.full_type - - -class MemberVariable: - """Simple class to hold information about a member variable""" - def __init__(self, name, **kwargs): - self.name = name - self.full_type = kwargs.pop('type', '') - self.description = kwargs.pop('description', '') - self.default_val = kwargs.pop('default_val', None) - self.unit = kwargs.pop('unit', None) - self.is_builtin = False - self.is_builtin_array = False - self.is_array = False - # ensure that this will break somewhere if requested but not set - self.namespace, self.bare_type = None, None - self.julia_type = None - self.array_namespace, self.array_bare_type = None, None - - self.array_type = kwargs.pop('array_type', None) - self.array_size = kwargs.pop('array_size', None) - - self.includes = set() - self.jl_imports = set() - self.interface_types = [] # populated in the generator script if necessary - - if kwargs: - raise ValueError(f"Unused kwargs in MemberVariable: {list(kwargs.keys())}") - - if self.array_type is not None and self.array_size is not None: - self.is_array = True - self.is_builtin_array = self.array_type in BUILTIN_TYPES - # We also have to check if this is a fixed width integer array - if not self.is_builtin_array: - if _is_fixed_width_type(self.array_type): - self.is_builtin_array = True - self.array_type = self.normalize_fw_type(self.array_type) - - self.full_type = rf'std::array<{self.array_type}, {self.array_size}>' - self.includes.add('#include ') - self.jl_imports.add('using StaticArrays') - - is_fw_type = _is_fixed_width_type(self.full_type) - self.is_builtin = self.full_type in BUILTIN_TYPES or is_fw_type - - if is_fw_type: - self.full_type = self.normalize_fw_type(self.full_type) - - # Needed in case the PODs are exposed - self.sub_members = None - - if self.is_array: - self.array_namespace, self.array_bare_type = _get_namespace_class(self.array_type) - else: - self.namespace, self.bare_type = _get_namespace_class(self.full_type) - - self.julia_type = get_julia_type(self.bare_type, is_array=self.is_array, - array_type=self.array_type, array_size=self.array_size) - - @property - def signature(self): - """Get the signature for this member variable to be used in function definitions""" - return f"{self.full_type} {self.name}" - - @property - def docstring(self): - """Docstring to be used in code generation""" - if self.unit is not None: - docstring = rf'{self.description} [{self.unit}]' - else: - docstring = self.description - return docstring - - def __str__(self): - """string representation""" - # Make sure to include scope-operator if necessary - if self.namespace: - scoped_type = f'::{self.namespace}::{self.bare_type}' - else: - scoped_type = self.full_type + """Simple class to hold information about a datatype or component that is + defined in the datamodel.""" - if self.default_val: - definition = rf'{scoped_type} {self.name}{{{self.default_val}}};' - else: - definition = rf'{scoped_type} {self.name}{{}};' + def __init__(self, klass): + self.full_type = klass + self.namespace, self.bare_type = _get_namespace_class(self.full_type) - if self.docstring: - definition += rf' ///< {self.docstring}' - return definition + def __str__(self): + if self.namespace: + scoped_type = f"::{self.namespace}::{self.bare_type}" + else: + scoped_type = self.full_type - def getter_name(self, get_syntax): - """Get the getter name of the variable""" - if not get_syntax: - return self.name - return _prefix_name(self.name, "get") + return scoped_type - def getter_return_type(self, for_array=False): - """Get the return type for a getter function for a variable + def _to_json(self): + """Return a string representation that can be parsed again""" + return self.full_type - All types that are builtin will be returned by value, the rest will be - returned as const& - Args: - for_array (bool, optional): Whether the type should be for an indexed - array access - """ - if for_array: - if self.is_builtin_array: - return self.array_type - return f"const {self.array_type}&" - if self.is_builtin: - return self.full_type - # everything else will just be by const reference - return f"const {self.full_type}&" - - def setter_name(self, get_syntax, is_relation=False): - """Get the setter name of the variable""" - if is_relation: - if not get_syntax: - return 'add' + self.name - return _prefix_name(self.name, 'addTo') - - if not get_syntax: - return self.name - return _prefix_name(self.name, 'set') - - def normalize_fw_type(self, fw_type): - """Normalize the fixed width type and make sure to inclde """ - self.includes.add("#include ") - if not fw_type.startswith("std::"): - return f"std::{fw_type}" - return fw_type - - def _to_json(self): - """Return a string representation that can be parsed again.""" - # The __str__ method is geared towards c++ too much, so we have to build - # things again here from available information - def_val = f'{{{self.default_val}}}' if self.default_val else '' - description = f' // {self.description}' if self.description else '' - unit = f'[{self.unit}]' if self.unit else '' - return f'{self.full_type} {self.name}{def_val}{unit}{description}' +class MemberVariable: + """Simple class to hold information about a member variable""" + + def __init__(self, name, **kwargs): + self.name = name + self.full_type = kwargs.pop("type", "") + self.description = kwargs.pop("description", "") + self.default_val = kwargs.pop("default_val", None) + self.unit = kwargs.pop("unit", None) + self.is_builtin = False + self.is_builtin_array = False + self.is_array = False + # ensure that this will break somewhere if requested but not set + self.namespace, self.bare_type = None, None + self.julia_type = None + self.array_namespace, self.array_bare_type = None, None + + self.array_type = kwargs.pop("array_type", None) + self.array_size = kwargs.pop("array_size", None) + + self.includes = set() + self.jl_imports = set() + self.interface_types = [] # populated in the generator script if necessary + + if kwargs: + raise ValueError(f"Unused kwargs in MemberVariable: {list(kwargs.keys())}") + + if self.array_type is not None and self.array_size is not None: + self.is_array = True + self.is_builtin_array = self.array_type in BUILTIN_TYPES + # We also have to check if this is a fixed width integer array + if not self.is_builtin_array: + if _is_fixed_width_type(self.array_type): + self.is_builtin_array = True + self.array_type = self.normalize_fw_type(self.array_type) + + self.full_type = rf"std::array<{self.array_type}, {self.array_size}>" + self.includes.add("#include ") + self.jl_imports.add("using StaticArrays") + + is_fw_type = _is_fixed_width_type(self.full_type) + self.is_builtin = self.full_type in BUILTIN_TYPES or is_fw_type + + if is_fw_type: + self.full_type = self.normalize_fw_type(self.full_type) + + # Needed in case the PODs are exposed + self.sub_members = None + + if self.is_array: + self.array_namespace, self.array_bare_type = _get_namespace_class(self.array_type) + else: + self.namespace, self.bare_type = _get_namespace_class(self.full_type) + + self.julia_type = get_julia_type( + self.bare_type, + is_array=self.is_array, + array_type=self.array_type, + array_size=self.array_size, + ) + + @property + def signature(self): + """Get the signature for this member variable to be used in function definitions""" + return f"{self.full_type} {self.name}" + + @property + def docstring(self): + """Docstring to be used in code generation""" + if self.unit is not None: + docstring = rf"{self.description} [{self.unit}]" + else: + docstring = self.description + return docstring + + def __str__(self): + """string representation""" + # Make sure to include scope-operator if necessary + if self.namespace: + scoped_type = f"::{self.namespace}::{self.bare_type}" + else: + scoped_type = self.full_type + + if self.default_val: + definition = rf"{scoped_type} {self.name}{{{self.default_val}}};" + else: + definition = rf"{scoped_type} {self.name}{{}};" + + if self.docstring: + definition += rf" ///< {self.docstring}" + return definition + + def getter_name(self, get_syntax): + """Get the getter name of the variable""" + if not get_syntax: + return self.name + return _prefix_name(self.name, "get") + + def getter_return_type(self, for_array=False): + """Get the return type for a getter function for a variable + + All types that are builtin will be returned by value, the rest will be + returned as const& + + Args: + for_array (bool, optional): Whether the type should be for an indexed + array access + """ + if for_array: + if self.is_builtin_array: + return self.array_type + return f"const {self.array_type}&" + if self.is_builtin: + return self.full_type + # everything else will just be by const reference + return f"const {self.full_type}&" + + def setter_name(self, get_syntax, is_relation=False): + """Get the setter name of the variable""" + if is_relation: + if not get_syntax: + return "add" + self.name + return _prefix_name(self.name, "addTo") + + if not get_syntax: + return self.name + return _prefix_name(self.name, "set") + + def normalize_fw_type(self, fw_type): + """Normalize the fixed width type and make sure to inclde """ + self.includes.add("#include ") + if not fw_type.startswith("std::"): + return f"std::{fw_type}" + return fw_type + + def _to_json(self): + """Return a string representation that can be parsed again.""" + # The __str__ method is geared towards c++ too much, so we have to build + # things again here from available information + def_val = f"{{{self.default_val}}}" if self.default_val else "" + description = f" // {self.description}" if self.description else "" + unit = f"[{self.unit}]" if self.unit else "" + return f"{self.full_type} {self.name}{def_val}{unit}{description}" class DataModel: # pylint: disable=too-few-public-methods - """A class for holding a complete datamodel read from a configuration file""" - def __init__(self, datatypes=None, components=None, interfaces=None, options=None, schema_version=None): - self.options = options or { - # should getters / setters be prefixed with get / set? - "getSyntax": False, - # should POD members be exposed with getters/setters in classes that have them as members? - "exposePODMembers": True, - # use subfolder when including package header files - "includeSubfolder": False, + """A class for holding a complete datamodel read from a configuration file""" + + def __init__( + self, + datatypes=None, + components=None, + interfaces=None, + options=None, + schema_version=None, + ): + self.options = options or { + # should getters / setters be prefixed with get / set? + "getSyntax": False, + # should POD members be exposed with getters/setters in classes that + # have them as members? + "exposePODMembers": True, + # use subfolder when including package header files + "includeSubfolder": False, } - self.schema_version = schema_version - self.components = components or {} - self.datatypes = datatypes or {} - self.interfaces = interfaces or {} + self.schema_version = schema_version + self.components = components or {} + self.datatypes = datatypes or {} + self.interfaces = interfaces or {} - def _to_json(self): - """Return the dictionary, so that we can easily hook this into the pythons - JSON ecosystem""" - return self.__dict__ + def _to_json(self): + """Return the dictionary, so that we can easily hook this into the pythons + JSON ecosystem""" + return self.__dict__ class DataModelJSONEncoder(json.JSONEncoder): - """A JSON encoder for DataModels, resp. anything hat has a _to_json method.""" - - def default(self, o): - """The override for the default, first trying to call _to_json, otherwise - handing off to the default JSONEncoder""" - try: - return o._to_json() # pylint: disable=protected-access - except AttributeError: - return super().default(o) + """A JSON encoder for DataModels, resp. anything hat has a _to_json method.""" + + def default(self, o): + """The override for the default, first trying to call _to_json, otherwise + handing off to the default JSONEncoder""" + try: + return o._to_json() # pylint: disable=protected-access + except AttributeError: + return super().default(o) diff --git a/python/podio_gen/julia_generator.py b/python/podio_gen/julia_generator.py index 82fe18ffd..12b966188 100644 --- a/python/podio_gen/julia_generator.py +++ b/python/podio_gen/julia_generator.py @@ -15,127 +15,146 @@ class JuliaClassGenerator(ClassGeneratorBaseMixin): - """The julia class / code generator for podio""" - def print_report(self): - """Print a summary of the generated code""" - nfiles = len(self.datamodel.datatypes) + len(self.datamodel.components) + 1 - text = REPORT_TEXT_JULIA.format(yamlfile=self.yamlfile, - nfiles=nfiles, - installdir=self.install_dir) - - for summaryline in text.splitlines(): - print(summaryline) - print() - - def pre_process(self): - """The necessary specific pre-processing for julia code generation""" - datamodel = {} - datamodel['class'] = DataType(self.package_name.capitalize()) - datamodel['upstream_edm'] = self.upstream_edm - datamodel['upstream_edm_name'] = self.get_upstream_name() - return datamodel - - def post_process(self, datamodel): - """The necessary julia specific post processing""" - datamodel['static_arrays_import'] = self._has_static_arrays_import(datamodel['components'] + datamodel['datatypes']) - datamodel['includes'] = self._sort_components_and_datatypes(datamodel['components'] + datamodel['datatypes']) - self._fill_templates("ParentModule", datamodel) - - def do_process_component(self, _, component): - """Do the julia specific processing of a component""" - component['upstream_edm'] = self.upstream_edm - component['upstream_edm_name'] = self.get_upstream_name() - self._fill_templates("MutableStruct", component) - return component - - def do_process_datatype(self, _, datatype): - """Do the julia specific processing for a datatype""" - if any(self._is_interface(r.full_type) for r in datatype["OneToOneRelations"] + datatype["OneToManyRelations"]): - # Julia doesn't support any interfaces yet, so we have to also sort out - # all the datatypes that use them - return None - - datatype["params_jl"] = sorted(self._get_julia_params(datatype), key=lambda x: x[0]) - datatype["upstream_edm"] = self.upstream_edm - datatype["upstream_edm_name"] = self.get_upstream_name() - - self._fill_templates("MutableStruct", datatype) - return datatype - - def do_process_interface(self, _, __): - """Julia does not support interface types yet, so this does nothing""" - return None - - def get_upstream_name(self): - """Get the name of the upstream datamodel if any""" - if self.upstream_edm: - return self.upstream_edm.options["includeSubfolder"].split("/")[-2].capitalize() - return "" - - @staticmethod - def _get_julia_params(datatype): - """Get the relations as parametric types for MutableStructs""" - params = set() - for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']: - if not relation.is_builtin: - params.add((relation.bare_type, relation.full_type)) - return list(params) - - @staticmethod - def _sort_components_and_datatypes(data): - """Sorts a list of components and datatypes based on dependencies, ensuring that components and datatypes - with no dependencies or dependencies on built-in types come first. The function performs - topological sorting using Kahn's algorithm.""" - # Create a dictionary to store dependencies - dependencies = {} - bare_types_mapping = {} - - for component_data in data: - full_type = component_data['class'].full_type - bare_type = component_data['class'].bare_type - bare_types_mapping[full_type] = bare_type - dependencies[full_type] = set() - - # Check dependencies in 'Members' - if 'Members' in component_data: - for member_data in component_data['Members']: - member_full_type = member_data.full_type - if not member_data.is_builtin and not member_data.is_builtin_array: - dependencies[full_type].add(member_full_type) - - # Check dependencies in 'VectorMembers' - if 'VectorMembers' in component_data: - for vector_member_data in component_data['VectorMembers']: - vector_member_full_type = vector_member_data.full_type - if not vector_member_data.is_builtin and not vector_member_data.is_builtin_array: - dependencies[full_type].add(vector_member_full_type) - - # Perform topological sorting using Kahn's algorithm - sorted_components = [] - while dependencies: - ready = {component for component, deps in dependencies.items() if not deps} - if not ready: - sorted_components.extend(bare_types_mapping[component] for component in dependencies) - break - - for component in ready: - del dependencies[component] - sorted_components.append(bare_types_mapping[component]) - - for deps in dependencies.values(): - deps -= ready - - # Return the Sorted Components (bare_types) - return sorted_components - - @staticmethod - def _has_static_arrays_import(data): - """Checks if any member within a list of components and datatypes contains the import statement - 'using StaticArrays' in its jl_imports. Returns True if found in any member, otherwise False.""" - for component_data in data: - members_data = component_data.get('Members', []) - for member_data in members_data: - jl_imports = member_data.jl_imports - if 'using StaticArrays' in jl_imports: - return True - return False + """The julia class / code generator for podio""" + + def print_report(self): + """Print a summary of the generated code""" + nfiles = len(self.datamodel.datatypes) + len(self.datamodel.components) + 1 + text = REPORT_TEXT_JULIA.format( + yamlfile=self.yamlfile, nfiles=nfiles, installdir=self.install_dir + ) + + for summaryline in text.splitlines(): + print(summaryline) + print() + + def pre_process(self): + """The necessary specific pre-processing for julia code generation""" + datamodel = {} + datamodel["class"] = DataType(self.package_name.capitalize()) + datamodel["upstream_edm"] = self.upstream_edm + datamodel["upstream_edm_name"] = self.get_upstream_name() + return datamodel + + def post_process(self, datamodel): + """The necessary julia specific post processing""" + datamodel["static_arrays_import"] = self._has_static_arrays_import( + datamodel["components"] + datamodel["datatypes"] + ) + datamodel["includes"] = self._sort_components_and_datatypes( + datamodel["components"] + datamodel["datatypes"] + ) + self._fill_templates("ParentModule", datamodel) + + def do_process_component(self, _, component): + """Do the julia specific processing of a component""" + component["upstream_edm"] = self.upstream_edm + component["upstream_edm_name"] = self.get_upstream_name() + self._fill_templates("MutableStruct", component) + return component + + def do_process_datatype(self, _, datatype): + """Do the julia specific processing for a datatype""" + if any( + self._is_interface(r.full_type) + for r in datatype["OneToOneRelations"] + datatype["OneToManyRelations"] + ): + # Julia doesn't support any interfaces yet, so we have to also sort out + # all the datatypes that use them + return None + + datatype["params_jl"] = sorted(self._get_julia_params(datatype), key=lambda x: x[0]) + datatype["upstream_edm"] = self.upstream_edm + datatype["upstream_edm_name"] = self.get_upstream_name() + + self._fill_templates("MutableStruct", datatype) + return datatype + + def do_process_interface(self, _, __): + """Julia does not support interface types yet, so this does nothing""" + return None + + def get_upstream_name(self): + """Get the name of the upstream datamodel if any""" + if self.upstream_edm: + return self.upstream_edm.options["includeSubfolder"].split("/")[-2].capitalize() + return "" + + @staticmethod + def _get_julia_params(datatype): + """Get the relations as parametric types for MutableStructs""" + params = set() + for relation in datatype["OneToManyRelations"] + datatype["OneToOneRelations"]: + if not relation.is_builtin: + params.add((relation.bare_type, relation.full_type)) + return list(params) + + @staticmethod + def _sort_components_and_datatypes(data): + """Sorts a list of components and datatypes based on dependencies, + ensuring that components and datatypes with no dependencies or + dependencies on built-in types come first. The function performs + topological sorting using Kahn's algorithm. + + """ + # Create a dictionary to store dependencies + dependencies = {} + bare_types_mapping = {} + + for component_data in data: + full_type = component_data["class"].full_type + bare_type = component_data["class"].bare_type + bare_types_mapping[full_type] = bare_type + dependencies[full_type] = set() + + # Check dependencies in 'Members' + if "Members" in component_data: + for member_data in component_data["Members"]: + member_full_type = member_data.full_type + if not member_data.is_builtin and not member_data.is_builtin_array: + dependencies[full_type].add(member_full_type) + + # Check dependencies in 'VectorMembers' + if "VectorMembers" in component_data: + for vector_member_data in component_data["VectorMembers"]: + vector_member_full_type = vector_member_data.full_type + if ( + not vector_member_data.is_builtin + and not vector_member_data.is_builtin_array + ): + dependencies[full_type].add(vector_member_full_type) + + # Perform topological sorting using Kahn's algorithm + sorted_components = [] + while dependencies: + ready = {component for component, deps in dependencies.items() if not deps} + if not ready: + sorted_components.extend( + bare_types_mapping[component] for component in dependencies + ) + break + + for component in ready: + del dependencies[component] + sorted_components.append(bare_types_mapping[component]) + + for deps in dependencies.values(): + deps -= ready + + # Return the Sorted Components (bare_types) + return sorted_components + + @staticmethod + def _has_static_arrays_import(data): + """Checks if any member within a list of components and datatypes + contains the import statement 'using StaticArrays' in its jl_imports. + Returns True if found in any member, otherwise False. + + """ + for component_data in data: + members_data = component_data.get("Members", []) + for member_data in members_data: + jl_imports = member_data.jl_imports + if "using StaticArrays" in jl_imports: + return True + return False diff --git a/python/podio_gen/podio_config_reader.py b/python/podio_gen/podio_config_reader.py index 39275b503..b6d9a32fd 100644 --- a/python/podio_gen/podio_config_reader.py +++ b/python/podio_gen/podio_config_reader.py @@ -6,467 +6,550 @@ import warnings import yaml -from podio_gen.generator_utils import MemberVariable, DefinitionError, BUILTIN_TYPES, DataModel, DataType +from podio_gen.generator_utils import ( + MemberVariable, + DefinitionError, + BUILTIN_TYPES, + DataModel, + DataType, +) class MemberParser: - """Class to parse member variables from strings without doing too much validation""" - # Doing this with regex is non-ideal, but we should be able to at least - # enforce something that will yield valid c++ identifiers even if we might not - # cover all possibilities that are admitted by the c++ standard - - # A type can either start with a double colon, or a character (types starting - # with _ are technically allowed, but partially reserved for compilers) - # Additionally we have to take int account the possible whitespaces in the - # builtin types above. Currently this is done by simple brute-forcing. - # To "ensure" that the builtin matching is greedy and picks up as much as - # possible, sort the types by their length in descending order - builtin_str = r'|'.join(rf'(?:{t})' for t in sorted(BUILTIN_TYPES, key=len, reverse=True)) - type_str = rf'({builtin_str}|(?:\:{{2}})?[a-zA-Z]+[a-zA-Z0-9:_]*)' - type_re = re.compile(type_str) - - # Names can be almost anything as long as it doesn't start with a digit and - # doesn't contain anything fancy - name_str = r'([a-zA-Z_]+\w*)' - name_re = re.compile(name_str) - - # Units are given in square brakets - unit_str = r'(?:\[([a-zA-Z_*\/]+\w*)\])?' - unit_re = re.compile(unit_str) - - # Comments can be anything after // - # stripping of trailing whitespaces is done later as it is hard to do with regex - comment_str = r'\/\/ *(.*)' - # std::array declaration with some whitespace distribution freedom - array_str = rf' *std::array *< *{type_str} *, *([0-9]+) *>' - - # Default values can be anything that in curly braces, but not the empty - # default initialization, since that is what we generate in any case - def_val_str = r'(?:{(.+)})?' - - array_re = re.compile(array_str) - full_array_re = re.compile(rf'{array_str} *{name_str} *{def_val_str} *{unit_str} *{comment_str}') - member_re = re.compile(rf' *{type_str} +{name_str} *{def_val_str} *{unit_str} *{comment_str}') - - # For cases where we don't require a description - bare_member_re = re.compile(rf' *{type_str} +{name_str} *{def_val_str} *{unit_str}') - bare_array_re = re.compile(rf' *{array_str} +{name_str} *{def_val_str} *{unit_str}') - - @staticmethod - def _parse_with_regexps(string, regexps_callbacks): - """Parse the string using the passed regexps and corresponding callbacks that - take the match and return a MemberVariable from it""" - for rgx, callback in regexps_callbacks: - result = rgx.match(string) - if result: - return callback(result) - raise DefinitionError(f"'{string}' is not a valid member definition. Check syntax of the member definition.") - - @staticmethod - def _full_array_conv(result): - """MemberVariable construction for array members with a docstring""" - typ, size, name, def_val, unit, comment = result.groups() - return MemberVariable(name=name, array_type=typ, array_size=size, description=comment.strip(), - unit=unit, default_val=def_val) - - @staticmethod - def _full_member_conv(result): - """MemberVariable construction for members with a docstring""" - klass, name, def_val, unit, comment = result.groups() - return MemberVariable(name=name, type=klass, description=comment.strip(), unit=unit, default_val=def_val) - - @staticmethod - def _bare_array_conv(result): - """MemberVariable construction for array members without docstring""" - typ, size, name, def_val, unit = result.groups() - return MemberVariable(name=name, array_type=typ, array_size=size, unit=unit, default_val=def_val) - - @staticmethod - def _bare_member_conv(result): - """MemberVarible construction for members without docstring""" - klass, name, def_val, unit = result.groups() - return MemberVariable(name=name, type=klass, unit=unit, default_val=def_val) - - def parse(self, string, require_description=True): - """Parse the passed string""" - default_matchers_cbs = [ - (self.full_array_re, self._full_array_conv), - (self.member_re, self._full_member_conv)] - - no_desc_matchers_cbs = [ - (self.bare_array_re, self._bare_array_conv), - (self.bare_member_re, self._bare_member_conv)] - - if require_description: - try: - return self._parse_with_regexps(string, default_matchers_cbs) - except DefinitionError: - # check whether we could parse this if we don't require a description and - # provide more details in the error if we can - self._parse_with_regexps(string, no_desc_matchers_cbs) - # pylint: disable-next=raise-missing-from - raise DefinitionError(f"'{string}' is not a valid member definition. Description comment is missing.\n" - "Correct Syntax: // ") - - return self._parse_with_regexps(string, default_matchers_cbs + no_desc_matchers_cbs) + """Class to parse member variables from strings without doing too much validation""" + + # Doing this with regex is non-ideal, but we should be able to at least + # enforce something that will yield valid c++ identifiers even if we might not + # cover all possibilities that are admitted by the c++ standard + + # A type can either start with a double colon, or a character (types starting + # with _ are technically allowed, but partially reserved for compilers) + # Additionally we have to take int account the possible whitespaces in the + # builtin types above. Currently this is done by simple brute-forcing. + # To "ensure" that the builtin matching is greedy and picks up as much as + # possible, sort the types by their length in descending order + builtin_str = r"|".join(rf"(?:{t})" for t in sorted(BUILTIN_TYPES, key=len, reverse=True)) + type_str = rf"({builtin_str}|(?:\:{{2}})?[a-zA-Z]+[a-zA-Z0-9:_]*)" + type_re = re.compile(type_str) + + # Names can be almost anything as long as it doesn't start with a digit and + # doesn't contain anything fancy + name_str = r"([a-zA-Z_]+\w*)" + name_re = re.compile(name_str) + + # Units are given in square brakets + unit_str = r"(?:\[([a-zA-Z_*\/]+\w*)\])?" + unit_re = re.compile(unit_str) + + # Comments can be anything after // + # stripping of trailing whitespaces is done later as it is hard to do with regex + comment_str = r"\/\/ *(.*)" + # std::array declaration with some whitespace distribution freedom + array_str = rf" *std::array *< *{type_str} *, *([0-9]+) *>" + + # Default values can be anything that in curly braces, but not the empty + # default initialization, since that is what we generate in any case + def_val_str = r"(?:{(.+)})?" + + array_re = re.compile(array_str) + full_array_re = re.compile( + rf"{array_str} *{name_str} *{def_val_str} *{unit_str} *{comment_str}" + ) + member_re = re.compile(rf" *{type_str} +{name_str} *{def_val_str} *{unit_str} *{comment_str}") + + # For cases where we don't require a description + bare_member_re = re.compile(rf" *{type_str} +{name_str} *{def_val_str} *{unit_str}") + bare_array_re = re.compile(rf" *{array_str} +{name_str} *{def_val_str} *{unit_str}") + + @staticmethod + def _parse_with_regexps(string, regexps_callbacks): + """Parse the string using the passed regexps and corresponding callbacks that + take the match and return a MemberVariable from it""" + for rgx, callback in regexps_callbacks: + result = rgx.match(string) + if result: + return callback(result) + raise DefinitionError( + f"'{string}' is not a valid member definition. Check syntax of the member definition." + ) + + @staticmethod + def _full_array_conv(result): + """MemberVariable construction for array members with a docstring""" + typ, size, name, def_val, unit, comment = result.groups() + return MemberVariable( + name=name, + array_type=typ, + array_size=size, + description=comment.strip(), + unit=unit, + default_val=def_val, + ) + + @staticmethod + def _full_member_conv(result): + """MemberVariable construction for members with a docstring""" + klass, name, def_val, unit, comment = result.groups() + return MemberVariable( + name=name, + type=klass, + description=comment.strip(), + unit=unit, + default_val=def_val, + ) + + @staticmethod + def _bare_array_conv(result): + """MemberVariable construction for array members without docstring""" + typ, size, name, def_val, unit = result.groups() + return MemberVariable( + name=name, array_type=typ, array_size=size, unit=unit, default_val=def_val + ) + + @staticmethod + def _bare_member_conv(result): + """MemberVarible construction for members without docstring""" + klass, name, def_val, unit = result.groups() + return MemberVariable(name=name, type=klass, unit=unit, default_val=def_val) + + def parse(self, string, require_description=True): + """Parse the passed string""" + default_matchers_cbs = [ + (self.full_array_re, self._full_array_conv), + (self.member_re, self._full_member_conv), + ] + + no_desc_matchers_cbs = [ + (self.bare_array_re, self._bare_array_conv), + (self.bare_member_re, self._bare_member_conv), + ] + + if require_description: + try: + return self._parse_with_regexps(string, default_matchers_cbs) + except DefinitionError: + # check whether we could parse this if we don't require a description and + # provide more details in the error if we can + self._parse_with_regexps(string, no_desc_matchers_cbs) + # pylint: disable-next=raise-missing-from + raise DefinitionError( + f"'{string}' is not a valid member definition. " + "Description comment is missing.\n" + "Correct Syntax: // " + ) + + return self._parse_with_regexps(string, default_matchers_cbs + no_desc_matchers_cbs) class ClassDefinitionValidator: - """Validate the datamodel read from the input yaml file for the most obvious - problems. - """ - # All these keys refer to datatypes only, the subset that is allowed for - # components make it possible to more easily check that in the - # _check_components method - required_datatype_keys = ( - "Description", - "Author", - ) - valid_datatype_member_keys = ( - "Members", - "VectorMembers", - "OneToOneRelations", - "OneToManyRelations", - # "TransientMembers", # not used anywhere in class generator - # "Typedefs", # not used anywhere in class generator - ) - valid_extra_datatype_keys = ( - "ExtraCode", - "MutableExtraCode" - ) - - # documented but not yet implemented - not_yet_implemented_keys = ( - "TransientMembers", - "Typedefs", - ) - - # The interface definitions need the normal datatype keys plus Types to which - # it applies and also which accessor functions to generate - required_interface_keys = required_datatype_keys + ( - "Members", - "Types" - ) - - valid_extra_code_keys = ("declaration", "implementation", "includes") - # documented but not yet implemented - not_yet_implemented_extra_code = ('declarationFile', 'implementationFile') - - @classmethod - def validate(cls, datamodel, upstream_edm=None): - """Validate the datamodel.""" - cls._check_components(datamodel, upstream_edm) - expose_pod_members = datamodel.options['exposePODMembers'] - cls._check_datatypes(datamodel, expose_pod_members, upstream_edm) - cls._check_interfaces(datamodel, upstream_edm) - - @classmethod - def _check_comp(cls, member, components, upstream_edm): - """Check if the passed member is a component defined in either the datamodel - itself or in the upstream datamodel.""" - def _from_upstream(): - return upstream_edm.components if upstream_edm else [] - - return member in components or member in _from_upstream() - - @classmethod - def _check_components(cls, datamodel, upstream_edm): - """Check the components.""" - for name, component in datamodel.components.items(): - for field in component: - if field not in ['Members', 'ExtraCode', 'Description', 'Author']: - raise DefinitionError(f"{name} defines a '{field}' field which is not allowed for a component") - - if 'ExtraCode' in component: - for key in component['ExtraCode']: - if key not in ('declaration', 'includes'): - raise DefinitionError(f"'{key}' field found in 'ExtraCode' of component '{name}'." - " Only 'declaration' and 'includes' are allowed here") - - for member in component['Members']: - is_builtin = member.is_builtin or member.is_builtin_array - is_component_array = member.is_array and cls._check_comp(member.array_type, datamodel.components, upstream_edm) - is_component = cls._check_comp(member.full_type, datamodel.components, upstream_edm) or is_component_array - - if not is_builtin and not is_component: - raise DefinitionError(f'{member.name} of component {name} is not a builtin type, ' - 'another component or one from the upstream EDM') - - @classmethod - def _check_datatypes(cls, datamodel, expose_pod_members, upstream_edm): - """Check the datatypes.""" - # Get all of the datatype names here to avoid depending on the order of - # declaration. NOTE: In this way also invalid classes will be considered, - # but they should hopefully be caught later - for name, definition in datamodel.datatypes.items(): - cls._check_keys(name, definition) - cls._fill_defaults(definition) - cls._check_datatype(name, definition, expose_pod_members, datamodel, upstream_edm) - - @classmethod - def _check_interfaces(cls, datamodel, upstream_edm): - """Check the interface definitions""" - all_types = list(datamodel.datatypes.keys()) - ext_types = upstream_edm.datatypes.keys() if upstream_edm else [] - all_types.extend(ext_types) - - for name, definition in datamodel.interfaces.items(): - if name in all_types: - raise DefinitionError(f"'{name}' defines an interface type with the same name as an existing datatype") - cls._check_interface_fields(name, definition) - cls._check_interface_types(name, definition, all_types) - - @classmethod - def _check_datatype(cls, classname, definition, expose_pod_members, datamodel, upstream_edm): - """Check that a datatype only defines valid types and relations.""" - cls._check_members(classname, definition.get("Members", []), expose_pod_members, datamodel, upstream_edm) - cls._check_relations(classname, definition, datamodel, upstream_edm) - - @classmethod - def _check_members(cls, classname, members, expose_pod_members, datamodel, upstream_edm): - """Check the members of a class for name clashes or undefined classes.""" - all_types = list(datamodel.components.keys()) - ext_types = upstream_edm.components.keys() if upstream_edm else [] - all_types.extend(ext_types) - - all_members = {} - for member in members: - is_builtin = member.is_builtin or member.is_builtin_array - in_definitions = member.full_type in all_types or member.array_type in all_types - - if not is_builtin and not in_definitions: - raise DefinitionError(f"'{classname}' defines member '{member.name}' of type '{member.full_type}' that is not" - " declared!") - - if member.name in all_members: - raise DefinitionError(f"'{member.name}' clashes with another member in class '{classname}', previously defined" - f" in '{all_members[member.name]}'") - - all_members[member.name] = classname - if expose_pod_members and not member.is_builtin and not member.is_array: - for sub_member in datamodel.components[member.full_type]['Members']: - if sub_member.name in all_members: - raise DefinitionError(f"'{sub_member.name}' clashes with another member name in class '{classname}'" - f" (defined in the component '{member.name}' and '{all_members[sub_member.name]}')") - - all_members[sub_member.name] = f"member '{member.name}'" - - @classmethod - def _check_relations(cls, classname, definition, datamodel, upstream_edm): - """Check the relations of a class.""" - def _valid_datatype(rel_type): - if rel_type in datamodel.datatypes or rel_type in datamodel.interfaces: - return True - if upstream_edm and (rel_type in upstream_edm.datatypes or rel_type in upstream_edm.interfaces): - return True - return False - - many_relations = definition.get("OneToManyRelations", []) - for relation in many_relations: - if not _valid_datatype(relation.full_type): - raise DefinitionError(f"'{classname}' declares a non-allowed many-relation to '{relation.full_type}'") - - one_relations = definition.get("OneToOneRelations", []) - for relation in one_relations: - if not _valid_datatype(relation.full_type): - raise DefinitionError(f"'{classname}' declares a non-allowed single-relation to '{relation.full_type}'") - - vector_members = definition.get("VectorMembers", []) - for vecmem in vector_members: - if not vecmem.is_builtin and not cls._check_comp(vecmem.full_type, datamodel.components, upstream_edm): - raise DefinitionError(f"'{classname}' declares a non-allowed vector member of type '{vecmem.full_type}'") - - @classmethod - def _check_keys(cls, classname, definition): - """Check the keys of a datatype.""" - allowed_keys = cls.required_datatype_keys + cls.valid_datatype_member_keys + cls.valid_extra_datatype_keys - # Give some more info for not yet implemented features - invalid_keys = [k for k in definition.keys() if k not in allowed_keys] - if invalid_keys: - not_yet_impl = [k for k in invalid_keys if k in cls.not_yet_implemented_keys] - if not_yet_impl: - not_yet_impl = f' (not yet implemented: {not_yet_impl})' - else: - not_yet_impl = '' - - raise DefinitionError(f"'{classname}' defines invalid categories: {invalid_keys}{not_yet_impl}") - - for key in cls.required_datatype_keys: - if key not in definition: - raise DefinitionError(f"'{classname}' does not define '{key}'") - - if 'ExtraCode' in definition: - extracode = definition['ExtraCode'] - invalid_keys = [k for k in extracode if k not in cls.valid_extra_code_keys] - if invalid_keys: - not_yet_impl = [k for k in invalid_keys if k in cls.not_yet_implemented_extra_code] - if not_yet_impl: - not_yet_impl = f' (not yet implemented: {not_yet_impl})' - else: - not_yet_impl = '' - - raise DefinitionError("{classname} defines invalid 'ExtraCode' categories: {invalid_keys}{not_yet_impl}") - - @classmethod - def _check_interface_fields(cls, name, definition): - """Check whether the fields of an interface definition follow the required schema""" - for key in cls.required_interface_keys: - if key not in definition: - raise DefinitionError(f"interface '{name}' does not define '{key}' field which is required") - - invalid_keys = [k for k in definition.keys() if k not in cls.required_interface_keys] - if invalid_keys: - raise DefinitionError(f"interface '{name}' defines invalid fields: {invalid_keys}") - - @classmethod - def _check_interface_types(cls, name, definition, known_datatypes): - """Check whether an interface really only uses known datatypes""" - for wrapped_type in definition["Types"]: - if wrapped_type.full_type not in known_datatypes: - raise DefinitionError(f"interface '{name}' tries to use Type '{wrapped_type}' which is not defined anywhere") - - @classmethod - def _fill_defaults(cls, definition): - """Fill some of the fields with empty defaults in order to make it easier to - handle them afterwards and not having to check every time whether they exist. - TODO: This is a rather ugly thing to do as it strongly couples all the - components (reader, validation, generator) to each other. But currently the - generator assumes that all these fields are present and would require a lot - of changes to accommodate to optionally get these. Trying to at least - encapsulate this into one place here, such that it makes it easier to remove - once the generator is more robust against missing fields + """Validate the datamodel read from the input yaml file for the most obvious + problems. """ - for field in cls.valid_datatype_member_keys: - if field not in definition: - definition[field] = [] - for field in cls.valid_extra_datatype_keys: - if field not in definition: - definition[field] = {} + # All these keys refer to datatypes only, the subset that is allowed for + # components make it possible to more easily check that in the + # _check_components method + required_datatype_keys = ( + "Description", + "Author", + ) + valid_datatype_member_keys = ( + "Members", + "VectorMembers", + "OneToOneRelations", + "OneToManyRelations", + # "TransientMembers", # not used anywhere in class generator + # "Typedefs", # not used anywhere in class generator + ) + valid_extra_datatype_keys = ("ExtraCode", "MutableExtraCode") + + # documented but not yet implemented + not_yet_implemented_keys = ( + "TransientMembers", + "Typedefs", + ) + + # The interface definitions need the normal datatype keys plus Types to which + # it applies and also which accessor functions to generate + required_interface_keys = required_datatype_keys + ("Members", "Types") + + valid_extra_code_keys = ("declaration", "implementation", "includes") + # documented but not yet implemented + not_yet_implemented_extra_code = ("declarationFile", "implementationFile") + + @classmethod + def validate(cls, datamodel, upstream_edm=None): + """Validate the datamodel.""" + cls._check_components(datamodel, upstream_edm) + expose_pod_members = datamodel.options["exposePODMembers"] + cls._check_datatypes(datamodel, expose_pod_members, upstream_edm) + cls._check_interfaces(datamodel, upstream_edm) + + @classmethod + def _check_comp(cls, member, components, upstream_edm): + """Check if the passed member is a component defined in either the datamodel + itself or in the upstream datamodel.""" + + def _from_upstream(): + return upstream_edm.components if upstream_edm else [] + + return member in components or member in _from_upstream() + + @classmethod + def _check_components(cls, datamodel, upstream_edm): + """Check the components.""" + for name, component in datamodel.components.items(): + for field in component: + if field not in ["Members", "ExtraCode", "Description", "Author"]: + raise DefinitionError( + f"{name} defines a '{field}' field which is not allowed for a component" + ) + + if "ExtraCode" in component: + for key in component["ExtraCode"]: + if key not in ("declaration", "includes"): + raise DefinitionError( + f"'{key}' field found in 'ExtraCode' of component '{name}'." + " Only 'declaration' and 'includes' are allowed here" + ) + + for member in component["Members"]: + is_builtin = member.is_builtin or member.is_builtin_array + is_component_array = member.is_array and cls._check_comp( + member.array_type, datamodel.components, upstream_edm + ) + is_component = ( + cls._check_comp(member.full_type, datamodel.components, upstream_edm) + or is_component_array + ) + + if not is_builtin and not is_component: + raise DefinitionError( + f"{member.name} of component {name} is not a builtin type, " + "another component or one from the upstream EDM" + ) + + @classmethod + def _check_datatypes(cls, datamodel, expose_pod_members, upstream_edm): + """Check the datatypes.""" + # Get all of the datatype names here to avoid depending on the order of + # declaration. NOTE: In this way also invalid classes will be considered, + # but they should hopefully be caught later + for name, definition in datamodel.datatypes.items(): + cls._check_keys(name, definition) + cls._fill_defaults(definition) + cls._check_datatype(name, definition, expose_pod_members, datamodel, upstream_edm) + + @classmethod + def _check_interfaces(cls, datamodel, upstream_edm): + """Check the interface definitions""" + all_types = list(datamodel.datatypes.keys()) + ext_types = upstream_edm.datatypes.keys() if upstream_edm else [] + all_types.extend(ext_types) + + for name, definition in datamodel.interfaces.items(): + if name in all_types: + raise DefinitionError( + f"'{name}' defines an interface type with the same name " + "as an existing datatype" + ) + cls._check_interface_fields(name, definition) + cls._check_interface_types(name, definition, all_types) + + @classmethod + def _check_datatype(cls, classname, definition, expose_pod_members, datamodel, upstream_edm): + """Check that a datatype only defines valid types and relations.""" + cls._check_members( + classname, + definition.get("Members", []), + expose_pod_members, + datamodel, + upstream_edm, + ) + cls._check_relations(classname, definition, datamodel, upstream_edm) + + @classmethod + def _check_members(cls, classname, members, expose_pod_members, datamodel, upstream_edm): + """Check the members of a class for name clashes or undefined classes.""" + all_types = list(datamodel.components.keys()) + ext_types = upstream_edm.components.keys() if upstream_edm else [] + all_types.extend(ext_types) + + all_members = {} + for member in members: + is_builtin = member.is_builtin or member.is_builtin_array + in_definitions = member.full_type in all_types or member.array_type in all_types + + if not is_builtin and not in_definitions: + raise DefinitionError( + f"'{classname}' defines member '{member.name}' of type '{member.full_type}'" + " that is not declared!" + ) + + if member.name in all_members: + raise DefinitionError( + f"'{member.name}' clashes with another member in class '{classname}', " + f"previously defined in '{all_members[member.name]}'" + ) + + all_members[member.name] = classname + if expose_pod_members and not member.is_builtin and not member.is_array: + for sub_member in datamodel.components[member.full_type]["Members"]: + if sub_member.name in all_members: + raise DefinitionError( + f"'{sub_member.name}' clashes with another member name in class " + f"'{classname}' (defined in the component '{member.name}' " + f"and '{all_members[sub_member.name]}')" + ) + + all_members[sub_member.name] = f"member '{member.name}'" + + @classmethod + def _check_relations(cls, classname, definition, datamodel, upstream_edm): + """Check the relations of a class.""" + + def _valid_datatype(rel_type): + if rel_type in datamodel.datatypes or rel_type in datamodel.interfaces: + return True + if upstream_edm and ( + rel_type in upstream_edm.datatypes or rel_type in upstream_edm.interfaces + ): + return True + return False + + many_relations = definition.get("OneToManyRelations", []) + for relation in many_relations: + if not _valid_datatype(relation.full_type): + raise DefinitionError( + f"'{classname}' declares a invalid many-relation to '{relation.full_type}'" + ) + + one_relations = definition.get("OneToOneRelations", []) + for relation in one_relations: + if not _valid_datatype(relation.full_type): + raise DefinitionError( + f"'{classname}' declares an inalid single-relation to '{relation.full_type}'" + ) + + vector_members = definition.get("VectorMembers", []) + for vecmem in vector_members: + if not vecmem.is_builtin and not cls._check_comp( + vecmem.full_type, datamodel.components, upstream_edm + ): + raise DefinitionError( + f"'{classname}' declares a invalid vector member of type '{vecmem.full_type}'" + ) + + @classmethod + def _check_keys(cls, classname, definition): + """Check the keys of a datatype.""" + allowed_keys = ( + cls.required_datatype_keys + + cls.valid_datatype_member_keys + + cls.valid_extra_datatype_keys + ) + # Give some more info for not yet implemented features + invalid_keys = [k for k in definition.keys() if k not in allowed_keys] + if invalid_keys: + not_yet_impl = [k for k in invalid_keys if k in cls.not_yet_implemented_keys] + if not_yet_impl: + not_yet_impl = f" (not yet implemented: {not_yet_impl})" + else: + not_yet_impl = "" + + raise DefinitionError( + f"'{classname}' defines invalid categories: {invalid_keys}{not_yet_impl}" + ) + + for key in cls.required_datatype_keys: + if key not in definition: + raise DefinitionError(f"'{classname}' does not define '{key}'") + + if "ExtraCode" in definition: + extracode = definition["ExtraCode"] + invalid_keys = [k for k in extracode if k not in cls.valid_extra_code_keys] + if invalid_keys: + not_yet_impl = [k for k in invalid_keys if k in cls.not_yet_implemented_extra_code] + if not_yet_impl: + not_yet_impl = f" (not yet implemented: {not_yet_impl})" + else: + not_yet_impl = "" + + raise DefinitionError( + f"{classname} defines invalid 'ExtraCode' categories: " + f"{invalid_keys}{not_yet_impl}" + ) + + @classmethod + def _check_interface_fields(cls, name, definition): + """Check whether the fields of an interface definition follow the required schema""" + for key in cls.required_interface_keys: + if key not in definition: + raise DefinitionError( + f"interface '{name}' does not define '{key}' field which is required" + ) + + invalid_keys = [k for k in definition.keys() if k not in cls.required_interface_keys] + if invalid_keys: + raise DefinitionError(f"interface '{name}' defines invalid fields: {invalid_keys}") + + @classmethod + def _check_interface_types(cls, name, definition, known_datatypes): + """Check whether an interface really only uses known datatypes""" + for wrapped_type in definition["Types"]: + if wrapped_type.full_type not in known_datatypes: + raise DefinitionError( + f"interface '{name}' tries to use Type '{wrapped_type}' which is undefined" + ) + + @classmethod + def _fill_defaults(cls, definition): + """Fill some of the fields with empty defaults in order to make it easier to + handle them afterwards and not having to check every time whether they exist. + TODO: This is a rather ugly thing to do as it strongly couples all the + components (reader, validation, generator) to each other. But currently the + generator assumes that all these fields are present and would require a lot + of changes to accommodate to optionally get these. Trying to at least + encapsulate this into one place here, such that it makes it easier to remove + once the generator is more robust against missing fields + """ + for field in cls.valid_datatype_member_keys: + if field not in definition: + definition[field] = [] + + for field in cls.valid_extra_datatype_keys: + if field not in definition: + definition[field] = {} class PodioConfigReader: - """Config reader that does basic parsing of the member definitions and puts - everything into a somewhat uniform structure without doing any fancy - validation - """ - member_parser = MemberParser() - # default options - options = { - # should getters / setters be prefixed with get / set? - "getSyntax": False, - # should POD members be exposed with getters/setters in classes that have them as members? - "exposePODMembers": True, - # use subfolder when including package header files - "includeSubfolder": False, - } - - @staticmethod - def _handle_extracode(definition): - """Handle the extra code definition. Currently simply returning a copy""" - return copy.deepcopy(definition) - - @classmethod - def _read_component(cls, definition): - """Read the component and put it into an easily digestible format. + """Config reader that does basic parsing of the member definitions and puts + everything into a somewhat uniform structure without doing any fancy + validation """ - component = {} - for name, category in definition.items(): - if name == 'Members': - component['Members'] = [] - for member in definition[name]: - # for components we do not require a description in the members - component['Members'].append(cls.member_parser.parse(member, False)) - else: - component[name] = copy.deepcopy(category) - - return component - - @classmethod - def _read_datatype(cls, value): - """Read the datatype and put it into an easily digestible format""" - datatype = {} - for category, definition in value.items(): - # special handling of the member types. Parse them here directly - if category in ClassDefinitionValidator.valid_datatype_member_keys: - members = [] - for member in definition: - members.append(cls.member_parser.parse(member)) - datatype[category] = members - else: - datatype[category] = copy.deepcopy(definition) - - return datatype - - @classmethod - def _read_interface(cls, value): - """Read an interface definition and put it into a more easily digestible format""" - interface = {} - for category, definition in value.items(): - if category == "Members": - members = [] - for member in definition: - members.append(cls.member_parser.parse(member)) - interface["Members"] = members - elif category == "Types": - types = [] - for typ in definition: - types.append(DataType(typ)) - interface["Types"] = types - else: - interface[category] = copy.deepcopy(definition) - - return interface - - @classmethod - def parse_model(cls, model_dict, package_name, upstream_edm=None): - """Parse a model from the dictionary, e.g. read from a yaml file.""" - - if "schema_version" in model_dict: - schema_version = model_dict["schema_version"] - if int(schema_version) <= 0: - raise DefinitionError(f"schema_version has to be larger than 0 (is {schema_version})") - else: - warnings.warn("Please provide a schema_version entry. It will become mandatory. Setting it to 1 as default", - FutureWarning, stacklevel=3) - schema_version = 1 - - components = {} - if "components" in model_dict: - for klassname, value in model_dict["components"].items(): - components[klassname] = cls._read_component(value) - - datatypes = {} - if "datatypes" in model_dict: - for klassname, value in model_dict["datatypes"].items(): - datatypes[klassname] = cls._read_datatype(value) - - interfaces = {} - if "interfaces" in model_dict: - for klassname, value in model_dict["interfaces"].items(): - interfaces[klassname] = cls._read_interface(value) - - options = copy.deepcopy(cls.options) - if "options" in model_dict: - for option, value in model_dict["options"].items(): - options[option] = value - - # Normalize the includeSubfoler internally already here - if options['includeSubfolder']: - options['includeSubfolder'] = f'{package_name}/' - else: - options['includeSubfolder'] = '' - - # If this doesn't raise an exception everything should in principle work out - validator = ClassDefinitionValidator() - datamodel = DataModel(datatypes, components, interfaces, options, schema_version) - validator.validate(datamodel, upstream_edm) - return datamodel - - @classmethod - def read(cls, yamlfile, package_name, upstream_edm=None): - """Read the datamodel definition from the yamlfile.""" - with open(yamlfile, "r", encoding='utf-8') as stream: - content = yaml.load(stream, yaml.SafeLoader) - - return cls.parse_model(content, package_name, upstream_edm) + + member_parser = MemberParser() + # default options + options = { + # should getters / setters be prefixed with get / set? + "getSyntax": False, + # should POD members be exposed with getters/setters in classes that have them as members? + "exposePODMembers": True, + # use subfolder when including package header files + "includeSubfolder": False, + } + + @staticmethod + def _handle_extracode(definition): + """Handle the extra code definition. Currently simply returning a copy""" + return copy.deepcopy(definition) + + @classmethod + def _read_component(cls, definition): + """Read the component and put it into an easily digestible format.""" + component = {} + for name, category in definition.items(): + if name == "Members": + component["Members"] = [] + for member in definition[name]: + # for components we do not require a description in the members + component["Members"].append(cls.member_parser.parse(member, False)) + else: + component[name] = copy.deepcopy(category) + + return component + + @classmethod + def _read_datatype(cls, value): + """Read the datatype and put it into an easily digestible format""" + datatype = {} + for category, definition in value.items(): + # special handling of the member types. Parse them here directly + if category in ClassDefinitionValidator.valid_datatype_member_keys: + members = [] + for member in definition: + members.append(cls.member_parser.parse(member)) + datatype[category] = members + else: + datatype[category] = copy.deepcopy(definition) + + return datatype + + @classmethod + def _read_interface(cls, value): + """Read an interface definition and put it into a more easily digestible format""" + interface = {} + for category, definition in value.items(): + if category == "Members": + members = [] + for member in definition: + members.append(cls.member_parser.parse(member)) + interface["Members"] = members + elif category == "Types": + types = [] + for typ in definition: + types.append(DataType(typ)) + interface["Types"] = types + else: + interface[category] = copy.deepcopy(definition) + + return interface + + @classmethod + def parse_model(cls, model_dict, package_name, upstream_edm=None): + """Parse a model from the dictionary, e.g. read from a yaml file.""" + + if "schema_version" in model_dict: + schema_version = model_dict["schema_version"] + if int(schema_version) <= 0: + raise DefinitionError( + f"schema_version has to be larger than 0 (is {schema_version})" + ) + else: + warnings.warn( + "Please provide a schema_version entry. It will become mandatory. " + "Setting it to 1 as default", + FutureWarning, + stacklevel=3, + ) + schema_version = 1 + + components = {} + if "components" in model_dict: + for klassname, value in model_dict["components"].items(): + components[klassname] = cls._read_component(value) + + datatypes = {} + if "datatypes" in model_dict: + for klassname, value in model_dict["datatypes"].items(): + datatypes[klassname] = cls._read_datatype(value) + + interfaces = {} + if "interfaces" in model_dict: + for klassname, value in model_dict["interfaces"].items(): + interfaces[klassname] = cls._read_interface(value) + + options = copy.deepcopy(cls.options) + if "options" in model_dict: + for option, value in model_dict["options"].items(): + options[option] = value + + # Normalize the includeSubfoler internally already here + if options["includeSubfolder"]: + options["includeSubfolder"] = f"{package_name}/" + else: + options["includeSubfolder"] = "" + + # If this doesn't raise an exception everything should in principle work out + validator = ClassDefinitionValidator() + datamodel = DataModel(datatypes, components, interfaces, options, schema_version) + validator.validate(datamodel, upstream_edm) + return datamodel + + @classmethod + def read(cls, yamlfile, package_name, upstream_edm=None): + """Read the datamodel definition from the yamlfile.""" + with open(yamlfile, "r", encoding="utf-8") as stream: + content = yaml.load(stream, yaml.SafeLoader) + + return cls.parse_model(content, package_name, upstream_edm) diff --git a/python/podio_gen/test_ClassDefinitionValidator.py b/python/podio_gen/test_ClassDefinitionValidator.py index 8cd80b8b9..cf271e2ef 100644 --- a/python/podio_gen/test_ClassDefinitionValidator.py +++ b/python/podio_gen/test_ClassDefinitionValidator.py @@ -8,455 +8,597 @@ import unittest from copy import deepcopy -from podio_gen.podio_config_reader import ClassDefinitionValidator, MemberVariable, DefinitionError +from podio_gen.podio_config_reader import ( + ClassDefinitionValidator, + MemberVariable, + DefinitionError, +) from podio_gen.generator_utils import DataModel, DataType def make_dm(components, datatypes, interfaces=None, options=None): - """Small helper function to turn things into a datamodel dict as expected by - the validator""" - return DataModel(datatypes, components, interfaces, options) + """Small helper function to turn things into a datamodel dict as expected by + the validator""" + return DataModel(datatypes, components, interfaces, options) class ClassDefinitionValidatorTest(unittest.TestCase): # pylint: disable=too-many-public-methods - """Unit tests for the ClassDefinitionValidator""" - def setUp(self): - valid_component_members = [ - MemberVariable(type='int', name='anInt'), - MemberVariable(type='float', name='aFloat'), - MemberVariable(array_type='int', array_size='4', name='anArray') + """Unit tests for the ClassDefinitionValidator""" + + def setUp(self): + valid_component_members = [ + MemberVariable(type="int", name="anInt"), + MemberVariable(type="float", name="aFloat"), + MemberVariable(array_type="int", array_size="4", name="anArray"), ] - self.valid_component = { - 'Component': { - 'Members': valid_component_members, - 'ExtraCode': { - 'includes': '#include "someFancyHeader.h"', - 'declaration': 'we do not validate if this is valid c++' - } + self.valid_component = { + "Component": { + "Members": valid_component_members, + "ExtraCode": { + "includes": '#include "someFancyHeader.h"', + "declaration": "we do not validate if this is valid c++", + }, } } - valid_datatype_members = [ - MemberVariable(type='float', name='energy', description='energy [GeV]'), - MemberVariable(array_type='int', array_size='5', name='anArray', description='some Array') + valid_datatype_members = [ + MemberVariable(type="float", name="energy", description="energy [GeV]"), + MemberVariable( + array_type="int", + array_size="5", + name="anArray", + description="some Array", + ), ] - self.valid_datatype = { - 'DataType': { - 'Author': 'Mr. Bean', - 'Description': 'I am merely here for a test', - 'Members': valid_datatype_members, - 'ExtraCode': { - 'includes': '#include not checking for valid c++', - 'declaration': 'not necessarily valid c++', - 'implementation': 'still not checked for c++ validity', + self.valid_datatype = { + "DataType": { + "Author": "Mr. Bean", + "Description": "I am merely here for a test", + "Members": valid_datatype_members, + "ExtraCode": { + "includes": "#include not checking for valid c++", + "declaration": "not necessarily valid c++", + "implementation": "still not checked for c++ validity", + }, + "MutableExtraCode": { + "declaration": "also not checked for valid c++", + "implementation": "nothing has changed", + "includes": "#include this will appear in both includes!", }, - 'MutableExtraCode': { - 'declaration': 'also not checked for valid c++', - 'implementation': 'nothing has changed', - 'includes': '#include this will appear in both includes!' - } } } - self.valid_interface = { - "InterfaceType": { - "Author": "Karma Chameleon", - "Description": "I can be many things but only one at a time", - "Members": valid_datatype_members, - "Types": [DataType("DataType")] + self.valid_interface = { + "InterfaceType": { + "Author": "Karma Chameleon", + "Description": "I can be many things but only one at a time", + "Members": valid_datatype_members, + "Types": [DataType("DataType")], } } - # The default options that should be used for validation - self.def_opts = {'exposePODMembers': False} - - self.validator = ClassDefinitionValidator() - self.validate = self.validator.validate - - def _assert_no_exception(self, exceptions, message, func, *args, **kwargs): - """Helper function to assert a function does not raise any of the specific exceptions""" - try: - func(*args, **kwargs) - except exceptions: - self.fail(message.format(func.__name__)) - - def test_component_invalid_extra_code(self): - component = deepcopy(self.valid_component) - component['Component']['ExtraCode']['const_declaration'] = '// not even valid c++ passes here' - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}), False) - - component = deepcopy(self.valid_component) - component['Component']['ExtraCode']['const_implementation'] = '// it does not either here' - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}), False) - - def test_component_invalid_member(self): - # non-builtin type - component = deepcopy(self.valid_component) - component['Component']['Members'].append(MemberVariable(type='NonBuiltinType', name='foo')) - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}), False) - - # non-builtin array that is also not in another component - component = deepcopy(self.valid_component) - component['Component']['Members'].append( - MemberVariable(array_type='NonBuiltinType', array_size=3, name='complexArray')) - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}), False) - - def test_component_valid_members(self): - self._assert_no_exception(DefinitionError, '{} should not raise for a valid component', - self.validate, make_dm(self.valid_component, {}), False) - - components = deepcopy(self.valid_component) - components['SecondComponent'] = { - 'Members': [MemberVariable(array_type='Component', array_size='10', name='referToOtheComponent')] + # The default options that should be used for validation + self.def_opts = {"exposePODMembers": False} + + self.validator = ClassDefinitionValidator() + self.validate = self.validator.validate + + def _assert_no_exception(self, exceptions, message, func, *args, **kwargs): + """Helper function to assert a function does not raise any of the specific exceptions""" + try: + func(*args, **kwargs) + except exceptions: + self.fail(message.format(func.__name__)) + + def test_component_invalid_extra_code(self): + component = deepcopy(self.valid_component) + component["Component"]["ExtraCode"][ + "const_declaration" + ] = "// not even valid c++ passes here" + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}), False) + + component = deepcopy(self.valid_component) + component["Component"]["ExtraCode"]["const_implementation"] = "// it does not either here" + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}), False) + + def test_component_invalid_member(self): + # non-builtin type + component = deepcopy(self.valid_component) + component["Component"]["Members"].append(MemberVariable(type="NonBuiltinType", name="foo")) + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}), False) + + # non-builtin array that is also not in another component + component = deepcopy(self.valid_component) + component["Component"]["Members"].append( + MemberVariable(array_type="NonBuiltinType", array_size=3, name="complexArray") + ) + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}), False) + + def test_component_valid_members(self): + self._assert_no_exception( + DefinitionError, + "{} should not raise for a valid component", + self.validate, + make_dm(self.valid_component, {}), + False, + ) + + components = deepcopy(self.valid_component) + components["SecondComponent"] = { + "Members": [ + MemberVariable( + array_type="Component", array_size="10", name="referToOtheComponent" + ) + ] } - self._assert_no_exception(DefinitionError, '{} should allow for component members in components', - self.validate, make_dm(components, {}), False) - - def test_datatype_valid_members(self): - self._assert_no_exception(DefinitionError, '{} should not raise for a valid datatype', - self.validate, make_dm({}, self.valid_datatype, options=self.def_opts)) - - # things should still work if we add a component member - self.valid_datatype['DataType']['Members'].append(MemberVariable(type='Component', name='comp')) - self._assert_no_exception(DefinitionError, '{} should allow for members that are components', - self.validate, - make_dm(self.valid_component, self.valid_datatype, options=self.def_opts)) - - # also when we add an array of components - self.valid_datatype['DataType']['Members'].append(MemberVariable(array_type='Component', - array_size='3', - name='arrComp')) - self._assert_no_exception(DefinitionError, '{} should allow for arrays of components as members', - self.validate, - make_dm(self.valid_component, self.valid_datatype, options=self.def_opts)) - - # pod members can be redefined if they are note exposed - self.valid_datatype['DataType']['Members'].append(MemberVariable(type='double', name='aFloat')) - self._assert_no_exception(DefinitionError, - '{} should allow for re-use of component names if the components are not exposed', - self.validate, - make_dm(self.valid_component, self.valid_datatype, options=self.def_opts)) - - datatype = { - 'DataTypeWithoutMembers': { - 'Author': 'Anonymous', 'Description': 'A pretty useless Datatype as it is' + self._assert_no_exception( + DefinitionError, + "{} should allow for component members in components", + self.validate, + make_dm(components, {}), + False, + ) + + def test_datatype_valid_members(self): + self._assert_no_exception( + DefinitionError, + "{} should not raise for a valid datatype", + self.validate, + make_dm({}, self.valid_datatype, options=self.def_opts), + ) + + # things should still work if we add a component member + self.valid_datatype["DataType"]["Members"].append( + MemberVariable(type="Component", name="comp") + ) + self._assert_no_exception( + DefinitionError, + "{} should allow for members that are components", + self.validate, + make_dm(self.valid_component, self.valid_datatype, options=self.def_opts), + ) + + # also when we add an array of components + self.valid_datatype["DataType"]["Members"].append( + MemberVariable(array_type="Component", array_size="3", name="arrComp") + ) + self._assert_no_exception( + DefinitionError, + "{} should allow for arrays of components as members", + self.validate, + make_dm(self.valid_component, self.valid_datatype, options=self.def_opts), + ) + + # pod members can be redefined if they are note exposed + self.valid_datatype["DataType"]["Members"].append( + MemberVariable(type="double", name="aFloat") + ) + self._assert_no_exception( + DefinitionError, + "{} should allow for re-use of component names if the components are not exposed", + self.validate, + make_dm(self.valid_component, self.valid_datatype, options=self.def_opts), + ) + + datatype = { + "DataTypeWithoutMembers": { + "Author": "Anonymous", + "Description": "A pretty useless Datatype as it is", } } - self._assert_no_exception(DefinitionError, '{} should allow for almost empty datatypes', - self.validate, make_dm({}, datatype, options=self.def_opts)) - - def test_datatype_invalid_definitions(self): - for required in ('Author', 'Description'): - datatype = deepcopy(self.valid_datatype) - del datatype['DataType'][required] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['ExtraCode']['invalid_extracode'] = 'an invalid entry to the ExtraCode' - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - datatype = deepcopy(self.valid_datatype) - datatype['InvalidCategory'] = {'key': 'invalid value'} - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - def test_datatype_invalid_members(self): - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['Members'].append(MemberVariable(type='NonDeclaredType', name='foo')) - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype, self.def_opts)) - - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['Members'].append(MemberVariable(type='float', name='definedTwice')) - datatype['DataType']['Members'].append(MemberVariable(type='int', name='definedTwice')) - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype, self.def_opts)) - - # Re-definition of a member present in a component and pod members are exposed - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['Members'].append(MemberVariable(type='Component', name='aComponent')) - datatype['DataType']['Members'].append(MemberVariable(type='float', name='aFloat')) - with self.assertRaises(DefinitionError): - self.validate(make_dm(self.valid_component, datatype, {'exposePODMembers': True})) - - datatype = deepcopy(self.valid_datatype) - datatype['AnotherType'] = { - 'Author': 'Avril L.', - 'Description': 'I\'m just a datatype', + self._assert_no_exception( + DefinitionError, + "{} should allow for almost empty datatypes", + self.validate, + make_dm({}, datatype, options=self.def_opts), + ) + + def test_datatype_invalid_definitions(self): + for required in ("Author", "Description"): + datatype = deepcopy(self.valid_datatype) + del datatype["DataType"][required] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["ExtraCode"][ + "invalid_extracode" + ] = "an invalid entry to the ExtraCode" + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + datatype = deepcopy(self.valid_datatype) + datatype["InvalidCategory"] = {"key": "invalid value"} + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + def test_datatype_invalid_members(self): + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["Members"].append(MemberVariable(type="NonDeclaredType", name="foo")) + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype, self.def_opts)) + + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["Members"].append(MemberVariable(type="float", name="definedTwice")) + datatype["DataType"]["Members"].append(MemberVariable(type="int", name="definedTwice")) + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype, self.def_opts)) + + # Re-definition of a member present in a component and pod members are exposed + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["Members"].append(MemberVariable(type="Component", name="aComponent")) + datatype["DataType"]["Members"].append(MemberVariable(type="float", name="aFloat")) + with self.assertRaises(DefinitionError): + self.validate(make_dm(self.valid_component, datatype, {"exposePODMembers": True})) + + datatype = deepcopy(self.valid_datatype) + datatype["AnotherType"] = { + "Author": "Avril L.", + "Description": "I'm just a datatype", } - datatype['DataType']['Members'].append( - MemberVariable(type='AnotherType', name='impossibleType', - description='Another datatype cannot be a member')) - with self.assertRaises(DefinitionError): - self.validate(make_dm(self.valid_component, datatype, self.def_opts)) - - def _test_datatype_valid_relations(self, rel_type): - self.valid_datatype['DataType'][rel_type] = [ - MemberVariable(type='DataType', name='selfRelation') + datatype["DataType"]["Members"].append( + MemberVariable( + type="AnotherType", + name="impossibleType", + description="Another datatype cannot be a member", + ) + ) + with self.assertRaises(DefinitionError): + self.validate(make_dm(self.valid_component, datatype, self.def_opts)) + + def _test_datatype_valid_relations(self, rel_type): + self.valid_datatype["DataType"][rel_type] = [ + MemberVariable(type="DataType", name="selfRelation") ] - self._assert_no_exception(DefinitionError, - '{} should allow for relations of datatypes to themselves', - self.validate, make_dm({}, self.valid_datatype), False) - - self.valid_datatype['BlackKnight'] = { - 'Author': 'John Cleese', - 'Description': 'Tis but a scratch', - 'Members': [MemberVariable(type='int', name='counter', description='number of arms')], - rel_type: [MemberVariable(type='DataType', name='relation', description='soo many relations')] + self._assert_no_exception( + DefinitionError, + "{} should allow for relations of datatypes to themselves", + self.validate, + make_dm({}, self.valid_datatype), + False, + ) + + self.valid_datatype["BlackKnight"] = { + "Author": "John Cleese", + "Description": "Tis but a scratch", + "Members": [MemberVariable(type="int", name="counter", description="number of arms")], + rel_type: [ + MemberVariable(type="DataType", name="relation", description="soo many relations") + ], } - self._assert_no_exception(DefinitionError, '{} should validate a valid relation', - self.validate, make_dm(self.valid_component, self.valid_datatype), False) - - def test_datatype_valid_many_relations(self): - self._test_datatype_valid_relations('OneToManyRelations') + self._assert_no_exception( + DefinitionError, + "{} should validate a valid relation", + self.validate, + make_dm(self.valid_component, self.valid_datatype), + False, + ) + + def test_datatype_valid_many_relations(self): + self._test_datatype_valid_relations("OneToManyRelations") + + def test_datatype_valid_single_relations(self): + self._test_datatype_valid_relations("OneToOneRelations") + + def _test_datatype_invalid_relations(self, rel_type): + datatype = deepcopy(self.valid_datatype) + datatype["DataType"][rel_type] = [MemberVariable(type="NonExistentDataType", name="aName")] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + datatype = deepcopy(self.valid_datatype) + datatype["DataType"][rel_type] = [ + MemberVariable(type="Component", name="componentRelation") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm(self.valid_component, datatype), False) - def test_datatype_valid_single_relations(self): - self._test_datatype_valid_relations('OneToOneRelations') + datatype = deepcopy(self.valid_datatype) + datatype["DataType"][rel_type] = [ + MemberVariable(array_type="int", array_size="42", name="arrayRelation") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) - def _test_datatype_invalid_relations(self, rel_type): - datatype = deepcopy(self.valid_datatype) - datatype['DataType'][rel_type] = [MemberVariable(type='NonExistentDataType', - name='aName')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) + def test_datatype_invalid_many_relations(self): + self._test_datatype_invalid_relations("OneToManyRelations") - datatype = deepcopy(self.valid_datatype) - datatype['DataType'][rel_type] = [MemberVariable(type='Component', - name='componentRelation')] - with self.assertRaises(DefinitionError): - self.validate(make_dm(self.valid_component, datatype), False) + def test_datatype_invalid_single_relations(self): + self._test_datatype_invalid_relations("OneToOneRelations") - datatype = deepcopy(self.valid_datatype) - datatype['DataType'][rel_type] = [ - MemberVariable(array_type='int', array_size='42', name='arrayRelation') + def test_datatype_valid_vector_members(self): + self.valid_datatype["DataType"]["VectorMembers"] = [ + MemberVariable(type="int", name="someInt") ] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - def test_datatype_invalid_many_relations(self): - self._test_datatype_invalid_relations('OneToManyRelations') - - def test_datatype_invalid_single_relations(self): - self._test_datatype_invalid_relations('OneToOneRelations') - - def test_datatype_valid_vector_members(self): - self.valid_datatype['DataType']['VectorMembers'] = [MemberVariable(type='int', name='someInt')] - self._assert_no_exception(DefinitionError, - '{} should validate builtin VectorMembers', - self.validate, make_dm({}, self.valid_datatype), False) - - self.valid_datatype['DataType']['VectorMembers'] = [MemberVariable(type='Component', name='components')] - self._assert_no_exception(DefinitionError, - '{} should validate component VectorMembers', - self.validate, make_dm(self.valid_component, self.valid_datatype), False) - - def test_datatype_invalid_vector_members(self): - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['VectorMembers'] = [MemberVariable(type='DataType', name='invalid')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - datatype['Brian'] = { - 'Author': 'Graham Chapman', - 'Description': 'Not the messiah, a very naughty boy', - 'VectorMembers': [ - MemberVariable(type='DataType', name='invalid', - description='also non-self relations are not allowed') - ] + self._assert_no_exception( + DefinitionError, + "{} should validate builtin VectorMembers", + self.validate, + make_dm({}, self.valid_datatype), + False, + ) + + self.valid_datatype["DataType"]["VectorMembers"] = [ + MemberVariable(type="Component", name="components") + ] + self._assert_no_exception( + DefinitionError, + "{} should validate component VectorMembers", + self.validate, + make_dm(self.valid_component, self.valid_datatype), + False, + ) + + def test_datatype_invalid_vector_members(self): + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["VectorMembers"] = [MemberVariable(type="DataType", name="invalid")] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + datatype["Brian"] = { + "Author": "Graham Chapman", + "Description": "Not the messiah, a very naughty boy", + "VectorMembers": [ + MemberVariable( + type="DataType", + name="invalid", + description="also non-self relations are not allowed", + ) + ], } - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['VectorMembers'] = [ - MemberVariable(type='Component', name='component', - description='not working because component will not be part of the datamodel we pass')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, datatype), False) - - def test_component_valid_upstream(self): - """Test that a component from an upstream datamodel passes here""" - component = { - 'DownstreamComponent': { - 'Members': [ - MemberVariable(type='Component', name='UpstreamComponent'), - MemberVariable(array_type='Component', array_size='42', name='upstreamArray') + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["VectorMembers"] = [ + MemberVariable( + type="Component", + name="component", + description="not working because component is not part of the datamodel we pass", + ) + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, datatype), False) + + def test_component_valid_upstream(self): + """Test that a component from an upstream datamodel passes here""" + component = { + "DownstreamComponent": { + "Members": [ + MemberVariable(type="Component", name="UpstreamComponent"), + MemberVariable(array_type="Component", array_size="42", name="upstreamArray"), ] } } - upstream_dm = make_dm(self.valid_component, {}) - - self._assert_no_exception(DefinitionError, '{} should allow to use upstream components in components', - self.validate, make_dm(component, {}, options=self.def_opts), upstream_dm) - - def test_component_invalid_upstream(self): - """Test that a component does not pass if it is not available upstream or in the - current definition""" - # Valid non-array component, invalid array - component = { - 'DownstreamComponent': { - 'Members': [ - MemberVariable(type='Component', name='UpstreamComponent'), - MemberVariable(array_type='NotAvailComponent', array_size='42', name='upstreamArray') + upstream_dm = make_dm(self.valid_component, {}) + + self._assert_no_exception( + DefinitionError, + "{} should allow to use upstream components in components", + self.validate, + make_dm(component, {}, options=self.def_opts), + upstream_dm, + ) + + def test_component_invalid_upstream(self): + """Test that a component does not pass if it is not available upstream or in the + current definition""" + # Valid non-array component, invalid array + component = { + "DownstreamComponent": { + "Members": [ + MemberVariable(type="Component", name="UpstreamComponent"), + MemberVariable( + array_type="NotAvailComponent", + array_size="42", + name="upstreamArray", + ), ] } } - upstream_dm = make_dm(self.valid_component, {}) + upstream_dm = make_dm(self.valid_component, {}) - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}, self.def_opts), upstream_dm) + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}, self.def_opts), upstream_dm) - # invalid non-array component, valid array - component = { - 'DownstreamComponent': { - 'Members': [ - MemberVariable(type='NotAvailComponent', name='UpstreamComponent'), - MemberVariable(array_type='Component', array_size='42', name='upstreamArray') + # invalid non-array component, valid array + component = { + "DownstreamComponent": { + "Members": [ + MemberVariable(type="NotAvailComponent", name="UpstreamComponent"), + MemberVariable(array_type="Component", array_size="42", name="upstreamArray"), ] } } - with self.assertRaises(DefinitionError): - self.validate(make_dm(component, {}, self.def_opts), upstream_dm) - - def test_datatype_valid_upstream(self): - """Test that a datatype from an upstream datamodel passes here""" - datatype = { - 'DownstreamDatatype': { - 'Members': [ - MemberVariable(type='Component', name='UpstreamComponent', description='upstream component') + with self.assertRaises(DefinitionError): + self.validate(make_dm(component, {}, self.def_opts), upstream_dm) + + def test_datatype_valid_upstream(self): + """Test that a datatype from an upstream datamodel passes here""" + datatype = { + "DownstreamDatatype": { + "Members": [ + MemberVariable( + type="Component", + name="UpstreamComponent", + description="upstream component", + ) ], - 'Description': 'A datatype with upstream components and relations', - 'Author': 'Sophisticated datamodel authors', - 'OneToOneRelations': [ - MemberVariable(type='DataType', name='upSingleRel', description='upstream single relation') + "Description": "A datatype with upstream components and relations", + "Author": "Sophisticated datamodel authors", + "OneToOneRelations": [ + MemberVariable( + type="DataType", + name="upSingleRel", + description="upstream single relation", + ) ], - 'OneToManyRelations': [ - MemberVariable(type='DataType', name='upManyRel', description='upstream many relation') + "OneToManyRelations": [ + MemberVariable( + type="DataType", + name="upManyRel", + description="upstream many relation", + ) + ], + "VectorMembers": [ + MemberVariable( + type="Component", + name="upVector", + description="upstream component as vector member", + ) ], - 'VectorMembers': [ - MemberVariable(type='Component', name='upVector', description='upstream component as vector member') - ] } } - upstream_dm = make_dm(self.valid_component, self.valid_datatype, self.def_opts) - self._assert_no_exception(DefinitionError, '{} should allow to use to use upstream datatypes and components', - self.validate, make_dm({}, datatype, {}, self.def_opts), upstream_dm) - - def test_datatype_invalid_upstream(self): - """Test that datatypes that are not from upstream cannot be used""" - basetype = { - 'DsType': { - 'Author': 'Less sophisticated datamodel authors', - 'Description': 'A datatype trying to use non-existant upstream content' + upstream_dm = make_dm(self.valid_component, self.valid_datatype, self.def_opts) + self._assert_no_exception( + DefinitionError, + "{} should allow to use to use upstream datatypes and components", + self.validate, + make_dm({}, datatype, {}, self.def_opts), + upstream_dm, + ) + + def test_datatype_invalid_upstream(self): + """Test that datatypes that are not from upstream cannot be used""" + basetype = { + "DsType": { + "Author": "Less sophisticated datamodel authors", + "Description": "A datatype trying to use non-existant upstream content", } } - upstream_dm = make_dm(self.valid_component, self.valid_datatype, {}, self.def_opts) - - # Check for invalid members - dtype = deepcopy(basetype) - dtype['DsType']['Members'] = [MemberVariable(type='InvalidType', name='foo', description='non existant upstream')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) - - # Check relations - dtype = deepcopy(basetype) - dtype['DsType']['OneToOneRelations'] = [MemberVariable(type='InvalidType', name='foo', description='invalid')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) - - dtype = deepcopy(basetype) - dtype['DsType']['OneToManyRelations'] = [MemberVariable(type='InvalidType', name='foo', description='invalid')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) - - # vector members - dtype = deepcopy(basetype) - dtype['DsType']['VectorMembers'] = [MemberVariable(type='InvalidType', name='foo', description='invalid')] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) - - def test_interface_valid_def(self): - """Make sure that a valid interface definition inside a valid datamodel passes without exceptions""" - self._assert_no_exception(DefinitionError, '{} should not raise for a valid interface type', - self.validate, make_dm({}, self.valid_datatype, self.valid_interface), False) - - def test_interface_invalid_fields(self): - """Make sure that interface definitions do not contain any superfluous fields""" - for inv_field in ['OneToManyRelations', 'VectorMembers', 'OneToOneRelations']: - interface = deepcopy(self.valid_interface) - interface['InterfaceType'][inv_field] = ['An invalid field'] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, self.valid_datatype, interface), False) - - def test_interface_missing_fields(self): - """Make sure that interfaces have all the required types when they pass validation""" - for req in ('Author', 'Description', 'Members', 'Types'): - int_type = deepcopy(self.valid_interface) - del int_type['InterfaceType'][req] - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, self.valid_datatype, int_type), False) - - def test_interface_only_defined_datatypes(self): - """Make sure that the interface only uses defined datatypes""" - int_type = deepcopy(self.valid_interface) - int_type['InterfaceType']['Types'].append(DataType('UndefinedType')) - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, self.valid_datatype, int_type), False) - - int_type = deepcopy(self.valid_interface) - int_type['InterfaceType']['Types'].append(DataType('Component')) - - int_type = deepcopy(self.valid_interface) - int_type['InterfaceType']['Types'].append(DataType('float')) - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, self.valid_datatype, int_type), False) - - def test_interface_no_redefining_datatype(self): - """Make sure that there is no datatype already with the same name""" - int_type = { - 'DataType': { - 'Author': 'Copycat', - 'Description': 'I shall not redefine datatypes as interfaces', - 'Members': [], - 'Types': [] + upstream_dm = make_dm(self.valid_component, self.valid_datatype, {}, self.def_opts) + + # Check for invalid members + dtype = deepcopy(basetype) + dtype["DsType"]["Members"] = [ + MemberVariable(type="InvalidType", name="foo", description="non existant upstream") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) + + # Check relations + dtype = deepcopy(basetype) + dtype["DsType"]["OneToOneRelations"] = [ + MemberVariable(type="InvalidType", name="foo", description="invalid") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) + + dtype = deepcopy(basetype) + dtype["DsType"]["OneToManyRelations"] = [ + MemberVariable(type="InvalidType", name="foo", description="invalid") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) + + # vector members + dtype = deepcopy(basetype) + dtype["DsType"]["VectorMembers"] = [ + MemberVariable(type="InvalidType", name="foo", description="invalid") + ] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, dtype, options=self.def_opts), upstream_dm) + + def test_interface_valid_def(self): + """Make sure that a valid interface definition inside a valid datamodel + passes without exceptions""" + self._assert_no_exception( + DefinitionError, + "{} should not raise for a valid interface type", + self.validate, + make_dm({}, self.valid_datatype, self.valid_interface), + False, + ) + + def test_interface_invalid_fields(self): + """Make sure that interface definitions do not contain any superfluous fields""" + for inv_field in ["OneToManyRelations", "VectorMembers", "OneToOneRelations"]: + interface = deepcopy(self.valid_interface) + interface["InterfaceType"][inv_field] = ["An invalid field"] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, self.valid_datatype, interface), False) + + def test_interface_missing_fields(self): + """Make sure that interfaces have all the required types when they pass validation""" + for req in ("Author", "Description", "Members", "Types"): + int_type = deepcopy(self.valid_interface) + del int_type["InterfaceType"][req] + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, self.valid_datatype, int_type), False) + + def test_interface_only_defined_datatypes(self): + """Make sure that the interface only uses defined datatypes""" + int_type = deepcopy(self.valid_interface) + int_type["InterfaceType"]["Types"].append(DataType("UndefinedType")) + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, self.valid_datatype, int_type), False) + + int_type = deepcopy(self.valid_interface) + int_type["InterfaceType"]["Types"].append(DataType("Component")) + + int_type = deepcopy(self.valid_interface) + int_type["InterfaceType"]["Types"].append(DataType("float")) + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, self.valid_datatype, int_type), False) + + def test_interface_no_redefining_datatype(self): + """Make sure that there is no datatype already with the same name""" + int_type = { + "DataType": { + "Author": "Copycat", + "Description": "I shall not redefine datatypes as interfaces", + "Members": [], + "Types": [], } } - with self.assertRaises(DefinitionError): - self.validate(make_dm({}, self.valid_datatype, int_type), False) - - def test_datatype_uses_interface_type(self): - """Make sure that a data type can use a valid interface definition""" - datatype = deepcopy(self.valid_datatype) - datatype['DataType']['OneToManyRelations'] = [MemberVariable(type='InterfaceType', name='interfaceRelation')] - self._assert_no_exception(DefinitionError, '{} should allow to use relations to interface types', - self.validate, make_dm({}, datatype, self.valid_interface), False) - - def test_interface_valid_upstream(self): - """Make sure that we can use interface definitions from upstream models""" - # Create an upstream datamodel that contains the interface type - upstream_dm = make_dm({}, self.valid_datatype, self.valid_interface) - - # Make a downstream model datatype that uses the interface from the upstream - # but doesn't bring along its own interface definitions - datatype = {'DownstreamType': deepcopy(self.valid_datatype['DataType'])} - datatype['DownstreamType']['OneToOneRelations'] = [MemberVariable(type='InterfaceType', name='interfaceRelation')] - self._assert_no_exception(DefinitionError, '{} should allow to use interface types from an upstream datamodel', - self.validate, make_dm({}, datatype), upstream_dm) - - -if __name__ == '__main__': - unittest.main() + with self.assertRaises(DefinitionError): + self.validate(make_dm({}, self.valid_datatype, int_type), False) + + def test_datatype_uses_interface_type(self): + """Make sure that a data type can use a valid interface definition""" + datatype = deepcopy(self.valid_datatype) + datatype["DataType"]["OneToManyRelations"] = [ + MemberVariable(type="InterfaceType", name="interfaceRelation") + ] + self._assert_no_exception( + DefinitionError, + "{} should allow to use relations to interface types", + self.validate, + make_dm({}, datatype, self.valid_interface), + False, + ) + + def test_interface_valid_upstream(self): + """Make sure that we can use interface definitions from upstream models""" + # Create an upstream datamodel that contains the interface type + upstream_dm = make_dm({}, self.valid_datatype, self.valid_interface) + + # Make a downstream model datatype that uses the interface from the upstream + # but doesn't bring along its own interface definitions + datatype = {"DownstreamType": deepcopy(self.valid_datatype["DataType"])} + datatype["DownstreamType"]["OneToOneRelations"] = [ + MemberVariable(type="InterfaceType", name="interfaceRelation") + ] + self._assert_no_exception( + DefinitionError, + "{} should allow to use interface types from an upstream datamodel", + self.validate, + make_dm({}, datatype), + upstream_dm, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/podio_gen/test_DataModelJSONEncoder.py b/python/podio_gen/test_DataModelJSONEncoder.py index b69055681..f068b0336 100644 --- a/python/podio_gen/test_DataModelJSONEncoder.py +++ b/python/podio_gen/test_DataModelJSONEncoder.py @@ -8,64 +8,72 @@ def get_member_var_json(string): - """Get a MemberVariable encoded as JSON from the passed string. + """Get a MemberVariable encoded as JSON from the passed string. - Passes through the whole chain of parsing and JSON encoding, as it is done - during data model encoding. + Passes through the whole chain of parsing and JSON encoding, as it is done + during data model encoding. - Args: - string (str): The member variable definition as a string. NOTE: here it is - assumed that this is a valid string that can be parsed. + Args: + string (str): The member variable definition as a string. NOTE: here it is + assumed that this is a valid string that can be parsed. - Returns: - str: The json encoded member variable - """ - parser = MemberParser() - member_var = parser.parse(string, False) # be lenient with missing descriptions - return DataModelJSONEncoder().encode(member_var).strip('"') # strip quotes from JSON + Returns: + str: The json encoded member variable + """ + parser = MemberParser() + member_var = parser.parse(string, False) # be lenient with missing descriptions + return DataModelJSONEncoder().encode(member_var).strip('"') # strip quotes from JSON class DataModelJSONEncoderTest(unittest.TestCase): - """Unit tests for the DataModelJSONEncoder and the utility functionality in MemberVariable""" + """Unit tests for the DataModelJSONEncoder and the utility functionality in MemberVariable""" - def test_encode_only_types(self): - """Test that encoding works for type declarations only""" - for mdef in (r"float someFloat", - r"ArbitraryType name", - r"std::int16_t fixedWidth", - r"namespace::Type type"): - self.assertEqual(get_member_var_json(mdef), mdef) + def test_encode_only_types(self): + """Test that encoding works for type declarations only""" + for mdef in ( + r"float someFloat", + r"ArbitraryType name", + r"std::int16_t fixedWidth", + r"namespace::Type type", + ): + self.assertEqual(get_member_var_json(mdef), mdef) - # Fixed with without std are encoded with std namespace - fixed_w = r"int32_t fixedWidth" - self.assertEqual(get_member_var_json(fixed_w), f"std::{fixed_w}") + # Fixed with without std are encoded with std namespace + fixed_w = r"int32_t fixedWidth" + self.assertEqual(get_member_var_json(fixed_w), f"std::{fixed_w}") - def test_encode_array_types(self): - """Test that encoding array member variable declarations work""" - for mdef in (r"std::array anArray", - r"std::array fwArr", - r"std::array typeArr", - r"std::array namespacedTypeArr"): - self.assertEqual(get_member_var_json(mdef), mdef) + def test_encode_array_types(self): + """Test that encoding array member variable declarations work""" + for mdef in ( + r"std::array anArray", + r"std::array fwArr", + r"std::array typeArr", + r"std::array namespacedTypeArr", + ): + self.assertEqual(get_member_var_json(mdef), mdef) - def test_encode_default_vals(self): - """Test that encoding definitions with default values works""" - for mdef in (r"int i{42}", - r"std::uint32_t uint{64}", - r"ArbType a{123}", - r"namespace::Type t{whatever}", - r"std::array fs{3.14f, 6.28f}", - r"std::array typeArr{1, 2, 3}"): - self.assertEqual(get_member_var_json(mdef), mdef) + def test_encode_default_vals(self): + """Test that encoding definitions with default values works""" + for mdef in ( + r"int i{42}", + r"std::uint32_t uint{64}", + r"ArbType a{123}", + r"namespace::Type t{whatever}", + r"std::array fs{3.14f, 6.28f}", + r"std::array typeArr{1, 2, 3}", + ): + self.assertEqual(get_member_var_json(mdef), mdef) - def test_encode_with_description(self): - """Test that encoding definitions that contain a description works""" - for mdef in (r"int i // an unitialized int", - r"std::uint32_t ui{42} // an initialized unsigned int", - r"std::array fs // a float array", - r"std::array tA{1, 2, 3} // an initialized array of namespaced types", - r"AType type // a very special type", - r"nsp::Type nspT // a namespaced type", - r"nsp::Type nspT{with init} // an initialized namespaced type", - r"ArbitratyType arbT{42} // an initialized type"): - self.assertEqual(get_member_var_json(mdef), mdef) + def test_encode_with_description(self): + """Test that encoding definitions that contain a description works""" + for mdef in ( + r"int i // an unitialized int", + r"std::uint32_t ui{42} // an initialized unsigned int", + r"std::array fs // a float array", + r"std::array tA{1, 2, 3} // an initialized array of namespaced types", + r"AType type // a very special type", + r"nsp::Type nspT // a namespaced type", + r"nsp::Type nspT{with init} // an initialized namespaced type", + r"ArbitratyType arbT{42} // an initialized type", + ): + self.assertEqual(get_member_var_json(mdef), mdef) diff --git a/python/podio_gen/test_MemberParser.py b/python/podio_gen/test_MemberParser.py index f0d352dc1..3dc79ea26 100644 --- a/python/podio_gen/test_MemberParser.py +++ b/python/podio_gen/test_MemberParser.py @@ -10,292 +10,309 @@ class MemberParserTest(unittest.TestCase): - """Unit tests for the MemberParser""" - - def test_parse_valid(self): # pylint: disable=too-many-statements - """Test if valid member definitions pass""" - parser = MemberParser() - - parsed = parser.parse(r'float someFloat // with an additional comment') - self.assertEqual(parsed.full_type, r'float') - self.assertEqual(parsed.name, r'someFloat') - self.assertEqual(parsed.description, r'with an additional comment') - self.assertTrue(parsed.default_val is None) - self.assertEqual(parsed.julia_type, r'Float32') - - parsed = parser.parse(r'float float2 // with numbers') - self.assertEqual(parsed.full_type, r'float') - self.assertEqual(parsed.name, r'float2') - self.assertEqual(parsed.description, r'with numbers') - self.assertEqual(parsed.julia_type, r'Float32') - - parsed = parser.parse(r' float spacefloat // whitespace everywhere ') - self.assertEqual(parsed.full_type, r'float') - self.assertEqual(parsed.name, r'spacefloat') - self.assertEqual(parsed.description, 'whitespace everywhere') - self.assertEqual(parsed.julia_type, r'Float32') - - parsed = parser.parse(r'int snake_case // snake case') - self.assertEqual(parsed.full_type, r'int') - self.assertEqual(parsed.name, r'snake_case') - self.assertEqual(parsed.description, r'snake case') - self.assertEqual(parsed.julia_type, r'Int32') - - parsed = parser.parse(r'std::string mixed_UglyCase_12 // who wants this') - self.assertEqual(parsed.full_type, r'std::string') - self.assertEqual(parsed.name, r'mixed_UglyCase_12') - self.assertEqual(parsed.description, r'who wants this') - - # Check some of the trickier builtin types - parsed = parser.parse(r'unsigned long long uVar // an unsigned long variable') - self.assertEqual(parsed.full_type, r'unsigned long long') - self.assertEqual(parsed.name, r'uVar') - self.assertEqual(parsed.description, r'an unsigned long variable') - self.assertEqual(parsed.julia_type, r'UInt64') - - parsed = parser.parse(r'unsigned int uInt // an unsigned integer') - self.assertEqual(parsed.full_type, r'unsigned int') - self.assertEqual(parsed.name, r'uInt') - self.assertEqual(parsed.description, r'an unsigned integer') - self.assertEqual(parsed.julia_type, r'UInt32') - - # Fixed width integers in their various forms that they can be spelled out - # and be considered valid in our case - parsed = parser.parse(r'std::int16_t qualified // qualified fixed width ints work') - self.assertEqual(parsed.full_type, r'std::int16_t') - self.assertEqual(parsed.name, r'qualified') - self.assertEqual(parsed.description, r'qualified fixed width ints work') - self.assertTrue(parsed.is_builtin) - self.assertEqual(parsed.julia_type, r'Int16') - - parsed = parser.parse(r'std::uint64_t bits // fixed width integer types should work') - self.assertEqual(parsed.full_type, r'std::uint64_t') - self.assertEqual(parsed.name, r'bits') - self.assertEqual(parsed.description, r'fixed width integer types should work') - self.assertTrue(parsed.is_builtin) - self.assertEqual(parsed.julia_type, r'UInt64') - - parsed = parser.parse(r'int32_t fixedInt // fixed width signed integer should work') - self.assertEqual(parsed.full_type, r'std::int32_t') - self.assertEqual(parsed.name, r'fixedInt') - self.assertEqual(parsed.description, r'fixed width signed integer should work') - self.assertTrue(parsed.is_builtin) - self.assertEqual(parsed.julia_type, r'Int32') - - parsed = parser.parse(r'uint16_t fixedUInt // fixed width unsigned int with 16 bits') - self.assertEqual(parsed.full_type, r'std::uint16_t') - self.assertEqual(parsed.name, r'fixedUInt') - self.assertEqual(parsed.description, r'fixed width unsigned int with 16 bits') - self.assertTrue(parsed.is_builtin) - self.assertEqual(parsed.julia_type, r'UInt16') - - # an array definition with space everywhere it is allowed - parsed = parser.parse(r' std::array < double , 4 > someArray // a comment ') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.name, r'someArray') - self.assertEqual(parsed.description, r'a comment') - self.assertTrue(not parsed.is_builtin) - self.assertTrue(parsed.is_builtin_array) - self.assertEqual(int(parsed.array_size), 4) - self.assertEqual(parsed.array_type, r'double') - self.assertTrue(parsed.default_val is None) - self.assertEqual(parsed.julia_type, r'MVector{4, Float64}') - - # an array definition as terse as possible - parsed = parser.parse(r'std::arrayanArray//with a comment') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.name, r'anArray') - self.assertEqual(parsed.description, r'with a comment') - self.assertEqual(parsed.julia_type, r'MVector{2, Int32}') - - parsed = parser.parse('::TopLevelNamespaceType aValidType // hopefully') - self.assertEqual(parsed.full_type, '::TopLevelNamespaceType') - self.assertEqual(parsed.name, r'aValidType') - self.assertEqual(parsed.description, 'hopefully') - self.assertEqual(parsed.julia_type, r'TopLevelNamespaceType') - - parsed = parser.parse(r'std::array<::GlobalType, 1> anArray // with a top level type') - self.assertEqual(parsed.full_type, r'std::array<::GlobalType, 1>') - self.assertEqual(parsed.name, r'anArray') - self.assertEqual(parsed.description, r'with a top level type') - self.assertTrue(not parsed.is_builtin_array) - self.assertEqual(parsed.array_type, r'::GlobalType') - self.assertEqual(parsed.julia_type, r'MVector{1, GlobalTypeStruct}') - - parsed = parser.parse(r'std::array fixedWidthArray // a fixed width type array') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.name, r'fixedWidthArray') - self.assertEqual(parsed.description, r'a fixed width type array') - self.assertTrue(parsed.is_builtin_array) - self.assertEqual(parsed.array_type, r'std::int16_t') - self.assertEqual(parsed.julia_type, r'MVector{42, Int16}') - - parsed = parser.parse(r'std::array fixedWidthArray // a fixed width type array without namespace') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.name, r'fixedWidthArray') - self.assertEqual(parsed.description, r'a fixed width type array without namespace') - self.assertTrue(parsed.is_builtin_array) - self.assertEqual(parsed.array_type, r'std::uint32_t') - self.assertEqual(parsed.julia_type, r'MVector{42, UInt32}') - - def test_parse_valid_default_value(self): - """Test that member variables can be parsed correctly if they have a user - defined default value""" - parser = MemberParser() - - parsed = parser.parse(r'int fortyTwo{43} // default values can lie') - self.assertEqual(parsed.full_type, r'int') - self.assertEqual(parsed.name, r'fortyTwo') - self.assertEqual(parsed.description, 'default values can lie') - self.assertEqual(parsed.default_val, r'43') - self.assertEqual(str(parsed), 'int fortyTwo{43}; ///< default values can lie') - - parsed = parser.parse(r'float f{3.14f}', require_description=False) - self.assertEqual(parsed.full_type, 'float') - self.assertEqual(parsed.name, 'f') - self.assertEqual(parsed.default_val, '3.14f') - - parsed = parser.parse(r'std::array array{1, 2, 3} // arrays can be initialized') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.default_val, '1, 2, 3') - self.assertEqual(parsed.name, 'array') - - parsed = parser.parse(r'std::array array{1, 2, 3} // we do not have to init the complete array') - self.assertEqual(parsed.full_type, r'std::array') - self.assertEqual(parsed.default_val, r'1, 2, 3') - - # These are cases where we cannot really decide whether the initialization - # is valid just from the member declaration. We let them pass here - parsed = parser.parse('nsp::SomeValue val {42} // default values can have space') - self.assertEqual(parsed.full_type, 'nsp::SomeValue') - self.assertEqual(parsed.name, 'val') - self.assertEqual(parsed.default_val, '42') - self.assertEqual(parsed.namespace, 'nsp') - self.assertEqual(parsed.bare_type, 'SomeValue') - - parsed = parser.parse(r'edm4hep::Vector3d v{1, 2, 3, 4} // for aggregates we do not validate init values') - self.assertEqual(parsed.full_type, 'edm4hep::Vector3d') - self.assertEqual(parsed.default_val, '1, 2, 3, 4') - - # There are cases where we could technically validate this via a syntax - # check by the compiler but we don't do that because it is too costly and - # this is considered an expert feature. The generated code will not compile - parsed = parser.parse(r'AggType bogusInit{here, we can even put invalid c++}', False) - self.assertEqual(parsed.default_val, 'here, we can even put invalid c++') - - parsed = parser.parse(r'std::array array{1, 2, 3} // too many values provided') - self.assertEqual(parsed.default_val, '1, 2, 3') - - # Invalid user default value initialization - parsed = parser.parse(r'int weirdDefault{whatever, even space} // invalid c++ is not caught') - self.assertEqual(parsed.default_val, 'whatever, even space') - - parsed = parser.parse(r'int floatInit{3.14f} // implicit conversions are not allowed in aggregate initialization') - self.assertEqual(parsed.default_val, '3.14f') - - def test_parse_invalid(self): - """Test that invalid member variable definitions indeed fail during parsing""" - # setup an empty parser - parser = MemberParser() - - invalid_inputs = [ - r'int // a type without name', - r'int anIntWithoutDescription', - r'__someType name // an illformed type', - r'double 1WrongNamedDouble // an invalid name', - r'std::array', # array without name and description - r'std::array // an array without a name', - r'std::array anArrayWithoutDescription', - r'std::array<__foo, 3> anArray // with invalid type', - r'std::array array // with invalid size', - r'int another ill formed name // some comment', - r'float illFormedDefault {', - - # Some examples of valid c++ that are rejected by the validation - r'unsigned long int uLongInt // technically valid c++, but not in our builtin list', - r'::std::array a // technically valid c++, but breaks class generation', - r':: std :: array arr // also technically valid c++ but not in our case', - r'int8_t disallowed // fixed width ints with 8 bits are often aliased to signed char', - r'uint8_t disallowed // fixed width unsigned ints with 8 bits are often aliased to unsigned char', - r'int_least32_t disallowed // only allow fixed width integers with exact widths', - r'uint_fast64_t disallowed // only allow fixed width integers with exact widths', - r'std::int_least16_t disallowed // also adding a std namespace here does not make these allowed', - r'std::uint_fast16_t disallowed // also adding a std namespace here does not make these allowed', - r'std::array disallowedArray // arrays should not accept disallowed fixed width types', - - # Default values cannot be empty - r'int emptyDefault{} // valid c++, but we want an explicit default value here', - + """Unit tests for the MemberParser""" + + def test_parse_valid(self): # pylint: disable=too-many-statements + """Test if valid member definitions pass""" + parser = MemberParser() + + parsed = parser.parse(r"float someFloat // with an additional comment") + self.assertEqual(parsed.full_type, r"float") + self.assertEqual(parsed.name, r"someFloat") + self.assertEqual(parsed.description, r"with an additional comment") + self.assertTrue(parsed.default_val is None) + self.assertEqual(parsed.julia_type, r"Float32") + + parsed = parser.parse(r"float float2 // with numbers") + self.assertEqual(parsed.full_type, r"float") + self.assertEqual(parsed.name, r"float2") + self.assertEqual(parsed.description, r"with numbers") + self.assertEqual(parsed.julia_type, r"Float32") + + parsed = parser.parse(r" float spacefloat // whitespace everywhere ") + self.assertEqual(parsed.full_type, r"float") + self.assertEqual(parsed.name, r"spacefloat") + self.assertEqual(parsed.description, "whitespace everywhere") + self.assertEqual(parsed.julia_type, r"Float32") + + parsed = parser.parse(r"int snake_case // snake case") + self.assertEqual(parsed.full_type, r"int") + self.assertEqual(parsed.name, r"snake_case") + self.assertEqual(parsed.description, r"snake case") + self.assertEqual(parsed.julia_type, r"Int32") + + parsed = parser.parse(r"std::string mixed_UglyCase_12 // who wants this") + self.assertEqual(parsed.full_type, r"std::string") + self.assertEqual(parsed.name, r"mixed_UglyCase_12") + self.assertEqual(parsed.description, r"who wants this") + + # Check some of the trickier builtin types + parsed = parser.parse(r"unsigned long long uVar // an unsigned long variable") + self.assertEqual(parsed.full_type, r"unsigned long long") + self.assertEqual(parsed.name, r"uVar") + self.assertEqual(parsed.description, r"an unsigned long variable") + self.assertEqual(parsed.julia_type, r"UInt64") + + parsed = parser.parse(r"unsigned int uInt // an unsigned integer") + self.assertEqual(parsed.full_type, r"unsigned int") + self.assertEqual(parsed.name, r"uInt") + self.assertEqual(parsed.description, r"an unsigned integer") + self.assertEqual(parsed.julia_type, r"UInt32") + + # Fixed width integers in their various forms that they can be spelled out + # and be considered valid in our case + parsed = parser.parse(r"std::int16_t qualified // qualified fixed width ints work") + self.assertEqual(parsed.full_type, r"std::int16_t") + self.assertEqual(parsed.name, r"qualified") + self.assertEqual(parsed.description, r"qualified fixed width ints work") + self.assertTrue(parsed.is_builtin) + self.assertEqual(parsed.julia_type, r"Int16") + + parsed = parser.parse(r"std::uint64_t bits // fixed width integer types should work") + self.assertEqual(parsed.full_type, r"std::uint64_t") + self.assertEqual(parsed.name, r"bits") + self.assertEqual(parsed.description, r"fixed width integer types should work") + self.assertTrue(parsed.is_builtin) + self.assertEqual(parsed.julia_type, r"UInt64") + + parsed = parser.parse(r"int32_t fixedInt // fixed width signed integer should work") + self.assertEqual(parsed.full_type, r"std::int32_t") + self.assertEqual(parsed.name, r"fixedInt") + self.assertEqual(parsed.description, r"fixed width signed integer should work") + self.assertTrue(parsed.is_builtin) + self.assertEqual(parsed.julia_type, r"Int32") + + parsed = parser.parse(r"uint16_t fixedUInt // fixed width unsigned int with 16 bits") + self.assertEqual(parsed.full_type, r"std::uint16_t") + self.assertEqual(parsed.name, r"fixedUInt") + self.assertEqual(parsed.description, r"fixed width unsigned int with 16 bits") + self.assertTrue(parsed.is_builtin) + self.assertEqual(parsed.julia_type, r"UInt16") + + # an array definition with space everywhere it is allowed + parsed = parser.parse(r" std::array < double , 4 > someArray // a comment ") + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.name, r"someArray") + self.assertEqual(parsed.description, r"a comment") + self.assertTrue(not parsed.is_builtin) + self.assertTrue(parsed.is_builtin_array) + self.assertEqual(int(parsed.array_size), 4) + self.assertEqual(parsed.array_type, r"double") + self.assertTrue(parsed.default_val is None) + self.assertEqual(parsed.julia_type, r"MVector{4, Float64}") + + # an array definition as terse as possible + parsed = parser.parse(r"std::arrayanArray//with a comment") + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.name, r"anArray") + self.assertEqual(parsed.description, r"with a comment") + self.assertEqual(parsed.julia_type, r"MVector{2, Int32}") + + parsed = parser.parse("::TopLevelNamespaceType aValidType // hopefully") + self.assertEqual(parsed.full_type, "::TopLevelNamespaceType") + self.assertEqual(parsed.name, r"aValidType") + self.assertEqual(parsed.description, "hopefully") + self.assertEqual(parsed.julia_type, r"TopLevelNamespaceType") + + parsed = parser.parse(r"std::array<::GlobalType, 1> anArray // with a top level type") + self.assertEqual(parsed.full_type, r"std::array<::GlobalType, 1>") + self.assertEqual(parsed.name, r"anArray") + self.assertEqual(parsed.description, r"with a top level type") + self.assertTrue(not parsed.is_builtin_array) + self.assertEqual(parsed.array_type, r"::GlobalType") + self.assertEqual(parsed.julia_type, r"MVector{1, GlobalTypeStruct}") + + parsed = parser.parse( + r"std::array fixedWidthArray // a fixed width type array" + ) + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.name, r"fixedWidthArray") + self.assertEqual(parsed.description, r"a fixed width type array") + self.assertTrue(parsed.is_builtin_array) + self.assertEqual(parsed.array_type, r"std::int16_t") + self.assertEqual(parsed.julia_type, r"MVector{42, Int16}") + + parsed = parser.parse( + r"std::array fixedWidthArray // a fixed width type array without namespace" + ) + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.name, r"fixedWidthArray") + self.assertEqual(parsed.description, r"a fixed width type array without namespace") + self.assertTrue(parsed.is_builtin_array) + self.assertEqual(parsed.array_type, r"std::uint32_t") + self.assertEqual(parsed.julia_type, r"MVector{42, UInt32}") + + def test_parse_valid_default_value(self): + """Test that member variables can be parsed correctly if they have a user + defined default value""" + parser = MemberParser() + + parsed = parser.parse(r"int fortyTwo{43} // default values can lie") + self.assertEqual(parsed.full_type, r"int") + self.assertEqual(parsed.name, r"fortyTwo") + self.assertEqual(parsed.description, "default values can lie") + self.assertEqual(parsed.default_val, r"43") + self.assertEqual(str(parsed), "int fortyTwo{43}; ///< default values can lie") + + parsed = parser.parse(r"float f{3.14f}", require_description=False) + self.assertEqual(parsed.full_type, "float") + self.assertEqual(parsed.name, "f") + self.assertEqual(parsed.default_val, "3.14f") + + parsed = parser.parse(r"std::array array{1, 2, 3} // arrays can be initialized") + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.default_val, "1, 2, 3") + self.assertEqual(parsed.name, "array") + + parsed = parser.parse( + r"std::array array{1, 2, 3} // we do not have to init the complete array" + ) + self.assertEqual(parsed.full_type, r"std::array") + self.assertEqual(parsed.default_val, r"1, 2, 3") + + # These are cases where we cannot really decide whether the initialization + # is valid just from the member declaration. We let them pass here + parsed = parser.parse("nsp::SomeValue val {42} // default values can have space") + self.assertEqual(parsed.full_type, "nsp::SomeValue") + self.assertEqual(parsed.name, "val") + self.assertEqual(parsed.default_val, "42") + self.assertEqual(parsed.namespace, "nsp") + self.assertEqual(parsed.bare_type, "SomeValue") + + parsed = parser.parse( + r"edm4hep::Vector3d v{1, 2, 3, 4} // for aggregates we do not validate init values" + ) + self.assertEqual(parsed.full_type, "edm4hep::Vector3d") + self.assertEqual(parsed.default_val, "1, 2, 3, 4") + + # There are cases where we could technically validate this via a syntax + # check by the compiler but we don't do that because it is too costly and + # this is considered an expert feature. The generated code will not compile + parsed = parser.parse(r"AggType bogusInit{here, we can even put invalid c++}", False) + self.assertEqual(parsed.default_val, "here, we can even put invalid c++") + + parsed = parser.parse(r"std::array array{1, 2, 3} // too many values provided") + self.assertEqual(parsed.default_val, "1, 2, 3") + + # Invalid user default value initialization + parsed = parser.parse( + r"int weirdDefault{whatever, even space} // invalid c++ is not caught" + ) + self.assertEqual(parsed.default_val, "whatever, even space") + + parsed = parser.parse( + r"int floatInit{3.14f} // implicit conversions are not allowed in aggregate initialization" + ) + self.assertEqual(parsed.default_val, "3.14f") + + def test_parse_invalid(self): + """Test that invalid member variable definitions indeed fail during parsing""" + # setup an empty parser + parser = MemberParser() + + invalid_inputs = [ + r"int // a type without name", + r"int anIntWithoutDescription", + r"__someType name // an illformed type", + r"double 1WrongNamedDouble // an invalid name", + r"std::array", # array without name and description + r"std::array // an array without a name", + r"std::array anArrayWithoutDescription", + r"std::array<__foo, 3> anArray // with invalid type", + r"std::array array // with invalid size", + r"int another ill formed name // some comment", + r"float illFormedDefault {", + # Some examples of valid c++ that are rejected by the validation + r"unsigned long int uLongInt // technically valid c++, but not in our builtin list", + r"::std::array a // technically valid c++, but breaks class generation", + r":: std :: array arr // also technically valid c++ but not in our case", + r"int8_t disallowed // fixed width ints with 8 bits are often aliased to signed char", + r"uint8_t disallowed // fixed width unsigned ints with 8 bits are often aliased to unsigned char", + r"int_least32_t disallowed // only allow fixed width integers with exact widths", + r"uint_fast64_t disallowed // only allow fixed width integers with exact widths", + r"std::int_least16_t disallowed // also adding a std namespace here does not make these allowed", + r"std::uint_fast16_t disallowed // also adding a std namespace here does not make these allowed", + r"std::array disallowedArray // arrays should not accept disallowed fixed width types", + # Default values cannot be empty + r"int emptyDefault{} // valid c++, but we want an explicit default value here", ] - for inp in invalid_inputs: - try: - self.assertRaises(DefinitionError, parser.parse, inp) - except AssertionError: - # pylint: disable-next=raise-missing-from - raise AssertionError(f"'{inp}' should raise a DefinitionError from the MemberParser") - - def test_parse_valid_no_description(self): - """Test that member variable definitions are OK without description""" - parser = MemberParser() - - parsed = parser.parse('unsigned long long aLongWithoutDescription', False) - self.assertEqual(parsed.full_type, 'unsigned long long') - self.assertEqual(parsed.name, 'aLongWithoutDescription') - self.assertEqual(parsed.julia_type, r'UInt64') - - parsed = parser.parse('std::array unDescribedArray', False) - self.assertEqual(parsed.full_type, 'std::array') - self.assertEqual(parsed.name, 'unDescribedArray') - self.assertEqual(parsed.array_type, 'unsigned long') - self.assertTrue(parsed.is_builtin_array) - self.assertEqual(parsed.julia_type, r'MVector{123, UInt64}') - - parsed = parser.parse('std::array p [mm]', False) - self.assertEqual(parsed.full_type, 'std::array') - self.assertEqual(parsed.name, 'p') - self.assertEqual(parsed.array_type, 'int') - self.assertTrue(parsed.is_builtin_array) - self.assertEqual(parsed.julia_type, r'MVector{4, Int32}') - - parsed = parser.parse('unsigned long longWithReallyStupidName', False) - self.assertEqual(parsed.full_type, 'unsigned long') - self.assertEqual(parsed.name, 'longWithReallyStupidName') - self.assertEqual(parsed.julia_type, r'UInt64') - - parsed = parser.parse('NonBuiltIn aType // descriptions are not ignored even though they are not required', False) - self.assertEqual(parsed.full_type, 'NonBuiltIn') - self.assertEqual(parsed.name, 'aType') - self.assertEqual(parsed.description, 'descriptions are not ignored even though they are not required') - self.assertTrue(not parsed.is_builtin) - self.assertEqual(parsed.julia_type, r'NonBuiltIn') - - def test_parse_unit(self): - """Test that units are properly parsed""" - parser = MemberParser() - - parsed = parser.parse('unsigned long long var [GeV] // description') - self.assertEqual(parsed.unit, 'GeV') - - parsed = parser.parse('unsigned long long var{42} [GeV] // description') - self.assertEqual(parsed.unit, 'GeV') - self.assertEqual(parsed.default_val, '42') - - def test_string_representation(self): - """Test that the string representation that is used in the jinja2 templates - includes the default initialization""" - parser = MemberParser() - - parsed = parser.parse('unsigned long long var // description') - self.assertEqual(str(parsed), r'unsigned long long var{}; ///< description') - - # Also works without a description - parsed = parser.parse('SomeType memberVar', False) - self.assertEqual(str(parsed), r'SomeType memberVar{};') - - parsed = parser.parse('Type var//with very close comment') - self.assertEqual(str(parsed), r'Type var{}; ///< with very close comment') - - -if __name__ == '__main__': - unittest.main() + for inp in invalid_inputs: + try: + self.assertRaises(DefinitionError, parser.parse, inp) + except AssertionError: + # pylint: disable-next=raise-missing-from + raise AssertionError( + f"'{inp}' should raise a DefinitionError from the MemberParser" + ) + + def test_parse_valid_no_description(self): + """Test that member variable definitions are OK without description""" + parser = MemberParser() + + parsed = parser.parse("unsigned long long aLongWithoutDescription", False) + self.assertEqual(parsed.full_type, "unsigned long long") + self.assertEqual(parsed.name, "aLongWithoutDescription") + self.assertEqual(parsed.julia_type, r"UInt64") + + parsed = parser.parse("std::array unDescribedArray", False) + self.assertEqual(parsed.full_type, "std::array") + self.assertEqual(parsed.name, "unDescribedArray") + self.assertEqual(parsed.array_type, "unsigned long") + self.assertTrue(parsed.is_builtin_array) + self.assertEqual(parsed.julia_type, r"MVector{123, UInt64}") + + parsed = parser.parse("std::array p [mm]", False) + self.assertEqual(parsed.full_type, "std::array") + self.assertEqual(parsed.name, "p") + self.assertEqual(parsed.array_type, "int") + self.assertTrue(parsed.is_builtin_array) + self.assertEqual(parsed.julia_type, r"MVector{4, Int32}") + + parsed = parser.parse("unsigned long longWithReallyStupidName", False) + self.assertEqual(parsed.full_type, "unsigned long") + self.assertEqual(parsed.name, "longWithReallyStupidName") + self.assertEqual(parsed.julia_type, r"UInt64") + + parsed = parser.parse( + "NonBuiltIn aType // descriptions are not ignored even though they are not required", + False, + ) + self.assertEqual(parsed.full_type, "NonBuiltIn") + self.assertEqual(parsed.name, "aType") + self.assertEqual( + parsed.description, + "descriptions are not ignored even though they are not required", + ) + self.assertTrue(not parsed.is_builtin) + self.assertEqual(parsed.julia_type, r"NonBuiltIn") + + def test_parse_unit(self): + """Test that units are properly parsed""" + parser = MemberParser() + + parsed = parser.parse("unsigned long long var [GeV] // description") + self.assertEqual(parsed.unit, "GeV") + + parsed = parser.parse("unsigned long long var{42} [GeV] // description") + self.assertEqual(parsed.unit, "GeV") + self.assertEqual(parsed.default_val, "42") + + def test_string_representation(self): + """Test that the string representation that is used in the jinja2 templates + includes the default initialization""" + parser = MemberParser() + + parsed = parser.parse("unsigned long long var // description") + self.assertEqual(str(parsed), r"unsigned long long var{}; ///< description") + + # Also works without a description + parsed = parser.parse("SomeType memberVar", False) + self.assertEqual(str(parsed), r"SomeType memberVar{};") + + parsed = parser.parse("Type var//with very close comment") + self.assertEqual(str(parsed), r"Type var{}; ///< with very close comment") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/podio_schema_evolution.py b/python/podio_schema_evolution.py index b2c3b082b..b3b498fa2 100755 --- a/python/podio_schema_evolution.py +++ b/python/podio_schema_evolution.py @@ -13,369 +13,460 @@ class SchemaChange: - """The base class for all schema changes providing a brief description as representation""" - def __init__(self, description): - self.description = description + """The base class for all schema changes providing a brief description as representation""" - def __str__(self) -> str: - return self.description + def __init__(self, description): + self.description = description - def __repr__(self) -> str: - return self.description + def __str__(self) -> str: + return self.description + + def __repr__(self) -> str: + return self.description class AddedComponent(SchemaChange): - """Class representing an added component""" - def __init__(self, component, name): - self.component = component - self.name = name - super().__init__(f"'{self.component.name}' has been added") + """Class representing an added component""" + + def __init__(self, component, name): + self.component = component + self.name = name + super().__init__(f"'{self.component.name}' has been added") class DroppedComponent(SchemaChange): - """Class representing a dropped component""" - def __init__(self, component, name): - self.component = component - self.name = name - self.klassname = name - super().__init__(f"'{self.name}' has been dropped") + """Class representing a dropped component""" + + def __init__(self, component, name): + self.component = component + self.name = name + self.klassname = name + super().__init__(f"'{self.name}' has been dropped") class AddedDatatype(SchemaChange): - """Class representing an added datatype""" - def __init__(self, datatype, name): - self.datatype = datatype - self.name = name - self.klassname = name - super().__init__(f"'{self.name}' has been added") + """Class representing an added datatype""" + + def __init__(self, datatype, name): + self.datatype = datatype + self.name = name + self.klassname = name + super().__init__(f"'{self.name}' has been added") class DroppedDatatype(SchemaChange): - """Class representing a dropped datatype""" - def __init__(self, datatype, name): - self.datatype = datatype - self.name = name - self.klassname = name - super().__init__(f"'{self.name}' has been dropped") + """Class representing a dropped datatype""" + + def __init__(self, datatype, name): + self.datatype = datatype + self.name = name + self.klassname = name + super().__init__(f"'{self.name}' has been dropped") class RenamedDataType(SchemaChange): - """Class representing a renamed datatype""" - def __init__(self, name_old, name_new): - self.name_old = name_old - self.name_new = name_new - super().__init__(f"'{self.name_new}': datatype '{self.name_old}' renamed to '{self.name_new}'.") + """Class representing a renamed datatype""" + + def __init__(self, name_old, name_new): + self.name_old = name_old + self.name_new = name_new + super().__init__( + f"'{self.name_new}': datatype '{self.name_old}' renamed to '{self.name_new}'." + ) class AddedMember(SchemaChange): - """Class representing an added member""" - def __init__(self, member, definition_name): - self.member = member - self.definition_name = definition_name - self.klassname = definition_name - super().__init__(f"'{self.definition_name}' has an added member '{self.member.name}'") + """Class representing an added member""" + + def __init__(self, member, definition_name): + self.member = member + self.definition_name = definition_name + self.klassname = definition_name + super().__init__(f"'{self.definition_name}' has an added member '{self.member.name}'") class DroppedMember(SchemaChange): - """Class representing a dropped member""" - def __init__(self, member, definition_name): - self.member = member - self.definition_name = definition_name - self.klassname = definition_name - super().__init__(f"'{self.definition_name}' has a dropped member '{self.member.name}") + """Class representing a dropped member""" + + def __init__(self, member, definition_name): + self.member = member + self.definition_name = definition_name + self.klassname = definition_name + super().__init__(f"'{self.definition_name}' has a dropped member '{self.member.name}") class ChangedMember(SchemaChange): - """Class representing a type change in a member""" - def __init__(self, name, member_name, old_member, new_member): - self.name = name - self.member_name = member_name - self.old_member = old_member - self.new_member = new_member - self.klassname = name - super().__init__(f"'{self.name}.{self.member_name}' changed type from '+\ - '{self.old_member.full_type} to {self.new_member.full_type}") + """Class representing a type change in a member""" + + def __init__(self, name, member_name, old_member, new_member): + self.name = name + self.member_name = member_name + self.old_member = old_member + self.new_member = new_member + self.klassname = name + super().__init__( + f"'{self.name}.{self.member_name}' changed type from '+\ + '{self.old_member.full_type} to {self.new_member.full_type}" + ) class RenamedMember(SchemaChange): - """Class representing a renamed member""" - def __init__(self, name, member_name_old, member_name_new): - self.name = name - self.member_name_old = member_name_old - self.member_name_new = member_name_new - self.klassname = name - super().__init__(f"'{self.name}': member '{self.member_name_old}' renamed to '{self.member_name_new}'.") + """Class representing a renamed member""" + + def __init__(self, name, member_name_old, member_name_new): + self.name = name + self.member_name_old = member_name_old + self.member_name_new = member_name_new + self.klassname = name + super().__init__( + f"'{self.name}': member '{self.member_name_old}' renamed to '{self.member_name_new}'." + ) class RootIoRule: - """A placeholder IORule class""" - def __init__(self): - self.sourceClass = None - self.targetClass = None - self.version = None - self.source = None - self.target = None - self.code = None + """A placeholder IORule class""" + + def __init__(self): + self.sourceClass = None + self.targetClass = None + self.version = None + self.source = None + self.target = None + self.code = None def sio_filter(schema_changes): - """ - Checks what is required/supported for the SIO backend + """ + Checks what is required/supported for the SIO backend - At this point in time all schema changes have to be handled on PODIO side + At this point in time all schema changes have to be handled on PODIO side - """ - return schema_changes + """ + return schema_changes def root_filter(schema_changes): - """ - Checks what is required/supported for the ROOT backend + """ + Checks what is required/supported for the ROOT backend - At this point in time we are only interested in renames. - Everything else will be done by ROOT automatically - """ - relevant_schema_changes = [] - for schema_change in schema_changes: - if isinstance(schema_change, RenamedMember): - relevant_schema_changes.append(schema_change) - return relevant_schema_changes + At this point in time we are only interested in renames. + Everything else will be done by ROOT automatically + """ + relevant_schema_changes = [] + for schema_change in schema_changes: + if isinstance(schema_change, RenamedMember): + relevant_schema_changes.append(schema_change) + return relevant_schema_changes class DataModelComparator: - """ - Compares two datamodels and extracts required schema evolution - """ - def __init__(self, yamlfile_new, yamlfile_old, evolution_file=None) -> None: - self.yamlfile_new = yamlfile_new - self.yamlfile_old = yamlfile_old - self.evolution_file = evolution_file - self.reader = PodioConfigReader() - - self.datamodel_new = None - self.datamodel_old = None - self.detected_schema_changes = [] - self.read_schema_changes = [] - self.schema_changes = [] - - self.warnings = [] - self.errors = [] - - def compare(self) -> None: - """execute the comparison on-preloaded datamodel definitions""" - self._compare_components() - self._compare_datatypes() - self.heuristics() - - def _compare_components(self) -> None: - """compare component definitions of old and new datamodel""" - # first check for dropped, added and kept components - added_components, dropped_components, kept_components = self._compare_keys(self.datamodel_new.components.keys(), - self.datamodel_old.components.keys()) - # Make findings known globally - self.detected_schema_changes.extend([AddedComponent(self.datamodel_new.components[name], name) - for name in added_components]) - self.detected_schema_changes.extend([DroppedComponent(self.datamodel_old.components[name], name) - for name in dropped_components]) - - self._compare_definitions(kept_components, self.datamodel_new.components, self.datamodel_old.components, "Members") - - def _compare_datatypes(self) -> None: - """compare datatype definitions of old and new datamodel""" - # first check for dropped, added and kept components - added_types, dropped_types, kept_types = self._compare_keys(self.datamodel_new.datatypes.keys(), - self.datamodel_old.datatypes.keys()) - # Make findings known globally - self.detected_schema_changes.extend([AddedDatatype(self.datamodel_new.datatypes[name], name) - for name in added_types]) - self.detected_schema_changes.extend([DroppedDatatype(self.datamodel_old.datatypes[name], name) - for name in dropped_types]) - - self._compare_definitions(kept_types, self.datamodel_new.datatypes, self.datamodel_old.datatypes, "Members") - - def _compare_definitions(self, definitions, first, second, category) -> None: - """compare member definitions in old and new datamodel""" - for name in definitions: - # we are only interested in members not the extracode - members1 = {member.name: member for member in first[name][category]} - members2 = {member.name: member for member in second[name][category]} - added_members, dropped_members, kept_members = self._compare_keys(members1.keys(), - members2.keys()) - # Make findings known globally - self.detected_schema_changes.extend([AddedMember(members1[member], name) for member in added_members]) - self.detected_schema_changes.extend([DroppedMember(members2[member], name) for member in dropped_members]) - - # now let's compare old and new for the kept members - for member_name in kept_members: - new = members1[member_name] - old = members2[member_name] - if old.full_type != new.full_type: - self.detected_schema_changes.append(ChangedMember(name, member_name, old, new)) - - @staticmethod - def _compare_keys(keys1, keys2): - """compare keys of two given dicts. return added, dropped and overlapping keys""" - added = set(keys1).difference(keys2) - dropped = set(keys2).difference(keys1) - kept = set(keys1).intersection(keys2) - return added, dropped, kept - - def get_changed_schemata(self, schema_filter=None): - """return the schemata which actually changed""" - if schema_filter: - schema_changes = schema_filter(self.schema_changes) - else: - schema_changes = self.schema_changes - changed_klasses = {} - for schema_change in schema_changes: - changed_klass = changed_klasses.setdefault(schema_change.klassname, []) - changed_klass.append(schema_change) - return changed_klasses - - def heuristics_members(self, added_members, dropped_members, schema_changes): - """make analysis of member changes in a given data type """ - for dropped_member in dropped_members: - added_members_in_definition = [member for member in added_members if - dropped_member.definition_name == member.definition_name] - for added_member in added_members_in_definition: - if added_member.member.full_type == dropped_member.member.full_type: - # this is a rename candidate. So let's see whether it has been explicitly declared by the user - is_rename = False - for schema_change in self.read_schema_changes: - if isinstance(schema_change, RenamedMember) and \ - (schema_change.name == dropped_member.definition_name) and \ - (schema_change.member_name_old == dropped_member.member.name) and \ - (schema_change.member_name_new == added_member.member.name): - # remove the dropping/adding from the schema changes and replace it by the rename - schema_changes.remove(dropped_member) - schema_changes.remove(added_member) - schema_changes.append(schema_change) - is_rename = True - if not is_rename: - self.warnings.append(f"Definition '{dropped_member.definition_name}' has a potential rename " - f"'{dropped_member.member.name}' -> '{added_member.member.name}' of type " - f"'{dropped_member.member.full_type}'.") - - def heuristics(self): - """make an analysis of the data model changes: - - check which can be auto-resolved - - check which need extra information from the user - - check which one are plain forbidden/impossible """ - # let's analyse the changes in more detail - # make a copy that can be altered along the way - schema_changes = self.detected_schema_changes.copy() - # are there dropped/added member pairs that could be interpreted as rename? - dropped_members = [change for change in schema_changes if isinstance(change, DroppedMember)] - added_members = [change for change in schema_changes if isinstance(change, AddedMember)] - self.heuristics_members(added_members, dropped_members, schema_changes) - - # are the member changes actually supported/supportable? - changed_members = [change for change in schema_changes if isinstance(change, ChangedMember)] - for change in changed_members: - # changes between arrays and basic types are forbidden - if change.old_member.is_array != change.new_member.is_array: - self.errors.append(f"Forbidden schema change in '{change.name}' for '{change.member_name}' from " - f"'{change.old_member.full_type}' to '{change.new_member.full_type}'") - - # are there dropped/added datatype pairs that could be interpreted as rename? - # for now assuming no change to the individual datatype definition - # I do not think more complicated heuristics are needed at this point in time - dropped_datatypes = [change for change in schema_changes if isinstance(change, DroppedDatatype)] - added_datatypes = [change for change in schema_changes if isinstance(change, AddedDatatype)] - - for dropped in dropped_datatypes: - dropped_members = {member.name: member for member in dropped.datatype["Members"]} - is_known_evolution = False - for added in added_datatypes: - added_members = {member.name: member for member in added.datatype["Members"]} - if set(dropped_members.keys()) == set(added_members.keys()): - for schema_change in self.read_schema_changes: - if isinstance(schema_change, RenamedDataType) and \ - (schema_change.name_old == dropped.name and schema_change.name_new == added.name): - schema_changes.remove(dropped) - schema_changes.remove(added) - schema_changes.append(schema_change) - is_known_evolution = True - if not is_known_evolution: - self.warnings.append(f"Potential rename of '{dropped.name}' into '{added.name}'.") - - # are there dropped/added component pairs that could be interpreted as rename? - dropped_components = [change for change in schema_changes if isinstance(change, DroppedComponent)] - added_components = [change for change in schema_changes if isinstance(change, AddedComponent)] - - for dropped in dropped_components: - dropped_members = {member.name: member for member in dropped.component["Members"]} - for added in added_components: - added_members = {member.name: member for member in added.component["Members"]} - if set(dropped_members.keys()) == set(added_members.keys()): - self.warnings.append(f"Potential rename of '{dropped.name}' into '{added.name}'.") - - # make the results of the heuristics known to the instance - self.schema_changes = schema_changes - - def print_comparison(self): - """print the result of the datamodel comparison""" - print(f"Comparing datamodel versions {self.datamodel_new.schema_version} and {self.datamodel_old.schema_version}") - - print(f"Detected {len(self.schema_changes)} schema changes:") - for change in self.schema_changes: - print(f" - {change}") - - if len(self.warnings) > 0: - print("Warnings:") - for warning in self.warnings: - print(f" - {warning}") - - if len(self.errors) > 0: - print("ERRORS:") - for error in self.errors: - print(f" - {error}") - - def read(self) -> None: - """read datamodels from yaml files""" - self.datamodel_new = self.reader.read(self.yamlfile_new, package_name="new") - self.datamodel_old = self.reader.read(self.yamlfile_old, package_name="old") - if self.evolution_file: - self.read_evolution_file() - - def read_evolution_file(self) -> None: - """read and parse evolution file""" - supported_operations = ('member_rename', 'class_renamed_to') - with open(self.evolution_file, "r", encoding='utf-8') as stream: - content = yaml.load(stream, yaml.SafeLoader) - from_schema_version = content["from_schema_version"] - to_schema_version = content["to_schema_version"] - if (from_schema_version != self.datamodel_old.schema_version) or (to_schema_version != self.datamodel_new.schema_version): # nopep8 # noqa - raise BaseException("Versions in schema evolution file do not match versions in data model descriptions.") # nopep8 # noqa - - if "evolutions" in content: - for klassname, value in content["evolutions"].items(): - # now let's go through the various supported evolutions - for operation, details in value.items(): - if operation not in supported_operations: - raise BaseException(f'Schema evolution operation {operation} in {klassname} unknown or not supported') # nopep8 # noqa - if operation == 'member_rename': - schema_change = RenamedMember(klassname, details[0], details[1]) - self.read_schema_changes.append(schema_change) - elif operation == 'class_renamed_to': - schema_change = RenamedDataType(klassname, details) - self.read_schema_changes.append(schema_change) + Compares two datamodels and extracts required schema evolution + """ + + def __init__(self, yamlfile_new, yamlfile_old, evolution_file=None) -> None: + self.yamlfile_new = yamlfile_new + self.yamlfile_old = yamlfile_old + self.evolution_file = evolution_file + self.reader = PodioConfigReader() + + self.datamodel_new = None + self.datamodel_old = None + self.detected_schema_changes = [] + self.read_schema_changes = [] + self.schema_changes = [] + + self.warnings = [] + self.errors = [] + + def compare(self) -> None: + """execute the comparison on-preloaded datamodel definitions""" + self._compare_components() + self._compare_datatypes() + self.heuristics() + + def _compare_components(self) -> None: + """compare component definitions of old and new datamodel""" + # first check for dropped, added and kept components + added_components, dropped_components, kept_components = self._compare_keys( + self.datamodel_new.components.keys(), self.datamodel_old.components.keys() + ) + # Make findings known globally + self.detected_schema_changes.extend( + [ + AddedComponent(self.datamodel_new.components[name], name) + for name in added_components + ] + ) + self.detected_schema_changes.extend( + [ + DroppedComponent(self.datamodel_old.components[name], name) + for name in dropped_components + ] + ) + + self._compare_definitions( + kept_components, + self.datamodel_new.components, + self.datamodel_old.components, + "Members", + ) + + def _compare_datatypes(self) -> None: + """compare datatype definitions of old and new datamodel""" + # first check for dropped, added and kept components + added_types, dropped_types, kept_types = self._compare_keys( + self.datamodel_new.datatypes.keys(), self.datamodel_old.datatypes.keys() + ) + # Make findings known globally + self.detected_schema_changes.extend( + [AddedDatatype(self.datamodel_new.datatypes[name], name) for name in added_types] + ) + self.detected_schema_changes.extend( + [DroppedDatatype(self.datamodel_old.datatypes[name], name) for name in dropped_types] + ) + + self._compare_definitions( + kept_types, + self.datamodel_new.datatypes, + self.datamodel_old.datatypes, + "Members", + ) + + def _compare_definitions(self, definitions, first, second, category) -> None: + """compare member definitions in old and new datamodel""" + for name in definitions: + # we are only interested in members not the extracode + members1 = {member.name: member for member in first[name][category]} + members2 = {member.name: member for member in second[name][category]} + added_members, dropped_members, kept_members = self._compare_keys( + members1.keys(), members2.keys() + ) + # Make findings known globally + self.detected_schema_changes.extend( + [AddedMember(members1[member], name) for member in added_members] + ) + self.detected_schema_changes.extend( + [DroppedMember(members2[member], name) for member in dropped_members] + ) + + # now let's compare old and new for the kept members + for member_name in kept_members: + new = members1[member_name] + old = members2[member_name] + if old.full_type != new.full_type: + self.detected_schema_changes.append(ChangedMember(name, member_name, old, new)) + + @staticmethod + def _compare_keys(keys1, keys2): + """compare keys of two given dicts. return added, dropped and overlapping keys""" + added = set(keys1).difference(keys2) + dropped = set(keys2).difference(keys1) + kept = set(keys1).intersection(keys2) + return added, dropped, kept + + def get_changed_schemata(self, schema_filter=None): + """return the schemata which actually changed""" + if schema_filter: + schema_changes = schema_filter(self.schema_changes) + else: + schema_changes = self.schema_changes + changed_klasses = {} + for schema_change in schema_changes: + changed_klass = changed_klasses.setdefault(schema_change.klassname, []) + changed_klass.append(schema_change) + return changed_klasses + + def heuristics_members(self, added_members, dropped_members, schema_changes): + """make analysis of member changes in a given data type""" + for dropped_member in dropped_members: + added_members_in_definition = [ + member + for member in added_members + if dropped_member.definition_name == member.definition_name + ] + for added_member in added_members_in_definition: + if added_member.member.full_type == dropped_member.member.full_type: + # this is a rename candidate. So let's see whether it has + # been explicitly declared by the user + is_rename = False + for schema_change in self.read_schema_changes: + if ( + isinstance(schema_change, RenamedMember) + and (schema_change.name == dropped_member.definition_name) + and (schema_change.member_name_old == dropped_member.member.name) + and (schema_change.member_name_new == added_member.member.name) + ): + # remove the dropping/adding from the schema changes + # and replace it by the rename + schema_changes.remove(dropped_member) + schema_changes.remove(added_member) + schema_changes.append(schema_change) + is_rename = True + if not is_rename: + self.warnings.append( + f"Definition '{dropped_member.definition_name}' has a potential " + f"rename: '{dropped_member.member.name}' -> " + f"'{added_member.member.name}' of type " + f"'{dropped_member.member.full_type}'." + ) + + def heuristics(self): + """make an analysis of the data model changes: + - check which can be auto-resolved + - check which need extra information from the user + - check which one are plain forbidden/impossible + """ + # let's analyse the changes in more detail + # make a copy that can be altered along the way + schema_changes = self.detected_schema_changes.copy() + # are there dropped/added member pairs that could be interpreted as rename? + dropped_members = [ + change for change in schema_changes if isinstance(change, DroppedMember) + ] + added_members = [change for change in schema_changes if isinstance(change, AddedMember)] + self.heuristics_members(added_members, dropped_members, schema_changes) + + # are the member changes actually supported/supportable? + changed_members = [ + change for change in schema_changes if isinstance(change, ChangedMember) + ] + for change in changed_members: + # changes between arrays and basic types are forbidden + if change.old_member.is_array != change.new_member.is_array: + self.errors.append( + f"Forbidden schema change in '{change.name}' for '{change.member_name}' from " + f"'{change.old_member.full_type}' to '{change.new_member.full_type}'" + ) + + # are there dropped/added datatype pairs that could be interpreted as rename? + # for now assuming no change to the individual datatype definition + # I do not think more complicated heuristics are needed at this point in time + dropped_datatypes = [ + change for change in schema_changes if isinstance(change, DroppedDatatype) + ] + added_datatypes = [ + change for change in schema_changes if isinstance(change, AddedDatatype) + ] + + for dropped in dropped_datatypes: + dropped_members = {member.name: member for member in dropped.datatype["Members"]} + is_known_evolution = False + for added in added_datatypes: + added_members = {member.name: member for member in added.datatype["Members"]} + if set(dropped_members.keys()) == set(added_members.keys()): + for schema_change in self.read_schema_changes: + if isinstance(schema_change, RenamedDataType) and ( + schema_change.name_old == dropped.name + and schema_change.name_new == added.name + ): + schema_changes.remove(dropped) + schema_changes.remove(added) + schema_changes.append(schema_change) + is_known_evolution = True + if not is_known_evolution: + self.warnings.append( + f"Potential rename of '{dropped.name}' into '{added.name}'." + ) + + # are there dropped/added component pairs that could be interpreted as rename? + dropped_components = [ + change for change in schema_changes if isinstance(change, DroppedComponent) + ] + added_components = [ + change for change in schema_changes if isinstance(change, AddedComponent) + ] + + for dropped in dropped_components: + dropped_members = {member.name: member for member in dropped.component["Members"]} + for added in added_components: + added_members = {member.name: member for member in added.component["Members"]} + if set(dropped_members.keys()) == set(added_members.keys()): + self.warnings.append( + f"Potential rename of '{dropped.name}' into '{added.name}'." + ) + + # make the results of the heuristics known to the instance + self.schema_changes = schema_changes + + def print_comparison(self): + """print the result of the datamodel comparison""" + print( + f"Comparing datamodel versions {self.datamodel_new.schema_version}" + f" and {self.datamodel_old.schema_version}" + ) + + print(f"Detected {len(self.schema_changes)} schema changes:") + for change in self.schema_changes: + print(f" - {change}") + + if len(self.warnings) > 0: + print("Warnings:") + for warning in self.warnings: + print(f" - {warning}") + + if len(self.errors) > 0: + print("ERRORS:") + for error in self.errors: + print(f" - {error}") + + def read(self) -> None: + """read datamodels from yaml files""" + self.datamodel_new = self.reader.read(self.yamlfile_new, package_name="new") + self.datamodel_old = self.reader.read(self.yamlfile_old, package_name="old") + if self.evolution_file: + self.read_evolution_file() + + def read_evolution_file(self) -> None: + """read and parse evolution file""" + supported_operations = ("member_rename", "class_renamed_to") + with open(self.evolution_file, "r", encoding="utf-8") as stream: + content = yaml.load(stream, yaml.SafeLoader) + from_schema_version = content["from_schema_version"] + to_schema_version = content["to_schema_version"] + if (from_schema_version != self.datamodel_old.schema_version) or ( + to_schema_version != self.datamodel_new.schema_version + ): + raise BaseException( + "Versions in schema evolution file do not match versions in " + "data model descriptions." + ) + + if "evolutions" in content: + for klassname, value in content["evolutions"].items(): + # now let's go through the various supported evolutions + for operation, details in value.items(): + if operation not in supported_operations: + raise BaseException( + f"Schema evolution operation {operation} in {klassname} unknown" + " or not supported" + ) + if operation == "member_rename": + schema_change = RenamedMember(klassname, details[0], details[1]) + self.read_schema_changes.append(schema_change) + elif operation == "class_renamed_to": + schema_change = RenamedDataType(klassname, details) + self.read_schema_changes.append(schema_change) ########################## if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description='Given two yaml files this script analyzes ' - 'the difference of the two datamodels') - - parser.add_argument('new', help='yaml file describing the new datamodel') - parser.add_argument('old', help='yaml file describing the old datamodel') - parser.add_argument('-e', '--evo', help='yaml file clarifying schema evolutions', action='store') - args = parser.parse_args() - - comparator = DataModelComparator(args.new, args.old, evolution_file=args.evo) - comparator.read() - comparator.compare() - comparator.print_comparison() - print(comparator.get_changed_schemata(schema_filter=root_filter)) + import argparse + + parser = argparse.ArgumentParser( + description="Given two yaml files this script analyzes " + "the difference of the two datamodels" + ) + + parser.add_argument("new", help="yaml file describing the new datamodel") + parser.add_argument("old", help="yaml file describing the old datamodel") + parser.add_argument( + "-e", "--evo", help="yaml file clarifying schema evolutions", action="store" + ) + args = parser.parse_args() + + comparator = DataModelComparator(args.new, args.old, evolution_file=args.evo) + comparator.read() + comparator.compare() + comparator.print_comparison() + print(comparator.get_changed_schemata(schema_filter=root_filter)) diff --git a/tests/write_frame.py b/tests/write_frame.py index 313c72a9e..1ef8564ff 100644 --- a/tests/write_frame.py +++ b/tests/write_frame.py @@ -7,74 +7,74 @@ import ROOT if ROOT.gSystem.Load("libTestDataModelDict.so") < 0: # noqa: E402 - raise RuntimeError("Could not load TestDataModel dictionary") + raise RuntimeError("Could not load TestDataModel dictionary") from ROOT import ( # pylint: disable=wrong-import-position ExampleHitCollection, ExampleClusterCollection, - ) # noqa: E402 +) # noqa: E402 from podio import Frame # pylint: disable=wrong-import-position def create_hit_collection(): - """Create a simple hit collection with two hits for testing""" - hits = ExampleHitCollection() - hits.create(0xBAD, 0.0, 0.0, 0.0, 23.0) - hits.create(0xCAFFEE, 1.0, 0.0, 0.0, 12.0) + """Create a simple hit collection with two hits for testing""" + hits = ExampleHitCollection() + hits.create(0xBAD, 0.0, 0.0, 0.0, 23.0) + hits.create(0xCAFFEE, 1.0, 0.0, 0.0, 12.0) - return hits + return hits def create_cluster_collection(): - """Create a simple cluster collection with two clusters""" - clusters = ExampleClusterCollection() - clu0 = clusters.create() - clu0.energy(3.14) - clu1 = clusters.create() - clu1.energy(1.23) + """Create a simple cluster collection with two clusters""" + clusters = ExampleClusterCollection() + clu0 = clusters.create() + clu0.energy(3.14) + clu1 = clusters.create() + clu1.energy(1.23) - return clusters + return clusters def create_frame(): - """Create a frame with an ExampleHit and an ExampleCluster collection""" - frame = Frame() - hits = create_hit_collection() - frame.put(hits, "hits_from_python") - clusters = create_cluster_collection() - frame.put(clusters, "clusters_from_python") + """Create a frame with an ExampleHit and an ExampleCluster collection""" + frame = Frame() + hits = create_hit_collection() + frame.put(hits, "hits_from_python") + clusters = create_cluster_collection() + frame.put(clusters, "clusters_from_python") - frame.put_parameter("an_int", 42) - frame.put_parameter("some_floats", [1.23, 7.89, 3.14]) - frame.put_parameter("greetings", ["from", "python"]) - frame.put_parameter("real_float", 3.14, as_type="float") - frame.put_parameter("more_real_floats", [1.23, 4.56, 7.89], as_type="float") + frame.put_parameter("an_int", 42) + frame.put_parameter("some_floats", [1.23, 7.89, 3.14]) + frame.put_parameter("greetings", ["from", "python"]) + frame.put_parameter("real_float", 3.14, as_type="float") + frame.put_parameter("more_real_floats", [1.23, 4.56, 7.89], as_type="float") - return frame + return frame def write_file(writer_type, filename): - """Write a file using the given Writer type and put one Frame into it under - the events category - """ - io_backend, writer_name = writer_type.split(".") - io_module = importlib.import_module(f"podio.{io_backend}") + """Write a file using the given Writer type and put one Frame into it under + the events category + """ + io_backend, writer_name = writer_type.split(".") + io_module = importlib.import_module(f"podio.{io_backend}") - writer = getattr(io_module, writer_name)(filename) - event = create_frame() - writer.write_frame(event, "events") + writer = getattr(io_module, writer_name)(filename) + event = create_frame() + writer.write_frame(event, "events") if __name__ == "__main__": - import argparse + import argparse - parser = argparse.ArgumentParser() - parser.add_argument("outputfile", help="Output file name") - parser.add_argument("writer", help="The writer type to use") + parser = argparse.ArgumentParser() + parser.add_argument("outputfile", help="Output file name") + parser.add_argument("writer", help="The writer type to use") - args = parser.parse_args() + args = parser.parse_args() - io_format = args.outputfile.split(".")[-1] + io_format = args.outputfile.split(".")[-1] - write_file(args.writer, args.outputfile) + write_file(args.writer, args.outputfile) diff --git a/tools/podio-dump b/tools/podio-dump index ede7a148e..b0d5bbeae 100755 --- a/tools/podio-dump +++ b/tools/podio-dump @@ -11,165 +11,185 @@ from podio_version import __version__ def print_general_info(reader, filename): - """Print an overview of the file contents at the very beginning. - - This prints things like the available categories (and how many entries they - have) as well as the filename, etc. - - Args: - reader (root_io.Reader, sio_io.Reader): An initialized reader - """ - legacy_text = ' (this is a legacy file!)' if reader.is_legacy else '' - print(f'input file: {filename}{legacy_text}\n') - print(f'datamodel model definitions stored in this file: {", ".join(reader.datamodel_definitions)}') - print() - print('Frame categories in this file:') - print(f'{"Name":<20} {"Entries":<10}') - print('-' * 31) - for category in reader.categories: - print(f'{category:<20} {len(reader.get(category)):<10}') - print() + """Print an overview of the file contents at the very beginning. + + This prints things like the available categories (and how many entries they + have) as well as the filename, etc. + + Args: + reader (root_io.Reader, sio_io.Reader): An initialized reader + """ + legacy_text = " (this is a legacy file!)" if reader.is_legacy else "" + print(f"input file: {filename}{legacy_text}\n") + print( + "datamodel model definitions stored in this file: " + f'{", ".join(reader.datamodel_definitions)}' + ) + print() + print("Frame categories in this file:") + print(f'{"Name":<20} {"Entries":<10}') + print("-" * 31) + for category in reader.categories: + print(f"{category:<20} {len(reader.get(category)):<10}") + print() def print_frame_detailed(frame): - """Print the Frame in all its glory, dumping every collection via print - - Args: - frame (podio.Frame): The frame to print - """ - print('Collections:') - for name in sorted(frame.getAvailableCollections(), key=str.casefold): - coll = frame.get(name) - print(name, flush=True) - coll.print() + """Print the Frame in all its glory, dumping every collection via print + + Args: + frame (podio.Frame): The frame to print + """ + print("Collections:") + for name in sorted(frame.getAvailableCollections(), key=str.casefold): + coll = frame.get(name) + print(name, flush=True) + coll.print() + print(flush=True) + + print("\nParameters:", flush=True) + frame.get_parameters().print() print(flush=True) - print('\nParameters:', flush=True) - frame.get_parameters().print() - print(flush=True) - def print_frame_overview(frame): - """Print a Frame overview, dumping just collection names, types and sizes - - Args: - frame (podio.Frame): The frame to print - """ - rows = [] - for name in sorted(frame.getAvailableCollections(), key=str.casefold): - coll = frame.get(name) - rows.append( - (name, coll.getValueTypeName().data(), len(coll), f'{coll.getID():0>8x}') - ) - print('Collections:') - print(tabulate(rows, headers=["Name", "ValueType", "Size", "ID"])) - - rows = [] - for name in sorted(frame.parameters, key=str.casefold): - for par_type, n_pars in frame.get_param_info(name).items(): - rows.append([name, par_type, n_pars]) - print('\nParameters:') - print(tabulate(rows, headers=["Name", "Type", "Elements"])) + """Print a Frame overview, dumping just collection names, types and sizes + + Args: + frame (podio.Frame): The frame to print + """ + rows = [] + for name in sorted(frame.getAvailableCollections(), key=str.casefold): + coll = frame.get(name) + rows.append((name, coll.getValueTypeName().data(), len(coll), f"{coll.getID():0>8x}")) + print("Collections:") + print(tabulate(rows, headers=["Name", "ValueType", "Size", "ID"])) + + rows = [] + for name in sorted(frame.parameters, key=str.casefold): + for par_type, n_pars in frame.get_param_info(name).items(): + rows.append([name, par_type, n_pars]) + print("\nParameters:") + print(tabulate(rows, headers=["Name", "Type", "Elements"])) def print_frame(frame, cat_name, ientry, detailed): - """Print a Frame. - - Args: - frame (podio.Frame): The frame to print - cat_name (str): The category name - ientry (int): The entry number of this Frame - detailed (bool): Print just an overview or dump the whole contents - """ - print('{:#^82}'.format(f' {cat_name}: {ientry} ')) # pylint: disable=consider-using-f-string - - if detailed: - print_frame_detailed(frame) - else: - print_frame_overview(frame) + """Print a Frame. + + Args: + frame (podio.Frame): The frame to print + cat_name (str): The category name + ientry (int): The entry number of this Frame + detailed (bool): Print just an overview or dump the whole contents + """ + print("{:#^82}".format(f" {cat_name}: {ientry} ")) # pylint: disable=consider-using-f-string + + if detailed: + print_frame_detailed(frame) + else: + print_frame_overview(frame) - # Additional new line before the next entry - print('\n', flush=True) + # Additional new line before the next entry + print("\n", flush=True) def dump_model(reader, model_name): - """Dump the model in yaml format""" - if model_name not in reader.datamodel_definitions: - print(f'ERROR: Cannot dump model \'{model_name}\' (not present in file)') - return False + """Dump the model in yaml format""" + if model_name not in reader.datamodel_definitions: + print(f"ERROR: Cannot dump model '{model_name}' (not present in file)") + return False - model_def = json.loads(reader.get_datamodel_definition(model_name)) - print(yaml.dump(model_def, sort_keys=False, default_flow_style=False)) + model_def = json.loads(reader.get_datamodel_definition(model_name)) + print(yaml.dump(model_def, sort_keys=False, default_flow_style=False)) - return True + return True def main(args): - """Main""" - from podio.reading import get_reader # pylint: disable=import-outside-toplevel - try: - reader = get_reader(args.inputfile) - except ValueError as err: - print(f'ERROR: Cannot open file \'{args.inputfile}\': {err}') - sys.exit(1) - - if args.dump_edm is not None: - if dump_model(reader, args.dump_edm): - sys.exit(0) - else: - sys.exit(1) - - print_general_info(reader, args.inputfile) - if args.category not in reader.categories: - print(f'ERROR: Cannot print category \'{args.category}\' (not present in file)') - sys.exit(1) + """Main""" + from podio.reading import get_reader # pylint: disable=import-outside-toplevel - frames = reader.get(args.category) - for ient in args.entries: try: - print_frame(frames[ient], args.category, ient, args.detailed) - except IndexError: - print(f'WARNING: Entry no. {ient} in "{args.category}" not present in the file!') + reader = get_reader(args.inputfile) + except ValueError as err: + print(f"ERROR: Cannot open file '{args.inputfile}': {err}") + sys.exit(1) + + if args.dump_edm is not None: + if dump_model(reader, args.dump_edm): + sys.exit(0) + else: + sys.exit(1) + + print_general_info(reader, args.inputfile) + if args.category not in reader.categories: + print(f"ERROR: Cannot print category '{args.category}' (not present in file)") + sys.exit(1) + + frames = reader.get(args.category) + for ient in args.entries: + try: + print_frame(frames[ient], args.category, ient, args.detailed) + except IndexError: + print(f'WARNING: Entry no. {ient} in "{args.category}" not present in the file!') def parse_entry_range(ent_string): - """Parse which entries to print""" - try: - return [int(ent_string)] - except ValueError: - pass - - try: - return [int(i) for i in ent_string.split(',')] - except ValueError: - pass - - try: - first, last = [int(i) for i in ent_string.split(':')] - return list(range(first, last + 1)) - except ValueError: - pass - - raise argparse.ArgumentTypeError(f'\'{ent_string}\' cannot be parsed into a list of entries') - - -if __name__ == '__main__': - import argparse - # pylint: disable=invalid-name # before 2.5.0 pylint is too strict with the naming here - parser = argparse.ArgumentParser(description='Dump contents of a podio file to stdout') - parser.add_argument('inputfile', help='Name of the file to dump content from') - parser.add_argument('-c', '--category', help='Which Frame category to dump', - default='events', type=str) - parser.add_argument('-e', '--entries', - help='Which entries to print. A single number, comma separated list of numbers' - ' or "first:last" for an inclusive range of entries. Defaults to the first entry.', - type=parse_entry_range, default=[0]) - parser.add_argument('-d', '--detailed', help='Dump the full contents not just the collection info', - action='store_true', default=False) - parser.add_argument('--dump-edm', - help='Dump the specified EDM definition from the file in yaml format', - type=str, default=None) - parser.add_argument('--version', action='version', version=f'podio {__version__}') - - clargs = parser.parse_args() - main(clargs) + """Parse which entries to print""" + try: + return [int(ent_string)] + except ValueError: + pass + + try: + return [int(i) for i in ent_string.split(",")] + except ValueError: + pass + + try: + first, last = [int(i) for i in ent_string.split(":")] + return list(range(first, last + 1)) + except ValueError: + pass + + raise argparse.ArgumentTypeError(f"'{ent_string}' cannot be parsed into a list of entries") + + +if __name__ == "__main__": + import argparse + + # pylint: disable=invalid-name # before 2.5.0 pylint is too strict with the naming here + parser = argparse.ArgumentParser(description="Dump contents of a podio file to stdout") + parser.add_argument("inputfile", help="Name of the file to dump content from") + parser.add_argument( + "-c", + "--category", + help="Which Frame category to dump", + default="events", + type=str, + ) + parser.add_argument( + "-e", + "--entries", + help="Which entries to print. A single number, comma separated list of numbers" + ' or "first:last" for an inclusive range of entries. Defaults to the first entry.', + type=parse_entry_range, + default=[0], + ) + parser.add_argument( + "-d", + "--detailed", + help="Dump the full contents not just the collection info", + action="store_true", + default=False, + ) + parser.add_argument( + "--dump-edm", + help="Dump the specified EDM definition from the file in yaml format", + type=str, + default=None, + ) + parser.add_argument("--version", action="version", version=f"podio {__version__}") + + clargs = parser.parse_args() + main(clargs) diff --git a/tools/podio-ttree-to-rntuple b/tools/podio-ttree-to-rntuple index ace64514d..2e32601d1 100755 --- a/tools/podio-ttree-to-rntuple +++ b/tools/podio-ttree-to-rntuple @@ -4,21 +4,24 @@ import argparse import podio.root_io -parser = argparse.ArgumentParser(description='podio-ttree-to-rntuple tool to create' - 'an rntuple file from a ttree file or viceversa') -parser.add_argument('input_file', help='input file') -parser.add_argument('output_file', help='output file') -parser.add_argument('-r', '--reverse', action='store_true', - help='reverse the conversion (from RNTuple to TTree)') +parser = argparse.ArgumentParser( + description="podio-ttree-to-rntuple tool to create" + "an rntuple file from a ttree file or viceversa" +) +parser.add_argument("input_file", help="input file") +parser.add_argument("output_file", help="output file") +parser.add_argument( + "-r", "--reverse", action="store_true", help="reverse the conversion (from RNTuple to TTree)" +) args = parser.parse_args() if not args.reverse: - reader = podio.root_io.Reader(args.input_file) - writer = podio.root_io.RNTupleWriter(args.output_file) + reader = podio.root_io.Reader(args.input_file) + writer = podio.root_io.RNTupleWriter(args.output_file) else: - reader = podio.root_io.RNTupleReader(args.input_file) - writer = podio.root_io.Writer(args.output_file) + reader = podio.root_io.RNTupleReader(args.input_file) + writer = podio.root_io.Writer(args.output_file) for category in reader.categories: - for frame in reader.get(category): - writer.write_frame(frame, category) + for frame in reader.get(category): + writer.write_frame(frame, category) diff --git a/tools/podio-vis b/tools/podio-vis index 667b5feab..f0e1a3334 100755 --- a/tools/podio-vis +++ b/tools/podio-vis @@ -1,109 +1,128 @@ #!/usr/bin/env python3 -'''Tool to transform data model descriptions in YAML to a graph that can be visualized''' +"""Tool to transform data model descriptions in YAML to a graph that can be visualized""" import sys import argparse import yaml from podio_gen.podio_config_reader import PodioConfigReader + try: - from graphviz import Digraph + from graphviz import Digraph except ImportError: - print('graphviz is not installed. please run pip install graphviz') - sys.exit(1) + print("graphviz is not installed. please run pip install graphviz") + sys.exit(1) class ModelToGraphviz: - """Class to transform a data model description into a graphical representation""" - - def __init__(self, yamlfile, dot, fmt, filename, graph_conf): - self.yamlfile = yamlfile - self.use_dot = dot - self.datamodel = PodioConfigReader.read(yamlfile, 'podio') - self.graph = Digraph(node_attr={'shape': 'box'}) - self.graph.attr(rankdir='RL', size='8,5') - self.fmt = fmt - self.filename = filename - self.graph_conf = {} - self.remove = set() - if graph_conf: - with open(graph_conf, encoding='utf8') as inp: - self.graph_conf = yaml.safe_load(inp) - if 'Filter' in self.graph_conf: - self.remove = set(self.graph_conf['Filter']) - - def make_viz(self): - '''Make the graph and render it in the chosen format''' - - # Make the grouped nodes first - # It doesn't matter if they are remade latter so we don't need - # to check for that - for i, (label, group) in enumerate(self.graph_conf.items()): - with self.graph.subgraph(name=f'cluster{i+1}') as subgraph: - subgraph.attr(label=label) - for name in group: - if name in self.remove: - continue - subgraph.node(name.replace('::', '_'), label=name) - - with_association = False - for name, datatype in self.datamodel.datatypes.items(): - if name in self.remove: - continue - if 'Association' in name: - with_association = True - self.graph.edge(datatype['OneToOneRelations'][0].full_type.replace('::', '_'), - datatype['OneToOneRelations'][1].full_type.replace('::', '_'), - label=name.replace('edm4hep::', ''), color='black', dir='both') - continue - - compatible_name = name.replace('::', '_') # graphviz gets confused with C++ '::' and formatting strings - self.graph.node(compatible_name, label=name) - self.graph.attr('edge', color='blue') - for rel in datatype["OneToOneRelations"]: - if rel.full_type in self.remove: - continue - compatible_type = rel.full_type.replace('::', '_') - self.graph.edge(compatible_name, compatible_type) - self.graph.attr('edge', color='red') - for rel in datatype["OneToManyRelations"]: - if rel.full_type in self.remove: - continue - compatible_type = rel.full_type.replace('::', '_') - self.graph.edge(compatible_name, compatible_type) - - with self.graph.subgraph(name='cluster0') as subg: - subg.attr('node', shape='plaintext') - subg.node('l1', '') - subg.node('r1', 'One to One Relation') - subg.edge('l1', 'r1', color='blue') - subg.node('l2', '') - subg.node('r2', 'One to Many Relation') - subg.edge('l2', 'r2', color='red') - if with_association: - subg.node('r3', 'Association') - subg.node('l3', '') - subg.edge('l3', 'r3', color='black', dir='both') - - if self.use_dot: - self.graph.save() - else: - print(f'Saving file {self.filename} and {self.filename}.{self.fmt}') - self.graph.render(filename=self.filename, format=self.fmt) + """Class to transform a data model description into a graphical representation""" + + def __init__(self, yamlfile, dot, fmt, filename, graph_conf): + self.yamlfile = yamlfile + self.use_dot = dot + self.datamodel = PodioConfigReader.read(yamlfile, "podio") + self.graph = Digraph(node_attr={"shape": "box"}) + self.graph.attr(rankdir="RL", size="8,5") + self.fmt = fmt + self.filename = filename + self.graph_conf = {} + self.remove = set() + if graph_conf: + with open(graph_conf, encoding="utf8") as inp: + self.graph_conf = yaml.safe_load(inp) + if "Filter" in self.graph_conf: + self.remove = set(self.graph_conf["Filter"]) + + def make_viz(self): + """Make the graph and render it in the chosen format""" + + # Make the grouped nodes first + # It doesn't matter if they are remade latter so we don't need + # to check for that + for i, (label, group) in enumerate(self.graph_conf.items()): + with self.graph.subgraph(name=f"cluster{i+1}") as subgraph: + subgraph.attr(label=label) + for name in group: + if name in self.remove: + continue + subgraph.node(name.replace("::", "_"), label=name) + + with_association = False + for name, datatype in self.datamodel.datatypes.items(): + if name in self.remove: + continue + if "Association" in name: + with_association = True + self.graph.edge( + datatype["OneToOneRelations"][0].full_type.replace("::", "_"), + datatype["OneToOneRelations"][1].full_type.replace("::", "_"), + label=name.replace("edm4hep::", ""), + color="black", + dir="both", + ) + continue + + compatible_name = name.replace( + "::", "_" + ) # graphviz gets confused with C++ '::' and formatting strings + self.graph.node(compatible_name, label=name) + self.graph.attr("edge", color="blue") + for rel in datatype["OneToOneRelations"]: + if rel.full_type in self.remove: + continue + compatible_type = rel.full_type.replace("::", "_") + self.graph.edge(compatible_name, compatible_type) + self.graph.attr("edge", color="red") + for rel in datatype["OneToManyRelations"]: + if rel.full_type in self.remove: + continue + compatible_type = rel.full_type.replace("::", "_") + self.graph.edge(compatible_name, compatible_type) + + with self.graph.subgraph(name="cluster0") as subg: + subg.attr("node", shape="plaintext") + subg.node("l1", "") + subg.node("r1", "One to One Relation") + subg.edge("l1", "r1", color="blue") + subg.node("l2", "") + subg.node("r2", "One to Many Relation") + subg.edge("l2", "r2", color="red") + if with_association: + subg.node("r3", "Association") + subg.node("l3", "") + subg.edge("l3", "r3", color="black", dir="both") + + if self.use_dot: + self.graph.save() + else: + print(f"Saving file {self.filename} and {self.filename}.{self.fmt}") + self.graph.render(filename=self.filename, format=self.fmt) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Given a description yaml file this script generates ' - 'a visualization in the target directory') + parser = argparse.ArgumentParser( + description="Given a description yaml file this script generates " + "a visualization in the target directory" + ) - parser.add_argument('description', help='yaml file describing the datamodel') - parser.add_argument('-d', '--dot', action='store_true', default=False, - help='just write the dot file') - parser.add_argument('--fmt', default='svg', help='Which format to use for saving the file') - parser.add_argument('--filename', default='gv', help='Which filename to use for the output') - parser.add_argument('--graph-conf', help='Configuration file for defining groups') + parser.add_argument("description", help="yaml file describing the datamodel") + parser.add_argument( + "-d", + "--dot", + action="store_true", + default=False, + help="just write the dot file", + ) + parser.add_argument("--fmt", default="svg", help="Which format to use for saving the file") + parser.add_argument("--filename", default="gv", help="Which filename to use for the output") + parser.add_argument("--graph-conf", help="Configuration file for defining groups") - args = parser.parse_args() + args = parser.parse_args() - vis = ModelToGraphviz(args.description, args.dot, fmt=args.fmt, - filename=args.filename, graph_conf=args.graph_conf) - vis.make_viz() + vis = ModelToGraphviz( + args.description, + args.dot, + fmt=args.fmt, + filename=args.filename, + graph_conf=args.graph_conf, + ) + vis.make_viz()