Skip to content

Commit 0aec8ca

Browse files
authored
[tensoflow] Add truncated normal operation (deepjavalibrary#1005)
1 parent d8e7e1d commit 0aec8ca

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

api/src/main/java/ai/djl/ndarray/BaseNDManager.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
153153
throw new UnsupportedOperationException("Not supported!");
154154
}
155155

156+
/** {@inheritDoc} */
157+
@Override
158+
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
159+
throw new UnsupportedOperationException("Not supported!");
160+
}
161+
156162
/** {@inheritDoc} */
157163
@Override
158164
public NDArray randomMultinomial(int n, NDArray pValues) {

api/src/main/java/ai/djl/ndarray/NDManager.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,65 @@ default NDArray randomNormal(
12321232
return newSubManager(device).randomNormal(loc, scale, shape, dataType);
12331233
}
12341234

1235+
/**
1236+
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
1237+
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
1238+
* mean.
1239+
*
1240+
* <p>Samples are distributed according to a normal distribution parametrized by mean = 0 and
1241+
* standard deviation = 1.
1242+
*
1243+
* @param shape the output {@link Shape}
1244+
* @return the drawn samples {@link NDArray}
1245+
*/
1246+
default NDArray truncatedNormal(Shape shape) {
1247+
return truncatedNormal(0f, 1f, shape, DataType.FLOAT32);
1248+
}
1249+
1250+
/**
1251+
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
1252+
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
1253+
* mean.
1254+
*
1255+
* @param shape the output {@link Shape}
1256+
* @param dataType the {@link DataType} of the {@link NDArray}
1257+
* @return the drawn samples {@link NDArray}
1258+
*/
1259+
default NDArray truncatedNormal(Shape shape, DataType dataType) {
1260+
return truncatedNormal(0.0f, 1.0f, shape, dataType);
1261+
}
1262+
1263+
/**
1264+
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
1265+
* samples that are more than two standard deviations from the mean.
1266+
*
1267+
* @param loc the mean (centre) of the distribution
1268+
* @param scale the standard deviation (spread or "width") of the distribution
1269+
* @param shape the output {@link Shape}
1270+
* @param dataType the {@link DataType} of the {@link NDArray}
1271+
* @return the drawn samples {@link NDArray}
1272+
*/
1273+
NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType);
1274+
1275+
/**
1276+
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
1277+
* samples that are more than two standard deviations from the mean.
1278+
*
1279+
* @param loc the mean (centre) of the distribution
1280+
* @param scale the standard deviation (spread or "width") of the distribution
1281+
* @param shape the output {@link Shape}
1282+
* @param dataType the {@link DataType} of the {@link NDArray}
1283+
* @param device the {@link Device} of the {@link NDArray}
1284+
* @return the drawn samples {@link NDArray}
1285+
*/
1286+
default NDArray truncatedNormal(
1287+
float loc, float scale, Shape shape, DataType dataType, Device device) {
1288+
if (device == null || device.equals(getDevice())) {
1289+
return truncatedNormal(loc, scale, shape, dataType);
1290+
}
1291+
return newSubManager(device).truncatedNormal(loc, scale, shape, dataType);
1292+
}
1293+
12351294
/**
12361295
* Draw samples from a multinomial distribution.
12371296
*

tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,31 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
242242
}
243243
}
244244

245+
/** {@inheritDoc} */
246+
@Override
247+
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
248+
if (DataType.STRING.equals(dataType)) {
249+
throw new IllegalArgumentException("String data type is not supported!");
250+
}
251+
NDArray axes = create(shape.getShape());
252+
TfOpExecutor opBuilder =
253+
opExecutor("TruncatedNormal").addInput(axes).addParam("dtype", dataType);
254+
Integer seed = getEngine().getSeed();
255+
if (seed != null) {
256+
// seed1 is graph-level seed
257+
// set it to default graph seed used by tensorflow
258+
// https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/framework/random_seed.py#L31
259+
opBuilder.addParam("seed", 87654321);
260+
opBuilder.addParam("seed2", seed);
261+
}
262+
try (NDArray array = opBuilder.buildSingletonOrThrow();
263+
NDArray temp = array.mul(scale)) {
264+
return temp.add(loc);
265+
} finally {
266+
axes.close();
267+
}
268+
}
269+
245270
/** {@inheritDoc} */
246271
@Override
247272
public TfNDManager newSubManager(Device device) {

0 commit comments

Comments
 (0)