@@ -219,11 +219,67 @@ those inputs where it is safe and beneficial to do so.
219219
220220.. testcode ::
221221
222- def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
223- """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs `."""
224- # Implementation would create a new Op with appropriate destroy_map
225- # Return self by default if no inplace version is available
226- return self
222+ import numpy as np
223+ import pytensor
224+ import pytensor.tensor as pt
225+ from pytensor.graph.basic import Apply
226+ from pytensor.graph.op import Op
227+ from pytensor.tensor.blockwise import Blockwise
228+
229+ class MyOpWithInplace(Op):
230+ __props__ = ("destroy_a",)
231+
232+ def __init__(self, destroy_a):
233+ self.destroy_a = destroy_a
234+ if destroy_a:
235+ self.destroy_map = {0: [0]}
236+
237+ def make_node(self, a):
238+ return Apply(self, [a], [a.type()])
239+
240+ def perform(self, node, inputs, output_storage):
241+ [a] = inputs
242+ if not self.destroy_a:
243+ a = a.copy()
244+ a[0] += 1
245+ output_storage[0][0] = a
246+
247+ def inplace_on_inputs(self, allowed_inplace_inputs):
248+ if 0 in allowed_inplace_inputs:
249+ return MyOpWithInplace(destroy_a=True)
250+ return self
251+
252+ a = pt.vector("a")
253+ # Only Blockwise trigger inplace automatically for now
254+ # Since the Blockwise isn't needed in this case, it will be removed after the inplace optimization
255+ op = Blockwise(MyOpWithInplace(destroy_a=False), signature="(a)->(a)")
256+ out = op(a)
257+
258+ # Give PyTensor permission to inplace on user provided inputs
259+ fn = pytensor.function([pytensor.In(a, mutable=True)], out)
260+
261+ # Confirm that we have the inplace version of the Op
262+ fn.dprint(print_destroy_map=True)
263+
264+ .. testoutput ::
265+
266+ Blockwise{MyOpWithInplace{destroy_a=True}, (a)->(a)} [id A] '' 5
267+ └─ a [id B]
268+
269+ The output shows that the function now uses the inplace version (`destroy_a=True `).
270+
271+ .. testcode ::
272+
273+ # Test that inplace modification works
274+ test_a = np.zeros(5)
275+ result = fn(test_a)
276+ print("Function result:", result)
277+ print("Original array after function call:", test_a)
278+
279+ .. testoutput ::
280+
281+ Function result: [1. 0. 0. 0. 0.]
282+ Original array after function call: [1. 0. 0. 0. 0.]
227283
228284Currently, this method is primarily used with Blockwise operations through PyTensor's
229285rewriting system, but it will be extended to support core ops directly in future versions.
0 commit comments