@@ -134,30 +134,31 @@ void compute_weights (const amrex::ParticleReal xp,
134134 * \param W 2D array of weights for each neighbouring node
135135 * \param scalar_field Array4 of the nodal scalar field, either full array or tile.
136136 */
137+ template <int DIMS = AMREX_SPACEDIM>
137138AMREX_GPU_HOST_DEVICE AMREX_INLINE
138139amrex::Real interp_field_nodal (int i, int j, int k,
139140 const amrex::Real W[AMREX_SPACEDIM][2 ],
140141 amrex::Array4<const amrex::Real> const & scalar_field) noexcept
141142{
142143 amrex::Real value = 0 ;
143- # if (defined WARPX_DIM_3D)
144- value += scalar_field (i, j , k ) * W[0 ][0 ] * W[1 ][0 ] * W[2 ][0 ];
145- value += scalar_field (i+1 , j , k ) * W[0 ][1 ] * W[1 ][0 ] * W[2 ][0 ];
146- value += scalar_field (i, j+1 , k ) * W[0 ][0 ] * W[1 ][1 ] * W[2 ][0 ];
147- value += scalar_field (i+1 , j+1 , k ) * W[0 ][1 ] * W[1 ][1 ] * W[2 ][0 ];
148- value += scalar_field (i, j , k+1 ) * W[0 ][0 ] * W[1 ][0 ] * W[2 ][1 ];
149- value += scalar_field (i+1 , j , k+1 ) * W[0 ][1 ] * W[1 ][0 ] * W[2 ][1 ];
150- value += scalar_field (i , j+1 , k+1 ) * W[0 ][0 ] * W[1 ][1 ] * W[2 ][1 ];
151- value += scalar_field (i+1 , j+1 , k+1 ) * W[0 ][1 ] * W[1 ][1 ] * W[2 ][1 ];
152- # elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
153- value += scalar_field (i, j , k) * W[0 ][0 ] * W[1 ][0 ];
154- value += scalar_field (i+1 , j , k) * W[0 ][1 ] * W[1 ][0 ];
155- value += scalar_field (i, j+1 , k) * W[0 ][0 ] * W[1 ][1 ];
156- value += scalar_field (i+1 , j+1 , k) * W[0 ][1 ] * W[1 ][1 ];
157- # else
158- value += scalar_field (i, j , k) * W[0 ][0 ];
159- value += scalar_field (i+1 , j , k) * W[0 ][1 ];
160- # endif
144+ if constexpr (DIMS == 3 ) {
145+ value += scalar_field (i, j , k ) * W[0 ][0 ] * W[1 ][0 ] * W[2 ][0 ];
146+ value += scalar_field (i+1 , j , k ) * W[0 ][1 ] * W[1 ][0 ] * W[2 ][0 ];
147+ value += scalar_field (i, j+1 , k ) * W[0 ][0 ] * W[1 ][1 ] * W[2 ][0 ];
148+ value += scalar_field (i+1 , j+1 , k ) * W[0 ][1 ] * W[1 ][1 ] * W[2 ][0 ];
149+ value += scalar_field (i, j , k+1 ) * W[0 ][0 ] * W[1 ][0 ] * W[2 ][1 ];
150+ value += scalar_field (i+1 , j , k+1 ) * W[0 ][1 ] * W[1 ][0 ] * W[2 ][1 ];
151+ value += scalar_field (i , j+1 , k+1 ) * W[0 ][0 ] * W[1 ][1 ] * W[2 ][1 ];
152+ value += scalar_field (i+1 , j+1 , k+1 ) * W[0 ][1 ] * W[1 ][1 ] * W[2 ][1 ];
153+ } else if constexpr (DIMS == 2 ) {
154+ value += scalar_field (i, j , k) * W[0 ][0 ] * W[1 ][0 ];
155+ value += scalar_field (i+1 , j , k) * W[0 ][1 ] * W[1 ][0 ];
156+ value += scalar_field (i, j+1 , k) * W[0 ][0 ] * W[1 ][1 ];
157+ value += scalar_field (i+1 , j+1 , k) * W[0 ][1 ] * W[1 ][1 ];
158+ } else {
159+ value += scalar_field (i, j , k) * W[0 ][0 ];
160+ value += scalar_field (i+1 , j , k) * W[0 ][1 ];
161+ }
161162 return value;
162163}
163164
@@ -170,6 +171,7 @@ amrex::Real interp_field_nodal (int i, int j, int k,
170171 * \param dxi inverse 3D cell spacing
171172 * \param lo Index lower bounds of domain.
172173 */
174+ template <int DIMS = AMREX_SPACEDIM>
173175AMREX_GPU_HOST_DEVICE AMREX_INLINE
174176amrex::Real doGatherScalarFieldNodal (const amrex::ParticleReal xp,
175177 const amrex::ParticleReal yp,
@@ -183,7 +185,7 @@ amrex::Real doGatherScalarFieldNodal (const amrex::ParticleReal xp,
183185 amrex::Real W[AMREX_SPACEDIM][2 ];
184186 compute_weights<amrex::IndexType::NODE>(xp, yp, zp, lo, dxi, ii, jj, kk, W);
185187
186- return interp_field_nodal (ii, jj, kk, W, scalar_field);
188+ return interp_field_nodal<DIMS> (ii, jj, kk, W, scalar_field);
187189}
188190
189191/* *
@@ -195,6 +197,7 @@ amrex::Real doGatherScalarFieldNodal (const amrex::ParticleReal xp,
195197 * \param dxi inverse 3D cell spacing
196198 * \param lo Index lower bounds of domain.
197199 */
200+ template <int DIMS = AMREX_SPACEDIM>
198201AMREX_GPU_HOST_DEVICE AMREX_INLINE
199202amrex::GpuArray<amrex::Real, 3 >
200203doGatherVectorFieldNodal (const amrex::ParticleReal xp,
@@ -212,9 +215,9 @@ doGatherVectorFieldNodal (const amrex::ParticleReal xp,
212215 compute_weights<amrex::IndexType::NODE>(xp, yp, zp, lo, dxi, ii, jj, kk, W);
213216
214217 amrex::GpuArray<amrex::Real, 3 > const field_interp = {
215- interp_field_nodal (ii, jj, kk, W, vector_field_x),
216- interp_field_nodal (ii, jj, kk, W, vector_field_y),
217- interp_field_nodal (ii, jj, kk, W, vector_field_z)
218+ interp_field_nodal<DIMS> (ii, jj, kk, W, vector_field_x),
219+ interp_field_nodal<DIMS> (ii, jj, kk, W, vector_field_y),
220+ interp_field_nodal<DIMS> (ii, jj, kk, W, vector_field_z)
218221 };
219222
220223 return field_interp;
0 commit comments