Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import jax
import numpy as np

from keras.src.backend.common import global_state
from keras.src.random import seed_generator
from keras.src.utils import jax_utils
from keras.src.utils import rng_utils


def list_devices(device_type=None):
Expand Down Expand Up @@ -185,6 +188,52 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
return global_batch_array


def initialize_rng():
"""Initializes the global random number generator across processes.

This is required for consistent initialization in multi-host settings.
"""
global_seed = rng_utils.get_random_seed()
# Only set a random seed if not already set
# via keras.config.set_random_seed()
if global_seed is None:
# Generate a random seed on each CPU host and psum them to get a single
# consistent seed across all processes.
cpu_devices = jax.devices("cpu")
num_local_cpu_devices = jax.local_device_count("cpu")
# Seed must be in range [0, 2^32 - 1], so to ensure proper range and
# avoid signed integer overflow, we use uint32.
local_seed = jax.numpy.asarray(
[seed_generator.make_default_seed()] * num_local_cpu_devices,
dtype=jax.numpy.uint32,
)
# Sum across processes and pull out the first item.
global_seed = jax.pmap(
lambda x: jax.lax.psum(x, "all"),
axis_name="all",
devices=cpu_devices,
)(local_seed).item(0)
# Set the global seed.
rng_utils.set_random_seed(global_seed)

# Check if the global seed generator is set and ensure it has an initialized
# seed. Otherwise, reset the seed to the global seed.
global_seed_generator = global_state.get_global_attribute(
"global_seed_generator"
)
if global_seed_generator is not None:
seed = global_seed_generator.get_config()["seed"]
if seed is None:
global_state.set_global_attribute(
"global_seed_generator",
seed_generator.SeedGenerator(
seed=global_seed,
name=global_seed_generator.name,
backend=global_seed_generator.backend,
),
)


def initialize(job_addresses, num_processes, process_id):
if job_addresses and "," in job_addresses:
# When user provide all the job addresses, we will split and get the
Expand All @@ -208,6 +257,9 @@ def initialize(job_addresses, num_processes, process_id):
process_id=process_id,
)

# Ensure the random number generator is initialized across processes.
initialize_rng()


def num_processes():
"""Return the number of processes for the current distribution setting."""
Expand Down
15 changes: 15 additions & 0 deletions keras/src/utils/rng_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.utils.module_utils import tensorflow as tf

GLOBAL_RANDOM_SEED = "global_random_seed"


@keras_export("keras.utils.set_random_seed")
def set_random_seed(seed):
Expand Down Expand Up @@ -46,6 +49,9 @@ def set_random_seed(seed):
"Expected `seed` argument to be an integer. "
f"Received: seed={seed} (of type {type(seed)})"
)

# Store seed in global state so we can query it if set.
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
random.seed(seed)
np.random.seed(seed)
if tf.available:
Expand All @@ -54,3 +60,12 @@ def set_random_seed(seed):
import torch

torch.manual_seed(seed)


def get_random_seed():
"""Returns the explicit integer random seed if set.

If the seed has been explicitly set via `set_random_seed`, then
returns the seed. Otherwise, returns `None`.
"""
return global_state.get_global_attribute(GLOBAL_RANDOM_SEED)