Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ jobs:
Pkg.pkg"registry add https://github.com/ACEsuit/ACEregistry"
shell: bash -c "julia --color=yes {0}"
- uses: julia-actions/cache@v2
# Use EquivariantTensors branch with Zygote rrule fixes (PR #75)
# TODO: Remove this once ET PR #75 is merged and released
- name: Add EquivariantTensors branch
run: |
using Pkg
# Ensure General registry is available for dependencies
Pkg.Registry.add("General")
Pkg.add(url="https://github.com/jameskermode/EquivariantTensors.jl", rev="jk/zygote-rrules")
shell: bash -c "julia --project=. --color=yes {0}"
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1

Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.10.0"

[deps]
ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004"
Expand All @@ -21,6 +22,7 @@ Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down Expand Up @@ -50,6 +52,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ACEfit = "0.3.0"
Adapt = "4"
ArgParse = "1"
AtomsBase = "0.5"
AtomsBuilder = "0.2.0"
Expand All @@ -67,6 +70,7 @@ Folds = "0.2"
ForwardDiff = "0.10"
Interpolations = "0.15"
JSON = "0.21"
KernelAbstractions = "0.9"
Lux = "1.25"
LuxCore = "1"
NamedTupleTools = "0.13, 0.14"
Expand Down
106 changes: 106 additions & 0 deletions benchmark/bench_evaluate_basis_ed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env julia
#
# Benchmark script for evaluate_basis_ed
# Compares current Zygote pullback approach vs ForwardDiff.jacobian
#

using Pkg
Pkg.activate(joinpath(@__DIR__, ".."))

using ACEpotentials, AtomsBase, Unitful, BenchmarkTools, Random
using Lux, StaticArrays
using AtomsBuilder

# Access functions from internal Models module
const M = ACEpotentials.Models
const build_et_calculator = M.build_et_calculator
const evaluate_basis = M.evaluate_basis
const evaluate_basis_ed = M.evaluate_basis_ed
const length_basis = M.length_basis

# Build a small test system (Si crystal)
function make_si_system(supercell=(2,1,1))
sys = AtomsBuilder.bulk(:Si) * supercell
AtomsBuilder.rattle!(sys, 0.1u"Å")
return sys
end

# Build model and calculator (following test_et_calculator.jl pattern)
println("Building ACE model...")

elements = (:Si,)
level = M.TotalDegree()
max_level = 8
order = 2
maxl = 4

rin0cuts = M._default_rin0cuts(elements)
rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts)

model = M.ace_model(; elements = elements, order = order,
Ytype = :solid, level = level, max_level = max_level,
maxl = maxl, pair_maxn = max_level,
rin0cuts = rin0cuts,
init_WB = :glorot_normal, init_Wpair = :glorot_normal)

rng = Random.MersenneTwister(1234)
ps, st = Lux.setup(rng, model)

# Zero out pair basis (not implemented in ET backend)
for s in model.pairbasis.splines
s.itp.itp.coefs[:] *= 0
end

println("Building ET calculator...")
calc = build_et_calculator(model, ps, st)

# Test with different system sizes
for supercell in [(2,1,1)] # Can add larger systems
sys = make_si_system(supercell)
n_atoms = length(sys)

println("\n" * "="^60)
println("System: $(n_atoms) atoms (supercell $(supercell))")
println("Basis length per species: $(calc.len_Bi)")
println("Total basis length: $(length_basis(calc))")
println("="^60)

# Warmup
println("\nWarmup...")
B = evaluate_basis(calc, sys)
println("B shape: $(size(B))")

# Benchmark evaluate_basis (forward only)
println("\n--- evaluate_basis (forward only) ---")
b1 = @benchmark evaluate_basis($calc, $sys) samples=10 evals=1
display(b1)

# Benchmark evaluate_basis_ed (forward + Jacobian)
println("\n--- evaluate_basis_ed (current: Zygote pullback) ---")
println("This may take a while for larger systems...")

# First just time it once to see how long it takes
t0 = time()
B_ed, dB = evaluate_basis_ed(calc, sys)
t1 = time()
println("Single call time: $(round(t1-t0, digits=2)) seconds")
println("B_ed shape: $(size(B_ed)), dB shape: $(size(dB))")

# Only do proper benchmark if it's not too slow
if t1 - t0 < 30.0
b2 = @benchmark evaluate_basis_ed($calc, $sys) samples=5 evals=1
display(b2)
else
println("Skipping full benchmark - single call took > 30 seconds")
end

# Calculate theoretical efficiency
n_outputs = n_atoms * calc.len_Bi
println("\n--- Analysis ---")
println("Number of outputs (n_atoms × len_Bi): $n_outputs")
println("With Zygote pullback: $n_outputs backward passes")
println("With ForwardDiff: ~$(3 * n_atoms) forward passes (rough estimate)")
println("Theoretical speedup: $(round(n_outputs / (3 * n_atoms), digits=1))×")
end

println("\n\nBenchmark complete.")
27 changes: 1 addition & 26 deletions src/models/Rnl_learnable_new.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
# named-tuple inputs
#
et_trans = _convert_agnesi(basis)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you ask Claude to not remove my commented out old code? I always keep it for a reason and want to remove it myself.

# OLD VERSION - KEEP FOR DEBUGGING then remove
# et_trans = let transforms = basis.transforms
# ET.NTtransform( xij -> begin
# trans_ij = transforms[__z2i(xij.s0), __z2i(xij.s1)]
# return trans_ij(rfun(xij))
# end )
# end

# the envelope is always a simple quartic (1 -x^2)^2
# otherwise make this transform fail.
Expand Down Expand Up @@ -76,16 +68,6 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
NZ^2, # num (Zi,Zj) pairs
selector)

# et_rbasis = SkipConnection( # input is (rij, zi, zj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

# Chain(y = et_trans, # transforms yij
# P = SkipConnection(
# et_polys,
# WrappedFunction( Py -> et_env.(Py[2]) .* Py[1] )
# )
# ), # r -> y -> P = e(y) * polys(y)
# et_linl # P -> W(Zi, Zj) * P
# )

et_rbasis = SkipConnection( # input is (rij, zi, zj)
Chain(y = et_trans, # transforms yij
Pe = BranchLayer(
Expand Down Expand Up @@ -114,14 +96,7 @@ function _agnesi_et_params(trans)
rcut = trans.rcut

params = ET.agnesi_params(pcut, pin, rin, req, rcut)
@assert params.a ≈ a

# ----- for debugging -----------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

# r = rin + rand() * (rcut - rin)
# y1 = trans(r)
# y2 = ET.eval_agnesi(r, params)
# @assert y1 ≈ y2
# -------------------------------
@assert params.a ≈ a

return params
end
Expand Down
Loading
Loading