Skip to content

Commit 2387be5

Browse files
hansentclaude
andauthored
feat(sam3): enable SDK-based remote execution for SAM3 workflow blocks (#2042)
* feat(sam3): enable SDK-based remote execution for SAM3 workflow blocks SAM3 workflow blocks previously always used the API inference proxy for remote execution, regardless of configuration. This change decouples the workflow step execution mode from SAM3_EXEC_MODE, enabling SAM3 blocks to use the standard SDK-based remote execution pattern used by all other model blocks when WORKFLOWS_STEP_EXECUTION_MODE=remote. Also adds SAM3 SDK client methods (concept_segment, visual_segment, embed_image) and a new SAM3_FINE_TUNED_MODELS_ENABLED env var to control fine-tuned model access independently of execution mode. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * make style --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9105c2c commit 2387be5

File tree

6 files changed

+500
-47
lines changed

6 files changed

+500
-47
lines changed

inference/core/env.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@
173173
raise ValueError(
174174
f"Invalid SAM3 execution mode in ENVIRONMENT var SAM3_EXEC_MODE (local or remote): {SAM3_EXEC_MODE}"
175175
)
176+
# Whether fine-tuned SAM3 models (non-sam3/ prefix) are allowed.
177+
# Defaults to False when SAM3_EXEC_MODE=remote (backward compat with existing proxy deployments),
178+
# True otherwise (self-hosted users can use fine-tuned models).
179+
_sam3_fine_tuned_default = "False" if SAM3_EXEC_MODE == "remote" else "True"
180+
SAM3_FINE_TUNED_MODELS_ENABLED = str2bool(
181+
os.getenv("SAM3_FINE_TUNED_MODELS_ENABLED", _sam3_fine_tuned_default)
182+
)
176183

177184
# Flag to enable GAZE core model, default is True
178185
CORE_MODEL_GAZE_ENABLED = str2bool(os.getenv("CORE_MODEL_GAZE_ENABLED", True))

inference/core/interfaces/http/http_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
ROBOFLOW_INTERNAL_SERVICE_SECRET,
176176
ROBOFLOW_SERVICE_SECRET,
177177
SAM3_EXEC_MODE,
178+
SAM3_FINE_TUNED_MODELS_ENABLED,
178179
USE_INFERENCE_MODELS,
179180
WEBRTC_WORKER_ENABLED,
180181
WORKFLOWS_MAX_CONCURRENT_STEPS,
@@ -2586,12 +2587,14 @@ def sam3_segment_image(
25862587
countinference: Optional[bool] = None,
25872588
service_secret: Optional[str] = None,
25882589
):
2589-
if SAM3_EXEC_MODE == "remote":
2590+
if not SAM3_FINE_TUNED_MODELS_ENABLED:
25902591
if not inference_request.model_id.startswith("sam3/"):
25912592
raise HTTPException(
25922593
status_code=501,
2593-
detail="Fine-tuned SAM3 models are not supported in remote execution mode yet. Please use a workflow or self-host the server.",
2594+
detail="Fine-tuned SAM3 models are not supported on this deployment. Please use a workflow or self-host the server.",
25942595
)
2596+
2597+
if SAM3_EXEC_MODE == "remote":
25952598
endpoint = f"{API_BASE_URL}/inferenceproxy/seg-preview"
25962599

25972600
# Construct payload for remote API

inference/core/workflows/core_steps/models/foundation/segment_anything3/v1.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from inference.core.entities.responses.sam3 import Sam3SegmentationPrediction
1818
from inference.core.env import (
1919
API_BASE_URL,
20+
HOSTED_CORE_MODEL_URL,
21+
LOCAL_INFERENCE_API_URL,
2022
ROBOFLOW_INTERNAL_SERVICE_NAME,
2123
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2224
SAM3_EXEC_MODE,
25+
WORKFLOWS_REMOTE_API_TARGET,
2326
)
2427
from inference.core.managers.base import ModelManager
2528
from inference.core.roboflow_api import build_roboflow_api_headers
@@ -54,6 +57,7 @@
5457
WorkflowBlock,
5558
WorkflowBlockManifest,
5659
)
60+
from inference_sdk import InferenceHTTPClient
5761

