-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix lem * fix slem * update band plot * update e3 doc * update shift in loss analysis * update onsite shift loss for all * move E3 statistics initialization into dataset, optmize nested tensor support in hr2hk and eigvals compute * fix test build models * fix bug in batch mu and refactor visualization in loss analysis * update temp * rearrange doc * update doc and e3 example * doing grammer check for the updated doc
- Loading branch information
1 parent
aaa6375
commit d3c88f2
Showing
46 changed files
with
1,026 additions
and
325 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# More on Input Parameters | ||
In `common_options`, the user should define the some global param like: | ||
```JSON | ||
"common_options": { | ||
"basis": { | ||
"C": "2s2p1d", | ||
"N": "2s2p1d" | ||
}, | ||
"device": "cuda", | ||
"dtype": "float32", | ||
"overlap": true, | ||
"seed": 42 | ||
} | ||
``` | ||
- `basis` should align with the basis used to perform LCAO DFT calculations. The `"2s2p1d"` here indicates 2x`s` orbital, 2x`p`orbital and one `d` orbital. The | ||
- `seed` controls the global random seed of all related packages. `dtype` can be chosen between `float32` and `float64`, but the former is accurate enough in most cases. If you have multiple cards, the | ||
- `device` can be setted as `cuda:0`, `cuda:1` and so on, where the number is the device id. | ||
- `overlap` controls the fitting of the overlap matrix. The user should provide overlap in the dataset when configuring the data_options if `overlap` is setted as True. | ||
|
||
In `train_options`, a common parameter looks like this: | ||
```JSON | ||
"train_options": { | ||
"num_epoch": 500, | ||
"batch_size": 1, | ||
"optimizer": { | ||
"lr": 0.05, | ||
"type": "Adam" | ||
}, | ||
"lr_scheduler": { | ||
"type": "rop", | ||
"factor": 0.8, | ||
"patience": 6, | ||
"min_lr": 1e-6 | ||
}, | ||
"loss_options":{ | ||
"train": {"method": "hamil_abs", "onsite_shift": false}, | ||
"validation": {"method": "hamil_abs", "onsite_shift": false} | ||
}, | ||
"save_freq": 10, | ||
"validation_freq": 10, | ||
"display_freq": 10 | ||
} | ||
``` | ||
For `lr_scheduler`, please ensure the `patience` x `num_samples` / `batch_size` ranged between 2000 to 6000. | ||
|
||
When the dataset contains multiple elements, and you are fitting the Hamiltonian, it is suggested to open a tag in loss_options for better performance. Most DFT software would allow for a uniform shift when computing the electrostatic potentials, therefore, bringing an extra degree of freedom. The `onsite_shift` tag allows such freedom and makes the model generalizable to all sorts of element combinations: | ||
```JSON | ||
"loss_options":{ | ||
"train": {"method": "hamil_abs", "onsite_shift": true}, | ||
"validation": {"method": "hamil_abs", "onsite_shift" : true} | ||
} | ||
``` | ||
|
||
In `model_options`, we support two types of e3 group equivariant embedding methods: Strictly Localized Equivariant Message-passing or `slem`, and Localized Equivariant Message-passing or `lem`. The former ensures strict localization by truncating the propagation of distant neighbours' information and, therefore is suitable for bulk systems where the electron localization is enhanced by the scattering effect. `Lem` method, on the other hand, contained such localization de#herently by incorporating learnable decaying functions describing the dependency across distance. | ||
|
||
The model options for slem and lem are the same, here is an short example: | ||
```JSON | ||
"model_options": { | ||
"embedding": { | ||
"method": "slem", # or lem | ||
"r_max": {"Mo":7.0, "S":7.0, "W": 8.0}, | ||
"irreps_hidden": "64x0e+32x1o+32x2e+32x3o+32x4e+16x5o+8x6e+4x7o+4x8e", | ||
"n_layers": 4, | ||
"env_embed_multiplicity": 10, | ||
"avg_num_neighbors": 51, | ||
"latent_dim": 64, | ||
}, | ||
"prediction":{ | ||
"method": "e3tb", # need to be set as e3tb here | ||
"neurons": [32, 32] | ||
} | ||
} | ||
``` | ||
Here, `method` indicates the e3 descripor employed. | ||
|
||
`r_max` can be a float or int number, or a dict with atom species-specific float/int number, which indicates their cutoff envelope function, used to decay the distant atom's effect smoothly. We highly suggest the user go to the DFT calculation files and check the orbital's radial cutoff information to figure out how large this value should be. | ||
|
||
`irreps_hidden`: Very important! This parameter decides mostly the representation capacity of the model, along with the model size and consumption of GPU memory. This parameter indicates the irreps of hidden equivariant space, the definition here follows that for example, `64x0e` states `64` irreducible representation with `l=0` and `even` parity. For each basis set, we provide a tool to generate the least essential `irreps_hidden`, we also highly suggest the user add at least 3 times the number of essential irreps to enhance representation capacity. | ||
|
||
```IPYTHON | ||
In [5]: from dptb.data import OrbitalMapper | ||
In [6]: idp = OrbitalMapper(basis={"Si": "2s2p1d"}) | ||
In [7]: idp.get_irreps_ess() | ||
Out[7]: 7x0e+6x1o+6x2e+2x3o+1x4e | ||
``` | ||
|
||
`n_layers`: indicates the number of layers of the networks. | ||
|
||
`env_embed_multiplicity`: decide the irreps number when initializing the edge and node features. | ||
|
||
`avg_num_neighbors`: the averaged number of neighbours in the system given the cutoff radius set as `r_max`. It is recommended to do statistics of the system you are modelling, but just picking up a number ranging from 50 to 100 is also okay. | ||
|
||
`latent_dim`: The scalar channel's dimension of the system. 32/64/128 is good enough. | ||
|
||
For params in prediction, there is not much to be changed. The setting is pretty good. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Data Preparation | ||
We suggest the user use a data parsing tool [dftio](https://github.com/floatingCatty/dftio) to directly convert the output data from DFT calculation into readable datasets. Our implementation supports the parsed dataset format of `dftio`. Users can just clone the `dftio` repository and run `pip install .` in its root directory. Then one can use the following parsing command for the parallel data processing directly from the DFT output: | ||
```bash | ||
usage: dftio parse [-h] [-ll {DEBUG,3,INFO,2,WARNING,1,ERROR,0}] [-lp LOG_PATH] [-m MODE] [-n NUM_WORKERS] [-r ROOT] [-p PREFIX] [-o OUTROOT] [-f FORMAT] [-ham] [-ovp] [-dm] [-eig] | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
-ll {DEBUG,3,INFO,2,WARNING,1,ERROR,0}, --log-level {DEBUG,3,INFO,2,WARNING,1,ERROR,0} | ||
set verbosity level by string or number, 0=ERROR, 1=WARNING, 2=INFO and 3=DEBUG (default: INFO) | ||
-lp LOG_PATH, --log-path LOG_PATH | ||
set log file to log messages to disk, if not specified, the logs will only be output to console (default: None) | ||
-m MODE, --mode MODE The name of the DFT software. (default: abacus) | ||
-n NUM_WORKERS, --num_workers NUM_WORKERS | ||
The number of workers used to parse the dataset. (For n>1, we use the multiprocessing to accelerate io.) (default: 1) | ||
-r ROOT, --root ROOT The root directory of the DFT files. (default: ./) | ||
-p PREFIX, --prefix PREFIX | ||
The prefix of the DFT files under root. (default: frame) | ||
-o OUTROOT, --outroot OUTROOT | ||
The output root directory. (default: ./) | ||
-f FORMAT, --format FORMAT | ||
The output root directory. (default: dat) | ||
-ham, --hamiltonian Whether to parse the Hamiltonian matrix. (default: False) | ||
-ovp, --overlap Whether to parse the Overlap matrix (default: False) | ||
-dm, --density_matrix | ||
Whether to parse the Density matrix (default: False) | ||
-eig, --eigenvalue Whether to parse the kpoints and eigenvalues (default: False) | ||
``` | ||
After parsing, the user need to write a info.json file and put it in the dataset. For default dataset type, the `info.json` looks like: | ||
```JSON | ||
{ | ||
"nframes": 1, | ||
"pos_type": "cart", | ||
"AtomicData_options": { | ||
"r_max": 7.0, | ||
"pbc": true | ||
} | ||
} | ||
|
||
``` | ||
Here `pos_type` can be `cart`, `dirc` or `ase`. For `dftio` output dataset, we use `cart` by default. The `r_max`, in principle, should align with the orbital cutoff in the DFT calculation. For a single element, the `r_max` should be a float number, indicating the largest bond distance included. When the system has multiple atoms, the `r_max` can also be a dict of atomic species-specific number like `{A: 7.0, B: 8.0}`. Then the largest bond `A-A` would be 7 and `A-B` be (7+8)/2=7.5, and `B-B` would be 8. `pbc` can be a bool variable, indicating the open or close of the periodic boundary conditions of the model. It can also be a list of three bool elements like `[true, true, false]`, which means we can set the periodicity of each direction independently. | ||
For LMDB type Dataset, the info.json is much simpler, which looks like this: | ||
```JSON | ||
{ | ||
"r_max": 7.0 | ||
} | ||
``` | ||
Where other information has been stored in the dataset. LMDB dataset is designed for handeling very large data that cannot be fit into the memory directly. | ||
Then you can set the `data_options` in the input parameters to point directly to the prepared dataset, like: | ||
```JSON | ||
"data_options": { | ||
"train": { | ||
"root": "./data", | ||
"prefix": "Si64", | ||
"get_Hamiltonian": true, | ||
"get_overlap": true | ||
} | ||
} | ||
``` | ||
If you are using a python script, the dataset can be build with the same parameters using `build_datasets`: | ||
```Python | ||
from dptb.data import build_dataset | ||
|
||
dataset = build_dataset( | ||
root="your dataset root", | ||
type="DefaultDataset", | ||
prefix="frame", | ||
get_overlap=True, | ||
get_Hamiltonian=True, | ||
basis={"Si":"2s2p1d"} | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
================================================= | ||
E3TB Advanced | ||
================================================= | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
:caption: Examples | ||
|
||
advanced_input | ||
data_preparation | ||
loss_analysis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Loss Analysis | ||
## function | ||
The **DeePTB** contains a module to help the user better understand the details of the error of the **E3TB** module. | ||
We decompose the error of **E3TB** model into several parts: | ||
- onsite blocks: for diagonal blocks of the predicted quantum tensors the onsite blocks are further arranged according to the atom species. | ||
- hopping blocks: for off-diagonal blocks, the hopping block errors are then further arranged according to the atom-pair types. | ||
|
||
## usage | ||
For using this function, we need a dataset and the model. Just build them up in advance. | ||
```Python | ||
from dptb.data import build_dataset | ||
from dptb.nn import build_model | ||
|
||
dataset = build_dataset( | ||
root="your dataset root", | ||
type="DefaultDataset", | ||
prefix="frame", | ||
get_overlap=True, | ||
get_Hamiltonian=True, | ||
basis={"Si":"2s2p1d"} | ||
) | ||
|
||
model = build_model("./ovp/checkpoint/nnenv.best.pth", common_options={"device":"cuda"}) | ||
model.eval() | ||
``` | ||
|
||
Then, the user should sample over the dataset using the dataloader and doing a analysis with running average, the code looks like: | ||
```Python | ||
import torch | ||
from dptb.nnops.loss import HamilLossAnalysis | ||
from dptb.data.dataloader import DataLoader | ||
from tqdm import tqdm | ||
from dptb.data import AtomicData | ||
|
||
ana = HamilLossAnalysis(idp=model.idp, device=model.device, decompose=True, overlap=True) | ||
|
||
loader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=0) | ||
|
||
for data in tqdm(loader, desc="doing error analysis"): | ||
with torch.no_grad(): | ||
ref_data = AtomicData.to_AtomicDataDict(data.to("cuda")) | ||
data = model(ref_data) | ||
ana(data, ref_data, running_avg=True) | ||
``` | ||
The analysis results are stored in `ana.stats`, which is a dictionary of statistics. The user can check the value directly, or display the results by: | ||
|
||
```Python | ||
ana.report() | ||
``` | ||
Here is an example of the output: | ||
``` | ||
TOTAL: | ||
MAE: 0.00012021172733511776 | ||
RMSE: 0.00034208124270662665 | ||
Onsite: | ||
Si: | ||
MAE: 0.0012505357153713703 | ||
RMSE: 0.0023699181620031595 | ||
``` | ||
![MAE onsite](../../img/MAE_onsite.png) | ||
![RMSE onsite](../../img/RMSE_onsite.png) | ||
|
||
``` | ||
Hopping: | ||
Si-Si: | ||
MAE: 0.00016888207755982876 | ||
RMSE: 0.0003886453341692686 | ||
``` | ||
![MAE hopping](../../img/MAE_hopping.png) | ||
![RMSE hopping](../../img/RMSE_hopping.png) | ||
|
||
If the user wants to see the loss in a decomposed irreps format, one can set the `decompose` of the `HamilLossAnalysis` class to `True`, and rerun the analysis. We can display the decomposed irreps results using the following code: | ||
```Python | ||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
ana_result = ana.stats | ||
|
||
for bt, err in ana_result["hopping"].items(): | ||
print("rmse err for bond {bt}: {rmserr} \t mae err for bond {bt}: {maerr}".format(bt=bt, rmserr=err["rmse"], maerr=err["mae"])) | ||
|
||
for bt, err in ana_result["onsite"].items(): | ||
print("rmse err for atom {bt}: {rmserr} \t mae err for atom {bt}: {maerr}".format(bt=bt, rmserr=err["rmse"], maerr=err["mae"])) | ||
|
||
for bt, err in ana_result["hopping"].items(): | ||
x = list(range(model.idp.orbpair_irreps.num_irreps)) | ||
rmserr = err["rmse_per_irreps"] | ||
maerr = err["mae_per_irreps"] | ||
sort_index = torch.LongTensor(model.idp.orbpair_irreps.sort().inv) | ||
|
||
# rmserr = rmserr[sort_index] | ||
# maerr = maerr[sort_index] | ||
|
||
plt.figure(figsize=(20,3)) | ||
plt.bar(x, rmserr.cpu().detach(), label="RMSE per rme") | ||
plt.bar(x, maerr.cpu().detach(), alpha=0.6, label="MAE per rme") | ||
plt.legend() | ||
# plt.yscale("log") | ||
# plt.ylim([1e-5, 5e-4]) | ||
plt.title("rme specific error of bond type: {bt}".format(bt=bt)) | ||
plt.show() | ||
|
||
for at, err in ana_result["onsite"].items(): | ||
x = list(range(model.idp.orbpair_irreps.num_irreps)) | ||
rmserr = err["rmse_per_irreps"] | ||
maerr = err["mae_per_irreps"] | ||
sort_index = torch.LongTensor(model.idp.orbpair_irreps.sort().inv) | ||
|
||
rmserr = rmserr[sort_index] | ||
maerr = maerr[sort_index] | ||
|
||
plt.figure(figsize=(20,3)) | ||
plt.bar(x, rmserr.cpu().detach(), label="RMSE per rme") | ||
plt.bar(x, maerr.cpu().detach(), alpha=0.6, label="MAE per rme") | ||
plt.legend() | ||
# plt.yscale("log") | ||
# plt.ylim([1e-5, 2.e-2]) | ||
plt.title("rme specific error of atom type: {at}".format(at=at)) | ||
plt.show() | ||
|
||
``` |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
================================================= | ||
SKTB Advanced | ||
================================================= | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
:caption: Examples | ||
|
||
dftb | ||
dptb_env | ||
nrl_tb | ||
soc |
File renamed without changes.
File renamed without changes.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.