diff --git a/xtas/tasks/pipeline.py b/xtas/tasks/pipeline.py index 2a7f35d..2f871b9 100644 --- a/xtas/tasks/pipeline.py +++ b/xtas/tasks/pipeline.py @@ -3,6 +3,7 @@ """ import celery +import warnings from xtas.tasks.es import is_es_document, es_address from xtas.tasks.es import get_single_result, store_single, fetch @@ -12,44 +13,66 @@ def pipeline(doc, pipeline, store_final=True, store_intermediate=False, block=True): """ - Get the result for a given document. - Pipeline should be a list of dicts, with members task and argument - e.g. [{"module" : "tokenize"}, - {"module" : "pos_tag", "arguments" : {"model" : "nltk"}}] - @param block: if True, it will block and return the actual result. - If False, it will return an AsyncResult unless the result was - cached, in which case it returns the result immediately (!) - @param store_final: if True, store the final result - @param store_intermediate: if True, store all intermediate results as well + Apply a sequence of operations to a document and return the result. + + Parameters + ---------- + doc : es_document or string + The document to process, either as a string, or a result of es_document(). + + pipeline : list of dicts + A list of dicts, each with members "task" and "arguments", e.g. + [{"task" : "tokenize"}, + {"task" : "pos_tag", "arguments" : {"model" : "nltk"}}] + Using "module" instead of "task" is now obsolete, but still supported. + + store_final : bool + If True, store the final result in ElasticSearch. + + store_intermediate : bool + If True, store all intermediate results as well. + + block : bool + If True, block, and return the result when it arrives. + If False, return the result directly if it's available immediately, + otherwise return an AsyncResult. """ - # form basic pipeline by resolving task dictionaries to task objects - tasks = [_get_task(t) for t in pipeline] + tasks = [_get_task(t) for t in pipeline] if is_es_document(doc): idx, typ, id, field = es_address(doc) - chain = [] - input = None - # Check cache for existing documents - # Iterate over tasks in reverse order, check cached result, and - # otherwise prepend task (and cache store command) to chain - for i in range(len(tasks), 0, -1): - taskname = "__".join(t.task for t in tasks[:i]) - input = get_single_result(taskname, idx, typ, id) - if input: + + def result_name(task_i): + "Results are named after the task that created them" + return "__".join(t.task for t in tasks[:(task_i + 1)]) + + last_cached_result = None + # we always have doc, which is result -1, the input to task 0 + last_cached_i = -1 + for task_i in reversed(range(0, len(tasks))): + last_cached_result = get_single_result(result_name(task_i), + idx, typ, id) + if last_cached_result: + last_cached_i = task_i break - if (i == len(tasks) and store_final) or store_intermediate: - chain.insert(0, store_single.s(taskname, idx, typ, id)) - chain.insert(0, tasks[i-1]) - if not chain: # final result was cached, good! - return input - elif input is None: - input = fetch(doc) - else: - # the doc is a string, so we can't use caching - chain = tasks - input = doc - chain = celery.chain(*chain).delay(input) + if last_cached_i == -1: + doc = fetch(doc) + else: + doc = last_cached_result + + new_tasks = [] + for task_i in range(last_cached_i + 1, len(tasks)): + new_tasks.append(tasks[task_i]) + if (task_i == len(tasks) and store_final) or store_intermediate: + new_tasks.append(store_single.s(result_name(task_i), + idx, typ, id)) + tasks = new_tasks + + if not tasks: + return doc + + chain = celery.chain(*tasks).delay(doc) if block: return chain.get() else: @@ -57,8 +80,12 @@ def pipeline(doc, pipeline, store_final=True, store_intermediate=False, def _get_task(task_dict): - "Create a celery task object from a dictionary with module and arguments" - task = task_dict['module'] + "Create a celery task object from a dictionary with task name and arguments" + if 'task' not in task_dict and 'module' in task_dict: + warnings.warn('The "module" key is deprecated,' + ' please use "task" instead.', DeprecationWarning, stacklevel=2) + task_dict['task'] = task_dict['module'] + task = task_dict['task'] if isinstance(task, (str, unicode)): task = app.tasks[task] args = task_dict.get('arguments') diff --git a/xtas/tests/test_pipeline.py b/xtas/tests/test_pipeline.py index f06e512..ef58f79 100644 --- a/xtas/tests/test_pipeline.py +++ b/xtas/tests/test_pipeline.py @@ -33,12 +33,16 @@ def test_pipeline(): assert_equal(result, expected) with eager_celery(): # do we get correct result from pipeline? + r = pipeline(s, [{"task": tokenize}, + {"task": pos_tag, "arguments": {"model": "nltk"}}]) + assert_equal(r, expected) + # is the old syntax still supported? r = pipeline(s, [{"module": tokenize}, {"module": pos_tag, "arguments": {"model": "nltk"}}]) assert_equal(r, expected) # can we specify modules by name? - r = pipeline(s, [{"module": "xtas.tasks.single.tokenize"}, - {"module": "xtas.tasks.single.pos_tag", + r = pipeline(s, [{"task": "xtas.tasks.single.tokenize"}, + {"task": "xtas.tasks.single.pos_tag", "arguments": {"model": "nltk"}}]) assert_equal(r, expected)