Skip to content

Commit 3c40e8d

Browse files
committed
Add pass-through arguments and allow setting parameters in PySCF wrapper
1 parent cf9d7a8 commit 3c40e8d

1 file changed

Lines changed: 16 additions & 11 deletions

File tree

python/dftd3/pyscf.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
raise ModuleNotFoundError("This submodule requires pyscf installed")
2727

2828
import numpy as np
29-
from typing import Tuple
29+
from typing import Optional, Tuple
3030

3131
from .interface import (
3232
DispersionModel,
@@ -107,14 +107,15 @@ class DFTD3Dispersion(lib.StreamObject):
107107
array(-0.00574289)
108108
"""
109109

110-
def __init__(self, mol, xc="hf", version="d3bj", atm=False):
110+
def __init__(self, mol: gto.Mole, xc: str = "hf", version: str = "d3bj", atm: bool = False, param: Optional[Dict[str, float]] = None):
111111
self.mol = mol
112112
self.verbose = mol.verbose
113113
self.xc = xc
114+
self.param = param
114115
self.atm = atm
115116
self.version = version
116117

117-
def dump_flags(self, verbose=None):
118+
def dump_flags(self, verbose: Optional[bool] = None):
118119
"""
119120
Show options used for the DFT-D3 dispersion correction.
120121
"""
@@ -168,16 +169,19 @@ def kernel(self) -> Tuple[float, np.ndarray]:
168169
mol.atom_coords(),
169170
)
170171

171-
param = _damping_param[self.version](
172-
method=self.xc,
173-
atm=self.atm,
174-
)
172+
if self.param is not None:
173+
param = _damping_param[self.version](**self.param)
174+
else:
175+
param = _damping_param[self.version](
176+
method=self.xc,
177+
atm=self.atm,
178+
)
175179

176180
res = disp.get_dispersion(param=param, grad=True)
177181

178182
return res.get("energy"), res.get("gradient")
179183

180-
def reset(self, mol):
184+
def reset(self, mol: gto.Mole):
181185
"""Reset mol and clean up relevant attributes for scanner mode"""
182186
self.mol = mol
183187
return self
@@ -199,7 +203,7 @@ class _DFTD3Grad:
199203
pass
200204

201205

202-
def energy(mf):
206+
def energy(mf: scf.hf.Scf, **kwargs) -> scf.hf.Scf:
203207
"""
204208
Apply DFT-D3 corrections to SCF or MCSCF methods by returning an
205209
instance of a new class built from the original instances class.
@@ -248,6 +252,7 @@ def energy(mf):
248252
xc="hf"
249253
if isinstance(mf, casci.CASCI)
250254
else getattr(mf, "xc", "HF").upper().replace(" ", ""),
255+
**kwargs,
251256
)
252257

253258
if isinstance(mf, _DFTD3):
@@ -287,7 +292,7 @@ def nuc_grad_method(self):
287292
return DFTD3(mf, with_dftd3)
288293

289294

290-
def grad(scf_grad):
295+
def grad(scf_grad, **kwargs):
291296
"""
292297
Apply DFT-D3 corrections to SCF or MCSCF nuclear gradients methods
293298
by returning an instance of a new class built from the original class.
@@ -337,7 +342,7 @@ def grad(scf_grad):
337342

338343
# Ensure that the zeroth order results include DFTD3 corrections
339344
if not getattr(scf_grad.base, "with_dftd3", None):
340-
scf_grad.base = dftd3(scf_grad.base)
345+
scf_grad.base = energy(scf_grad.base, **kwargs)
341346

342347
class DFTD3Grad(_DFTD3Grad, scf_grad.__class__):
343348
def grad_nuc(self, mol=None, atmlst=None):

0 commit comments

Comments
 (0)