Update remote sampler and feature store with Neo4j examples#10673
Draft
victorneo4j wants to merge 7 commits intopyg-team:masterfrom
Draft
Update remote sampler and feature store with Neo4j examples#10673victorneo4j wants to merge 7 commits intopyg-team:masterfrom
victorneo4j wants to merge 7 commits intopyg-team:masterfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I have tested that it works to train a GNN with this. there are some methods in the neo4j- graph/feature store I haven't tested yet though.
Change Summary
New Files
torch_geometric/data/graph_store.py(+259 lines, modified)Added the database graph store infrastructure:
ResultSchema— abstract marker base for declarative record-shape descriptions. Carriesis_heteroflag.HomogeneousSchema— default schema; expectsnodes(list of global IDs) andedges(list of[src, dst]pairs) in the query result.HeterogeneousSchema— schema for hetero graphs; expectsnode_dictandedge_dictkeyed by type. Edge keys can be 3-tuples or separator-encoded strings.DatabaseGraphStore— abstract base that wrapsGraphStore. Subclasses implement one hook:_fetch_subgraph(query, kwargs). Decoding into PyG tensors (_decode_subgraph,_decode_homogeneous,_decode_heterogeneous) and empty-result fallback (_empty_result) are all handled here.sample_subgraphis the single entry point used by samplers.torch_geometric/sampler/database_sampler.py(new, 216 lines)Abstract sampler that pushes multi-hop sampling into a database:
node_sampling_queryandedge_sampling_queryonce at construction via abstract hooks_build_node_sampling_query/_build_edge_sampling_query(both returnNoneby default — override to support either mode).sample_from_nodes/sample_from_edgesextract seeds, call_build_query_params(abstract), rungraph_store.sample_subgraph, and wrap the result inSamplerOutputorHeteroSamplerOutputvia_build_output.DatabaseGraphSAINTSampler: passed explicitly or derived fromis_hetero; mismatch raisesValueError.torch_geometric/sampler/__init__.py(+2 lines)Removed stale
RemoteSamplerexport; addedDatabaseSamplerto__all__.torch_geometric/loader/database_graph_saint.py(new, 290 lines)DatabaseGraphSAINTSampler— aDataLoadersubclass for database-backed GraphSAINT:_sample_nodes(which nodes),_build_subgraph_query(one-time Cypher/query string),_build_data(assembleDatafrom decoded tensors).__getitem__draws one subgraph record per step;_collatedecodes it viagraph_store._decode_subgraphusingself.schema, stacks COOedge_index, looks up per-nodenode_norm, and calls_build_data._compute_normpre-samples untilN × sample_coveragevisits, computesnode_norm[v] = num_samples / count[v] / N(matches PyG formula). Norms are cached to disk viasave_dir._setup()post-init hook runs after_build_subgraph_query; subclasses can build extra queries there.examples/neo4j/data/neo4j_feature_store.py(new, 350 lines)Concrete
DatabaseFeatureStorebacked by Neo4j:attr_mapaccepts either a flat{attr_name: spec}(homogeneous) or a nested{node_label: {attr_name: spec}}(heterogeneous); normalised to nested form internally.speccarriesproperty(Neo4j property name),dtype(float32/int64/str), andencoding(f64[]orbyte[])._build_query: single CypherUNWIND … MATCH … RETURNthat fetches all requested attrs in one round-trip._fetch_remote_attrs/_decode_remote_attrs: single-attr hooks used by the base class pipeline.byte[]vianp.frombuffer,f64[]via direct cast. String labels encoded toint64via a per-instance vocabulary (_labelsdict), built lazily._put_tensor_db/_remove_tensor_db: write-back viaUNWIND … SET/REMOVE.apoc_available(): probes viaRETURN apoc.version().__getstate__/__setstate__drop and re-create it for safe DataLoader pickling.atexitcloses it on exit.examples/neo4j/data/neo4j_graph_store.py(new, 157 lines)Concrete
DatabaseGraphStorebacked by Neo4j:_fetch_subgraph: runs a query withfetch_size=-1, returns first record dict (orNone)._get_edge_index: fullMATCH … RETURN src, dstscan; sorts for CSC/CSR layouts._put_edge_index:UNWIND … MERGEfrom global-ID pairs._remove_edge_index:DELETEall relationships of the given type.apoc_available(): same probe pattern as the feature store.Neo4jFeatureStore.examples/neo4j/neo4j_samplers.py(new, 238 lines)Two classes for Neo4j-backed GraphSAGE neighbor sampling:
Neo4jSampler(DatabaseSampler)— abstract base for all Neo4j samplers:node_labelandrel_typeas safe Cypher identifiers (injection guard) via_validate_cypher_ident.nodeid_property,node_label,rel_type,profile._probe_apoc: optional helper subclasses call when their query usesapoc.*.Neo4jGraphSAGESampler(Neo4jSampler)— concrete GraphSAGE sampler:apoc.coll.toSet/flatten/randomItems); probed in__init__._build_node_sampling_query: generates a multi-hop BFS Cypher query parameterised by$seed_ids. Each hop is aCALL {}subquery block. Mirrors pyg-lib semantics: take-all whenk < 0ork ≥ |neighbourhood|, otherwise uniform sample viaapoc.coll.randomItems. Returnsedges+nodesmatchingHomogeneousSchema.direction:incoming/outgoing/undirected(adjusts edge pattern and neighbor-node expression).visited(nonodes_by_hop); finalRETURNcollects[n IN visited | n.nodeId].examples/neo4j/neo4j_graphsaint.py(new, 297 lines)Two classes for Neo4j-backed GraphSAINT:
Neo4jGraphSAINTSampler(DatabaseGraphSAINTSampler)— abstract Neo4j base:_get_total_nodesviaMATCH (n{label}) RETURN count(n)._build_subgraph_query: induced-subgraph query for$nodes; returnsnodes,edges, and optionallysplits(whensplit_propertyis set) in one round-trip._build_data: buildsTensorAttrlist from user-configuredfeature_attrs, callsfeature_store.multi_get_tensor, assemblesData, attaches optionaltrain_maskfromsplitscolumn, attachesnode_norm.feature_attrs:{data_key: (store_attr_name, dtype)}mapping — decouples Data attribute names from Neo4j property names.node_label,rel_type, andsplit_property.Neo4jGraphSAINTRandomWalkSampler(Neo4jGraphSAINTSampler)— concrete random-walk variant:_build_walk_query: pure-Cypher BFS random walk — picks$batch_sizerandom roots, unrollswalk_lengthhopCALL {}blocks, returns deduplicated visited IDs viaapoc.coll.toSet. APOC required; probed eagerly in__init__._sample_nodes: runs the walk query viagraph_store._fetch_subgraph._setup: builds_walk_queryafter base init (useswalk_length, available by then).