From eef62ce714e80720f597fd9bab7c9b5ccec5a995 Mon Sep 17 00:00:00 2001 From: Rakshit Kumar Singh Date: Wed, 11 Sep 2024 22:15:05 +0530 Subject: [PATCH] Cleanup - 2 (#4113) * dftutils update * update dftutils * change testing env from dqc to torch * added spin value --- deepchem/utils/dftutils.py | 8 ++------ deepchem/utils/test/test_dftutils.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/deepchem/utils/dftutils.py b/deepchem/utils/dftutils.py index dd9ef5693e..e984b474bc 100644 --- a/deepchem/utils/dftutils.py +++ b/deepchem/utils/dftutils.py @@ -11,11 +11,7 @@ except Exception as e: warnings.warn("Could not import torch. Skipping tests." + str(e)) -try: - import xitorch as xt -except Exception as e: - warnings.warn("Could not import xitorch. Skipping tests." + str(e)) - +from deepchem.utils.differentiation_utils import EditableModule T = TypeVar('T') @@ -231,7 +227,7 @@ def hashstr(s: str) -> str: return str(hashlib.blake2s(str.encode(s)).hexdigest()) -class BaseGrid(xt.EditableModule): +class BaseGrid(EditableModule): """ Interface to DQC's BaseGrid class. BaseGrid is a class that regulates the integration points over the spatial dimensions. diff --git a/deepchem/utils/test/test_dftutils.py b/deepchem/utils/test/test_dftutils.py index fe9b724259..814ddcdd4f 100644 --- a/deepchem/utils/test/test_dftutils.py +++ b/deepchem/utils/test/test_dftutils.py @@ -9,11 +9,9 @@ warnings.warn("Could not import torch. Skipping tests." + str(e)) -@pytest.mark.dqc +@pytest.mark.torch def test_dftutils(): - import dqc - from dqc.system.mol import Mol - from dqc.qccalc.ks import KS + from deepchem.utils.dft_utils import parse_moldesc, Mol, KS from deepchem.utils.dftutils import KSCalc system = { 'type': 'mol', @@ -22,8 +20,8 @@ def test_dftutils(): 'basis': '6-311++G(3df,3pd)' } } - atomzs, atomposs = dqc.parse_moldesc(system["kwargs"]["moldesc"]) - mol = Mol(**system["kwargs"]) + atomzs, atomposs = parse_moldesc(system["kwargs"]["moldesc"]) + mol = Mol(**system["kwargs"], spin=0.0) qc = KS(mol, xc='lda_x').run() qcs = KSCalc(qc) a = qcs.energy() @@ -31,7 +29,7 @@ def test_dftutils(): assert torch.allclose(a, b) -@pytest.mark.dqc +@pytest.mark.torch def test_SpinParam_sum(): from deepchem.utils.dftutils import SpinParam dens_u = torch.rand(10) @@ -41,7 +39,7 @@ def test_SpinParam_sum(): assert torch.all(sp.sum().eq(dens_u + dens_d)).item() -@pytest.mark.dqc +@pytest.mark.torch def test_SpinParam_reduce(): from deepchem.utils.dftutils import SpinParam dens_u = torch.rand(10) @@ -54,7 +52,7 @@ def fcn(a, b): assert torch.all(sp.reduce(fcn).eq(dens_u * dens_d)).item() -@pytest.mark.dqc +@pytest.mark.torch def test_str(): from deepchem.utils.dftutils import hashstr s = "hydrogen fluoride"