2424#include < ATen/autocast_mode.h>
2525
2626/* *
27- * ******************* ImageToTensor ***********************
27+ * ImageToTensor: Converts ITK image to torch tensor with spatial resampling
28+ *
29+ * @tparam TImage ITK image type with GetLargestPossibleRegion(), GetSpacing(), GetOrigin()
30+ * @tparam TInterpolator Interpolator type supporting Evaluate(point)
31+ * @param transformPoint Optional transformation function for each sampled point
32+ * @returns torch::Tensor with shape (D,H,W) for 3D or (H,W) for 2D
2833 */
2934namespace ImpactTensorUtils
3035{
@@ -74,7 +79,12 @@ ImageToTensor(typename TImage::ConstPointer
7479} // end ImageToTensor
7580
7681/* *
77- * ******************* TensorToImage ***********************
82+ * TensorToImage: Maps torch tensor back to ITK vector image
83+ *
84+ * @tparam TImage Reference image type for geometry info
85+ * @tparam TFeatureImage Output vector image type (typically itk::VectorImage)
86+ * @param layers Input tensor shape (C,D,H,W) or (C,H,W)
87+ * @returns ITK image preserving spatial properties with C-channel vectors
7888 */
7989template <typename TImage, typename TFeatureImage>
8090typename TFeatureImage::Pointer
@@ -140,7 +150,15 @@ TensorToImage(typename TImage::ConstPointer image, torch::Tensor layers)
140150} // end TensorToImage
141151
142152/* *
143- * ******************* generateCartesianProduct ***********************
153+ * generateCartesianProduct: Computes n-dimensional Cartesian product of index sets
154+ *
155+ * @param startIndex Vector of 1D index arrays to combine
156+ * @param current Working array for current combination
157+ * @param depth Current recursion depth
158+ * @param result Output vector storing all index combinations
159+ *
160+ * Recursively builds all possible combinations by selecting one value from each input array
161+ * For N input arrays of lengths L1,L2,...,LN, generates L1*L2*...*LN total combinations
144162 */
145163inline void
146164generateCartesianProduct (const std::vector<std::vector<int >> & startIndex,
@@ -163,7 +181,15 @@ generateCartesianProduct(const std::vector<std::vector<int>> & startIndex,
163181} // end generateCartesianProduct
164182
165183/* *
166- * ******************* getPatch ***********************
184+ * getPatch: Extracts and pads image patch from input tensor
185+ *
186+ * @param slice Starting coordinates for patch extraction
187+ * @param patchSize Desired output patch dimensions
188+ * @param input Source tensor to extract patch from
189+ * @returns Tensor containing extracted and padded patch
190+ *
191+ * Handles both 2D and 3D input tensors
192+ * Zero-pads extracted patch if smaller than patchSize
167193 */
168194inline torch::Tensor
169195getPatch (std::vector<int > slice, std::vector<int64_t > patchSize, torch::Tensor input)
@@ -191,7 +217,13 @@ getPatch(std::vector<int> slice, std::vector<int64_t> patchSize, torch::Tensor i
191217} // end getPatch
192218
193219/* *
194- * ******************* pcaFit ***********************
220+ * pcaFit: Computes top principal components for dimensionality reduction
221+ *
222+ * @param input Tensor of shape (C,N) where C=channels, N=flattened spatial dims
223+ * @param new_C Target number of components to keep
224+ * @returns Principal component matrix shape (C,new_C) in descending eigenvalue order
225+ *
226+ * Note: Centers data and uses SVD for numerical stability
195227 */
196228inline torch::Tensor
197229pcaFit (torch::Tensor input, int new_C)
@@ -213,7 +245,17 @@ pcaFit(torch::Tensor input, int new_C)
213245} // end pcaFit
214246
215247/* *
216- * ******************* pca_transform ***********************
248+ * pcaTransform: Project data onto principal component basis
249+ *
250+ * @param input Tensor of shape (C,D,H,W) or (C,H,W) to transform
251+ * @param principalComponents Principal component matrix from pcaFit
252+ * @returns Transformed tensor with reduced channels, preserving spatial dims
253+ *
254+ * Implementation:
255+ * 1. Reshapes input to (C,N) where N=prod(spatial_dims)
256+ * 2. Centers data by channel mean
257+ * 3. Projects using principal component matrix
258+ * 4. Reshapes back to original spatial dimensions
217259 */
218260inline torch::Tensor
219261pcaTransform (torch::Tensor input, torch::Tensor principalComponents)
@@ -229,7 +271,30 @@ pcaTransform(torch::Tensor input, torch::Tensor principalComponents)
229271} // end pcaTransform
230272
231273/* *
232- * ******************* GetFeaturesMaps ***********************
274+ * GetFeaturesMaps: Extract deep features from image using configured models
275+ *
276+ * @tparam TImage Input image type
277+ * @tparam FeaturesMaps Output feature maps container type
278+ * @tparam InterpolatorType Interpolator for image sampling
279+ * @tparam FeaturesImageType Output feature image type
280+ * @param image Input image to extract features from
281+ * @param interpolator Interpolator instance for sampling
282+ * @param modelsConfiguration List of deep model configurations
283+ * @param device Computation device (CPU/GPU)
284+ * @param pca Dimensions for PCA reduction per layer
285+ * @param principalComponents PCA matrices for dimensionality reduction
286+ * @param writeInputImage Optional callback to save input patches
287+ * @param transformPoint Optional point transformation
288+ * @returns Vector of feature maps, one per selected model layer
289+ *
290+ * Processing workflow:
291+ * 1. Converts input to tensor with proper spacing
292+ * 2. Extracts patches according to model config
293+ * 3. Processes patches through models
294+ * 4. Optionally reduces dimensionality via PCA
295+ * 5. Converts results back to ITK images
296+ *
297+ * Handles both 2D and 3D inputs with proper dimension management
233298 */
234299template <typename TImage, typename FeaturesMaps, typename InterpolatorType, typename FeaturesImageType>
235300std::vector<FeaturesMaps>
@@ -462,14 +527,33 @@ GetFeaturesMaps(
462527
463528
464529/* *
465- * **************** GetModelOutputsExample ****************
530+ * GetModelOutputsExample: Validate model configurations with dummy inputs
531+ *
532+ * @param modelsConfig Vector of model configurations to validate
533+ * @param modelType String identifier for error reporting
534+ * @param device Computation device (CPU/GPU)
535+ * @returns Vector of example output tensors from each model
536+ *
537+ * Validation steps:
538+ * 1. Creates zero-filled dummy patches matching config specs
539+ * 2. Runs patches through models to verify layer structure
540+ * 3. Verifies layer mask compatibility
541+ * 4. Computes center indices for feature extraction
542+ *
543+ * Error handling:
544+ * - Validates dimension/channel compatibility
545+ * - Checks layer mask alignment with outputs
546+ * - Reports detailed configuration issues
547+ *
548+ * Note: Uses no_grad mode for efficiency
466549 */
467550inline std::vector<torch::Tensor>
468551GetModelOutputsExample (std::vector<itk::ImpactModelConfiguration> & modelsConfig,
469552 const std::string & modelType,
470553 torch::Device device)
471554{
472555
556+ // For each model, create dummy patch and get output layers to check structure
473557 std::vector<torch::Tensor> outputsTensor;
474558 {
475559 torch::NoGradGuard noGrad;
@@ -544,7 +628,14 @@ GetModelOutputsExample(std::vector<itk::ImpactModelConfiguration> & modelsConfig
544628} // end GetModelOutputsExample
545629
546630/* *
547- * ******************* GetPatchIndex ***********************
631+ * GetPatchIndex: Generates sampling grid for patch extraction
632+ *
633+ * @param modelConfiguration Model-specific patch configuration
634+ * @param randomGenerator RNG for stochastic patch orientation
635+ * @param dimension Target space dimension
636+ * @returns List of sampling coordinates per patch point
637+ *
638+ * For 2D patches: Applies random rotation to sampling
548639 */
549640inline std::vector<std::vector<float >>
550641GetPatchIndex (const itk::ImpactModelConfiguration & modelConfiguration,
@@ -612,7 +703,15 @@ GetPatchIndex(const itk::ImpactModelConfiguration & modelConfiguration,
612703} // end GetPatchIndex
613704
614705/* *
615- * ******************* GenerateOutputs ***********************
706+ * GenerateOutputs: Batch processing of image patches through deep models
707+ *
708+ * @param modelConfig Configuration per model including architecture and layers
709+ * @param fixedPoints List of control points to extract patches around
710+ * @param patchIndex Pre-computed sampling indices for each patch
711+ * @param subsetsOfFeatures Selected feature channels per layer
712+ * @returns List of output tensors, one per selected layer per model
713+ *
714+ * Performance note: Processes patches in batches to maximize GPU utilization
616715 */
617716template <typename ImagePointType>
618717std::vector<torch::Tensor>
@@ -675,7 +774,20 @@ GenerateOutputs(const std::vector<itk::ImpactModelConfiguration> &
675774} // end GenerateOutputs
676775
677776/* *
678- * ******************* GenerateOutputsAndJacobian ***********************
777+ * GenerateOutputsAndJacobian: Computes model outputs and their Jacobians
778+ *
779+ * @param modelConfig List of model configurations with architectures
780+ * @param fixedPoints Control points for patch extraction
781+ * @param patchIndex Sampling grid coordinates per patch
782+ * @param subsetsOfFeatures Feature channel selections
783+ * @param fixedOutputsTensor Reference outputs for loss calculation
784+ * @param device Computation device (CPU/GPU)
785+ * @param losses Loss function objects per output layer
786+ * @param imagesPatchValuesAndJacobiansEvaluator Callback for patch and gradient evaluation
787+ * @returns List of Jacobian tensors for each model output
788+ *
789+ * Performance note: Batches computation and uses autograd for efficiency
790+ * Handles multiple models and multiple output layers per model
679791 */
680792template <typename ImagePointType>
681793std::vector<torch::Tensor>
0 commit comments