Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions python/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pyarrow.flight as flight
import pyodbc
import requests
from arango import ArangoClient
from gqlalchemy import Memgraph
from neo4j import GraphDatabase
from neo4j.time import DateTime as Neo4jDateTime
Expand Down Expand Up @@ -836,6 +837,109 @@ def cleanup_migrate_memgraph():
mgp.add_batch_read_proc(memgraph, init_migrate_memgraph, cleanup_migrate_memgraph)


# ArangoDB dictionary to store connections per thread
arangodb_dict = {}


def init_migrate_arangodb(
collection_or_query: str,
config: mgp.Map,
config_path: str = "",
params: mgp.Nullable[mgp.Any] = None,
):
"""
Initialize connection to ArangoDB and prepare query execution.

:param collection_or_query: Collection name or AQL query
:param config: Connection configuration (host, port, username, password, database)
:param config_path: Path to JSON file containing configuration parameters
:param params: Optional bind variables for parameterized AQL queries
"""
global arangodb_dict

thread_id = threading.get_native_id()
if thread_id not in arangodb_dict:
arangodb_dict[thread_id] = {}

if len(config_path) > 0:
config = _combine_config(config=config, config_path=config_path)

if params:
_check_params_type(params, (dict,))

# Build connection URL
host = config.get(Constants.HOST, "localhost")
port = config.get(Constants.PORT, 8529)
username = config.get(Constants.USERNAME, "root")
password = config.get(Constants.PASSWORD, "")
database = config.get(Constants.DATABASE, "_system")

# Connect to ArangoDB
client = ArangoClient(hosts=f"http://{host}:{port}")
db = client.db(database, username=username, password=password)

# Formulate AQL query
aql_query = _formulate_aql_query(collection_or_query, db)

# Execute query with optional bind variables
bind_vars = params if params is not None else {}
cursor = db.aql.execute(aql_query, bind_vars=bind_vars, stream=True)

arangodb_dict[thread_id][Constants.CONNECTION] = client
arangodb_dict[thread_id][Constants.DATABASE] = db
arangodb_dict[thread_id][Constants.CURSOR] = iter(cursor)


def arangodb(
collection_or_query: str,
config: mgp.Map,
config_path: str = "",
params: mgp.Nullable[mgp.Any] = None,
) -> mgp.Record(row=mgp.Map):
"""
Migrate data from ArangoDB to Memgraph. Can migrate a specific collection,
or execute a custom AQL query.

:param collection_or_query: Collection name or an AQL query
:param config: Connection configuration for ArangoDB
(host, port, username, password, database)
:param config_path: Path to a JSON file containing connection parameters
:param params: Optional bind variables for parameterized AQL queries
:return: Stream of rows from ArangoDB
"""
global arangodb_dict

thread_id = threading.get_native_id()
cursor = arangodb_dict[thread_id][Constants.CURSOR]

batch = []
for _ in range(Constants.BATCH_SIZE):
try:
row = _convert_arangodb_value(next(cursor))
batch.append(mgp.Record(row=row))
except StopIteration:
break

return batch


def cleanup_migrate_arangodb():
"""
Clean up ArangoDB dictionary references per-thread.
"""
global arangodb_dict

thread_id = threading.get_native_id()
if thread_id in arangodb_dict:
client = arangodb_dict[thread_id].get(Constants.CONNECTION)
if client:
client.close()
arangodb_dict.pop(thread_id, None)


mgp.add_batch_read_proc(arangodb, init_migrate_arangodb, cleanup_migrate_arangodb)


servicenow_dict = {}


Expand Down Expand Up @@ -1096,3 +1200,72 @@ def _build_neo4j_uri(config: mgp.Map) -> str:
port = config.get(Constants.PORT, 7687)
uri_scheme = config.get(Constants.URI_SCHEME, "bolt")
return f"{uri_scheme}://{host}:{port}"


def _formulate_aql_query(collection_or_query: str, db) -> str:
"""
Formulate an AQL query from a collection name or return the query if it's already AQL.

:param collection_or_query: Collection name or AQL query
:param db: ArangoDB database connection
:return: AQL query string
"""
words = collection_or_query.split()
# If it contains multiple words or AQL keywords, treat it as an AQL query
if len(words) > 1 or any(
keyword.lower() in collection_or_query.lower()
for keyword in ["FOR", "RETURN", "FILTER", "LET", "COLLECT", "SORT", "LIMIT"]
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: AQL Keyword Matching Fails on Substrings

The _formulate_aql_query function uses substring matching to detect AQL keywords, which is incorrect. It checks keyword.lower() in collection_or_query.lower(), but this will incorrectly match keywords that are part of a word (e.g., "FOR" would match "FOREVER"). This causes collection names like "FOREVER" to be treated as AQL queries, which will then fail at execution time. The keyword matching should use word boundary checking or proper tokenization instead of simple substring matching.

Fix in Cursor Fix in Web

return collection_or_query

# Otherwise, treat it as a collection name
collection_name = collection_or_query.strip()
# Check if collection exists
if not db.has_collection(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")

# Return AQL query to fetch all documents from the collection
return f"FOR doc IN {collection_name} RETURN doc"


def _convert_arangodb_value(value):
"""
Convert ArangoDB values to Python-compatible formats.
Handles ArangoDB-specific types like _id, _key, _rev, and nested structures.
"""
if value is None:
return None

# Handle dict types (documents)
if isinstance(value, dict):
result = {}
for key, val in value.items():
# Convert nested structures recursively
result[key] = _convert_arangodb_value(val)
return result

# Handle list types
if isinstance(value, list):
return [_convert_arangodb_value(item) for item in value]

# Handle bytes
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
return base64.b64encode(value).decode("ascii")

# Handle datetime objects
if isinstance(value, datetime.datetime):
return value.isoformat()

# Handle date objects
if isinstance(value, datetime.date):
return value.isoformat()

# Handle Decimal types
if isinstance(value, Decimal):
return float(value)

# For other types (int, float, bool, str), return as is
return value
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ platformdirs==4.3.6
pyarrow==19.0.1
psycopg2-binary==2.9.10
pyodbc==5.2.0
python-arango==8.2.2
python-Levenshtein==0.26.1
scikit-learn==1.5.2
scipy==1.13.1
Expand Down
Loading