Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions xtas/tasks/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import celery
import warnings

from xtas.tasks.es import is_es_document, es_address
from xtas.tasks.es import get_single_result, store_single, fetch
Expand All @@ -12,53 +13,79 @@
def pipeline(doc, pipeline, store_final=True, store_intermediate=False,
block=True):
"""
Get the result for a given document.
Pipeline should be a list of dicts, with members task and argument
e.g. [{"module" : "tokenize"},
{"module" : "pos_tag", "arguments" : {"model" : "nltk"}}]
@param block: if True, it will block and return the actual result.
If False, it will return an AsyncResult unless the result was
cached, in which case it returns the result immediately (!)
@param store_final: if True, store the final result
@param store_intermediate: if True, store all intermediate results as well
Apply a sequence of operations to a document and return the result.

Parameters
----------
doc : es_document or string
The document to process, either as a string, or a result of es_document().

pipeline : list of dicts
A list of dicts, each with members "task" and "arguments", e.g.
[{"task" : "tokenize"},
{"task" : "pos_tag", "arguments" : {"model" : "nltk"}}]
Using "module" instead of "task" is now obsolete, but still supported.

store_final : bool
If True, store the final result in ElasticSearch.

store_intermediate : bool
If True, store all intermediate results as well.

block : bool
If True, block, and return the result when it arrives.
If False, return the result directly if it's available immediately,
otherwise return an AsyncResult.
"""
# form basic pipeline by resolving task dictionaries to task objects
tasks = [_get_task(t) for t in pipeline]

tasks = [_get_task(t) for t in pipeline]
if is_es_document(doc):
idx, typ, id, field = es_address(doc)
chain = []
input = None
# Check cache for existing documents
# Iterate over tasks in reverse order, check cached result, and
# otherwise prepend task (and cache store command) to chain
for i in range(len(tasks), 0, -1):
taskname = "__".join(t.task for t in tasks[:i])
input = get_single_result(taskname, idx, typ, id)
if input:

def result_name(task_i):
"Results are named after the task that created them"
return "__".join(t.task for t in tasks[:(task_i + 1)])

last_cached_result = None
# we always have doc, which is result -1, the input to task 0
last_cached_i = -1
for task_i in reversed(range(0, len(tasks))):
last_cached_result = get_single_result(result_name(task_i),
idx, typ, id)
if last_cached_result:
last_cached_i = task_i
break
if (i == len(tasks) and store_final) or store_intermediate:
chain.insert(0, store_single.s(taskname, idx, typ, id))
chain.insert(0, tasks[i-1])
if not chain: # final result was cached, good!
return input
elif input is None:
input = fetch(doc)
else:
# the doc is a string, so we can't use caching
chain = tasks
input = doc

chain = celery.chain(*chain).delay(input)
if last_cached_i == -1:
doc = fetch(doc)
else:
doc = last_cached_result

new_tasks = []
for task_i in range(last_cached_i + 1, len(tasks)):
new_tasks.append(tasks[task_i])
if (task_i == len(tasks) and store_final) or store_intermediate:
new_tasks.append(store_single.s(result_name(task_i),
idx, typ, id))
tasks = new_tasks

if not tasks:
return doc

chain = celery.chain(*tasks).delay(doc)
if block:
return chain.get()
else:
return chain


def _get_task(task_dict):
"Create a celery task object from a dictionary with module and arguments"
task = task_dict['module']
"Create a celery task object from a dictionary with task name and arguments"
if 'task' not in task_dict and 'module' in task_dict:
warnings.warn('The "module" key is deprecated,'
' please use "task" instead.', DeprecationWarning, stacklevel=2)
task_dict['task'] = task_dict['module']
task = task_dict['task']
if isinstance(task, (str, unicode)):
task = app.tasks[task]
args = task_dict.get('arguments')
Expand Down
8 changes: 6 additions & 2 deletions xtas/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def test_pipeline():
assert_equal(result, expected)
with eager_celery():
# do we get correct result from pipeline?
r = pipeline(s, [{"task": tokenize},
{"task": pos_tag, "arguments": {"model": "nltk"}}])
assert_equal(r, expected)
# is the old syntax still supported?
r = pipeline(s, [{"module": tokenize},
{"module": pos_tag, "arguments": {"model": "nltk"}}])
assert_equal(r, expected)
# can we specify modules by name?
r = pipeline(s, [{"module": "xtas.tasks.single.tokenize"},
{"module": "xtas.tasks.single.pos_tag",
r = pipeline(s, [{"task": "xtas.tasks.single.tokenize"},
{"task": "xtas.tasks.single.pos_tag",
"arguments": {"model": "nltk"}}])
assert_equal(r, expected)

Expand Down