Skip to content

Commit

Permalink
Reformatted xs.py and photon_polarization.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Jan 15, 2024
1 parent 69eb720 commit a6e2a22
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 80 deletions.
34 changes: 20 additions & 14 deletions simulationworkflowschema/photon_polarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,52 @@
from nomad.metainfo import SubSection, Quantity, Reference
from nomad.datamodel.metainfo.simulation.method import BSE as BSEMethodology
from nomad.datamodel.metainfo.simulation.calculation import Spectra
from .general import SimulationWorkflowResults, SimulationWorkflowMethod, ParallelSimulation
from .general import (
SimulationWorkflowResults,
SimulationWorkflowMethod,
ParallelSimulation,
)


class PhotonPolarizationResults(SimulationWorkflowResults):
'''Groups all polarization outputs: spectrum.
'''
"""Groups all polarization outputs: spectrum."""

n_polarizations = Quantity(
type=np.int32,
description='''
description="""
Number of polarizations for the phonons used for the calculations.
''')
""",
)

spectrum_polarization = Quantity(
type=Reference(Spectra),
shape=['n_polarizations'],
description='''
shape=["n_polarizations"],
description="""
Spectrum for a given polarization of the photon.
''')
""",
)


class PhotonPolarizationMethod(SimulationWorkflowMethod):
'''Defines the full macroscopic dielectric tensor methodology: BSE method reference.
'''
"""Defines the full macroscopic dielectric tensor methodology: BSE method reference."""

# TODO add TDDFT methodology reference.

bse_method_ref = Quantity(
type=Reference(BSEMethodology),
description='''
description="""
BSE methodology reference.
''')
""",
)


class PhotonPolarization(ParallelSimulation):
'''The PhotonPolarization workflow is generated in an extra EntryArchive FOR all polarization
"""The PhotonPolarization workflow is generated in an extra EntryArchive FOR all polarization
EntryArchives present in the upload. It groups them for a set of given method parameters.
This entry is also recognized as the full macroscopic dielectric tensor entry (e.g. calculated
via BSE).
'''
"""

method = SubSection(sub_section=PhotonPolarizationMethod)

Expand Down
96 changes: 30 additions & 66 deletions simulationworkflowschema/xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,77 +15,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from nomad.metainfo import SubSection, Quantity, Reference
from nomad.metainfo import SubSection
from nomad.datamodel.metainfo.simulation.calculation import (
BandGap, Dos, BandStructure, Spectra,
ElectronicStructureProvenance
)
from .general import (
SimulationWorkflowResults, SimulationWorkflowMethod, SerialSimulation
Spectra,
ElectronicStructureProvenance,
)
from .general import DFTOutputs, GWOutputs, SimulationWorkflowMethod, SerialSimulation
from .photon_polarization import PhotonPolarizationResults


class XSResults(SimulationWorkflowResults):
'''Groups DFT, GW and PhotonPolarization outputs: band gaps (DFT, GW), DOS (DFT, GW),
class XSResults(DFTOutputs, GWOutputs):
"""
Groups DFT, GW and PhotonPolarization outputs: band gaps (DFT, GW), DOS (DFT, GW),
band structures (DFT, GW), spectra (PhotonPolarization). The ResultsNormalizer takes
care of adding a label 'DFT' or 'GW' in the method `get_xs_workflow_properties`.
'''

band_gap_dft = Quantity(
type=Reference(BandGap),
shape=['*'],
description='''
Reference to the DFT band gap.
''')

band_gap_gw = Quantity(
type=Reference(BandGap),
shape=['*'],
description='''
Reference to the GW band gap.
''')

band_structure_dft = Quantity(
type=Reference(BandStructure),
shape=['*'],
description='''
Reference to the DFT density of states.
''')

band_structure_gw = Quantity(
type=Reference(BandStructure),
shape=['*'],
description='''
Reference to the GW density of states.
''')

dos_dft = Quantity(
type=Reference(Dos),
shape=['*'],
description='''
Reference to the DFT band structure.
''')

dos_gw = Quantity(
type=Reference(Dos),
shape=['*'],
description='''
Reference to the GW band structure.
''')
"""

spectra = SubSection(sub_section=PhotonPolarizationResults, repeats=True)


class XSMethod(SimulationWorkflowMethod):

pass


class XS(SerialSimulation):
'''The XS workflow is generated in an extra EntryArchive IF both the DFT SinglePoint
"""
The XS workflow is generated in an extra EntryArchive IF both the DFT SinglePoint
and the PhotonPolarization EntryArchives are present in the upload.
'''
"""

# TODO extend to reference a GW SinglePoint.

method = SubSection(sub_section=XSMethod)
Expand All @@ -96,39 +54,45 @@ def normalize(self, archive, logger):
super().normalize(archive, logger)

if len(self.tasks) < 2:
logger.error('Expected more than one task: DFT+PhotonPolarization or DFT+GW+PhotonPolarization.')
logger.error(
"Expected more than one task: DFT+PhotonPolarization or DFT+GW+PhotonPolarization."
)
return

dft_task = self.tasks[0]
xs_tasks = [self.tasks[i] for i in range(1, len(self.tasks))]
gw_task = None
# Check if the xs_tasks contain GW SinglePoint or a list of PhotonPolarizations
if xs_tasks[0].task.m_def.name != 'PhotonPolarization':
if xs_tasks[0].task.m_def.name != "PhotonPolarization":
gw_task = xs_tasks[0]
xs_tasks.pop(0) # we delete the [0] element associated with GW in case DFT+GW+PhotonPolarization workflow
xs_tasks.pop(
0
) # we delete the [0] element associated with GW in case DFT+GW+PhotonPolarization workflow

if not self.results:
self.results = XSResults()

for name, section in self.results.m_def.all_quantities.items():
calc_name = '_'.join(name.split('_')[:-1])
if calc_name in ['dos', 'band_structure']:
calc_name = f'{calc_name}_electronic'
calc_name = "_".join(name.split("_")[:-1])
if calc_name in ["dos", "band_structure"]:
calc_name = f"{calc_name}_electronic"
calc_section = []
if 'dft' in name:
if "dft" in name:
calc_section = getattr(dft_task.outputs[-1].section, calc_name)
elif 'gw' in name and gw_task:
elif "gw" in name and gw_task:
calc_section = getattr(gw_task.outputs[-1].section, calc_name)
elif name == 'spectra':
elif name == "spectra":
pass
if calc_section:
self.results.m_set(section, calc_section)
for xs in xs_tasks:
if xs.m_xpath('task.results'):
if xs.m_xpath("task.results"):
photon_results = xs.task.results
# Adding provenance to BSE method section, in addition to the existent 'photon' provenance
if xs.task.m_xpath('inputs[1].section'):
if xs.task.m_xpath("inputs[1].section"):
for spectra in photon_results.spectrum_polarization:
provenance = ElectronicStructureProvenance(methodology=xs.task.inputs[1].section, label='bse')
provenance = ElectronicStructureProvenance(
methodology=xs.task.inputs[1].section, label="bse"
)
spectra.m_add_sub_section(Spectra.provenance, provenance)
self.results.m_add_sub_section(XSResults.spectra, photon_results)

0 comments on commit a6e2a22

Please # to comment.