Skip to content

Commit 81a1f43

Browse files
committed
fix: can set slices for PythonOPTemplate after initialization
Signed-off-by: zjgemi <[email protected]>
1 parent 9441971 commit 81a1f43

1 file changed

Lines changed: 38 additions & 30 deletions

File tree

src/dflow/python/python_op_template.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,6 @@ def __init__(self,
145145
class_name = op_class.__name__
146146
input_sign = op_class.get_input_sign()
147147
output_sign = op_class.get_output_sign()
148-
if slices is not None:
149-
assert isinstance(slices, Slices)
150-
if slices.input_artifact and not slices.sub_path:
151-
input_artifact_slices = {
152-
name: slices.slices for name in slices.input_artifact}
153-
if slices.input_parameter:
154-
input_parameter_slices = {
155-
name: slices.slices for name in slices.input_parameter}
156-
if slices.output_artifact:
157-
output_artifact_slices = {}
158-
for name in slices.output_artifact:
159-
output_artifact_slices[name] = slices.slices
160-
output_sign[name].archive = None # not archive for default
161-
if slices.output_parameter:
162-
output_parameter_slices = {
163-
name: slices.slices for name in slices.output_parameter}
164148
if output_artifact_save is not None:
165149
for name, save in output_artifact_save.items():
166150
output_sign[name].save = save
@@ -175,7 +159,6 @@ def __init__(self,
175159
outputs=Outputs(), volumes=volumes, mounts=mounts,
176160
requests=requests, limits=limits, envs=envs,
177161
init_containers=init_containers)
178-
self.slices = slices
179162
if timeout is not None:
180163
self.timeout = "%ss" % timeout
181164
if retry_on_transient_error is not None:
@@ -188,18 +171,9 @@ def __init__(self,
188171
self.dflow_vars = {}
189172
for name, sign in input_sign.items():
190173
if isinstance(sign, Artifact):
191-
if self.slices is not None and self.slices.sub_path and name \
192-
in self.slices.input_artifact:
193-
self.inputs.parameters["dflow_%s_sub_path" %
194-
name] = InputParameter(value=".")
195-
self.inputs.artifacts[name] = InputArtifact(
196-
path="/tmp/inputs/artifacts/%s/{{inputs.parameters."
197-
"dflow_%s_sub_path}}" % (name, name),
198-
optional=sign.optional, type=sign.type)
199-
else:
200-
self.inputs.artifacts[name] = InputArtifact(
201-
path="/tmp/inputs/artifacts/" + name,
202-
optional=sign.optional, type=sign.type)
174+
self.inputs.artifacts[name] = InputArtifact(
175+
path="/tmp/inputs/artifacts/" + name,
176+
optional=sign.optional, type=sign.type)
203177
elif isinstance(sign, BigParameter):
204178
self.inputs.parameters[name] = InputParameter(
205179
save_as_artifact=True, path="/tmp/inputs/parameters/"
@@ -287,7 +261,41 @@ def __init__(self,
287261
self.input_parameter_slices = input_parameter_slices
288262
self.output_artifact_slices = output_artifact_slices
289263
self.output_parameter_slices = output_parameter_slices
290-
self.render_script()
264+
self.slices = slices
265+
266+
def __setattr__(self, key, value):
267+
super().__setattr__(key, value)
268+
if key == "slices":
269+
self.init_slices(value)
270+
self.render_script()
271+
272+
def init_slices(self, slices):
273+
if slices is not None:
274+
assert isinstance(slices, Slices)
275+
if slices.input_artifact and not slices.sub_path:
276+
self.input_artifact_slices = {
277+
name: slices.slices for name in slices.input_artifact}
278+
if slices.input_parameter:
279+
self.input_parameter_slices = {
280+
name: slices.slices for name in slices.input_parameter}
281+
if slices.output_artifact:
282+
self.output_artifact_slices = {}
283+
for name in slices.output_artifact:
284+
self.output_artifact_slices[name] = slices.slices
285+
self.outputs.artifacts[name].archive = None # no archive
286+
if slices.output_parameter:
287+
self.output_parameter_slices = {
288+
name: slices.slices for name in slices.output_parameter}
289+
290+
if slices.sub_path:
291+
for name in slices.input_artifact:
292+
self.inputs.parameters["dflow_%s_sub_path" %
293+
name] = InputParameter(value=".")
294+
sign = self.input_sign[name]
295+
self.inputs.artifacts[name] = InputArtifact(
296+
path="/tmp/inputs/artifacts/%s/{{inputs.parameters."
297+
"dflow_%s_sub_path}}" % (name, name),
298+
optional=sign.optional, type=sign.type)
291299

292300
def render_script(self):
293301
op_class = self.op_class

0 commit comments

Comments
 (0)