Skip to content

Update remote sampler and feature store with Neo4j examples#10673

Draft
victorneo4j wants to merge 7 commits intopyg-team:masterfrom
victorneo4j:neo4j
Draft

Update remote sampler and feature store with Neo4j examples#10673
victorneo4j wants to merge 7 commits intopyg-team:masterfrom
victorneo4j:neo4j

Conversation

@victorneo4j
Copy link
Copy Markdown

@victorneo4j victorneo4j commented Apr 24, 2026

I have tested that it works to train a GNN with this. there are some methods in the neo4j- graph/feature store I haven't tested yet though.

Change Summary

New Files

torch_geometric/data/graph_store.py (+259 lines, modified)

Added the database graph store infrastructure:

  • ResultSchema — abstract marker base for declarative record-shape descriptions. Carries is_hetero flag.
  • HomogeneousSchema — default schema; expects nodes (list of global IDs) and edges (list of [src, dst] pairs) in the query result.
  • HeterogeneousSchema — schema for hetero graphs; expects node_dict and edge_dict keyed by type. Edge keys can be 3-tuples or separator-encoded strings.
  • DatabaseGraphStore — abstract base that wraps GraphStore. Subclasses implement one hook: _fetch_subgraph(query, kwargs). Decoding into PyG tensors (_decode_subgraph, _decode_homogeneous, _decode_heterogeneous) and empty-result fallback (_empty_result) are all handled here. sample_subgraph is the single entry point used by samplers.

torch_geometric/sampler/database_sampler.py (new, 216 lines)

Abstract sampler that pushes multi-hop sampling into a database:

  • Compiles node_sampling_query and edge_sampling_query once at construction via abstract hooks _build_node_sampling_query / _build_edge_sampling_query (both return None by default — override to support either mode).
  • sample_from_nodes / sample_from_edges extract seeds, call _build_query_params (abstract), run graph_store.sample_subgraph, and wrap the result in SamplerOutput or HeteroSamplerOutput via _build_output.
  • Schema handling mirrors DatabaseGraphSAINTSampler: passed explicitly or derived from is_hetero; mismatch raises ValueError.

torch_geometric/sampler/__init__.py (+2 lines)

Removed stale RemoteSampler export; added DatabaseSampler to __all__.


torch_geometric/loader/database_graph_saint.py (new, 290 lines)

DatabaseGraphSAINTSampler — a DataLoader subclass for database-backed GraphSAINT:

  • Subclasses implement three hooks: _sample_nodes (which nodes), _build_subgraph_query (one-time Cypher/query string), _build_data (assemble Data from decoded tensors).
  • __getitem__ draws one subgraph record per step; _collate decodes it via graph_store._decode_subgraph using self.schema, stacks COO edge_index, looks up per-node node_norm, and calls _build_data.
  • _compute_norm pre-samples until N × sample_coverage visits, computes node_norm[v] = num_samples / count[v] / N (matches PyG formula). Norms are cached to disk via save_dir.
  • _setup() post-init hook runs after _build_subgraph_query; subclasses can build extra queries there.
  • Supports both homogeneous and heterogeneous schemas; hetero path skips norm (not yet implemented for hetero).

examples/neo4j/data/neo4j_feature_store.py (new, 350 lines)

Concrete DatabaseFeatureStore backed by Neo4j:

  • attr_map accepts either a flat {attr_name: spec} (homogeneous) or a nested {node_label: {attr_name: spec}} (heterogeneous); normalised to nested form internally.
  • Each spec carries property (Neo4j property name), dtype (float32 / int64 / str), and encoding (f64[] or byte[]).
  • _build_query: single Cypher UNWIND … MATCH … RETURN that fetches all requested attrs in one round-trip.
  • _fetch_remote_attrs / _decode_remote_attrs: single-attr hooks used by the base class pipeline.
  • Float decoding: byte[] via np.frombuffer, f64[] via direct cast. String labels encoded to int64 via a per-instance vocabulary (_labels dict), built lazily.
  • _put_tensor_db / _remove_tensor_db: write-back via UNWIND … SET / REMOVE.
  • apoc_available(): probes via RETURN apoc.version().
  • Driver created lazily per-process; __getstate__ / __setstate__ drop and re-create it for safe DataLoader pickling. atexit closes it on exit.

examples/neo4j/data/neo4j_graph_store.py (new, 157 lines)

Concrete DatabaseGraphStore backed by Neo4j:

  • Implements _fetch_subgraph: runs a query with fetch_size=-1, returns first record dict (or None).
  • _get_edge_index: full MATCH … RETURN src, dst scan; sorts for CSC/CSR layouts.
  • _put_edge_index: UNWIND … MERGE from global-ID pairs.
  • _remove_edge_index: DELETE all relationships of the given type.
  • apoc_available(): same probe pattern as the feature store.
  • Same lazy-driver / pickling / atexit pattern as Neo4jFeatureStore.

examples/neo4j/neo4j_samplers.py (new, 238 lines)

Two classes for Neo4j-backed GraphSAGE neighbor sampling:

Neo4jSampler(DatabaseSampler) — abstract base for all Neo4j samplers:

  • Validates node_label and rel_type as safe Cypher identifiers (injection guard) via _validate_cypher_ident.
  • Stores nodeid_property, node_label, rel_type, profile.
  • _probe_apoc: optional helper subclasses call when their query uses apoc.*.

Neo4jGraphSAGESampler(Neo4jSampler) — concrete GraphSAGE sampler:

  • Requires APOC (apoc.coll.toSet / flatten / randomItems); probed in __init__.
  • _build_node_sampling_query: generates a multi-hop BFS Cypher query parameterised by $seed_ids. Each hop is a CALL {} subquery block. Mirrors pyg-lib semantics: take-all when k < 0 or k ≥ |neighbourhood|, otherwise uniform sample via apoc.coll.randomItems. Returns edges + nodes matching HomogeneousSchema.
  • Supports direction: incoming / outgoing / undirected (adjusts edge pattern and neighbor-node expression).
  • Fixed bug from previous version: query now tracks only visited (no nodes_by_hop); final RETURN collects [n IN visited | n.nodeId].

examples/neo4j/neo4j_graphsaint.py (new, 297 lines)

Two classes for Neo4j-backed GraphSAINT:

Neo4jGraphSAINTSampler(DatabaseGraphSAINTSampler) — abstract Neo4j base:

  • Implements _get_total_nodes via MATCH (n{label}) RETURN count(n).
  • _build_subgraph_query: induced-subgraph query for $nodes; returns nodes, edges, and optionally splits (when split_property is set) in one round-trip.
  • _build_data: builds TensorAttr list from user-configured feature_attrs, calls feature_store.multi_get_tensor, assembles Data, attaches optional train_mask from splits column, attaches node_norm.
  • feature_attrs: {data_key: (store_attr_name, dtype)} mapping — decouples Data attribute names from Neo4j property names.
  • Cypher-identifier validation for node_label, rel_type, and split_property.

Neo4jGraphSAINTRandomWalkSampler(Neo4jGraphSAINTSampler) — concrete random-walk variant:

  • _build_walk_query: pure-Cypher BFS random walk — picks $batch_size random roots, unrolls walk_length hop CALL {} blocks, returns deduplicated visited IDs via apoc.coll.toSet. APOC required; probed eagerly in __init__.
  • _sample_nodes: runs the walk query via graph_store._fetch_subgraph.
  • _setup: builds _walk_query after base init (uses walk_length, available by then).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant