Skip to content

Commit

Permalink
Cleanup - 2 (deepchem#4113)
Browse files Browse the repository at this point in the history
* dftutils update

* update dftutils

* change testing env from dqc to torch

* added spin value
  • Loading branch information
sudo-rsingh authored Sep 11, 2024
1 parent 632c427 commit eef62ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
8 changes: 2 additions & 6 deletions deepchem/utils/dftutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
except Exception as e:
warnings.warn("Could not import torch. Skipping tests." + str(e))

try:
import xitorch as xt
except Exception as e:
warnings.warn("Could not import xitorch. Skipping tests." + str(e))

from deepchem.utils.differentiation_utils import EditableModule

T = TypeVar('T')

Expand Down Expand Up @@ -231,7 +227,7 @@ def hashstr(s: str) -> str:
return str(hashlib.blake2s(str.encode(s)).hexdigest())


class BaseGrid(xt.EditableModule):
class BaseGrid(EditableModule):
"""
Interface to DQC's BaseGrid class. BaseGrid is a class that regulates the integration points over the spatial
dimensions.
Expand Down
16 changes: 7 additions & 9 deletions deepchem/utils/test/test_dftutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
warnings.warn("Could not import torch. Skipping tests." + str(e))


@pytest.mark.dqc
@pytest.mark.torch
def test_dftutils():
import dqc
from dqc.system.mol import Mol
from dqc.qccalc.ks import KS
from deepchem.utils.dft_utils import parse_moldesc, Mol, KS
from deepchem.utils.dftutils import KSCalc
system = {
'type': 'mol',
Expand All @@ -22,16 +20,16 @@ def test_dftutils():
'basis': '6-311++G(3df,3pd)'
}
}
atomzs, atomposs = dqc.parse_moldesc(system["kwargs"]["moldesc"])
mol = Mol(**system["kwargs"])
atomzs, atomposs = parse_moldesc(system["kwargs"]["moldesc"])
mol = Mol(**system["kwargs"], spin=0.0)
qc = KS(mol, xc='lda_x').run()
qcs = KSCalc(qc)
a = qcs.energy()
b = torch.tensor(-99.1360, dtype=torch.float64)
assert torch.allclose(a, b)


@pytest.mark.dqc
@pytest.mark.torch
def test_SpinParam_sum():
from deepchem.utils.dftutils import SpinParam
dens_u = torch.rand(10)
Expand All @@ -41,7 +39,7 @@ def test_SpinParam_sum():
assert torch.all(sp.sum().eq(dens_u + dens_d)).item()


@pytest.mark.dqc
@pytest.mark.torch
def test_SpinParam_reduce():
from deepchem.utils.dftutils import SpinParam
dens_u = torch.rand(10)
Expand All @@ -54,7 +52,7 @@ def fcn(a, b):
assert torch.all(sp.reduce(fcn).eq(dens_u * dens_d)).item()


@pytest.mark.dqc
@pytest.mark.torch
def test_str():
from deepchem.utils.dftutils import hashstr
s = "hydrogen fluoride"
Expand Down

0 comments on commit eef62ce

Please # to comment.