Skip to content

Commit f9e1a9c

Browse files
author
Valentin Boussot
committed
ENH: Enable LibTorch build caching in
1 parent 6f1ff0b commit f9e1a9c

File tree

2 files changed

+136
-16
lines changed

2 files changed

+136
-16
lines changed

.github/actions/build_libtorch/action.yml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,28 @@ outputs:
4646
runs:
4747
using: "composite"
4848
steps:
49+
- name: Cache libtorch install
50+
id: cache-libtorch
51+
uses: actions/cache@v4
52+
with:
53+
path: ${{ inputs.install_prefix }}
54+
key: libtorch-install-${{ runner.os }}-${{ inputs.pytorch_ref }}-${{ inputs.build_type }}-cuda${{ inputs.use_cuda }}-mps${{ inputs.use_mps }}
55+
4956
- name: Setup Python
57+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
5058
uses: actions/setup-python@v5
5159
with:
5260
python-version: ${{ inputs.python_version }}
5361

5462
- name: Install Python tooling
63+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
5564
shell: bash
5665
run: |
5766
python -m pip install --upgrade pip
5867
python -m pip install ninja
5968
6069
- name: Clone PyTorch
70+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
6171
shell: bash
6272
run: |
6373
git clone ${{ inputs.pytorch_repo }} "${{ github.workspace }}/pytorch"
@@ -67,18 +77,15 @@ runs:
6777
git submodule update --init --recursive
6878
6979
- name: Install PyTorch requirements (build only)
80+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
7081
shell: bash
7182
working-directory: ${{ github.workspace }}/pytorch
7283
run: |
7384
python -m pip install -r requirements.txt
7485
75-
- name: Cache libtorch install
76-
uses: actions/cache@v4
77-
with:
78-
path: ${{ inputs.install_prefix }}
79-
key: libtorch-install-${{ runner.os }}-${{ inputs.pytorch_ref }}-${{ inputs.build_type }}-cuda${{ inputs.use_cuda }}-mps${{ inputs.use_mps }}
8086
8187
- name: Configure (CMake)
88+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
8289
shell: bash
8390
run: |
8491
cmake -S "${{ github.workspace }}/pytorch" \
@@ -97,6 +104,7 @@ runs:
97104
${{ inputs.extra_cmake_flags }}
98105
99106
- name: Build & Install
107+
if: ${{ steps['cache-libtorch'].outputs['cache-hit'] != 'true' }}
100108
shell: bash
101109
run: |
102110
ninja -C "${{ inputs.build_dir }}" install

Components/Metrics/Impact/ImpactTensorUtils.hxx

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
#include <ATen/autocast_mode.h>
2525

2626
/**
27-
* ******************* ImageToTensor ***********************
27+
* ImageToTensor: Converts ITK image to torch tensor with spatial resampling
28+
*
29+
* @tparam TImage ITK image type with GetLargestPossibleRegion(), GetSpacing(), GetOrigin()
30+
* @tparam TInterpolator Interpolator type supporting Evaluate(point)
31+
* @param transformPoint Optional transformation function for each sampled point
32+
* @returns torch::Tensor with shape (D,H,W) for 3D or (H,W) for 2D
2833
*/
2934
namespace ImpactTensorUtils
3035
{
@@ -74,7 +79,12 @@ ImageToTensor(typename TImage::ConstPointer
7479
} // end ImageToTensor
7580

7681
/**
77-
* ******************* TensorToImage ***********************
82+
* TensorToImage: Maps torch tensor back to ITK vector image
83+
*
84+
* @tparam TImage Reference image type for geometry info
85+
* @tparam TFeatureImage Output vector image type (typically itk::VectorImage)
86+
* @param layers Input tensor shape (C,D,H,W) or (C,H,W)
87+
* @returns ITK image preserving spatial properties with C-channel vectors
7888
*/
7989
template <typename TImage, typename TFeatureImage>
8090
typename TFeatureImage::Pointer
@@ -140,7 +150,15 @@ TensorToImage(typename TImage::ConstPointer image, torch::Tensor layers)
140150
} // end TensorToImage
141151

