Skip to content

Commit 753febe

Browse files
authored
Merge pull request #2920 from nicolossus/fix_DumpLayerConnections
Filter connections also by target when dumping layer connections
2 parents 7e91422 + c63fb17 commit 753febe

11 files changed

Lines changed: 420 additions & 489 deletions

File tree

nestkernel/conn_builder.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,14 @@ nest::OneToOneBuilder::connect_()
619619
Node* target = n->get_node();
620620

621621
const size_t tnode_id = n->get_node_id();
622-
const int idx = targets_->find( tnode_id );
623-
if ( idx < 0 ) // Is local node in target list?
622+
const long lid = targets_->get_lid( tnode_id );
623+
if ( lid < 0 ) // Is local node in target list?
624624
{
625625
continue;
626626
}
627627

628628
// one-to-one, thus we can use target idx for source as well
629-
const size_t snode_id = ( *sources_ )[ idx ];
629+
const size_t snode_id = ( *sources_ )[ lid ];
630630
if ( not allow_autapses_ and snode_id == tnode_id )
631631
{
632632
// no skipping required / possible,
@@ -819,7 +819,7 @@ nest::AllToAllBuilder::connect_()
819819
const size_t tnode_id = n->get_node_id();
820820

821821
// Is the local node in the targets list?
822-
if ( targets_->find( tnode_id ) < 0 )
822+
if ( targets_->get_lid( tnode_id ) < 0 )
823823
{
824824
continue;
825825
}
@@ -1102,7 +1102,7 @@ nest::FixedInDegreeBuilder::connect_()
11021102
const size_t tnode_id = n->get_node_id();
11031103

11041104
// Is the local node in the targets list?
1105-
if ( targets_->find( tnode_id ) < 0 )
1105+
if ( targets_->get_lid( tnode_id ) < 0 )
11061106
{
11071107
continue;
11081108
}
@@ -1534,7 +1534,7 @@ nest::BernoulliBuilder::connect_()
15341534
const size_t tnode_id = n->get_node_id();
15351535

15361536
// Is the local node in the targets list?
1537-
if ( targets_->find( tnode_id ) < 0 )
1537+
if ( targets_->get_lid( tnode_id ) < 0 )
15381538
{
15391539
continue;
15401540
}

nestkernel/layer_impl.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ Layer< D >::dump_connections( std::ostream& out,
306306
std::vector< std::pair< Position< D >, size_t > >* src_vec = get_global_positions_vector( node_collection );
307307

308308
// Dictionary with parameters for get_connections()
309-
DictionaryDatum ncdict( new Dictionary );
310-
def( ncdict, names::synapse_model, syn_model );
309+
DictionaryDatum conn_filter( new Dictionary );
310+
def( conn_filter, names::synapse_model, syn_model );
311+
def( conn_filter, names::target, NodeCollectionDatum( target_layer->get_node_collection() ) );
311312

312313
// Avoid setting up new array for each iteration of the loop
313314
std::vector< size_t > source_array( 1 );
@@ -321,8 +322,8 @@ Layer< D >::dump_connections( std::ostream& out,
321322
const Position< D > source_pos = src_iter->first;
322323

323324
source_array[ 0 ] = source_node_id;
324-
def( ncdict, names::source, NodeCollectionDatum( NodeCollection::create( source_array ) ) );
325-
ArrayDatum connectome = kernel().connection_manager.get_connections( ncdict );
325+
def( conn_filter, names::source, NodeCollectionDatum( NodeCollection::create( source_array ) ) );
326+
ArrayDatum connectome = kernel().connection_manager.get_connections( conn_filter );
326327

327328
// Print information about all local connections for current source
328329
for ( size_t i = 0; i < connectome.size(); ++i )
@@ -344,8 +345,9 @@ Layer< D >::dump_connections( std::ostream& out,
344345
Layer< D >* tgt_layer = dynamic_cast< Layer< D >* >( target_layer.get() );
345346

346347
out << ' ';
347-
const size_t tnode_id = tgt_layer->node_collection_->find( target_node_id );
348-
tgt_layer->compute_displacement( source_pos, tnode_id ).print( out );
348+
const long tnode_lid = tgt_layer->node_collection_->get_lid( target_node_id );
349+
assert( tnode_lid >= 0 );
350+
tgt_layer->compute_displacement( source_pos, tnode_lid ).print( out );
349351
out << '\n';
350352
}
351353
}

nestkernel/nestmodule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ NestModule::Find_g_iFunction::execute( SLIInterpreter* i ) const
10721072
NodeCollectionDatum nodecollection = getValue< NodeCollectionDatum >( i->OStack.pick( 1 ) );
10731073
const long node_id = getValue< long >( i->OStack.pick( 0 ) );
10741074

1075-
const auto res = nodecollection->find( node_id );
1075+
const auto res = nodecollection->get_lid( node_id );
10761076
i->OStack.pop( 2 );
10771077
i->OStack.push( res );
10781078
i->EStack.pop();

nestkernel/node_collection.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,11 +981,11 @@ NodeCollectionComposite::merge_parts_( std::vector< NodeCollectionPrimitive >& p
981981
bool
982982
NodeCollectionComposite::contains( const size_t node_id ) const
983983
{
984-
return find( node_id ) != -1;
984+
return get_lid( node_id ) != -1;
985985
}
986986

987987
long
988-
NodeCollectionComposite::find( const size_t node_id ) const
988+
NodeCollectionComposite::get_lid( const size_t node_id ) const
989989
{
990990
const auto add_size_op = []( const long a, const NodeCollectionPrimitive& b ) { return a + b.size(); };
991991

@@ -1018,7 +1018,7 @@ NodeCollectionComposite::find( const size_t node_id ) const
10181018
// Need to find number of nodes in previous parts to know if the the step hits the node_id.
10191019
const auto num_prev_nodes =
10201020
std::accumulate( parts_.begin(), parts_.begin() + middle, static_cast< size_t >( 0 ), add_size_op );
1021-
const auto absolute_pos = num_prev_nodes + parts_[ middle ].find( node_id );
1021+
const auto absolute_pos = num_prev_nodes + parts_[ middle ].get_lid( node_id );
10221022

10231023
// The first or the last node can be somewhere in the middle part.
10241024
const auto absolute_part_start = start_part_ == middle ? start_offset_ : 0;
@@ -1041,7 +1041,7 @@ NodeCollectionComposite::find( const size_t node_id ) const
10411041
// Since NC is not sliced, we can just calculate and return the local ID.
10421042
const auto sum_pre =
10431043
std::accumulate( parts_.begin(), parts_.begin() + middle, static_cast< size_t >( 0 ), add_size_op );
1044-
return sum_pre + parts_[ middle ].find( node_id );
1044+
return sum_pre + parts_[ middle ].get_lid( node_id );
10451045
}
10461046
}
10471047
}

nestkernel/node_collection.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ class NodeCollection
386386
*
387387
* @return Index of node with given node ID; -1 if node not in NodeCollection.
388388
*/
389-
virtual long find( const size_t ) const = 0;
389+
virtual long get_lid( const size_t ) const = 0;
390390

391391
/**
392392
* Returns whether the NodeCollection contains any nodes with proxies or not.
@@ -512,7 +512,7 @@ class NodeCollectionPrimitive : public NodeCollection
512512
bool is_range() const override;
513513
bool empty() const override;
514514

515-
long find( const size_t ) const override;
515+
long get_lid( const size_t ) const override;
516516

517517
bool has_proxies() const override;
518518

@@ -650,7 +650,7 @@ class NodeCollectionComposite : public NodeCollection
650650
bool is_range() const override;
651651
bool empty() const override;
652652

653-
long find( const size_t ) const override;
653+
long get_lid( const size_t ) const override;
654654

655655
bool has_proxies() const override;
656656
};
@@ -807,7 +807,7 @@ NodeCollectionPrimitive::empty() const
807807
}
808808

809809
inline long
810-
NodeCollectionPrimitive::find( const size_t neuron_id ) const
810+
NodeCollectionPrimitive::get_lid( const size_t neuron_id ) const
811811
{
812812
if ( neuron_id > last_ )
813813
{

pynest/nest/lib/hl_api_helper.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,16 @@
2424
API of the PyNEST wrapper.
2525
"""
2626

27-
import warnings
28-
import json
2927
import functools
30-
import textwrap
28+
import json
3129
import os
3230
import pydoc
33-
31+
import textwrap
32+
import warnings
3433
from string import Template
3534

36-
from ..ll_api import sli_func, sps, sr, spp
3735
from .. import pynestkernel as kernel
36+
from ..ll_api import sli_func, spp, sps, sr
3837

3938
__all__ = [
4039
"broadcast",
@@ -52,6 +51,7 @@
5251
"restructure_data",
5352
"show_deprecation_warning",
5453
"show_help_with_pager",
54+
"stringify_path",
5555
"SuppressedDeprecationWarning",
5656
]
5757

@@ -148,6 +148,31 @@ def new_func(*args, **kwargs):
148148
return deprecated_decorator
149149

150150

151+
def stringify_path(filepath):
152+
"""
153+
Convert path-like object to string form.
154+
155+
Attempt to convert path-like object to a string by coercing objects
156+
supporting the fspath protocol to its ``__fspath__`` method. Anything that
157+
is not path-like, which includes bytes and strings, is passed through
158+
unchanged.
159+
160+
Parameters
161+
----------
162+
filepath : object
163+
Object representing file system path.
164+
165+
Returns
166+
-------
167+
filepath : str
168+
Stringified filepath.
169+
"""
170+
171+
if isinstance(filepath, os.PathLike):
172+
filepath = filepath.__fspath__() # should return str or bytes object
173+
return filepath
174+
175+
151176
def is_literal(obj):
152177
"""Check whether obj is a "literal": a unicode string or SLI literal
153178

pynest/nest/lib/hl_api_spatial.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,20 @@
2323
Functions relating to spatial properties of nodes
2424
"""
2525

26+
import os
2627

2728
import numpy as np
2829

2930
from ..ll_api import sli_func
30-
from .hl_api_helper import is_iterable
3131
from .hl_api_connections import GetConnections
32+
from .hl_api_helper import is_iterable, stringify_path
3233
from .hl_api_parallel_computing import NumProcesses, Rank
3334
from .hl_api_types import NodeCollection
3435

3536
try:
3637
import matplotlib as mpl
37-
import matplotlib.path as mpath
3838
import matplotlib.patches as mpatches
39+
import matplotlib.path as mpath
3940

4041
HAVE_MPL = True
4142
except ImportError:
@@ -530,9 +531,12 @@ def DumpLayerNodes(layer, outname):
530531
nest.DumpLayerNodes(s_nodes, 'positions.txt')
531532
532533
"""
534+
533535
if not isinstance(layer, NodeCollection):
534536
raise TypeError("layer must be a NodeCollection")
535537

538+
outname = stringify_path(outname)
539+
536540
sli_func(
537541
"""
538542
(w) file exch DumpLayerNodes close
@@ -599,11 +603,15 @@ def DumpLayerConnections(source_layer, target_layer, synapse_model, outname):
599603
# write connectivity information to file
600604
nest.DumpLayerConnections(s_nodes, s_nodes, 'static_synapse', 'conns.txt')
601605
"""
606+
602607
if not isinstance(source_layer, NodeCollection):
603608
raise TypeError("source_layer must be a NodeCollection")
609+
604610
if not isinstance(target_layer, NodeCollection):
605611
raise TypeError("target_layer must be a NodeCollection")
606612

613+
outname = stringify_path(outname)
614+
607615
sli_func(
608616
"""
609617
/oname Set
@@ -1467,8 +1475,8 @@ def _create_mask_patches(mask, periodic, extent, source_pos, face_color="yellow"
14671475

14681476
# import pyplot here and not at toplevel to avoid preventing users
14691477
# from changing matplotlib backend after importing nest
1470-
import matplotlib.pyplot as plt
14711478
import matplotlib as mtpl
1479+
import matplotlib.pyplot as plt
14721480

14731481
edge_color = "black"
14741482
alpha = 0.2
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_issue_2629.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
"""
23+
Regression test for Issue #2629 (GitHub).
24+
25+
The issue was that ``DumpLayerConnections`` failed when a source layer was
26+
connected to more than one target layer. The test ensures that this is no
27+
longer the case.
28+
29+
For each connection between the specified source and target layer,
30+
``DumpLayerConnections`` writes the following to file:
31+
32+
source_node_id target_node_id weight delay dx dy [dz]
33+
34+
where (dx, dy [, dz]) is the displacement from source to target node.
35+
36+
This test uses the ``tmp_path`` Pytest fixture, which will provide a
37+
temporary directory unique to the test invocation. ``tmp_path`` is a
38+
``pathlib.Path`` object. Hence, the test also implicitly verifies that it
39+
is possible to pass a ``pathlib.Path`` object as filename.
40+
"""
41+
42+
import pytest
43+
44+
import nest
45+
46+
47+
@pytest.fixture(scope="module")
48+
def network():
49+
"""Fixture for building network."""
50+
51+
grid = nest.spatial.grid(shape=[2, 1])
52+
src_layer = nest.Create("iaf_psc_alpha", positions=grid)
53+
tgt_layer_1 = nest.Create("iaf_psc_alpha", positions=grid)
54+
tgt_layer_2 = nest.Create("iaf_psc_alpha", positions=grid)
55+
56+
nest.Connect(src_layer, tgt_layer_1, "all_to_all")
57+
nest.Connect(src_layer, tgt_layer_2, "one_to_one")
58+
59+
return src_layer, tgt_layer_1, tgt_layer_2
60+
61+
62+
def test_dump_layer_connections_target_1(tmp_path, network):
63+
"""Test that dumping connections with target layer 1 works."""
64+
65+
src_layer, tgt_layer_1, _ = network
66+
67+
fname_1 = tmp_path / "conns_1.txt"
68+
nest.DumpLayerConnections(src_layer, tgt_layer_1, "static_synapse", fname_1)
69+
expected_dump_1 = [
70+
"1 3 1 1 0 0",
71+
"1 4 1 1 0.5 0",
72+
"2 3 1 1 -0.5 0",
73+
"2 4 1 1 0 0",
74+
]
75+
actual_dump_1 = fname_1.read_text().splitlines()
76+
assert actual_dump_1 == expected_dump_1
77+
78+
79+
def test_dump_layer_connections_target_2(tmp_path, network):
80+
"""Test that dumping connections with target layer 2 works."""
81+
82+
src_layer, _, tgt_layer_2 = network
83+
84+
fname_2 = tmp_path / "conns_2.txt"
85+
nest.DumpLayerConnections(src_layer, tgt_layer_2, "static_synapse", fname_2)
86+
expected_dump_2 = [
87+
"1 5 1 1 0 0",
88+
"2 6 1 1 0 0",
89+
]
90+
actual_dump_2 = fname_2.read_text().splitlines()
91+
assert actual_dump_2 == expected_dump_2

0 commit comments

Comments
 (0)