Skip to content

Commit e3b9fad

Browse files
author
Akshay Chitneni
committed
Adding cache initializer
Signed-off-by: Akshay Chitneni <[email protected]>
1 parent bbaca24 commit e3b9fad

File tree

6 files changed

+759
-0
lines changed

6 files changed

+759
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
huggingface-hub>=0.27.0,<0.28
2+
kubernetes>=27.2.0

pkg/initializers/dataset/__main__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from urllib.parse import urlparse
44

55
import pkg.initializers.utils.utils as utils
6+
from pkg.initializers.dataset.cache import CacheInitializer
67
from pkg.initializers.dataset.huggingface import HuggingFace
78

89
logging.basicConfig(
@@ -27,6 +28,10 @@ def main():
2728
hf = HuggingFace()
2829
hf.load_config()
2930
hf.download_dataset()
31+
case utils.CACHE_SCHEME:
32+
cache = CacheInitializer()
33+
cache.load_config()
34+
cache.download_dataset()
3035
case _:
3136
logging.error("STORAGE_URI must have the valid dataset provider")
3237
raise Exception

pkg/initializers/dataset/cache.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
import logging
2+
import time
3+
from typing import Optional
4+
5+
from kubernetes import client, config
6+
from kubernetes.client.rest import ApiException
7+
8+
import pkg.initializers.types.types as types
9+
import pkg.initializers.utils.utils as utils
10+
11+
logging.basicConfig(
12+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
13+
datefmt="%Y-%m-%dT%H:%M:%SZ",
14+
level=logging.INFO,
15+
)
16+
17+
18+
class CacheInitializer(utils.DatasetProvider):
19+
20+
def load_config(self):
21+
config_dict = utils.get_config_from_env(types.CacheDatasetInitializer)
22+
self.config = types.CacheDatasetInitializer(**config_dict)
23+
24+
def download_dataset(self):
25+
logging.info(
26+
f"Cache initializer called with storage URI: {self.config.storage_uri}"
27+
)
28+
29+
train_job_name = self.config.train_job_name
30+
cache_image = self.config.cache_image
31+
cluster_size = int(self.config.cluster_size)
32+
iam_role = self.config.iam_role
33+
head_cpu = self.config.head_cpu
34+
head_mem = self.config.head_mem
35+
worker_cpu = self.config.worker_cpu
36+
worker_mem = self.config.worker_mem
37+
38+
# Create Kubernetes resources using client SDK
39+
create_cache_resources(
40+
train_job_name=train_job_name,
41+
iam_role=iam_role,
42+
cluster_size=cluster_size,
43+
cache_image=cache_image,
44+
head_cpu=head_cpu,
45+
head_mem=head_mem,
46+
worker_cpu=worker_cpu,
47+
worker_mem=worker_mem,
48+
namespace=self.config.namespace,
49+
metadata_loc=self.config.metadata_loc,
50+
table_name=self.config.table_name,
51+
schema_name=self.config.schema_name,
52+
)
53+
54+
logging.info("Cache dataset initialization completed")
55+
56+
57+
def create_cache_resources(
58+
train_job_name: str,
59+
iam_role: str,
60+
cluster_size: int,
61+
cache_image: str,
62+
head_cpu: str,
63+
head_mem: str,
64+
worker_cpu: str,
65+
worker_mem: str,
66+
namespace: str,
67+
metadata_loc: Optional[str] = None,
68+
table_name: Optional[str] = None,
69+
schema_name: Optional[str] = None,
70+
) -> bool:
71+
"""
72+
Creates Kubernetes resources for cache initializer using the client SDK.
73+
74+
Args:
75+
train_job_name: Name of the training job
76+
iam_role: IAM role ARN for the service account
77+
cluster_size: Number of workers in the cluster
78+
cache_image: Container image to use
79+
head_cpu: CPU limit/request for head node
80+
head_mem: Memory limit/request for head node
81+
worker_cpu: CPU limit/request for worker nodes
82+
worker_mem: Memory limit/request for worker nodes
83+
metadata_loc: Optional metadata location
84+
table_name: Optional table name
85+
schema_name: Optional schema name
86+
namespace: Target Kubernetes namespace
87+
88+
Returns:
89+
bool: True if deployment succeeded
90+
"""
91+
# Load Kubernetes configuration
92+
config.load_incluster_config()
93+
94+
api_client = client.ApiClient()
95+
core_v1 = client.CoreV1Api(api_client)
96+
custom_api = client.CustomObjectsApi(api_client)
97+
98+
# Get TrainingJob for owner reference
99+
try:
100+
training_job = custom_api.get_namespaced_custom_object(
101+
group="trainer.kubeflow.org",
102+
version="v1alpha1",
103+
plural="trainjobs",
104+
namespace=namespace,
105+
name=train_job_name,
106+
)
107+
logging.info(f"TrainJob: {training_job}")
108+
109+
# Create owner reference
110+
owner_ref = {
111+
"apiVersion": training_job["apiVersion"],
112+
"kind": training_job["kind"],
113+
"name": training_job["metadata"]["name"],
114+
"uid": training_job["metadata"]["uid"],
115+
"controller": True,
116+
"blockOwnerDeletion": True,
117+
}
118+
except ApiException as e:
119+
logging.error(f"Failed to get TrainingJob {train_job_name}: {e}")
120+
return False
121+
122+
try:
123+
# Create ServiceAccount
124+
service_account = client.V1ServiceAccount(
125+
metadata=client.V1ObjectMeta(
126+
name=f"{train_job_name}-sa",
127+
namespace=namespace,
128+
annotations={
129+
"eks.amazonaws.com/sts-regional-endpoints": "true",
130+
"eks.amazonaws.com/role-arn": iam_role,
131+
},
132+
owner_references=[owner_ref],
133+
)
134+
)
135+
136+
try:
137+
core_v1.create_namespaced_service_account(
138+
namespace=namespace, body=service_account
139+
)
140+
logging.info(f"Created ServiceAccount {train_job_name}-sa")
141+
except ApiException as e:
142+
if e.status == 409:
143+
logging.info(
144+
f"ServiceAccount {train_job_name}-sa already exists, skipping creation"
145+
)
146+
else:
147+
raise
148+
149+
# Prepare environment variables
150+
env_vars = []
151+
if metadata_loc:
152+
env_vars.append({"name": "METADATA_LOC", "value": metadata_loc})
153+
if table_name:
154+
env_vars.append({"name": "TABLE_NAME", "value": table_name})
155+
if schema_name:
156+
env_vars.append({"name": "SCHEMA_NAME", "value": schema_name})
157+
158+
# Create LeaderWorkerSet
159+
lws_body = {
160+
"apiVersion": "leaderworkerset.x-k8s.io/v1",
161+
"kind": "LeaderWorkerSet",
162+
"metadata": {
163+
"name": f"{train_job_name}-cache",
164+
"namespace": namespace,
165+
"ownerReferences": [owner_ref],
166+
},
167+
"spec": {
168+
"replicas": 1,
169+
"leaderWorkerTemplate": {
170+
"size": cluster_size,
171+
"leaderTemplate": {
172+
"metadata": {"labels": {"app": f"{train_job_name}-cache-head"}},
173+
"spec": {
174+
"serviceAccountName": f"{train_job_name}-sa",
175+
"containers": [
176+
{
177+
"name": "head",
178+
"image": cache_image,
179+
"command": ["head"],
180+
"args": ["0.0.0.0", "50051"],
181+
"resources": {
182+
"limits": {"cpu": head_cpu, "memory": head_mem},
183+
"requests": {
184+
"cpu": head_cpu,
185+
"memory": head_mem,
186+
},
187+
},
188+
"env": env_vars,
189+
"ports": [{"containerPort": 50051}],
190+
}
191+
],
192+
},
193+
},
194+
"workerTemplate": {
195+
"spec": {
196+
"serviceAccountName": f"{train_job_name}-sa",
197+
"containers": [
198+
{
199+
"name": "worker",
200+
"image": cache_image,
201+
"command": ["worker"],
202+
"args": ["0.0.0.0", "50051"],
203+
"resources": {
204+
"limits": {
205+
"cpu": worker_cpu,
206+
"memory": worker_mem,
207+
},
208+
"requests": {
209+
"cpu": worker_cpu,
210+
"memory": worker_mem,
211+
},
212+
},
213+
"env": env_vars,
214+
"ports": [{"containerPort": 50051}],
215+
}
216+
],
217+
}
218+
},
219+
},
220+
},
221+
}
222+
223+
# Create LeaderWorkerSet
224+
custom_api.create_namespaced_custom_object(
225+
group="leaderworkerset.x-k8s.io",
226+
version="v1",
227+
namespace=namespace,
228+
plural="leaderworkersets",
229+
body=lws_body,
230+
)
231+
logging.info(f"Created LeaderWorkerSet {train_job_name}-cache")
232+
233+
# Create Service
234+
service = client.V1Service(
235+
metadata=client.V1ObjectMeta(
236+
name=f"{train_job_name}-cache-service",
237+
namespace=namespace,
238+
owner_references=[owner_ref],
239+
),
240+
spec=client.V1ServiceSpec(
241+
selector={"app": f"{train_job_name}-cache-head"},
242+
ports=[
243+
client.V1ServicePort(protocol="TCP", port=50051, target_port=50051)
244+
],
245+
),
246+
)
247+
248+
try:
249+
core_v1.create_namespaced_service(namespace=namespace, body=service)
250+
logging.info(f"Created Service {train_job_name}-cache-service")
251+
except ApiException as e:
252+
if e.status == 409:
253+
logging.info(
254+
f"Service {train_job_name}-cache-service already exists, skipping creation"
255+
)
256+
else:
257+
raise
258+
259+
# Wait for LeaderWorkerSet to become ready
260+
lws_name = f"{train_job_name}-cache"
261+
262+
while True:
263+
try:
264+
lws = custom_api.get_namespaced_custom_object(
265+
group="leaderworkerset.x-k8s.io",
266+
version="v1",
267+
plural="leaderworkersets",
268+
name=lws_name,
269+
namespace=namespace,
270+
)
271+
272+
conditions = lws.get("status", {}).get("conditions", [])
273+
if any(
274+
c["type"] == "Available" and c["status"] == "True"
275+
for c in conditions
276+
):
277+
logging.info(f"LeaderWorkerSet {lws_name} is ready")
278+
break
279+
280+
time.sleep(5)
281+
except ApiException:
282+
time.sleep(2)
283+
284+
return True
285+
286+
except ApiException as e:
287+
logging.error(f"Deployment failed: {e}")
288+
# Cleanup on failure
289+
try:
290+
core_v1.delete_namespaced_service_account(
291+
name=f"{train_job_name}-sa", namespace=namespace
292+
)
293+
except Exception as cleanup_error:
294+
logging.error(f"Error cleaning up ServiceAccount: {cleanup_error}")
295+
return False

0 commit comments

Comments
 (0)