Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion api/transformerlab/routers/compute_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,17 @@ async def launch_template_on_provider(
post_hook=str(post_task_hook) if post_task_hook is not None else None,
)

# Apply provider-level setup hooks (pre/post) around the resolved setup script (if any).
pre_setup_hook = extra_config_for_hooks.get("pre_setup_hook")
post_setup_hook = extra_config_for_hooks.get("post_setup_hook")
setup_with_hooks = final_setup
if setup_with_hooks and str(setup_with_hooks).strip():
setup_with_hooks = build_hooked_command(
str(setup_with_hooks),
pre_hook=str(pre_setup_hook) if pre_setup_hook is not None else None,
post_hook=str(post_setup_hook) if post_setup_hook is not None else None,
)

# Wrap the user command with tfl-remote-trap so we can track live_status in job_data.
# This uses the tfl-remote-trap helper from the transformerlab SDK, which:
# - sets job_data.live_status="started" when execution begins
Expand All @@ -2031,7 +2042,7 @@ async def launch_template_on_provider(
provider_name=provider_display_name,
provider_id=provider.id,
run=wrapped_run,
setup=final_setup,
setup=setup_with_hooks,
env_vars=env_vars,
cpus=request.cpus,
memory=request.memory,
Expand Down
145 changes: 116 additions & 29 deletions src/renderer/components/Team/ProviderDetailsModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { useAPI, useAuth } from 'renderer/lib/authContext';
import { getPath } from 'renderer/lib/api-client/urls';
import { Endpoints } from 'renderer/lib/api-client/endpoints';
import { useNotification } from 'renderer/components/Shared/NotificationSystem';
import { ChevronDownIcon, ChevronRightIcon } from 'lucide-react';

interface ProviderDetailsModalProps {
open: boolean;
Expand Down Expand Up @@ -86,6 +87,9 @@ export default function ProviderDetailsModal({
);
const [preTaskHook, setPreTaskHook] = useState<string>('');
const [postTaskHook, setPostTaskHook] = useState<string>('');
const [preSetupHook, setPreSetupHook] = useState<string>('');
const [postSetupHook, setPostSetupHook] = useState<string>('');
const [hooksExpanded, setHooksExpanded] = useState(false);

// SLURM-specific form fields
const [slurmMode, setSlurmMode] = useState<'ssh' | 'rest'>('ssh');
Expand Down Expand Up @@ -183,9 +187,13 @@ export default function ProviderDetailsModal({
if (extraConfig && typeof extraConfig === 'object') {
setPreTaskHook(extraConfig.pre_task_hook || '');
setPostTaskHook(extraConfig.post_task_hook || '');
setPreSetupHook(extraConfig.pre_setup_hook || '');
setPostSetupHook(extraConfig.post_setup_hook || '');
} else {
setPreTaskHook('');
setPostTaskHook('');
setPreSetupHook('');
setPostSetupHook('');
}

// Extract supported_accelerators into dedicated state, but do not show it in raw JSON.
Expand All @@ -207,6 +215,8 @@ export default function ProviderDetailsModal({
setSupportedAccelerators([]);
setPreTaskHook('');
setPostTaskHook('');
setPreSetupHook('');
setPostSetupHook('');
setSlurmMode('ssh');
setSlurmSshHost('');
setSlurmSshUser('');
Expand All @@ -226,6 +236,8 @@ export default function ProviderDetailsModal({
setSupportedAccelerators([]);
setPreTaskHook('');
setPostTaskHook('');
setPreSetupHook('');
setPostSetupHook('');
setSlurmMode('ssh');
setSlurmSshHost('');
setSlurmSshUser('');
Expand Down Expand Up @@ -441,6 +453,16 @@ export default function ProviderDetailsModal({
} else {
delete parsedConfig.extra_config.post_task_hook;
}
if (preSetupHook.trim()) {
parsedConfig.extra_config.pre_setup_hook = preSetupHook;
} else {
delete parsedConfig.extra_config.pre_setup_hook;
}
if (postSetupHook.trim()) {
parsedConfig.extra_config.post_setup_hook = postSetupHook;
} else {
delete parsedConfig.extra_config.post_setup_hook;
}
if (Object.keys(parsedConfig.extra_config).length === 0) {
delete parsedConfig.extra_config;
}
Expand Down Expand Up @@ -709,36 +731,101 @@ export default function ProviderDetailsModal({
</>
)}

<FormControl sx={{ mt: 2 }}>
<FormLabel>Harness hooks (bash)</FormLabel>
<Typography
level="body-sm"
sx={{ color: 'text.tertiary', mb: 1 }}
<Box
sx={{
mt: 2,
border: '1px solid',
borderColor: 'neutral.outlinedBorder',
borderRadius: 'sm',
}}
>
<Button
variant="plain"
color="neutral"
onClick={() => setHooksExpanded((v) => !v)}
sx={{
width: '100%',
justifyContent: 'space-between',
borderRadius: 'sm',
p: 1,
}}
>
These are concatenated around the task command as: pre ; task
; post
</Typography>
<FormLabel sx={{ mt: 1 }}>Pre hook (optional)</FormLabel>
<Textarea
value={preTaskHook}
onChange={(event) =>
setPreTaskHook(event.currentTarget.value)
}
placeholder="Commands to run before the run script in every task"
minRows={2}
maxRows={6}
/>
<FormLabel sx={{ mt: 1 }}>Post hook (optional)</FormLabel>
<Textarea
value={postTaskHook}
onChange={(event) =>
setPostTaskHook(event.currentTarget.value)
}
placeholder="Commands to run after the run script in every task"
minRows={2}
maxRows={6}
/>
</FormControl>
<Box
sx={{
display: 'flex',
flexDirection: 'column',
alignItems: 'flex-start',
}}
>
<Typography level="title-sm">
Advanced: Harness hooks (optional)
</Typography>
<Typography level="body-xs" sx={{ color: 'text.tertiary' }}>
Add hooks to run before and after the provider setup and
run scripts.
</Typography>
</Box>
{hooksExpanded ? (
<ChevronDownIcon size={18} />
) : (
<ChevronRightIcon size={18} />
)}
</Button>

{hooksExpanded && (
<FormControl sx={{ px: 1, pb: 1 }}>
<FormLabel sx={{ mt: 1 }}>
Setup pre hook (optional)
</FormLabel>
<Textarea
value={preSetupHook}
onChange={(event) =>
setPreSetupHook(event.currentTarget.value)
}
placeholder="Commands to run before the provider setup script"
minRows={2}
maxRows={6}
/>
<FormLabel sx={{ mt: 1 }}>
Setup post hook (optional)
</FormLabel>
<Textarea
value={postSetupHook}
onChange={(event) =>
setPostSetupHook(event.currentTarget.value)
}
placeholder="Commands to run after the provider setup script"
minRows={2}
maxRows={6}
/>

<FormLabel sx={{ mt: 2 }}>
Run pre hook (optional)
</FormLabel>
<Textarea
value={preTaskHook}
onChange={(event) =>
setPreTaskHook(event.currentTarget.value)
}
placeholder="Commands to run before the run script in every task"
minRows={2}
maxRows={6}
/>
<FormLabel sx={{ mt: 1 }}>
Run post hook (optional)
</FormLabel>
<Textarea
value={postTaskHook}
onChange={(event) =>
setPostTaskHook(event.currentTarget.value)
}
placeholder="Commands to run after the run script in every task"
minRows={2}
maxRows={6}
/>
</FormControl>
)}
</Box>

{/* Generic JSON config for non-SLURM providers or advanced editing */}
{type !== 'slurm' && type !== 'local' && (
Expand Down
Loading