20257: Fix use of dataclass
[arvados.git] / sdk / python / arvados / util.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import division
6 from builtins import range
7
8 import fcntl
9 import hashlib
10 import httplib2
11 import os
12 import random
13 import re
14 import subprocess
15 import errno
16 import sys
17
18 import arvados
19 from arvados.collection import CollectionReader
20
21 HEX_RE = re.compile(r'^[0-9a-fA-F]+$')
22 CR_UNCOMMITTED = 'Uncommitted'
23 CR_COMMITTED = 'Committed'
24 CR_FINAL = 'Final'
25
26 keep_locator_pattern = re.compile(r'[0-9a-f]{32}\+\d+(\+\S+)*')
27 signed_locator_pattern = re.compile(r'[0-9a-f]{32}\+\d+(\+\S+)*\+A\S+(\+\S+)*')
28 portable_data_hash_pattern = re.compile(r'[0-9a-f]{32}\+\d+')
29 uuid_pattern = re.compile(r'[a-z0-9]{5}-[a-z0-9]{5}-[a-z0-9]{15}')
30 collection_uuid_pattern = re.compile(r'[a-z0-9]{5}-4zz18-[a-z0-9]{15}')
31 group_uuid_pattern = re.compile(r'[a-z0-9]{5}-j7d0g-[a-z0-9]{15}')
32 user_uuid_pattern = re.compile(r'[a-z0-9]{5}-tpzed-[a-z0-9]{15}')
33 link_uuid_pattern = re.compile(r'[a-z0-9]{5}-o0j2j-[a-z0-9]{15}')
34 job_uuid_pattern = re.compile(r'[a-z0-9]{5}-8i9sb-[a-z0-9]{15}')
35 container_uuid_pattern = re.compile(r'[a-z0-9]{5}-dz642-[a-z0-9]{15}')
36 manifest_pattern = re.compile(r'((\S+)( +[a-f0-9]{32}(\+\d+)(\+\S+)*)+( +\d+:\d+:\S+)+$)+', flags=re.MULTILINE)
37
38 def clear_tmpdir(path=None):
39     """
40     Ensure the given directory (or TASK_TMPDIR if none given)
41     exists and is empty.
42     """
43     if path is None:
44         path = arvados.current_task().tmpdir
45     if os.path.exists(path):
46         p = subprocess.Popen(['rm', '-rf', path])
47         stdout, stderr = p.communicate(None)
48         if p.returncode != 0:
49             raise Exception('rm -rf %s: %s' % (path, stderr))
50     os.mkdir(path)
51
52 def run_command(execargs, **kwargs):
53     kwargs.setdefault('stdin', subprocess.PIPE)
54     kwargs.setdefault('stdout', subprocess.PIPE)
55     kwargs.setdefault('stderr', sys.stderr)
56     kwargs.setdefault('close_fds', True)
57     kwargs.setdefault('shell', False)
58     p = subprocess.Popen(execargs, **kwargs)
59     stdoutdata, stderrdata = p.communicate(None)
60     if p.returncode != 0:
61         raise arvados.errors.CommandFailedError(
62             "run_command %s exit %d:\n%s" %
63             (execargs, p.returncode, stderrdata))
64     return stdoutdata, stderrdata
65
66 def git_checkout(url, version, path):
67     if not re.search('^/', path):
68         path = os.path.join(arvados.current_job().tmpdir, path)
69     if not os.path.exists(path):
70         run_command(["git", "clone", url, path],
71                     cwd=os.path.dirname(path))
72     run_command(["git", "checkout", version],
73                 cwd=path)
74     return path
75
76 def tar_extractor(path, decompress_flag):
77     return subprocess.Popen(["tar",
78                              "-C", path,
79                              ("-x%sf" % decompress_flag),
80                              "-"],
81                             stdout=None,
82                             stdin=subprocess.PIPE, stderr=sys.stderr,
83                             shell=False, close_fds=True)
84
85 def tarball_extract(tarball, path):
86     """Retrieve a tarball from Keep and extract it to a local
87     directory.  Return the absolute path where the tarball was
88     extracted. If the top level of the tarball contained just one
89     file or directory, return the absolute path of that single
90     item.
91
92     tarball -- collection locator
93     path -- where to extract the tarball: absolute, or relative to job tmp
94     """
95     if not re.search('^/', path):
96         path = os.path.join(arvados.current_job().tmpdir, path)
97     lockfile = open(path + '.lock', 'w')
98     fcntl.flock(lockfile, fcntl.LOCK_EX)
99     try:
100         os.stat(path)
101     except OSError:
102         os.mkdir(path)
103     already_have_it = False
104     try:
105         if os.readlink(os.path.join(path, '.locator')) == tarball:
106             already_have_it = True
107     except OSError:
108         pass
109     if not already_have_it:
110
111         # emulate "rm -f" (i.e., if the file does not exist, we win)
112         try:
113             os.unlink(os.path.join(path, '.locator'))
114         except OSError:
115             if os.path.exists(os.path.join(path, '.locator')):
116                 os.unlink(os.path.join(path, '.locator'))
117
118         for f in CollectionReader(tarball).all_files():
119             if re.search('\.(tbz|tar.bz2)$', f.name()):
120                 p = tar_extractor(path, 'j')
121             elif re.search('\.(tgz|tar.gz)$', f.name()):
122                 p = tar_extractor(path, 'z')
123             elif re.search('\.tar$', f.name()):
124                 p = tar_extractor(path, '')
125             else:
126                 raise arvados.errors.AssertionError(
127                     "tarball_extract cannot handle filename %s" % f.name())
128             while True:
129                 buf = f.read(2**20)
130                 if len(buf) == 0:
131                     break
132                 p.stdin.write(buf)
133             p.stdin.close()
134             p.wait()
135             if p.returncode != 0:
136                 lockfile.close()
137                 raise arvados.errors.CommandFailedError(
138                     "tar exited %d" % p.returncode)
139         os.symlink(tarball, os.path.join(path, '.locator'))
140     tld_extracts = [f for f in os.listdir(path) if f != '.locator']
141     lockfile.close()
142     if len(tld_extracts) == 1:
143         return os.path.join(path, tld_extracts[0])
144     return path
145
146 def zipball_extract(zipball, path):
147     """Retrieve a zip archive from Keep and extract it to a local
148     directory.  Return the absolute path where the archive was
149     extracted. If the top level of the archive contained just one
150     file or directory, return the absolute path of that single
151     item.
152
153     zipball -- collection locator
154     path -- where to extract the archive: absolute, or relative to job tmp
155     """
156     if not re.search('^/', path):
157         path = os.path.join(arvados.current_job().tmpdir, path)
158     lockfile = open(path + '.lock', 'w')
159     fcntl.flock(lockfile, fcntl.LOCK_EX)
160     try:
161         os.stat(path)
162     except OSError:
163         os.mkdir(path)
164     already_have_it = False
165     try:
166         if os.readlink(os.path.join(path, '.locator')) == zipball:
167             already_have_it = True
168     except OSError:
169         pass
170     if not already_have_it:
171
172         # emulate "rm -f" (i.e., if the file does not exist, we win)
173         try:
174             os.unlink(os.path.join(path, '.locator'))
175         except OSError:
176             if os.path.exists(os.path.join(path, '.locator')):
177                 os.unlink(os.path.join(path, '.locator'))
178
179         for f in CollectionReader(zipball).all_files():
180             if not re.search('\.zip$', f.name()):
181                 raise arvados.errors.NotImplementedError(
182                     "zipball_extract cannot handle filename %s" % f.name())
183             zip_filename = os.path.join(path, os.path.basename(f.name()))
184             zip_file = open(zip_filename, 'wb')
185             while True:
186                 buf = f.read(2**20)
187                 if len(buf) == 0:
188                     break
189                 zip_file.write(buf)
190             zip_file.close()
191
192             p = subprocess.Popen(["unzip",
193                                   "-q", "-o",
194                                   "-d", path,
195                                   zip_filename],
196                                  stdout=None,
197                                  stdin=None, stderr=sys.stderr,
198                                  shell=False, close_fds=True)
199             p.wait()
200             if p.returncode != 0:
201                 lockfile.close()
202                 raise arvados.errors.CommandFailedError(
203                     "unzip exited %d" % p.returncode)
204             os.unlink(zip_filename)
205         os.symlink(zipball, os.path.join(path, '.locator'))
206     tld_extracts = [f for f in os.listdir(path) if f != '.locator']
207     lockfile.close()
208     if len(tld_extracts) == 1:
209         return os.path.join(path, tld_extracts[0])
210     return path
211
212 def collection_extract(collection, path, files=[], decompress=True):
213     """Retrieve a collection from Keep and extract it to a local
214     directory.  Return the absolute path where the collection was
215     extracted.
216
217     collection -- collection locator
218     path -- where to extract: absolute, or relative to job tmp
219     """
220     matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
221     if matches:
222         collection_hash = matches.group(1)
223     else:
224         collection_hash = hashlib.md5(collection).hexdigest()
225     if not re.search('^/', path):
226         path = os.path.join(arvados.current_job().tmpdir, path)
227     lockfile = open(path + '.lock', 'w')
228     fcntl.flock(lockfile, fcntl.LOCK_EX)
229     try:
230         os.stat(path)
231     except OSError:
232         os.mkdir(path)
233     already_have_it = False
234     try:
235         if os.readlink(os.path.join(path, '.locator')) == collection_hash:
236             already_have_it = True
237     except OSError:
238         pass
239
240     # emulate "rm -f" (i.e., if the file does not exist, we win)
241     try:
242         os.unlink(os.path.join(path, '.locator'))
243     except OSError:
244         if os.path.exists(os.path.join(path, '.locator')):
245             os.unlink(os.path.join(path, '.locator'))
246
247     files_got = []
248     for s in CollectionReader(collection).all_streams():
249         stream_name = s.name()
250         for f in s.all_files():
251             if (files == [] or
252                 ((f.name() not in files_got) and
253                  (f.name() in files or
254                   (decompress and f.decompressed_name() in files)))):
255                 outname = f.decompressed_name() if decompress else f.name()
256                 files_got += [outname]
257                 if os.path.exists(os.path.join(path, stream_name, outname)):
258                     continue
259                 mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
260                 outfile = open(os.path.join(path, stream_name, outname), 'wb')
261                 for buf in (f.readall_decompressed() if decompress
262                             else f.readall()):
263                     outfile.write(buf)
264                 outfile.close()
265     if len(files_got) < len(files):
266         raise arvados.errors.AssertionError(
267             "Wanted files %s but only got %s from %s" %
268             (files, files_got,
269              [z.name() for z in CollectionReader(collection).all_files()]))
270     os.symlink(collection_hash, os.path.join(path, '.locator'))
271
272     lockfile.close()
273     return path
274
275 def mkdir_dash_p(path):
276     if not os.path.isdir(path):
277         try:
278             os.makedirs(path)
279         except OSError as e:
280             if e.errno == errno.EEXIST and os.path.isdir(path):
281                 # It is not an error if someone else creates the
282                 # directory between our exists() and makedirs() calls.
283                 pass
284             else:
285                 raise
286
287 def stream_extract(stream, path, files=[], decompress=True):
288     """Retrieve a stream from Keep and extract it to a local
289     directory.  Return the absolute path where the stream was
290     extracted.
291
292     stream -- StreamReader object
293     path -- where to extract: absolute, or relative to job tmp
294     """
295     if not re.search('^/', path):
296         path = os.path.join(arvados.current_job().tmpdir, path)
297     lockfile = open(path + '.lock', 'w')
298     fcntl.flock(lockfile, fcntl.LOCK_EX)
299     try:
300         os.stat(path)
301     except OSError:
302         os.mkdir(path)
303
304     files_got = []
305     for f in stream.all_files():
306         if (files == [] or
307             ((f.name() not in files_got) and
308              (f.name() in files or
309               (decompress and f.decompressed_name() in files)))):
310             outname = f.decompressed_name() if decompress else f.name()
311             files_got += [outname]
312             if os.path.exists(os.path.join(path, outname)):
313                 os.unlink(os.path.join(path, outname))
314             mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
315             outfile = open(os.path.join(path, outname), 'wb')
316             for buf in (f.readall_decompressed() if decompress
317                         else f.readall()):
318                 outfile.write(buf)
319             outfile.close()
320     if len(files_got) < len(files):
321         raise arvados.errors.AssertionError(
322             "Wanted files %s but only got %s from %s" %
323             (files, files_got, [z.name() for z in stream.all_files()]))
324     lockfile.close()
325     return path
326
327 def listdir_recursive(dirname, base=None, max_depth=None):
328     """listdir_recursive(dirname, base, max_depth)
329
330     Return a list of file and directory names found under dirname.
331
332     If base is not None, prepend "{base}/" to each returned name.
333
334     If max_depth is None, descend into directories and return only the
335     names of files found in the directory tree.
336
337     If max_depth is a non-negative integer, stop descending into
338     directories at the given depth, and at that point return directory
339     names instead.
340
341     If max_depth==0 (and base is None) this is equivalent to
342     sorted(os.listdir(dirname)).
343     """
344     allfiles = []
345     for ent in sorted(os.listdir(dirname)):
346         ent_path = os.path.join(dirname, ent)
347         ent_base = os.path.join(base, ent) if base else ent
348         if os.path.isdir(ent_path) and max_depth != 0:
349             allfiles += listdir_recursive(
350                 ent_path, base=ent_base,
351                 max_depth=(max_depth-1 if max_depth else None))
352         else:
353             allfiles += [ent_base]
354     return allfiles
355
356 def is_hex(s, *length_args):
357     """is_hex(s[, length[, max_length]]) -> boolean
358
359     Return True if s is a string of hexadecimal digits.
360     If one length argument is given, the string must contain exactly
361     that number of digits.
362     If two length arguments are given, the string must contain a number of
363     digits between those two lengths, inclusive.
364     Return False otherwise.
365     """
366     num_length_args = len(length_args)
367     if num_length_args > 2:
368         raise arvados.errors.ArgumentError(
369             "is_hex accepts up to 3 arguments ({} given)".format(1 + num_length_args))
370     elif num_length_args == 2:
371         good_len = (length_args[0] <= len(s) <= length_args[1])
372     elif num_length_args == 1:
373         good_len = (len(s) == length_args[0])
374     else:
375         good_len = True
376     return bool(good_len and HEX_RE.match(s))
377
378 def list_all(fn, num_retries=0, **kwargs):
379     # Default limit to (effectively) api server's MAX_LIMIT
380     kwargs.setdefault('limit', sys.maxsize)
381     items = []
382     offset = 0
383     items_available = sys.maxsize
384     while len(items) < items_available:
385         c = fn(offset=offset, **kwargs).execute(num_retries=num_retries)
386         items += c['items']
387         items_available = c['items_available']
388         offset = c['offset'] + len(c['items'])
389     return items
390
391 def keyset_list_all(fn, order_key="created_at", num_retries=0, ascending=True, **kwargs):
392     pagesize = 1000
393     kwargs["limit"] = pagesize
394     kwargs["count"] = 'none'
395     asc = "asc" if ascending else "desc"
396     kwargs["order"] = ["%s %s" % (order_key, asc), "uuid %s" % asc]
397     other_filters = kwargs.get("filters", [])
398
399     if "select" in kwargs and "uuid" not in kwargs["select"]:
400         kwargs["select"].append("uuid")
401
402     nextpage = []
403     tot = 0
404     expect_full_page = True
405     seen_prevpage = set()
406     seen_thispage = set()
407     lastitem = None
408     prev_page_all_same_order_key = False
409
410     while True:
411         kwargs["filters"] = nextpage+other_filters
412         items = fn(**kwargs).execute(num_retries=num_retries)
413
414         if len(items["items"]) == 0:
415             if prev_page_all_same_order_key:
416                 nextpage = [[order_key, ">" if ascending else "<", lastitem[order_key]]]
417                 prev_page_all_same_order_key = False
418                 continue
419             else:
420                 return
421
422         seen_prevpage = seen_thispage
423         seen_thispage = set()
424
425         for i in items["items"]:
426             # In cases where there's more than one record with the
427             # same order key, the result could include records we
428             # already saw in the last page.  Skip them.
429             if i["uuid"] in seen_prevpage:
430                 continue
431             seen_thispage.add(i["uuid"])
432             yield i
433
434         firstitem = items["items"][0]
435         lastitem = items["items"][-1]
436
437         if firstitem[order_key] == lastitem[order_key]:
438             # Got a page where every item has the same order key.
439             # Switch to using uuid for paging.
440             nextpage = [[order_key, "=", lastitem[order_key]], ["uuid", ">" if ascending else "<", lastitem["uuid"]]]
441             prev_page_all_same_order_key = True
442         else:
443             # Start from the last order key seen, but skip the last
444             # known uuid to avoid retrieving the same row twice.  If
445             # there are multiple rows with the same order key it is
446             # still likely we'll end up retrieving duplicate rows.
447             # That's handled by tracking the "seen" rows for each page
448             # so they can be skipped if they show up on the next page.
449             nextpage = [[order_key, ">=" if ascending else "<=", lastitem[order_key]], ["uuid", "!=", lastitem["uuid"]]]
450             prev_page_all_same_order_key = False
451
452
453 def ca_certs_path(fallback=httplib2.CA_CERTS):
454     """Return the path of the best available CA certs source.
455
456     This function searches for various distribution sources of CA
457     certificates, and returns the first it finds.  If it doesn't find any,
458     it returns the value of `fallback` (httplib2's CA certs by default).
459     """
460     for ca_certs_path in [
461         # SSL_CERT_FILE and SSL_CERT_DIR are openssl overrides - note
462         # that httplib2 itself also supports HTTPLIB2_CA_CERTS.
463         os.environ.get('SSL_CERT_FILE'),
464         # Arvados specific:
465         '/etc/arvados/ca-certificates.crt',
466         # Debian:
467         '/etc/ssl/certs/ca-certificates.crt',
468         # Red Hat:
469         '/etc/pki/tls/certs/ca-bundle.crt',
470         ]:
471         if ca_certs_path and os.path.exists(ca_certs_path):
472             return ca_certs_path
473     return fallback
474
475 def new_request_id():
476     rid = "req-"
477     # 2**104 > 36**20 > 2**103
478     n = random.getrandbits(104)
479     for _ in range(20):
480         c = n % 36
481         if c < 10:
482             rid += chr(c+ord('0'))
483         else:
484             rid += chr(c+ord('a')-10)
485         n = n // 36
486     return rid
487
488 def get_config_once(svc):
489     if not svc._rootDesc.get('resources').get('configs', False):
490         # Old API server version, no config export endpoint
491         return {}
492     if not hasattr(svc, '_cached_config'):
493         svc._cached_config = svc.configs().get().execute()
494     return svc._cached_config
495
496 def get_vocabulary_once(svc):
497     if not svc._rootDesc.get('resources').get('vocabularies', False):
498         # Old API server version, no vocabulary export endpoint
499         return {}
500     if not hasattr(svc, '_cached_vocabulary'):
501         svc._cached_vocabulary = svc.vocabularies().get().execute()
502     return svc._cached_vocabulary
503
504 def trim_name(collectionname):
505     """
506     trim_name takes a record name (collection name, project name, etc)
507     and trims it to fit the 255 character name limit, with additional
508     space for the timestamp added by ensure_unique_name, by removing
509     excess characters from the middle and inserting an ellipse
510     """
511
512     max_name_len = 254 - 28
513
514     if len(collectionname) > max_name_len:
515         over = len(collectionname) - max_name_len
516         split = int(max_name_len/2)
517         collectionname = collectionname[0:split] + "…" + collectionname[split+over:]
518
519     return collectionname