Skip to content

Commit 1581a07

Browse files
akshaychitneniAkshay Chitneni
andauthored
feat(cache): KEP-2655: Adding cache initializer (#2793)
Signed-off-by: Akshay Chitneni <[email protected]> Co-authored-by: Akshay Chitneni <[email protected]>
1 parent 2256a81 commit 1581a07

File tree

7 files changed

+529
-1
lines changed

7 files changed

+529
-1
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
22
max-line-length = 100
3-
extend-ignore = W503
3+
extend-ignore = W503, E203
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
huggingface-hub>=0.27.0,<0.28
2+
kubernetes>=27.2.0

pkg/initializers/dataset/__main__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from urllib.parse import urlparse
44

55
import pkg.initializers.utils.utils as utils
6+
from pkg.initializers.dataset.cache import CacheInitializer
67
from pkg.initializers.dataset.huggingface import HuggingFace
78

89
logging.basicConfig(
@@ -27,6 +28,10 @@ def main():
2728
hf = HuggingFace()
2829
hf.load_config()
2930
hf.download_dataset()
31+
case utils.CACHE_SCHEME:
32+
cache = CacheInitializer()
33+
cache.load_config()
34+
cache.download_dataset()
3035
case _:
3136
logging.error("STORAGE_URI must have the valid dataset provider")
3237
raise Exception

pkg/initializers/dataset/cache.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import logging
2+
import time
3+
4+
from kubernetes import client, config
5+
from kubernetes.client.rest import ApiException
6+
from kubernetes.dynamic.exceptions import ConflictError
7+
8+
import pkg.initializers.types.types as types
9+
import pkg.initializers.utils.utils as utils
10+
11+
logging.basicConfig(
12+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
13+
datefmt="%Y-%m-%dT%H:%M:%SZ",
14+
level=logging.INFO,
15+
)
16+
17+
18+
def get_namespace() -> str:
19+
"""Get the current namespace from the service account token."""
20+
try:
21+
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") as f:
22+
return f.readline().strip()
23+
except FileNotFoundError:
24+
logging.warning(
25+
"Service account namespace file not found, using 'default' namespace"
26+
)
27+
return "default"
28+
29+
30+
class CacheInitializer(utils.DatasetProvider):
31+
32+
def load_config(self):
33+
config_dict = utils.get_config_from_env(types.CacheDatasetInitializer)
34+
self.config = types.CacheDatasetInitializer(**config_dict)
35+
36+
# Parse schema_name and table_name from storage_uri
37+
# Format: cache://<SCHEMA_NAME>/<TABLE_NAME>
38+
uri_path = self.config.storage_uri[len("cache://") :]
39+
parts = uri_path.split("/")
40+
self.schema_name = parts[0]
41+
self.table_name = parts[1]
42+
43+
def download_dataset(self):
44+
"""Bootstrap cache cluster with dataset"""
45+
logging.info(
46+
f"Cache initializer called with storage URI: {self.config.storage_uri}"
47+
)
48+
49+
train_job_name = self.config.train_job_name
50+
cache_image = self.config.cache_image
51+
cluster_size = int(self.config.cluster_size)
52+
iam_role = self.config.iam_role
53+
head_cpu = self.config.head_cpu
54+
head_mem = self.config.head_mem
55+
worker_cpu = self.config.worker_cpu
56+
worker_mem = self.config.worker_mem
57+
namespace = get_namespace()
58+
metadata_loc = self.config.metadata_loc
59+
table_name = self.table_name
60+
schema_name = self.schema_name
61+
62+
# Load Kubernetes configuration
63+
config.load_incluster_config()
64+
65+
api_client = client.ApiClient()
66+
core_v1 = client.CoreV1Api(api_client)
67+
custom_api = client.CustomObjectsApi(api_client)
68+
69+
# Get TrainJob for owner reference
70+
try:
71+
training_job = custom_api.get_namespaced_custom_object(
72+
group="trainer.kubeflow.org",
73+
version="v1alpha1",
74+
plural="trainjobs",
75+
namespace=namespace,
76+
name=train_job_name,
77+
)
78+
logging.info(f"TrainJob: {training_job}")
79+
80+
# Create owner reference dictionary
81+
logging.info(
82+
f"Creating owner reference from TrainJob: {training_job['metadata']['name']}"
83+
)
84+
85+
owner_ref_dict = {
86+
"apiVersion": training_job["apiVersion"],
87+
"kind": training_job["kind"],
88+
"name": training_job["metadata"]["name"],
89+
"uid": training_job["metadata"]["uid"],
90+
"controller": True,
91+
"blockOwnerDeletion": True,
92+
}
93+
94+
logging.info(
95+
f"Owner reference created with apiVersion='{training_job['apiVersion']}', "
96+
f"kind='{training_job['kind']}')"
97+
)
98+
except ApiException as e:
99+
logging.error(f"Failed to get TrainJob {train_job_name}: {e}")
100+
return
101+
102+
try:
103+
# Create ServiceAccount
104+
service_account = client.V1ServiceAccount(
105+
metadata=client.V1ObjectMeta(
106+
name=f"{train_job_name}-cache",
107+
namespace=namespace,
108+
annotations={
109+
"eks.amazonaws.com/sts-regional-endpoints": "true",
110+
"eks.amazonaws.com/role-arn": iam_role,
111+
},
112+
owner_references=[owner_ref_dict],
113+
)
114+
)
115+
116+
try:
117+
core_v1.create_namespaced_service_account(
118+
namespace=namespace, body=service_account
119+
)
120+
logging.info(f"Created ServiceAccount {service_account.metadata.name}")
121+
except ApiException as e:
122+
if e.status == 409:
123+
logging.info(
124+
f"ServiceAccount {service_account.metadata.name} "
125+
f"already exists, skipping creation"
126+
)
127+
else:
128+
raise e
129+
130+
# Prepare environment variables
131+
env_vars = []
132+
if metadata_loc:
133+
env_vars.append({"name": "METADATA_LOC", "value": metadata_loc})
134+
if table_name:
135+
env_vars.append({"name": "TABLE_NAME", "value": table_name})
136+
if schema_name:
137+
env_vars.append({"name": "SCHEMA_NAME", "value": schema_name})
138+
139+
# Create LeaderWorkerSet
140+
lws_body = {
141+
"apiVersion": "leaderworkerset.x-k8s.io/v1",
142+
"kind": "LeaderWorkerSet",
143+
"metadata": {
144+
"name": f"{train_job_name}-cache",
145+
"namespace": namespace,
146+
"ownerReferences": [owner_ref_dict],
147+
},
148+
"spec": {
149+
"replicas": 1,
150+
"leaderWorkerTemplate": {
151+
"size": cluster_size,
152+
"leaderTemplate": {
153+
"metadata": {
154+
"labels": {"app": f"{train_job_name}-cache-head"}
155+
},
156+
"spec": {
157+
"serviceAccountName": service_account.metadata.name,
158+
"containers": [
159+
{
160+
"name": "head",
161+
"image": cache_image,
162+
"command": ["head"],
163+
"args": ["0.0.0.0", "50051"],
164+
"resources": {
165+
"limits": {
166+
"cpu": head_cpu,
167+
"memory": head_mem,
168+
},
169+
"requests": {
170+
"cpu": head_cpu,
171+
"memory": head_mem,
172+
},
173+
},
174+
"env": env_vars,
175+
"ports": [{"containerPort": 50051}],
176+
}
177+
],
178+
},
179+
},
180+
"workerTemplate": {
181+
"spec": {
182+
"serviceAccountName": f"{train_job_name}-cache",
183+
"containers": [
184+
{
185+
"name": "worker",
186+
"image": cache_image,
187+
"command": ["worker"],
188+
"args": ["0.0.0.0", "50051"],
189+
"resources": {
190+
"limits": {
191+
"cpu": worker_cpu,
192+
"memory": worker_mem,
193+
},
194+
"requests": {
195+
"cpu": worker_cpu,
196+
"memory": worker_mem,
197+
},
198+
},
199+
"env": env_vars,
200+
"ports": [{"containerPort": 50051}],
201+
}
202+
],
203+
}
204+
},
205+
},
206+
},
207+
}
208+
209+
# Create LeaderWorkerSet
210+
custom_api.create_namespaced_custom_object(
211+
group="leaderworkerset.x-k8s.io",
212+
version="v1",
213+
namespace=namespace,
214+
plural="leaderworkersets",
215+
body=lws_body,
216+
)
217+
logging.info(f"Created LeaderWorkerSet {lws_body['metadata']['name']}")
218+
219+
# Create Service
220+
service = client.V1Service(
221+
metadata=client.V1ObjectMeta(
222+
name=f"{train_job_name}-cache-service",
223+
namespace=namespace,
224+
owner_references=[owner_ref_dict],
225+
),
226+
spec=client.V1ServiceSpec(
227+
selector={"app": f"{train_job_name}-cache-head"},
228+
ports=[
229+
client.V1ServicePort(
230+
protocol="TCP", port=50051, target_port=50051
231+
)
232+
],
233+
),
234+
)
235+
236+
try:
237+
core_v1.create_namespaced_service(namespace=namespace, body=service)
238+
logging.info(f"Created Service {service.metadata.name}")
239+
except ApiException as e:
240+
if e is ConflictError:
241+
logging.info(
242+
f"Service {service.metadata.name} already exists, "
243+
f"skipping creation"
244+
)
245+
else:
246+
raise e
247+
248+
# Wait for LeaderWorkerSet to become ready
249+
# TODO:// refactor to use watch API
250+
while True:
251+
try:
252+
lws = custom_api.get_namespaced_custom_object(
253+
group="leaderworkerset.x-k8s.io",
254+
version="v1",
255+
plural="leaderworkersets",
256+
name=lws_body["metadata"]["name"],
257+
namespace=namespace,
258+
)
259+
260+
conditions = lws.get("status", {}).get("conditions", [])
261+
if any(
262+
c["type"] == "Available" and c["status"] == "True"
263+
for c in conditions
264+
):
265+
logging.info(
266+
f"LeaderWorkerSet {lws_body['metadata']['name']} is ready"
267+
)
268+
break
269+
270+
time.sleep(5)
271+
except ApiException as e:
272+
raise e
273+
274+
except ApiException as e:
275+
logging.error(f"Cache cluster creation failed: {e}")
276+
# Cleanup on failure
277+
try:
278+
core_v1.delete_namespaced_service_account(
279+
name=f"{train_job_name}-cache", namespace=namespace
280+
)
281+
except Exception as cleanup_error:
282+
logging.error(f"Error cleaning up ServiceAccount: {cleanup_error}")
283+
return
284+
285+
logging.info("Cache cluster creation completed")

0 commit comments

Comments
 (0)