Skip to content

Commit

Permalink
refactor(data preprocess): remove the cut off options from info.json (#…
Browse files Browse the repository at this point in the history
…200)

* refactor(data preprocess): remove the cut off options from info.json and collect the values from input.json

* update LMDB info.json. not need anymore.

* refactor(default_dataset): refactor the _TrajData for ase data.

Previous the ase data will be transferred into text file and then loaded by the _TrajData. now i refactor the function.

both text and ase data are treated equally. will works as a class funtion to initial the _TrajData class.

* add print logo in main and format some of  the logger.info

* update argcheck  collect_cutoffs.  add new function with  get_cutoffs_from_model_options .

* Fix(get_cutoffs_from_model_options) : fix rcut in  powerlaw and varTang96.

For powerlaw and varTang96, the rs is not exactly the hard cutoff. so when extract the r_max for data. we have to use rs + 5 * w; but for other method just use rs.

* update band post process.

* update test

* update test

* update build and get_cutoffs_from_model_options to support the rmax to be dict.

* refactor(build dataset): change build_dataset from function to a class instance and add from_model class function.

note, compared to the previous build_dataset, this one is more flexible.
previous build_dataset is a function. now i define a class DataBuilder and re-defined __call__ function.  then build_dataset is an instance of DataBuilder class. so i can use build_dataset.from_model() to build dataset from model. at the same time the previous way to use  build_dataset is still available. like build_dataset(...).

* add checkcutoff in dataset builder.

* update AtomicData_options to make it compatible with older versions

* Update argcheck.py

---------

Co-authored-by: Yinzhanghao Zhou <64253517+floatingCatty@users.noreply.github.com>
  • Loading branch information
QG-phy and floatingCatty authored Aug 13, 2024
1 parent c5ca916 commit caa903d
Show file tree
Hide file tree
Showing 37 changed files with 1,513 additions and 566 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ dptb/tests/**/*.pth
dptb/tests/**/*.npy
dptb/tests/**/*.traj
dptb/tests/**/out*/*
dptb/tests/**/out*/*
dptb/tests/**/*lmdb
dptb/tests/**/*h5
examples/_*
*.dat
*log*
Expand Down
38 changes: 36 additions & 2 deletions dptb/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
from dptb.entrypoints.main import main
from dptb.entrypoints.main import main as entry_main
import logging
import pyfiglet
from dptb import __version__

logging.basicConfig(level=logging.INFO, format='%(message)s')
log = logging.getLogger(__name__)

def print_logo():
f = pyfiglet.Figlet(font='dos_rebel') # 您可以选择您喜欢的字体
logo = f.renderText("DeePTB")
log.info(" ")
log.info(" ")
log.info("#"*81)
log.info("#" + " "*79 + "#")
log.info("#" + " "*79 + "#")
for line in logo.split('\n'):
if line.strip(): # 避免记录空行
log.info('# '+line+ ' #')
log.info("#" + " "*79 + "#")
version_info = f"Version: {__version__}"
padding = (79 - len(version_info)) // 2
nspace = 79-padding
format_str = "#" + "{}"+"{:<"+f"{nspace}" + "}"+ "#"
log.info(format_str.format(" "*padding, version_info))
log.info("#" + " "*79 + "#")
log.info("#"*81)
log.info(" ")
log.info(" ")
def main() -> None:
"""
The main entry point for the dptb package.
"""
print_logo()
entry_main()

if __name__ == '__main__':
main()
#print_logo()
main()
2 changes: 1 addition & 1 deletion dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def from_points(
def from_ase(
cls,
atoms,
r_max,
r_max: Union[float, int, dict],
er_max: Optional[float] = None,
oer_max: Optional[float] = None,
key_mapping: Optional[Dict[str, str]] = {},
Expand Down
Loading

0 comments on commit caa903d

Please # to comment.