-
Notifications
You must be signed in to change notification settings - Fork 1
/
docker_in_worker.py
322 lines (263 loc) · 11.9 KB
/
docker_in_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import tempfile
import uuid
from airflow.decorators import dag, task
from airflow.models import Variable
from airflow.models.connection import Connection
from airflow.operators.python import get_current_context
from airflow.providers.ssh.operators.ssh import SSHOperator
import pendulum
from datacat_integration.connection import DataCatalogEntry
from datacat_integration.hooks import DataCatalogHook
import docker_cmd as doc
from b2shareoperator import (
add_file,
create_draft_record,
download_file,
get_file_list,
get_object_md,
get_record_template,
submit_draft,
)
from decors import get_connection
from docker_cmd import WORKER_DATA_LOCATION
"""This piplines is a test case for starting a clusterting algorithm with HeAT, running in a Docker environment.
A test set of parameters with a HeAT example:
Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image": "ghcr.io/helmholtz-analytics/heat:1.1.1-alpha", "entrypoint": "/bin/bash", "command": "python demo_knn.py iris.h5 calc_res.txt", "register":"True"}
Data Catalog Integration example: {"oid": "e13bcab6-3664-4090-bebb-defdb58483e0", "image":"hello-world", "register":"True"}
Params:
oid (str): oid of the data (e.g, from data catalog)
image (str): a docker contianer image
job_args (str):
Optional: a string of further arguments which might be needed for the task execution
entrypoint (str):
Optional: you can specify or overwrite the docker entrypoint
command (str):
Optional: you can specify or override the command to be executed
args_to_dockerrun (str):
Optional: docker run additional arguments
register (True, False):
Optional, default is False: register the resulsts in the data catalog
"""
default_args = {
"owner": "airflow",
}
@dag(
default_args=default_args,
schedule=None,
start_date=pendulum.today('UTC'),
tags=["example", "docker", "datacat"],
)
def docker_in_worker():
DW_CONNECTION_ID = "docker_worker"
@task()
def stagein(**kwargs):
""" stage in task
This task gets the 'datacat_oid' or 'oid' from the DAG params to retreive a connection from it (b2share for now).
It then downloads all data from the b2share entry to the local disk, and returns a mapping of these files to the local download location,
which can be used by the following tasks.
"""
params = kwargs["params"]
datacat_hook = DataCatalogHook()
if "oid" not in params: # {"oid": "b143bf73efd24d149bba4c081964b459"}
if "datacat_oid" not in params:
print("Missing object id in pipeline parameters")
return -1 # non zero exit code is a task failure
params["oid"] = params["datacat_oid"]
oid_split = params["oid"].split("/")
datacat_type = "dataset"
oid = "placeholder_text"
if len(oid_split) == 2:
datacat_type = oid_split[0]
oid = oid_split[1]
elif len(oid_split) == 1:
oid = oid_split[0]
else:
print("Malformed oid passed as parameter.")
return -1
entry = DataCatalogEntry.from_json(datacat_hook.get_entry(datacat_type, oid))
print(f"using entry: {entry}")
b2share_server_uri = entry.url
# TODO general stage in based on type metadata
# using only b2share for now
b2share_oid = entry.metadata["b2share_oid"]
obj = get_object_md(server=b2share_server_uri, oid=b2share_oid)
print(f"Retrieved object {oid}: {obj}")
flist = get_file_list(obj)
name_mappings = {}
tmp_dir = Variable.get("working_dir", default_var="/tmp/")
print(f"Local working dir is: {tmp_dir}")
for fname, url in flist.items():
print(f"Processing: {fname} --> {url}")
tmpname = download_file(url=url, target_dir=tmp_dir)
name_mappings[fname] = tmpname
return name_mappings
@task()
def move_to_docker_host(files: dict, **kwargs):
"""This task copies the data onto the remote docker worker,
which will enable the following tasks an access to the data
Args:
files (dict): the files that will be stored on the docker worker
Returns:
target_dir: the location of the files on the docker worker
"""
print(f"Using {DW_CONNECTION_ID} connection")
ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
user_dir_name = str(uuid.uuid4())
target_dir = os.path.join(WORKER_DATA_LOCATION, user_dir_name)
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
sftp_client.mkdir(target_dir, mode=0o755)
for [truename, local] in files.items():
print(
f"Copying {local} --> {DW_CONNECTION_ID}:{os.path.join(target_dir, truename)}"
)
sftp_client.put(local, os.path.join(target_dir, truename))
# or separate cleanup task?
os.unlink(local)
return target_dir
@task
def run_container(data_location, **kwargs):
"""A task which runs in the docker worker and spins up a docker container with the an image and giver parameters.
Args:
image (str): a docker contianer image
job_args (str):
Optional: a string of further arguments which might be needed for the task execution
entrypoint (str):
Optional: you can specify or overwrite the docker entrypoint
command (str):
Optional: you can specify or override the command to be executed
args_to_dockerrun (str):
Optional: docker run additional arguments
"""
params = kwargs["params"]
cmd = doc.get_dockercmd(params, data_location)
print(f"Executing docker command {cmd}")
print(f"Using {DW_CONNECTION_ID} connection")
hook = get_connection(conn_id=DW_CONNECTION_ID)
task_calculate = SSHOperator(task_id="calculate", ssh_hook=hook, command=cmd)
context = get_current_context()
task_calculate.execute(context)
return data_location
@task
def ls_results(output_dir):
if not output_dir:
return "No output to stage out. Nothing more to do."
hook = get_connection(conn_id=DW_CONNECTION_ID)
cmd = f"ls -al {output_dir}"
process = SSHOperator(task_id="print_results", ssh_hook=hook, command=cmd)
context = get_current_context()
process.execute(context)
@task()
def retrieve_res(output_dir: str, input_files: dict, **kwargs):
"""This task copies the data from the remote docker worker back to airflow workspace
Args:
output_dir (str): the folder containing all the user files for the executed task, located on the docker worker
Returns:
local_fpath (list): the path of the files copied back to the airflow host
"""
working_dir = Variable.get("working_dir", default_var="/tmp/")
name_mappings = {}
print(f"Using {DW_CONNECTION_ID} connection")
ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
for fname in sftp_client.listdir(output_dir):
if fname not in input_files.keys():
tmpname = tempfile.mktemp(dir=working_dir)
local = os.path.join(working_dir, tmpname)
print(f"Copying {os.path.join(output_dir, fname)} to {local}")
sftp_client.get(os.path.join(output_dir, fname), local)
name_mappings[fname] = local
return name_mappings
@task()
def cleanup_doc_worker(res_fpaths_local, data_on_worker, **kwargs):
"""This task deletes all the files from the docker worker
Args:
res_fpaths_local: used only to define the order of tasks within the DAG, i.e. wait for previos task to complete before cleaning the worker space
data_on_worker (str): delete the folder with the user data from the docker worker
"""
print(f"Using {DW_CONNECTION_ID} connection")
ssh_hook = get_connection(conn_id=DW_CONNECTION_ID)
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
d = os.path.join(WORKER_DATA_LOCATION, data_on_worker)
for f in sftp_client.listdir(d):
print(f"Deleting file {f}")
sftp_client.remove(os.path.join(d, f))
print(f"Deleting directory {DW_CONNECTION_ID}:{d}")
sftp_client.rmdir(d)
@task
def stageout_results(output_mappings: dict):
"""This task transfers the output files to b2share
Args:
output_mappings (dict): {true_filename, local_path} a dictionary of the output files to be submitted to the remote storage, e.g., b2share
Returns:
a b2share record
"""
if not output_mappings:
print("No output to stage out. Nothing more to do.")
return -1
connection = Connection.get_connection_from_secrets("default_b2share")
server = "https://" + connection.host
token = ""
if "access_token" in connection.extra_dejson.keys():
token = connection.extra_dejson["access_token"]
print(f"Registering data to {server}")
template = get_record_template()
r = create_draft_record(server=server, token=token, record=template)
print(f"record {r}")
if "id" in r:
print(f"Draft record created {r['id']} --> {r['links']['self']}")
else:
print("Something went wrong with registration", r, r.text)
return -1
for [truename, local] in output_mappings.items():
print(f"Uploading {truename}")
_ = add_file(record=r, fname=local, token=token, remote=truename)
# delete local
os.unlink(local)
print("Submitting record for pubication")
submitted = submit_draft(record=r, token=token)
print(f"Record created {submitted}")
return submitted["links"]["publication"]
@task()
def register(object_url, additional_metadata={}, **kwargs):
"""This task registers the b2share record into the data catalog
Args:
object_url: from b2share
additional_metadata
"""
params = kwargs["params"]
reg = params.get("register", False)
if not reg:
print("Skipping registration as 'register' parameter is not set")
return 0
hook = DataCatalogHook()
print("Connected to datacat via hook")
if not additional_metadata.get("author", False):
additional_metadata["author"] = "DLS on behalft of eFlows"
if not additional_metadata.get("access", False):
additional_metadata["access"] = "hook-based"
entry = DataCatalogEntry(
name=f"DLS results {kwargs['run_id']}",
url=object_url,
metadata=additional_metadata,
)
try:
r = hook.create_entry(datacat_type="dataset", entry=entry)
print("Hook registration returned: ", r)
return f"{hook.base_url}/dataset/{r}"
except ConnectionError as e:
print("Registration failed", e)
return -1
input_files = stagein()
data_location = move_to_docker_host(input_files)
data_on_worker = run_container(data_location)
ls_results(data_on_worker)
res_fpaths = retrieve_res(data_on_worker, input_files)
cleanup_doc_worker(res_fpaths, data_on_worker)
url_or_errcode = stageout_results(res_fpaths)
register(url_or_errcode)
# files >> data_locations >> output_fnames >> ls_results(output_fnames) >> files >> stageout_results(files) >> cleanup()
dag = docker_in_worker()