@@ -1232,6 +1232,65 @@ default NDArray randomNormal(
12321232 return newSubManager (device ).randomNormal (loc , scale , shape , dataType );
12331233 }
12341234
1235+ /**
1236+ * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
1237+ * 1, discarding and re-drawing any samples that are more than two standard deviations from the
1238+ * mean.
1239+ *
1240+ * <p>Samples are distributed according to a normal distribution parametrized by mean = 0 and
1241+ * standard deviation = 1.
1242+ *
1243+ * @param shape the output {@link Shape}
1244+ * @return the drawn samples {@link NDArray}
1245+ */
1246+ default NDArray truncatedNormal (Shape shape ) {
1247+ return truncatedNormal (0f , 1f , shape , DataType .FLOAT32 );
1248+ }
1249+
1250+ /**
1251+ * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
1252+ * 1, discarding and re-drawing any samples that are more than two standard deviations from the
1253+ * mean.
1254+ *
1255+ * @param shape the output {@link Shape}
1256+ * @param dataType the {@link DataType} of the {@link NDArray}
1257+ * @return the drawn samples {@link NDArray}
1258+ */
1259+ default NDArray truncatedNormal (Shape shape , DataType dataType ) {
1260+ return truncatedNormal (0.0f , 1.0f , shape , dataType );
1261+ }
1262+
1263+ /**
1264+ * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
1265+ * samples that are more than two standard deviations from the mean.
1266+ *
1267+ * @param loc the mean (centre) of the distribution
1268+ * @param scale the standard deviation (spread or "width") of the distribution
1269+ * @param shape the output {@link Shape}
1270+ * @param dataType the {@link DataType} of the {@link NDArray}
1271+ * @return the drawn samples {@link NDArray}
1272+ */
1273+ NDArray truncatedNormal (float loc , float scale , Shape shape , DataType dataType );
1274+
1275+ /**
1276+ * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
1277+ * samples that are more than two standard deviations from the mean.
1278+ *
1279+ * @param loc the mean (centre) of the distribution
1280+ * @param scale the standard deviation (spread or "width") of the distribution
1281+ * @param shape the output {@link Shape}
1282+ * @param dataType the {@link DataType} of the {@link NDArray}
1283+ * @param device the {@link Device} of the {@link NDArray}
1284+ * @return the drawn samples {@link NDArray}
1285+ */
1286+ default NDArray truncatedNormal (
1287+ float loc , float scale , Shape shape , DataType dataType , Device device ) {
1288+ if (device == null || device .equals (getDevice ())) {
1289+ return truncatedNormal (loc , scale , shape , dataType );
1290+ }
1291+ return newSubManager (device ).truncatedNormal (loc , scale , shape , dataType );
1292+ }
1293+
12351294 /**
12361295 * Draw samples from a multinomial distribution.
12371296 *
0 commit comments