19688: Make registered workflows lightweight wrappers
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19
20 import ruamel.yaml as yaml
21
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
25
26 import arvados.util
27 import arvados.collection
28 import arvados.arvfile
29 import arvados.errors
30
31 from googleapiclient.errors import HttpError
32
33 from schema_salad.ref_resolver import DefaultFetcher
34
35 logger = logging.getLogger('arvados.cwl-runner')
36
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38
39 class CollectionCache(object):
40     def __init__(self, api_client, keep_client, num_retries,
41                  cap=256*1024*1024,
42                  min_entries=2):
43         self.api_client = api_client
44         self.keep_client = keep_client
45         self.num_retries = num_retries
46         self.collections = OrderedDict()
47         self.lock = threading.Lock()
48         self.total = 0
49         self.cap = cap
50         self.min_entries = min_entries
51
52     def set_cap(self, cap):
53         self.cap = cap
54
55     def cap_cache(self, required):
56         # ordered dict iterates from oldest to newest
57         for pdh, v in list(self.collections.items()):
58             available = self.cap - self.total
59             if available >= required or len(self.collections) < self.min_entries:
60                 return
61             # cut it loose
62             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63             del self.collections[pdh]
64             self.total -= v[1]
65
66     def get(self, locator):
67         with self.lock:
68             if locator not in self.collections:
69                 m = pdh_size.match(locator)
70                 if m:
71                     self.cap_cache(int(m.group(2)) * 128)
72                 logger.debug("Creating collection reader for %s", locator)
73                 try:
74                     cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75                                                              keep_client=self.keep_client,
76                                                              num_retries=self.num_retries)
77                 except arvados.errors.ApiError as ap:
78                     raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79                 sz = len(cr.manifest_text()) * 128
80                 self.collections[locator] = (cr, sz)
81                 self.total += sz
82             else:
83                 cr, sz = self.collections[locator]
84                 # bump it to the back
85                 del self.collections[locator]
86                 self.collections[locator] = (cr, sz)
87             return cr
88
89
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91     """Implement the cwltool FsAccess interface for Arvados Collections."""
92
93     def __init__(self, basedir, collection_cache=None):
94         super(CollectionFsAccess, self).__init__(basedir)
95         self.collection_cache = collection_cache
96
97     def get_collection(self, path):
98         sp = path.split("/", 1)
99         p = sp[0]
100         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101                                       arvados.util.collection_uuid_pattern.match(p[5:])):
102             locator = p[5:]
103             rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
104             return (self.collection_cache.get(locator), rest)
105         else:
106             return (None, path)
107
108     def _match(self, collection, patternsegments, parent):
109         if not patternsegments:
110             return []
111
112         if not isinstance(collection, arvados.collection.RichCollectionBase):
113             return []
114
115         ret = []
116         # iterate over the files and subcollections in 'collection'
117         for filename in collection:
118             if patternsegments[0] == '.':
119                 # Pattern contains something like "./foo" so just shift
120                 # past the "./"
121                 ret.extend(self._match(collection, patternsegments[1:], parent))
122             elif fnmatch.fnmatch(filename, patternsegments[0]):
123                 cur = os.path.join(parent, filename)
124                 if len(patternsegments) == 1:
125                     ret.append(cur)
126                 else:
127                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
128         return ret
129
130     def glob(self, pattern):
131         collection, rest = self.get_collection(pattern)
132         if collection is not None and rest in (None, "", "."):
133             return [pattern]
134         patternsegments = rest.split("/")
135         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
136
137     def open(self, fn, mode, encoding=None):
138         collection, rest = self.get_collection(fn)
139         if collection is not None:
140             return collection.open(rest, mode, encoding=encoding)
141         else:
142             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
143
144     def exists(self, fn):
145         try:
146             collection, rest = self.get_collection(fn)
147         except HttpError as err:
148             if err.resp.status == 404:
149                 return False
150             else:
151                 raise
152         except IOError as err:
153             if err.errno == errno.ENOENT:
154                 return False
155             else:
156                 raise
157         if collection is not None:
158             if rest:
159                 return collection.exists(rest)
160             else:
161                 return True
162         else:
163             return super(CollectionFsAccess, self).exists(fn)
164
165     def size(self, fn):  # type: (unicode) -> bool
166         collection, rest = self.get_collection(fn)
167         if collection is not None:
168             if rest:
169                 arvfile = collection.find(rest)
170                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
171                     return arvfile.size()
172             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
173         else:
174             return super(CollectionFsAccess, self).size(fn)
175
176     def isfile(self, fn):  # type: (unicode) -> bool
177         collection, rest = self.get_collection(fn)
178         if collection is not None:
179             if rest:
180                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
181             else:
182                 return False
183         else:
184             return super(CollectionFsAccess, self).isfile(fn)
185
186     def isdir(self, fn):  # type: (unicode) -> bool
187         collection, rest = self.get_collection(fn)
188         if collection is not None:
189             if rest:
190                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
191             else:
192                 return True
193         else:
194             return super(CollectionFsAccess, self).isdir(fn)
195
196     def listdir(self, fn):  # type: (unicode) -> List[unicode]
197         collection, rest = self.get_collection(fn)
198         if collection is not None:
199             if rest:
200                 dir = collection.find(rest)
201             else:
202                 dir = collection
203             if dir is None:
204                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
205             if not isinstance(dir, arvados.collection.RichCollectionBase):
206                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
207             return [abspath(l, fn) for l in list(dir.keys())]
208         else:
209             return super(CollectionFsAccess, self).listdir(fn)
210
211     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
212         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
213             return paths[-1]
214         return os.path.join(path, *paths)
215
216     def realpath(self, path):
217         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
218             return path
219         collection, rest = self.get_collection(path)
220         if collection is not None:
221             return path
222         else:
223             return os.path.realpath(path)
224
225 class CollectionFetcher(DefaultFetcher):
226     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
227         super(CollectionFetcher, self).__init__(cache, session)
228         self.api_client = api_client
229         self.fsaccess = fs_access
230         self.num_retries = num_retries
231
232     def fetch_text(self, url, content_types=None):
233         if url.startswith("keep:"):
234             with self.fsaccess.open(url, "r", encoding="utf-8") as f:
235                 return f.read()
236         if url.startswith("arvwf:"):
237             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
238             definition = yaml.round_trip_load(record["definition"])
239             definition["label"] = record["name"]
240             return yaml.round_trip_dump(definition)
241         return super(CollectionFetcher, self).fetch_text(url)
242
243     def check_exists(self, url):
244         try:
245             if url.startswith("http://arvados.org/cwl"):
246                 return True
247             urld, _ = urllib.parse.urldefrag(url)
248             if urld.startswith("keep:"):
249                 return self.fsaccess.exists(urld)
250             if urld.startswith("arvwf:"):
251                 if self.fetch_text(urld):
252                     return True
253         except arvados.errors.NotFoundError:
254             return False
255         except Exception:
256             logger.exception("Got unexpected exception checking if file exists")
257             return False
258         return super(CollectionFetcher, self).check_exists(url)
259
260     def urljoin(self, base_url, url):
261         if not url:
262             return base_url
263
264         urlsp = urllib.parse.urlsplit(url)
265         if urlsp.scheme or not base_url:
266             return url
267
268         basesp = urllib.parse.urlsplit(base_url)
269         if basesp.scheme in ("keep", "arvwf"):
270             if not basesp.path:
271                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
272
273             baseparts = basesp.path.split("/")
274             urlparts = urlsp.path.split("/") if urlsp.path else []
275
276             locator = baseparts.pop(0)
277
278             if (basesp.scheme == "keep" and
279                 (not arvados.util.keep_locator_pattern.match(locator)) and
280                 (not arvados.util.collection_uuid_pattern.match(locator))):
281                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
282
283             if urlsp.path.startswith("/"):
284                 baseparts = []
285                 urlparts.pop(0)
286
287             if baseparts and urlsp.path:
288                 baseparts.pop()
289
290             path = "/".join([locator] + baseparts + urlparts)
291             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
292
293         return super(CollectionFetcher, self).urljoin(base_url, url)
294
295     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
296
297     def supported_schemes(self):  # type: () -> List[Text]
298         return self.schemes
299
300
301 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
302 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
303
304 def collectionResolver(api_client, document_loader, uri, num_retries=4):
305     if uri.startswith("keep:") or uri.startswith("arvwf:"):
306         return str(uri)
307
308     if workflow_uuid_pattern.match(uri):
309         return u"arvwf:%s#main" % (uri)
310
311     if pipeline_template_uuid_pattern.match(uri):
312         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
313         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
314
315     p = uri.split("/")
316     if arvados.util.keep_locator_pattern.match(p[0]):
317         return u"keep:%s" % (uri)
318
319     if arvados.util.collection_uuid_pattern.match(p[0]):
320         return u"keep:%s%s" % (api_client.collections().
321                               get(uuid=p[0]).execute()["portable_data_hash"],
322                               uri[len(p[0]):])
323
324     return cwltool.resolver.tool_resolver(document_loader, uri)