21718: Replace .decode method with str(bytes, "utf-8")
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urllib.parse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13 from io import StringIO
14
15 import ruamel.yaml
16
17 import cwltool.stdfsaccess
18 from cwltool.pathmapper import abspath
19 import cwltool.resolver
20
21 import arvados.util
22 import arvados.collection
23 import arvados.arvfile
24 import arvados.errors
25
26 from googleapiclient.errors import HttpError
27
28 from schema_salad.ref_resolver import DefaultFetcher
29
30 logger = logging.getLogger('arvados.cwl-runner')
31
32 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
33
34 class CollectionCache(object):
35     def __init__(self, api_client, keep_client, num_retries,
36                  cap=256*1024*1024,
37                  min_entries=2):
38         self.api_client = api_client
39         self.keep_client = keep_client
40         self.num_retries = num_retries
41         self.collections = OrderedDict()
42         self.lock = threading.Lock()
43         self.total = 0
44         self.cap = cap
45         self.min_entries = min_entries
46
47     def set_cap(self, cap):
48         self.cap = cap
49
50     def cap_cache(self, required):
51         # ordered dict iterates from oldest to newest
52         for pdh, v in list(self.collections.items()):
53             available = self.cap - self.total
54             if available >= required or len(self.collections) < self.min_entries:
55                 return
56             # cut it loose
57             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
58             del self.collections[pdh]
59             self.total -= v[1]
60
61     def get(self, locator):
62         with self.lock:
63             if locator not in self.collections:
64                 m = pdh_size.match(locator)
65                 if m:
66                     self.cap_cache(int(m.group(2)) * 128)
67                 logger.debug("Creating collection reader for %s", locator)
68                 try:
69                     cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
70                                                              keep_client=self.keep_client,
71                                                              num_retries=self.num_retries)
72                 except arvados.errors.ApiError as ap:
73                     raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
74                 sz = len(cr.manifest_text()) * 128
75                 self.collections[locator] = (cr, sz)
76                 self.total += sz
77             else:
78                 cr, sz = self.collections[locator]
79                 # bump it to the back
80                 del self.collections[locator]
81                 self.collections[locator] = (cr, sz)
82             return cr
83
84
85 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
86     """Implement the cwltool FsAccess interface for Arvados Collections."""
87
88     def __init__(self, basedir, collection_cache=None):
89         super(CollectionFsAccess, self).__init__(basedir)
90         self.collection_cache = collection_cache
91
92     def get_collection(self, path):
93         sp = path.split("/", 1)
94         p = sp[0]
95         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
96                                       arvados.util.collection_uuid_pattern.match(p[5:])):
97             locator = p[5:]
98             rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
99             return (self.collection_cache.get(locator), rest)
100         else:
101             return (None, path)
102
103     def _match(self, collection, patternsegments, parent):
104         if not patternsegments:
105             return []
106
107         if not isinstance(collection, arvados.collection.RichCollectionBase):
108             return []
109
110         ret = []
111         # iterate over the files and subcollections in 'collection'
112         for filename in collection:
113             if patternsegments[0] == '.':
114                 # Pattern contains something like "./foo" so just shift
115                 # past the "./"
116                 ret.extend(self._match(collection, patternsegments[1:], parent))
117             elif fnmatch.fnmatch(filename, patternsegments[0]):
118                 cur = os.path.join(parent, filename)
119                 if len(patternsegments) == 1:
120                     ret.append(cur)
121                 else:
122                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
123         return ret
124
125     def glob(self, pattern):
126         collection, rest = self.get_collection(pattern)
127         if collection is not None and rest in (None, "", "."):
128             return [pattern]
129         patternsegments = rest.split("/")
130         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
131
132     def open(self, fn, mode, encoding=None):
133         collection, rest = self.get_collection(fn)
134         if collection is not None:
135             return collection.open(rest, mode, encoding=encoding)
136         else:
137             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
138
139     def exists(self, fn):
140         try:
141             collection, rest = self.get_collection(fn)
142         except HttpError as err:
143             if err.resp.status == 404:
144                 return False
145             else:
146                 raise
147         except IOError as err:
148             if err.errno == errno.ENOENT:
149                 return False
150             else:
151                 raise
152         if collection is not None:
153             if rest:
154                 return collection.exists(rest)
155             else:
156                 return True
157         else:
158             return super(CollectionFsAccess, self).exists(fn)
159
160     def size(self, fn):  # type: (unicode) -> bool
161         collection, rest = self.get_collection(fn)
162         if collection is not None:
163             if rest:
164                 arvfile = collection.find(rest)
165                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
166                     return arvfile.size()
167             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
168         else:
169             return super(CollectionFsAccess, self).size(fn)
170
171     def isfile(self, fn):  # type: (unicode) -> bool
172         collection, rest = self.get_collection(fn)
173         if collection is not None:
174             if rest:
175                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
176             else:
177                 return False
178         else:
179             return super(CollectionFsAccess, self).isfile(fn)
180
181     def isdir(self, fn):  # type: (unicode) -> bool
182         collection, rest = self.get_collection(fn)
183         if collection is not None:
184             if rest:
185                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
186             else:
187                 return True
188         else:
189             return super(CollectionFsAccess, self).isdir(fn)
190
191     def listdir(self, fn):  # type: (unicode) -> List[unicode]
192         collection, rest = self.get_collection(fn)
193         if collection is not None:
194             if rest:
195                 dir = collection.find(rest)
196             else:
197                 dir = collection
198             if dir is None:
199                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
200             if not isinstance(dir, arvados.collection.RichCollectionBase):
201                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
202             return [abspath(l, fn) for l in list(dir.keys())]
203         else:
204             return super(CollectionFsAccess, self).listdir(fn)
205
206     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
207         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
208             return paths[-1]
209         return os.path.join(path, *paths)
210
211     def realpath(self, path):
212         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
213             return path
214         collection, rest = self.get_collection(path)
215         if collection is not None:
216             return path
217         else:
218             return os.path.realpath(path)
219
220 class CollectionFetcher(DefaultFetcher):
221     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
222         super(CollectionFetcher, self).__init__(cache, session)
223         self.api_client = api_client
224         self.fsaccess = fs_access
225         self.num_retries = num_retries
226
227     def fetch_text(self, url, content_types=None):
228         if url.startswith("keep:"):
229             with self.fsaccess.open(url, "r", encoding="utf-8") as f:
230                 return f.read()
231         if url.startswith("arvwf:"):
232             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
233             yaml = ruamel.yaml.YAML(typ='rt', pure=True)
234             definition = yaml.load(record["definition"])
235             definition["label"] = record["name"]
236             stream = StringIO()
237             yaml.dump(definition, stream)
238             return stream.getvalue()
239         return super(CollectionFetcher, self).fetch_text(url)
240
241     def check_exists(self, url):
242         try:
243             if url.startswith("http://arvados.org/cwl"):
244                 return True
245             urld, _ = urllib.parse.urldefrag(url)
246             if urld.startswith("keep:"):
247                 return self.fsaccess.exists(urld)
248             if urld.startswith("arvwf:"):
249                 if self.fetch_text(urld):
250                     return True
251         except arvados.errors.NotFoundError:
252             return False
253         except Exception:
254             logger.exception("Got unexpected exception checking if file exists")
255             return False
256         return super(CollectionFetcher, self).check_exists(url)
257
258     def urljoin(self, base_url, url):
259         if not url:
260             return base_url
261
262         urlsp = urllib.parse.urlsplit(url)
263         if urlsp.scheme or not base_url:
264             return url
265
266         basesp = urllib.parse.urlsplit(base_url)
267         if basesp.scheme in ("keep", "arvwf"):
268             if not basesp.path:
269                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
270
271             baseparts = basesp.path.split("/")
272             urlparts = urlsp.path.split("/") if urlsp.path else []
273
274             locator = baseparts.pop(0)
275
276             if (basesp.scheme == "keep" and
277                 (not arvados.util.keep_locator_pattern.match(locator)) and
278                 (not arvados.util.collection_uuid_pattern.match(locator))):
279                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
280
281             if urlsp.path.startswith("/"):
282                 baseparts = []
283                 urlparts.pop(0)
284
285             if baseparts and urlsp.path:
286                 baseparts.pop()
287
288             path = "/".join([locator] + baseparts + urlparts)
289             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
290
291         return super(CollectionFetcher, self).urljoin(base_url, url)
292
293     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
294
295     def supported_schemes(self):  # type: () -> List[Text]
296         return self.schemes
297
298
299 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
300 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
301
302 def collectionResolver(api_client, document_loader, uri, num_retries=4):
303     if uri.startswith("keep:") or uri.startswith("arvwf:"):
304         return str(uri)
305
306     if workflow_uuid_pattern.match(uri):
307         return u"arvwf:%s#main" % (uri)
308
309     if pipeline_template_uuid_pattern.match(uri):
310         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
311         return u"keep:" + next(pt["components"].values())["script_parameters"]["cwl:tool"]
312
313     p = uri.split("/")
314     if arvados.util.keep_locator_pattern.match(p[0]):
315         return u"keep:%s" % (uri)
316
317     if arvados.util.collection_uuid_pattern.match(p[0]):
318         return u"keep:%s%s" % (api_client.collections().
319                               get(uuid=p[0]).execute()["portable_data_hash"],
320                               uri[len(p[0]):])
321
322     return cwltool.resolver.tool_resolver(document_loader, uri)