142152
/**
143-
* ******************* generateCartesianProduct ***********************
153+
* generateCartesianProduct: Computes n-dimensional Cartesian product of index sets
154+
*
155+
* @param startIndex Vector of 1D index arrays to combine
156+
* @param current Working array for current combination
157+
* @param depth Current recursion depth
158+
* @param result Output vector storing all index combinations
159+
*
160+
* Recursively builds all possible combinations by selecting one value from each input array
161+
* For N input arrays of lengths L1,L2,...,LN, generates L1*L2*...*LN total combinations
144162
*/
145163
inline void
146164
generateCartesianProduct(const std::vector<std::vector<int>> & startIndex,
@@ -163,7 +181,15 @@ generateCartesianProduct(const std::vector<std::vector<int>> & startIndex,
163181
} // end generateCartesianProduct
164182

165183
/**
166-
* ******************* getPatch ***********************
184+
* getPatch: Extracts and pads image patch from input tensor
185+
*
186+
* @param slice Starting coordinates for patch extraction
187+
* @param patchSize Desired output patch dimensions
188+
* @param input Source tensor to extract patch from
189+
* @returns Tensor containing extracted and padded patch
190+
*
191+
* Handles both 2D and 3D input tensors
192+
* Zero-pads extracted patch if smaller than patchSize
167193
*/
168194
inline torch::Tensor
169195
getPatch(std::vector<int> slice, std::vector<int64_t> patchSize, torch::Tensor input)
@@ -191,7 +217,13 @@ getPatch(std::vector<int> slice, std::vector<int64_t> patchSize, torch::Tensor i
191217
} // end getPatch
192218

193219
/**
194-
* ******************* pcaFit ***********************
220+
* pcaFit: Computes top principal components for dimensionality reduction
221+
*
222+
* @param input Tensor of shape (C,N) where C=channels, N=flattened spatial dims
223+
* @param new_C Target number of components to keep
224+
* @returns Principal component matrix shape (C,new_C) in descending eigenvalue order
225+
*
226+
* Note: Centers data and uses SVD for numerical stability
195227
*/
196228
inline torch::Tensor
197229
pcaFit(torch::Tensor input, int new_C)
@@ -213,7 +245,17 @@ pcaFit(torch::Tensor input, int new_C)
213245
} // end pcaFit
214246

215247
/**
216-
* ******************* pca_transform ***********************
248+
* pcaTransform: Project data onto principal component basis
249+
*
250+
* @param input Tensor of shape (C,D,H,W) or (C,H,W) to transform
251+
* @param principalComponents Principal component matrix from pcaFit
252+
* @returns Transformed tensor with reduced channels, preserving spatial dims
253+
*
254+
* Implementation:
255+
* 1. Reshapes input to (C,N) where N=prod(spatial_dims)
256+
* 2. Centers data by channel mean
257+
* 3. Projects using principal component matrix
258+
* 4. Reshapes back to original spatial dimensions
217259
*/
218260
inline torch::Tensor
219261
pcaTransform(torch::Tensor input, torch::Tensor principalComponents)
@@ -229,7 +271,30 @@ pcaTransform(torch::Tensor input, torch::Tensor principalComponents)
229271
} // end pcaTransform
230272

