2626 raise ModuleNotFoundError ("This submodule requires pyscf installed" )
2727
2828import numpy as np
29- from typing import Tuple
29+ from typing import Optional , Tuple
3030
3131from .interface import (
3232 DispersionModel ,
@@ -107,14 +107,15 @@ class DFTD3Dispersion(lib.StreamObject):
107107 array(-0.00574289)
108108 """
109109
110- def __init__ (self , mol , xc = "hf" , version = "d3bj" , atm = False ):
110+ def __init__ (self , mol : gto . Mole , xc : str = "hf" , version : str = "d3bj" , atm : bool = False , param : Optional [ Dict [ str , float ]] = None ):
111111 self .mol = mol
112112 self .verbose = mol .verbose
113113 self .xc = xc
114+ self .param = param
114115 self .atm = atm
115116 self .version = version
116117
117- def dump_flags (self , verbose = None ):
118+ def dump_flags (self , verbose : Optional [ bool ] = None ):
118119 """
119120 Show options used for the DFT-D3 dispersion correction.
120121 """
@@ -168,16 +169,19 @@ def kernel(self) -> Tuple[float, np.ndarray]:
168169 mol .atom_coords (),
169170 )
170171
171- param = _damping_param [self .version ](
172- method = self .xc ,
173- atm = self .atm ,
174- )
172+ if self .param is not None :
173+ param = _damping_param [self .version ](** self .param )
174+ else :
175+ param = _damping_param [self .version ](
176+ method = self .xc ,
177+ atm = self .atm ,
178+ )
175179
176180 res = disp .get_dispersion (param = param , grad = True )
177181
178182 return res .get ("energy" ), res .get ("gradient" )
179183
180- def reset (self , mol ):
184+ def reset (self , mol : gto . Mole ):
181185 """Reset mol and clean up relevant attributes for scanner mode"""
182186 self .mol = mol
183187 return self
@@ -199,7 +203,7 @@ class _DFTD3Grad:
199203 pass
200204
201205
202- def energy (mf ) :
206+ def energy (mf : scf . hf . Scf , ** kwargs ) -> scf . hf . Scf :
203207 """
204208 Apply DFT-D3 corrections to SCF or MCSCF methods by returning an
205209 instance of a new class built from the original instances class.
@@ -248,6 +252,7 @@ def energy(mf):
248252 xc = "hf"
249253 if isinstance (mf , casci .CASCI )
250254 else getattr (mf , "xc" , "HF" ).upper ().replace (" " , "" ),
255+ ** kwargs ,
251256 )
252257
253258 if isinstance (mf , _DFTD3 ):
@@ -287,7 +292,7 @@ def nuc_grad_method(self):
287292 return DFTD3 (mf , with_dftd3 )
288293
289294
290- def grad (scf_grad ):
295+ def grad (scf_grad , ** kwargs ):
291296 """
292297 Apply DFT-D3 corrections to SCF or MCSCF nuclear gradients methods
293298 by returning an instance of a new class built from the original class.
@@ -337,7 +342,7 @@ def grad(scf_grad):
337342
338343 # Ensure that the zeroth order results include DFTD3 corrections
339344 if not getattr (scf_grad .base , "with_dftd3" , None ):
340- scf_grad .base = dftd3 (scf_grad .base )
345+ scf_grad .base = energy (scf_grad .base , ** kwargs )
341346
342347 class DFTD3Grad (_DFTD3Grad , scf_grad .__class__ ):
343348 def grad_nuc (self , mol = None , atmlst = None ):
0 commit comments