diff --git a/simulationworkflowschema/photon_polarization.py b/simulationworkflowschema/photon_polarization.py index fa45338..c684b8a 100644 --- a/simulationworkflowschema/photon_polarization.py +++ b/simulationworkflowschema/photon_polarization.py @@ -20,46 +20,52 @@ from nomad.metainfo import SubSection, Quantity, Reference from nomad.datamodel.metainfo.simulation.method import BSE as BSEMethodology from nomad.datamodel.metainfo.simulation.calculation import Spectra -from .general import SimulationWorkflowResults, SimulationWorkflowMethod, ParallelSimulation +from .general import ( + SimulationWorkflowResults, + SimulationWorkflowMethod, + ParallelSimulation, +) class PhotonPolarizationResults(SimulationWorkflowResults): - '''Groups all polarization outputs: spectrum. - ''' + """Groups all polarization outputs: spectrum.""" n_polarizations = Quantity( type=np.int32, - description=''' + description=""" Number of polarizations for the phonons used for the calculations. - ''') + """, + ) spectrum_polarization = Quantity( type=Reference(Spectra), - shape=['n_polarizations'], - description=''' + shape=["n_polarizations"], + description=""" Spectrum for a given polarization of the photon. - ''') + """, + ) class PhotonPolarizationMethod(SimulationWorkflowMethod): - '''Defines the full macroscopic dielectric tensor methodology: BSE method reference. - ''' + """Defines the full macroscopic dielectric tensor methodology: BSE method reference.""" + # TODO add TDDFT methodology reference. bse_method_ref = Quantity( type=Reference(BSEMethodology), - description=''' + description=""" BSE methodology reference. - ''') + """, + ) class PhotonPolarization(ParallelSimulation): - '''The PhotonPolarization workflow is generated in an extra EntryArchive FOR all polarization + """The PhotonPolarization workflow is generated in an extra EntryArchive FOR all polarization EntryArchives present in the upload. It groups them for a set of given method parameters. This entry is also recognized as the full macroscopic dielectric tensor entry (e.g. calculated via BSE). - ''' + """ method = SubSection(sub_section=PhotonPolarizationMethod) diff --git a/simulationworkflowschema/xs.py b/simulationworkflowschema/xs.py index d08bcd6..77e280b 100644 --- a/simulationworkflowschema/xs.py +++ b/simulationworkflowschema/xs.py @@ -15,77 +15,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from nomad.metainfo import SubSection, Quantity, Reference +from nomad.metainfo import SubSection from nomad.datamodel.metainfo.simulation.calculation import ( - BandGap, Dos, BandStructure, Spectra, - ElectronicStructureProvenance -) -from .general import ( - SimulationWorkflowResults, SimulationWorkflowMethod, SerialSimulation + Spectra, + ElectronicStructureProvenance, ) +from .general import DFTOutputs, GWOutputs, SimulationWorkflowMethod, SerialSimulation from .photon_polarization import PhotonPolarizationResults -class XSResults(SimulationWorkflowResults): - '''Groups DFT, GW and PhotonPolarization outputs: band gaps (DFT, GW), DOS (DFT, GW), +class XSResults(DFTOutputs, GWOutputs): + """ + Groups DFT, GW and PhotonPolarization outputs: band gaps (DFT, GW), DOS (DFT, GW), band structures (DFT, GW), spectra (PhotonPolarization). The ResultsNormalizer takes care of adding a label 'DFT' or 'GW' in the method `get_xs_workflow_properties`. - ''' - - band_gap_dft = Quantity( - type=Reference(BandGap), - shape=['*'], - description=''' - Reference to the DFT band gap. - ''') - - band_gap_gw = Quantity( - type=Reference(BandGap), - shape=['*'], - description=''' - Reference to the GW band gap. - ''') - - band_structure_dft = Quantity( - type=Reference(BandStructure), - shape=['*'], - description=''' - Reference to the DFT density of states. - ''') - - band_structure_gw = Quantity( - type=Reference(BandStructure), - shape=['*'], - description=''' - Reference to the GW density of states. - ''') - - dos_dft = Quantity( - type=Reference(Dos), - shape=['*'], - description=''' - Reference to the DFT band structure. - ''') - - dos_gw = Quantity( - type=Reference(Dos), - shape=['*'], - description=''' - Reference to the GW band structure. - ''') + """ spectra = SubSection(sub_section=PhotonPolarizationResults, repeats=True) class XSMethod(SimulationWorkflowMethod): - pass class XS(SerialSimulation): - '''The XS workflow is generated in an extra EntryArchive IF both the DFT SinglePoint + """ + The XS workflow is generated in an extra EntryArchive IF both the DFT SinglePoint and the PhotonPolarization EntryArchives are present in the upload. - ''' + """ + # TODO extend to reference a GW SinglePoint. method = SubSection(sub_section=XSMethod) @@ -96,39 +54,45 @@ def normalize(self, archive, logger): super().normalize(archive, logger) if len(self.tasks) < 2: - logger.error('Expected more than one task: DFT+PhotonPolarization or DFT+GW+PhotonPolarization.') + logger.error( + "Expected more than one task: DFT+PhotonPolarization or DFT+GW+PhotonPolarization." + ) return dft_task = self.tasks[0] xs_tasks = [self.tasks[i] for i in range(1, len(self.tasks))] gw_task = None # Check if the xs_tasks contain GW SinglePoint or a list of PhotonPolarizations - if xs_tasks[0].task.m_def.name != 'PhotonPolarization': + if xs_tasks[0].task.m_def.name != "PhotonPolarization": gw_task = xs_tasks[0] - xs_tasks.pop(0) # we delete the [0] element associated with GW in case DFT+GW+PhotonPolarization workflow + xs_tasks.pop( + 0 + ) # we delete the [0] element associated with GW in case DFT+GW+PhotonPolarization workflow if not self.results: self.results = XSResults() for name, section in self.results.m_def.all_quantities.items(): - calc_name = '_'.join(name.split('_')[:-1]) - if calc_name in ['dos', 'band_structure']: - calc_name = f'{calc_name}_electronic' + calc_name = "_".join(name.split("_")[:-1]) + if calc_name in ["dos", "band_structure"]: + calc_name = f"{calc_name}_electronic" calc_section = [] - if 'dft' in name: + if "dft" in name: calc_section = getattr(dft_task.outputs[-1].section, calc_name) - elif 'gw' in name and gw_task: + elif "gw" in name and gw_task: calc_section = getattr(gw_task.outputs[-1].section, calc_name) - elif name == 'spectra': + elif name == "spectra": pass if calc_section: self.results.m_set(section, calc_section) for xs in xs_tasks: - if xs.m_xpath('task.results'): + if xs.m_xpath("task.results"): photon_results = xs.task.results # Adding provenance to BSE method section, in addition to the existent 'photon' provenance - if xs.task.m_xpath('inputs[1].section'): + if xs.task.m_xpath("inputs[1].section"): for spectra in photon_results.spectrum_polarization: - provenance = ElectronicStructureProvenance(methodology=xs.task.inputs[1].section, label='bse') + provenance = ElectronicStructureProvenance( + methodology=xs.task.inputs[1].section, label="bse" + ) spectra.m_add_sub_section(Spectra.provenance, provenance) self.results.m_add_sub_section(XSResults.spectra, photon_results)