Skip to content

Commit 9fd9c93

Browse files
TensorFlow Federated Teamcopybara-github
authored andcommitted
Add integration test that exercises a tensorflow dataset.
PiperOrigin-RevId: 773442984
1 parent cf64c1c commit 9fd9c93

3 files changed

Lines changed: 46 additions & 0 deletions

File tree

tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def local_cpp_executor_factory(
122122
] = None,
123123
) -> federated_language.framework.ExecutorFactory:
124124
"""Local ExecutorFactory backed by C++ Executor bindings."""
125+
126+
print(
127+
"dalyk in cpp_executor_factory local_cpp_executor_factory, meaning we're"
128+
' using pybind'
129+
)
130+
125131
_check_num_clients_is_valid(default_num_clients)
126132

127133
def _executor_fn(

tensorflow_federated/python/core/impl/executor_stacks/executor_factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def local_cpp_executor_factory(
6060
RuntimeError: If an internal C++ worker binary can not be found.
6161
"""
6262

63+
print(
64+
"dalyk in executor_factory local_cpp_executor_factory, meaning we're"
65+
' using grpc'
66+
)
67+
6368
# This path is specified relative to this file because the relative location
6469
# of the worker binary will remain the same when this function is executed
6570
# from the Python package and from a Bazel test.

tensorflow_federated/python/tests/backend_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,41 @@ def foo():
173173
result = foo()
174174
self.assertEqual(result, 10)
175175

176+
@test_contexts.with_contexts(*test_contexts.get_all_contexts())
177+
def test_dataset(self):
178+
179+
client_data_type = federated_language.FederatedType(
180+
federated_language.TensorType(np.int32, [3]),
181+
federated_language.CLIENTS,
182+
)
183+
server_state_type = federated_language.FederatedType(
184+
np.int32, federated_language.SERVER
185+
)
186+
187+
def reduce_fn(state, batch):
188+
return state + batch
189+
190+
@tff.tensorflow.computation(client_data_type.member)
191+
def client_transform(data):
192+
dataset = tf.data.Dataset.from_tensor_slices(data)
193+
return dataset.reduce(0, reduce_fn)
194+
195+
@federated_language.federated_computation(
196+
[client_data_type, server_state_type]
197+
)
198+
def my_comp(client_data, server_data):
199+
transformed_client_data = federated_language.federated_map(
200+
client_transform, client_data
201+
)
202+
return (
203+
federated_language.federated_sum(transformed_client_data),
204+
server_data,
205+
)
206+
207+
result_1, result_2 = my_comp([[1, 2, 3], [4, 5, 6]], 30)
208+
self.assertEqual(result_1, 21)
209+
self.assertEqual(result_2, 30)
210+
176211
@test_contexts.with_contexts(*test_contexts.get_all_contexts())
177212
def test_empty_tuple(self):
178213

0 commit comments

Comments
 (0)