231273
/**
232-
* ******************* GetFeaturesMaps ***********************
274+
* GetFeaturesMaps: Extract deep features from image using configured models
275+
*
276+
* @tparam TImage Input image type
277+
* @tparam FeaturesMaps Output feature maps container type
278+
* @tparam InterpolatorType Interpolator for image sampling
279+
* @tparam FeaturesImageType Output feature image type
280+
* @param image Input image to extract features from
281+
* @param interpolator Interpolator instance for sampling
282+
* @param modelsConfiguration List of deep model configurations
283+
* @param device Computation device (CPU/GPU)
284+
* @param pca Dimensions for PCA reduction per layer
285+
* @param principalComponents PCA matrices for dimensionality reduction
286+
* @param writeInputImage Optional callback to save input patches
287+
* @param transformPoint Optional point transformation
288+
* @returns Vector of feature maps, one per selected model layer
289+
*
290+
* Processing workflow:
291+
* 1. Converts input to tensor with proper spacing
292+
* 2. Extracts patches according to model config
293+
* 3. Processes patches through models
294+
* 4. Optionally reduces dimensionality via PCA
295+
* 5. Converts results back to ITK images
296+
*
297+
* Handles both 2D and 3D inputs with proper dimension management
233298
*/
234299
template <typename TImage, typename FeaturesMaps, typename InterpolatorType, typename FeaturesImageType>
235300
std::vector<FeaturesMaps>
@@ -462,14 +527,33 @@ GetFeaturesMaps(
462527

463528

464529
/**
465-
* **************** GetModelOutputsExample ****************
530+
* GetModelOutputsExample: Validate model configurations with dummy inputs
531+
*
532+
* @param modelsConfig Vector of model configurations to validate
533+
* @param modelType String identifier for error reporting
534+
* @param device Computation device (CPU/GPU)
535+
* @returns Vector of example output tensors from each model
536+
*
537+
* Validation steps:
538+
* 1. Creates zero-filled dummy patches matching config specs
539+
* 2. Runs patches through models to verify layer structure
540+
* 3. Verifies layer mask compatibility
541+
* 4. Computes center indices for feature extraction
542+
*
543+
* Error handling:
544+
* - Validates dimension/channel compatibility
545+
* - Checks layer mask alignment with outputs
546+
* - Reports detailed configuration issues
547+
*
548+
* Note: Uses no_grad mode for efficiency
466549
*/
467550
inline std::vector<torch::Tensor>
468551
GetModelOutputsExample(std::vector<itk::ImpactModelConfiguration> & modelsConfig,
469552
const std::string & modelType,
470553
torch::Device device)
471554
{
472555

556+
// For each model, create dummy patch and get output layers to check structure
473557
std::vector<torch::Tensor> outputsTensor;
474558
{
475559
torch::NoGradGuard noGrad;
@@ -544,7 +628,14 @@ GetModelOutputsExample(std::vector<itk::ImpactModelConfiguration> & modelsConfig
544628
} // end GetModelOutputsExample
545629

546630
/**
547-
* ******************* GetPatchIndex ***********************
631+
* GetPatchIndex: Generates sampling grid for patch extraction
632+
*
633+
* @param modelConfiguration Model-specific patch configuration
634+
* @param randomGenerator RNG for stochastic patch orientation
635+
* @param dimension Target space dimension
636+
* @returns List of sampling coordinates per patch point
637+
*
638+
* For 2D patches: Applies random rotation to sampling
548639
*/
549640
inline std::vector<std::vector<float>>
550641
GetPatchIndex(const itk::ImpactModelConfiguration & modelConfiguration,
@@ -612,7 +703,15 @@ GetPatchIndex(const itk::ImpactModelConfiguration & modelConfiguration,
612703
} // end GetPatchIndex
613704

614705
/**
615-
* ******************* GenerateOutputs ***********************
706+
* GenerateOutputs: Batch processing of image patches through deep models
707+
*
708+
* @param modelConfig Configuration per model including architecture and layers
709+
* @param fixedPoints List of control points to extract patches around
710+
* @param patchIndex Pre-computed sampling indices for each patch
711+
* @param subsetsOfFeatures Selected feature channels per layer
712+
* @returns List of output tensors, one per selected layer per model
713+
*
714+
* Performance note: Processes patches in batches to maximize GPU utilization
616715
*/
617716
template <typename ImagePointType>
618717
std::vector<torch::Tensor>
@@ -675,7 +774,20 @@ GenerateOutputs(const std::vector<itk::ImpactModelConfiguration> &
675774
} // end GenerateOutputs
676775

677776
/**
678-
* ******************* GenerateOutputsAndJacobian ***********************
777+
* GenerateOutputsAndJacobian: Computes model outputs and their Jacobians
778+
*
779+
* @param modelConfig List of model configurations with architectures
780+
* @param fixedPoints Control points for patch extraction
781+
* @param patchIndex Sampling grid coordinates per patch
782+
* @param subsetsOfFeatures Feature channel selections
783+
* @param fixedOutputsTensor Reference outputs for loss calculation
784+
* @param device Computation device (CPU/GPU)
785+
* @param losses Loss function objects per output layer
786+
* @param imagesPatchValuesAndJacobiansEvaluator Callback for patch and gradient evaluation
787+
* @returns List of Jacobian tensors for each model output
788+
*
789+
* Performance note: Batches computation and uses autograd for efficiency
790+
* Handles multiple models and multiple output layers per model
679791
*/
680792
template <typename ImagePointType>
681793
std::vector<torch::Tensor>

0 commit comments

Comments
 (0)