Skip to content

Commit ebff16d

Browse files
committed
Allow to add data to a System
1 parent 14d2a72 commit ebff16d

39 files changed

+229
-188
lines changed

rascaline-c-api/src/calculator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,11 @@ pub unsafe extern fn rascal_calculator_compute(
393393
}
394394
check_pointers!(calculator, descriptor, systems);
395395

396-
// Create a Vec<Box<dyn System>> from the passed systems
396+
// Create a Vec<System> from the passed systems
397397
let c_systems = std::slice::from_raw_parts_mut(systems, systems_count);
398398
let mut systems = Vec::with_capacity(c_systems.len());
399399
for system in c_systems {
400-
systems.push(Box::new(system) as Box<dyn System>);
400+
systems.push(System::new(system));
401401
}
402402

403403
let c_gradients = std::slice::from_raw_parts(options.gradients, options.gradients_count);

rascaline-c-api/src/system.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::ffi::CStr;
33

44
use rascaline::types::{Vector3D, Matrix3};
55
use rascaline::systems::{SimpleSystem, Pair, UnitCell};
6-
use rascaline::{Error, System};
6+
use rascaline::{Error, SystemBase};
77

88
use crate::RASCAL_SYSTEM_ERROR;
99

@@ -111,7 +111,7 @@ pub struct rascal_system_t {
111111
unsafe impl Send for rascal_system_t {}
112112
unsafe impl Sync for rascal_system_t {}
113113

114-
impl<'a> System for &'a mut rascal_system_t {
114+
impl<'a> SystemBase for &'a mut rascal_system_t {
115115
fn size(&self) -> Result<usize, Error> {
116116
let function = self.size.ok_or_else(|| Error::External {
117117
status: RASCAL_SYSTEM_ERROR,
@@ -424,7 +424,7 @@ pub unsafe extern fn rascal_basic_systems_read(
424424
catch_unwind(move || {
425425
check_pointers!(path, systems, count);
426426
let path = CStr::from_ptr(path).to_str()?;
427-
let simple_systems = rascaline::systems::read_from_file(path)?;
427+
let simple_systems = rascaline::systems::read_simple_systems_from_file(path)?;
428428

429429
let mut c_systems = Vec::with_capacity(simple_systems.len());
430430
for system in simple_systems {

rascaline/benches/lode-spherical-expansion.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
#![allow(clippy::needless_return)]
2-
use rascaline::{Calculator, System, CalculationOptions};
2+
use rascaline::{Calculator, CalculationOptions};
33

44
use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
55
use criterion::{criterion_group, criterion_main};
66

7-
fn load_systems(path: &str) -> Vec<Box<dyn System>> {
8-
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
9-
.expect("failed to read file");
10-
11-
return systems.into_iter()
12-
.map(|s| Box::new(s) as Box<dyn System>)
13-
.collect()
14-
}
15-
167
fn run_spherical_expansion(mut group: BenchmarkGroup<WallTime>,
178
path: &str,
189
gradients: bool,
1910
test_mode: bool,
2011
) {
21-
let mut systems = load_systems(path);
12+
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
13+
.expect("failed to read file");
2214

2315
if test_mode {
2416
// Reduce the time/RAM required to test the benchmarks code.

rascaline/benches/soap-power-spectrum.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,18 @@
11
#![allow(clippy::needless_return)]
22

3-
use rascaline::{Calculator, System, CalculationOptions};
3+
use rascaline::{Calculator, CalculationOptions};
44

55
use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
66
use criterion::{criterion_group, criterion_main};
77

8-
9-
fn load_systems(path: &str) -> Vec<Box<dyn System>> {
10-
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
11-
.expect("failed to read file");
12-
13-
return systems.into_iter()
14-
.map(|s| Box::new(s) as Box<dyn System>)
15-
.collect()
16-
}
17-
188
fn run_soap_power_spectrum(
199
mut group: BenchmarkGroup<WallTime>,
2010
path: &str,
2111
gradients: bool,
2212
test_mode: bool,
2313
) {
24-
let mut systems = load_systems(path);
14+
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
15+
.expect("failed to read file");
2516

2617
if test_mode {
2718
// Reduce the time/RAM required to test the benchmarks code.

rascaline/benches/soap-spherical-expansion.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
#![allow(clippy::needless_return)]
2-
use rascaline::{Calculator, System, CalculationOptions};
2+
use rascaline::{Calculator, CalculationOptions};
33

44
use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
55
use criterion::{criterion_group, criterion_main};
66

7-
fn load_systems(path: &str) -> Vec<Box<dyn System>> {
8-
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
9-
.expect("failed to read file");
10-
11-
return systems.into_iter()
12-
.map(|s| Box::new(s) as Box<dyn System>)
13-
.collect()
14-
}
15-
167
fn run_spherical_expansion(mut group: BenchmarkGroup<WallTime>,
178
path: &str,
189
gradients: bool,
1910
test_mode: bool,
2011
) {
21-
let mut systems = load_systems(path);
12+
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
13+
.expect("failed to read file");
2214

2315
if test_mode {
2416
// Reduce the time/RAM required to test the benchmarks code.

rascaline/examples/compute-soap.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use metatensor::Labels;
2-
use rascaline::{Calculator, System, CalculationOptions};
2+
use rascaline::{Calculator, CalculationOptions};
33

44
fn main() -> Result<(), Box<dyn std::error::Error>> {
55
// load the systems from command line argument
66
let path = std::env::args().nth(1).expect("expected a command line argument");
7-
let systems = rascaline::systems::read_from_file(path)?;
8-
// transform systems into a vector of trait objects (`Vec<Box<dyn System>>`)
9-
let mut systems = systems.into_iter()
10-
.map(|s| Box::new(s) as Box<dyn System>)
11-
.collect::<Vec<_>>();
7+
let mut systems = rascaline::systems::read_from_file(path)?;
128

139
// pass hyper-parameters as JSON
1410
let parameters = r#"{

rascaline/examples/profiling.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use metatensor::{TensorMap, Labels};
2-
use rascaline::{Calculator, System, CalculationOptions};
2+
use rascaline::{Calculator, CalculationOptions};
33

44
fn main() -> Result<(), Box<dyn std::error::Error>> {
55
let path = std::env::args().nth(1).expect("expected a command line argument");
@@ -28,10 +28,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
2828
/// Compute SOAP power spectrum, this is the same code as the 'compute-soap'
2929
/// example
3030
fn compute_soap(path: &str) -> Result<TensorMap, Box<dyn std::error::Error>> {
31-
let systems = rascaline::systems::read_from_file(path)?;
32-
let mut systems = systems.into_iter()
33-
.map(|s| Box::new(s) as Box<dyn System>)
34-
.collect::<Vec<_>>();
31+
let mut systems = rascaline::systems::read_from_file(path)?;
3532

3633
let parameters = r#"{
3734
"cutoff": 5.0,

rascaline/src/calculator.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ impl Calculator {
285285
}
286286

287287
#[time_graph::instrument(name="Calculator::prepare")]
288-
fn prepare(&mut self, systems: &mut [Box<dyn System>], options: CalculationOptions) -> Result<TensorMap, Error> {
288+
fn prepare(&mut self, systems: &mut [System], options: CalculationOptions) -> Result<TensorMap, Error> {
289289
let default_keys = self.implementation.keys(systems)?;
290290
let keys = match options.selected_keys {
291291
Some(keys) if keys.is_empty() => {
@@ -446,14 +446,14 @@ impl Calculator {
446446
/// features.
447447
pub fn compute(
448448
&mut self,
449-
systems: &mut [Box<dyn System>],
449+
systems: &mut [System],
450450
options: CalculationOptions,
451451
) -> Result<TensorMap, Error> {
452452
let mut native_systems;
453453
let systems = if options.use_native_system {
454454
native_systems = Vec::with_capacity(systems.len());
455455
for system in systems {
456-
native_systems.push(Box::new(SimpleSystem::try_from(&**system)?) as Box<dyn System>);
456+
native_systems.push(System::new(SimpleSystem::try_from(&**system)?) as System);
457457
}
458458
&mut native_systems
459459
} else {

rascaline/src/calculators/atomic_composition.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ impl CalculatorBase for AtomicComposition {
3535
&[]
3636
}
3737

38-
fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
38+
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
3939
return CenterSpeciesKeys.keys(systems);
4040
}
4141

@@ -47,7 +47,7 @@ impl CalculatorBase for AtomicComposition {
4747
return vec!["structure", "center"];
4848
}
4949

50-
fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
50+
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
5151
assert_eq!(keys.names(), ["species_center"]);
5252
let mut samples = Vec::new();
5353
for [species_center_key] in keys.iter_fixed_size() {
@@ -84,7 +84,7 @@ impl CalculatorBase for AtomicComposition {
8484
&self,
8585
keys: &Labels,
8686
_samples: &[Labels],
87-
_systems: &mut [Box<dyn System>],
87+
_systems: &mut [System],
8888
) -> Result<Vec<Labels>, Error> {
8989
// Positions/cell gradients of the composition are zero everywhere.
9090
// Therefore, we only return a vector of empty labels (one for each key).
@@ -110,7 +110,7 @@ impl CalculatorBase for AtomicComposition {
110110

111111
fn compute(
112112
&mut self,
113-
systems: &mut [Box<dyn System>],
113+
systems: &mut [System],
114114
descriptor: &mut TensorMap,
115115
) -> Result<(), Error> {
116116
assert_eq!(descriptor.keys().names(), ["species_center"]);

rascaline/src/calculators/dummy_calculator.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ impl CalculatorBase for DummyCalculator {
4343
std::slice::from_ref(&self.cutoff)
4444
}
4545

46-
fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
46+
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
4747
return CenterSpeciesKeys.keys(systems);
4848
}
4949

5050
fn sample_names(&self) -> Vec<&str> {
5151
AtomCenteredSamples::sample_names()
5252
}
5353

54-
fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
54+
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
5555
assert_eq!(keys.names(), ["species_center"]);
5656
let mut samples = Vec::new();
5757
for [species_center] in keys.iter_fixed_size() {
@@ -75,7 +75,7 @@ impl CalculatorBase for DummyCalculator {
7575
}
7676
}
7777

78-
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
78+
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result<Vec<Labels>, Error> {
7979
debug_assert_eq!(keys.count(), samples.len());
8080
let mut gradient_samples = Vec::new();
8181
for ([species_center], samples) in keys.iter_fixed_size().zip(samples) {
@@ -110,7 +110,7 @@ impl CalculatorBase for DummyCalculator {
110110
}
111111

112112
#[time_graph::instrument(name = "DummyCalculator::compute")]
113-
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
113+
fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> {
114114
if self.name.contains("log-test-info:") {
115115
info!("{}", self.name);
116116
} else if self.name.contains("log-test-warn:") {

0 commit comments

Comments
 (0)