5862
DETECTIONS_CLASS_NAME_FIELD = "class_name"
5963
DETECTION_ID_FIELD = "detection_id"
@@ -169,30 +173,26 @@ def run(
169173
else:
170174
raise ValueError(f"Invalid class names type: {type(class_names)}")
171175

172-
exec_mode = self._step_execution_mode
173-
if SAM3_EXEC_MODE == "local":
174-
exec_mode = self._step_execution_mode
175-
elif SAM3_EXEC_MODE == "remote":
176-
exec_mode = (
177-
StepExecutionMode.REMOTE
178-
) # if SAM3_EXEC_MODE == "remote" then force remote execution mode only
179-
else:
180-
raise ValueError(
181-
f"Invalid SAM3 execution mode in ENVIRONMENT var SAM3_EXEC_MODE (local or remote): {SAM3_EXEC_MODE}"
176+
if SAM3_EXEC_MODE == "remote":
177+
logger.debug("Running SAM3 v1 via inference proxy (SAM3_EXEC_MODE=remote)")
178+
return self.run_via_request(
179+
images=images,
180+
class_names=class_names,
181+
threshold=threshold,
182182
)
183-
184-
if exec_mode is StepExecutionMode.LOCAL:
185-
logger.debug(f"Running SAM3 locally")
183+
elif self._step_execution_mode is StepExecutionMode.LOCAL:
184+
logger.debug("Running SAM3 v1 locally")
186185
return self.run_locally(
187186
images=images,
188187
model_id=model_id,
189188
class_names=class_names,
190189
threshold=threshold,
191190
)
192-
elif exec_mode is StepExecutionMode.REMOTE:
193-
logger.debug(f"Running SAM3 remotely")
194-
return self.run_via_request(
191+
elif self._step_execution_mode is StepExecutionMode.REMOTE:
192+
logger.debug("Running SAM3 v1 remotely via SDK")
193+
return self.run_remotely(
195194
images=images,
195+
model_id=model_id,
196196
class_names=class_names,
197197
threshold=threshold,
198198
)
@@ -276,6 +276,81 @@ def run_locally(
276276
predictions=predictions,
277277
)
278278

279+
def run_remotely(
280+
self,
281+
images: Batch[WorkflowImageData],
282+
model_id: str,
283+
class_names: Optional[List[str]],
284+
threshold: float,
285+
) -> BlockResult:
286+
predictions = []
287+
if class_names is None:
288+
class_names = []
289+
if len(class_names) == 0:
290+
class_names.append(None)
291+
292+
api_url = (
293+
LOCAL_INFERENCE_API_URL
294+
if WORKFLOWS_REMOTE_API_TARGET != "hosted"
295+
else HOSTED_CORE_MODEL_URL
296+
)
297+
client = InferenceHTTPClient(
298+
api_url=api_url,
299+
api_key=self._api_key,
300+
)
301+
if WORKFLOWS_REMOTE_API_TARGET == "hosted":
302+
client.select_api_v0()
303+
304+
for single_image in images:
305+
prompt_class_ids: List[Optional[int]] = []
306+
prompt_class_names: List[Optional[str]] = []
307+
prompt_detection_ids: List[Optional[str]] = []
308+
309+
http_prompts: List[dict] = []
310+
for class_name in class_names:
311+
http_prompts.append({"type": "text", "text": class_name})
312+
313+
resp_json = client.sam3_concept_segment(
314+
inference_input=single_image.base64_image,
315+
prompts=http_prompts,
316+
model_id=model_id,
317+
output_prob_thresh=threshold,
318+
)
319+
320+
class_predictions: List[InstanceSegmentationPrediction] = []
321+
for prompt_result in resp_json.get("prompt_results", []):
322+
idx = prompt_result.get("prompt_index", 0)
323+
class_name = class_names[idx] if idx < len(class_names) else None
324+
raw_predictions = prompt_result.get("predictions", [])
325+
adapted_predictions = [SimpleNamespace(**p) for p in raw_predictions]
326+
class_pred = convert_sam3_segmentation_response_to_inference_instances_seg_response(
327+
sam3_segmentation_predictions=adapted_predictions, # type: ignore[arg-type]
328+
image=single_image,
329+
prompt_class_ids=prompt_class_ids,
330+
prompt_class_names=prompt_class_names,
331+
prompt_detection_ids=prompt_detection_ids,
332+
threshold=threshold,
333+
text_prompt=class_name,
334+
specific_class_id=idx,
335+
)
336+
class_predictions.extend(class_pred.predictions)
337+
338+
image_width = single_image.numpy_image.shape[1]
339+
image_height = single_image.numpy_image.shape[0]
340+
final_inference_prediction = InstanceSegmentationInferenceResponse(
341+
predictions=class_predictions,
342+
image=InferenceResponseImage(width=image_width, height=image_height),
343+
)
344+
predictions.append(final_inference_prediction)
345+
346+
predictions = [
347+
e.model_dump(by_alias=True, exclude_none=True) for e in predictions
348+
]
349+
return self._post_process_result(
350+
images=images,
351+
predictions=predictions,
352+
)
353+
279354
def run_via_request(
280355
self,
281356
images: Batch[WorkflowImageData],

inference/core/workflows/core_steps/models/foundation/segment_anything3/v2.py

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from inference.core.entities.responses.sam3 import Sam3SegmentationPrediction
1818
from inference.core.env import (
1919
API_BASE_URL,
20+
HOSTED_CORE_MODEL_URL,
21+
LOCAL_INFERENCE_API_URL,
2022
ROBOFLOW_INTERNAL_SERVICE_NAME,
2123
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2224
SAM3_EXEC_MODE,
25+
WORKFLOWS_REMOTE_API_TARGET,
2326
)
2427
from inference.core.managers.base import ModelManager
2528
from inference.core.roboflow_api import build_roboflow_api_headers
@@ -54,6 +57,7 @@
5457
WorkflowBlock,
5558
WorkflowBlockManifest,
5659
)
60+
from inference_sdk import InferenceHTTPClient
5761

5862
DETECTIONS_CLASS_NAME_FIELD = "class_name"
5963
DETECTION_ID_FIELD = "detection_id"
@@ -223,20 +227,18 @@ def run(
223227
else:
224228
raise ValueError(f"Invalid class names type: {type(class_names)}")
225229

226-
exec_mode = self._step_execution_mode
227-
if SAM3_EXEC_MODE == "local":
228-
exec_mode = self._step_execution_mode
229-
elif SAM3_EXEC_MODE == "remote":
230-
exec_mode = (
231-
StepExecutionMode.REMOTE
232-
) # if SAM3_EXEC_MODE == "remote" then force remote execution mode only
233-
else:
234-
raise ValueError(
235-
f"Invalid SAM3 execution mode in ENVIRONMENT var SAM3_EXEC_MODE (local or remote): {SAM3_EXEC_MODE}"
230+
if SAM3_EXEC_MODE == "remote":
231+
logger.debug("Running SAM3 v2 via inference proxy (SAM3_EXEC_MODE=remote)")
232+
return self.run_via_request(
233+
images=images,
234+
class_names=class_names,
235+
confidence=confidence,
236+
per_class_confidence=per_class_confidence,
237+
apply_nms=apply_nms,
238+
nms_iou_threshold=nms_iou_threshold,
236239
)
237-
238-
if exec_mode is StepExecutionMode.LOCAL:
239-
logger.debug(f"Running SAM3 locally")
240+
elif self._step_execution_mode is StepExecutionMode.LOCAL:
241+
logger.debug("Running SAM3 v2 locally")
240242
return self.run_locally(
241243
images=images,
242244
model_id=model_id,
@@ -246,10 +248,11 @@ def run(
246248
apply_nms=apply_nms,
247249
nms_iou_threshold=nms_iou_threshold,
248250
)
249-
elif exec_mode is StepExecutionMode.REMOTE:
250-
logger.debug(f"Running SAM3 remotely")
251-
return self.run_via_request(
251+
elif self._step_execution_mode is StepExecutionMode.REMOTE:
252+
logger.debug("Running SAM3 v2 remotely via SDK")
253+
return self.run_remotely(
252254
images=images,
255+
model_id=model_id,
253256
class_names=class_names,
254257
confidence=confidence,
255258
per_class_confidence=per_class_confidence,
@@ -348,6 +351,88 @@ def run_locally(
348351
predictions=predictions,
349352
)
350353

354+
def run_remotely(
355+
self,
356+
images: Batch[WorkflowImageData],
357+
model_id: str,
358+
class_names: Optional[List[str]],
359+
confidence: float,
360+
per_class_confidence: Optional[List[float]] = None,
361+
apply_nms: bool = True,
362+
nms_iou_threshold: float = 0.9,
363+
) -> BlockResult:
364+
predictions = []
365+
if class_names is None:
366+
class_names = []
367+
if len(class_names) == 0:
368+
class_names.append(None)
369+
370+
api_url = (
371+
LOCAL_INFERENCE_API_URL
372+
if WORKFLOWS_REMOTE_API_TARGET != "hosted"
373+
else HOSTED_CORE_MODEL_URL
374+
)
375+
client = InferenceHTTPClient(
376+
api_url=api_url,
377+
api_key=self._api_key,
378+
)
379+
if WORKFLOWS_REMOTE_API_TARGET == "hosted":
380+
client.select_api_v0()
381+
382+
for single_image in images:
383+
prompt_class_ids: List[Optional[int]] = []
384+
prompt_class_names: List[Optional[str]] = []
385+
prompt_detection_ids: List[Optional[str]] = []
386+
387+
http_prompts: List[dict] = []
388+
for idx, class_name in enumerate(class_names):
389+
prompt_data = {"type": "text", "text": class_name}
390+
if per_class_confidence is not None:
391+
prompt_data["output_prob_thresh"] = per_class_confidence[idx]
392+
http_prompts.append(prompt_data)
393+
394+
resp_json = client.sam3_concept_segment(
395+
inference_input=single_image.base64_image,
396+
prompts=http_prompts,
397+
model_id=model_id,
398+
output_prob_thresh=confidence,
399+
nms_iou_threshold=nms_iou_threshold if apply_nms else None,
400+
)
401+
402+
class_predictions: List[InstanceSegmentationPrediction] = []
403+
for prompt_result in resp_json.get("prompt_results", []):
404+
idx = prompt_result.get("prompt_index", 0)
405+
class_name = class_names[idx] if idx < len(class_names) else None
406+
raw_predictions = prompt_result.get("predictions", [])
407+
adapted_predictions = [SimpleNamespace(**p) for p in raw_predictions]
408+
class_pred = convert_sam3_segmentation_response_to_inference_instances_seg_response(
409+
sam3_segmentation_predictions=adapted_predictions, # type: ignore[arg-type]
410+
image=single_image,
411+
prompt_class_ids=prompt_class_ids,
412+
prompt_class_names=prompt_class_names,
413+
prompt_detection_ids=prompt_detection_ids,
414+
confidence=confidence,
415+
text_prompt=class_name,
416+
specific_class_id=idx,
417+
)
418+
class_predictions.extend(class_pred.predictions)
419+
420+
image_width = single_image.numpy_image.shape[1]
421+
image_height = single_image.numpy_image.shape[0]
422+
final_inference_prediction = InstanceSegmentationInferenceResponse(
423+
predictions=class_predictions,
424+
image=InferenceResponseImage(width=image_width, height=image_height),
425+
)
426+
predictions.append(final_inference_prediction)
427+
428+
predictions = [
429+
e.model_dump(by_alias=True, exclude_none=True) for e in predictions
430+
]
431+
return self._post_process_result(
432+
images=images,
433+
predictions=predictions,
434+
)
435+
351436
def run_via_request(
352437
self,
353438
images: Batch[WorkflowImageData],

0 commit comments

Comments
 (0)