Skip to content

Commit 5fc611b

Browse files
author
Akshay Chitneni
committed
Adding cache initializer
Signed-off-by: Akshay Chitneni <[email protected]>
1 parent bbaca24 commit 5fc611b

File tree

6 files changed

+545
-0
lines changed

6 files changed

+545
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
huggingface-hub>=0.27.0,<0.28
2+
kubernetes>=27.2.0

pkg/initializers/dataset/__main__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from urllib.parse import urlparse
44

55
import pkg.initializers.utils.utils as utils
6+
from pkg.initializers.dataset.cache import CacheInitializer
67
from pkg.initializers.dataset.huggingface import HuggingFace
78

89
logging.basicConfig(
@@ -27,6 +28,10 @@ def main():
2728
hf = HuggingFace()
2829
hf.load_config()
2930
hf.download_dataset()
31+
case utils.CACHE_SCHEME:
32+
cache = CacheInitializer()
33+
cache.load_config()
34+
cache.download_dataset()
3035
case _:
3136
logging.error("STORAGE_URI must have the valid dataset provider")
3237
raise Exception

pkg/initializers/dataset/cache.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import logging
2+
import time
3+
4+
from kubernetes import client, config
5+
from kubernetes.client.rest import ApiException
6+
7+
import pkg.initializers.types.types as types
8+
import pkg.initializers.utils.utils as utils
9+
10+
logging.basicConfig(
11+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
12+
datefmt="%Y-%m-%dT%H:%M:%SZ",
13+
level=logging.INFO,
14+
)
15+
16+
17+
def get_namespace() -> str:
18+
"""Get the current namespace from the service account token."""
19+
try:
20+
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") as f:
21+
return f.readline().strip()
22+
except FileNotFoundError:
23+
logging.warning(
24+
"Service account namespace file not found, using 'default' namespace"
25+
)
26+
return "default"
27+
28+
29+
class CacheInitializer(utils.DatasetProvider):
30+
31+
def load_config(self):
32+
config_dict = utils.get_config_from_env(types.CacheDatasetInitializer)
33+
self.config = types.CacheDatasetInitializer(**config_dict)
34+
35+
def download_dataset(self):
36+
self.create_cache_cluster()
37+
38+
def create_cache_cluster(self):
39+
"""Create cache cluster resources using Kubernetes client SDK."""
40+
logging.info(
41+
f"Cache initializer called with storage URI: {self.config.storage_uri}"
42+
)
43+
44+
train_job_name = self.config.train_job_name
45+
cache_image = self.config.cache_image
46+
cluster_size = int(self.config.cluster_size)
47+
iam_role = self.config.iam_role
48+
head_cpu = self.config.head_cpu
49+
head_mem = self.config.head_mem
50+
worker_cpu = self.config.worker_cpu
51+
worker_mem = self.config.worker_mem
52+
namespace = get_namespace()
53+
metadata_loc = self.config.metadata_loc
54+
table_name = self.config.table_name
55+
schema_name = self.config.schema_name
56+
57+
# Load Kubernetes configuration
58+
config.load_incluster_config()
59+
60+
api_client = client.ApiClient()
61+
core_v1 = client.CoreV1Api(api_client)
62+
custom_api = client.CustomObjectsApi(api_client)
63+
64+
# Get TrainJob for owner reference
65+
try:
66+
training_job = custom_api.get_namespaced_custom_object(
67+
group="trainer.kubeflow.org",
68+
version="v1alpha1",
69+
plural="trainjobs",
70+
namespace=namespace,
71+
name=train_job_name,
72+
)
73+
logging.info(f"TrainJob: {training_job}")
74+
75+
# Create owner reference dictionary
76+
logging.info(
77+
f"Creating owner reference from TrainJob: {training_job['metadata']['name']}"
78+
)
79+
80+
owner_ref_dict = {
81+
"apiVersion": training_job["apiVersion"],
82+
"kind": training_job["kind"],
83+
"name": training_job["metadata"]["name"],
84+
"uid": training_job["metadata"]["uid"],
85+
"controller": True,
86+
"blockOwnerDeletion": True,
87+
}
88+
89+
logging.info(
90+
f"Owner reference created with apiVersion='{training_job['apiVersion']}', "
91+
f"kind='{training_job['kind']}')"
92+
)
93+
except ApiException as e:
94+
logging.error(f"Failed to get TrainJob {train_job_name}: {e}")
95+
return
96+
97+
try:
98+
# Create ServiceAccount
99+
service_account = client.V1ServiceAccount(
100+
metadata=client.V1ObjectMeta(
101+
name=f"{train_job_name}-cache",
102+
namespace=namespace,
103+
annotations={
104+
"eks.amazonaws.com/sts-regional-endpoints": "true",
105+
"eks.amazonaws.com/role-arn": iam_role,
106+
},
107+
owner_references=[owner_ref_dict],
108+
)
109+
)
110+
111+
try:
112+
core_v1.create_namespaced_service_account(
113+
namespace=namespace, body=service_account
114+
)
115+
logging.info(f"Created ServiceAccount {service_account.metadata.name}")
116+
except ApiException as e:
117+
if e.status == 409:
118+
logging.info(
119+
f"ServiceAccount {service_account.metadata.name} "
120+
f"already exists, skipping creation"
121+
)
122+
else:
123+
raise e
124+
125+
# Prepare environment variables
126+
env_vars = []
127+
if metadata_loc:
128+
env_vars.append({"name": "METADATA_LOC", "value": metadata_loc})
129+
if table_name:
130+
env_vars.append({"name": "TABLE_NAME", "value": table_name})
131+
if schema_name:
132+
env_vars.append({"name": "SCHEMA_NAME", "value": schema_name})
133+
134+
# Create LeaderWorkerSet
135+
lws_body = {
136+
"apiVersion": "leaderworkerset.x-k8s.io/v1",
137+
"kind": "LeaderWorkerSet",
138+
"metadata": {
139+
"name": f"{train_job_name}-cache",
140+
"namespace": namespace,
141+
"ownerReferences": [owner_ref_dict],
142+
},
143+
"spec": {
144+
"replicas": 1,
145+
"leaderWorkerTemplate": {
146+
"size": cluster_size,
147+
"leaderTemplate": {
148+
"metadata": {
149+
"labels": {"app": f"{train_job_name}-cache-head"}
150+
},
151+
"spec": {
152+
"serviceAccountName": f"{train_job_name}-cache",
153+
"containers": [
154+
{
155+
"name": "head",
156+
"image": cache_image,
157+
"command": ["head"],
158+
"args": ["0.0.0.0", "50051"],
159+
"resources": {
160+
"limits": {
161+
"cpu": head_cpu,
162+
"memory": head_mem,
163+
},
164+
"requests": {
165+
"cpu": head_cpu,
166+
"memory": head_mem,
167+
},
168+
},
169+
"env": env_vars,
170+
"ports": [{"containerPort": 50051}],
171+
}
172+
],
173+
},
174+
},
175+
"workerTemplate": {
176+
"spec": {
177+
"serviceAccountName": f"{train_job_name}-cache",
178+
"containers": [
179+
{
180+
"name": "worker",
181+
"image": cache_image,
182+
"command": ["worker"],
183+
"args": ["0.0.0.0", "50051"],
184+
"resources": {
185+
"limits": {
186+
"cpu": worker_cpu,
187+
"memory": worker_mem,
188+
},
189+
"requests": {
190+
"cpu": worker_cpu,
191+
"memory": worker_mem,
192+
},
193+
},
194+
"env": env_vars,
195+
"ports": [{"containerPort": 50051}],
196+
}
197+
],
198+
}
199+
},
200+
},
201+
},
202+
}
203+
204+
# Create LeaderWorkerSet
205+
custom_api.create_namespaced_custom_object(
206+
group="leaderworkerset.x-k8s.io",
207+
version="v1",
208+
namespace=namespace,
209+
plural="leaderworkersets",
210+
body=lws_body,
211+
)
212+
logging.info(f"Created LeaderWorkerSet {lws_body['metadata']['name']}")
213+
214+
# Create Service
215+
service = client.V1Service(
216+
metadata=client.V1ObjectMeta(
217+
name=f"{train_job_name}-cache-service",
218+
namespace=namespace,
219+
owner_references=[owner_ref_dict],
220+
),
221+
spec=client.V1ServiceSpec(
222+
selector={"app": f"{train_job_name}-cache-head"},
223+
ports=[
224+
client.V1ServicePort(
225+
protocol="TCP", port=50051, target_port=50051
226+
)
227+
],
228+
),
229+
)
230+
231+
try:
232+
core_v1.create_namespaced_service(namespace=namespace, body=service)
233+
logging.info(f"Created Service {service.metadata.name}")
234+
except ApiException as e:
235+
if e.status == 409:
236+
logging.info(
237+
f"Service {service.metadata.name} already exists, "
238+
f"skipping creation"
239+
)
240+
else:
241+
raise e
242+
243+
# Wait for LeaderWorkerSet to become ready
244+
# TODO:// refactor to use watch API
245+
while True:
246+
try:
247+
lws = custom_api.get_namespaced_custom_object(
248+
group="leaderworkerset.x-k8s.io",
249+
version="v1",
250+
plural="leaderworkersets",
251+
name=lws_body["metadata"]["name"],
252+
namespace=namespace,
253+
)
254+
255+
conditions = lws.get("status", {}).get("conditions", [])
256+
if any(
257+
c["type"] == "Available" and c["status"] == "True"
258+
for c in conditions
259+
):
260+
logging.info(
261+
f"LeaderWorkerSet {lws_body['metadata']['name']} is ready"
262+
)
263+
break
264+
265+
time.sleep(5)
266+
except ApiException as e:
267+
raise e
268+
269+
except ApiException as e:
270+
logging.error(f"Cache cluster creation failed: {e}")
271+
# Cleanup on failure
272+
try:
273+
core_v1.delete_namespaced_service_account(
274+
name=f"{train_job_name}-cache", namespace=namespace
275+
)
276+
except Exception as cleanup_error:
277+
logging.error(f"Error cleaning up ServiceAccount: {cleanup_error}")
278+
return
279+
280+
logging.info("Cache cluster creation completed")

0 commit comments

Comments
 (0)