14322: CollectionFsAccess accepts UUIDs
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19
20 import ruamel.yaml as yaml
21
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
25
26 import arvados.util
27 import arvados.collection
28 import arvados.arvfile
29 import arvados.errors
30
31 from googleapiclient.errors import HttpError
32
33 from schema_salad.ref_resolver import DefaultFetcher
34
35 logger = logging.getLogger('arvados.cwl-runner')
36
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38
39 class CollectionCache(object):
40     def __init__(self, api_client, keep_client, num_retries,
41                  cap=256*1024*1024,
42                  min_entries=2):
43         self.api_client = api_client
44         self.keep_client = keep_client
45         self.num_retries = num_retries
46         self.collections = OrderedDict()
47         self.lock = threading.Lock()
48         self.total = 0
49         self.cap = cap
50         self.min_entries = min_entries
51
52     def set_cap(self, cap):
53         self.cap = cap
54
55     def cap_cache(self, required):
56         # ordered dict iterates from oldest to newest
57         for pdh, v in list(self.collections.items()):
58             available = self.cap - self.total
59             if available >= required or len(self.collections) < self.min_entries:
60                 return
61             # cut it loose
62             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63             del self.collections[pdh]
64             self.total -= v[1]
65
66     def get(self, locator):
67         with self.lock:
68             if locator not in self.collections:
69                 m = pdh_size.match(locator)
70                 if m:
71                     self.cap_cache(int(m.group(2)) * 128)
72                 logger.debug("Creating collection reader for %s", locator)
73                 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
74                                                          keep_client=self.keep_client,
75                                                          num_retries=self.num_retries)
76                 sz = len(cr.manifest_text()) * 128
77                 self.collections[locator] = (cr, sz)
78                 self.total += sz
79             else:
80                 cr, sz = self.collections[locator]
81                 # bump it to the back
82                 del self.collections[locator]
83                 self.collections[locator] = (cr, sz)
84             return cr
85
86
87 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
88     """Implement the cwltool FsAccess interface for Arvados Collections."""
89
90     def __init__(self, basedir, collection_cache=None):
91         super(CollectionFsAccess, self).__init__(basedir)
92         self.collection_cache = collection_cache
93
94     def get_collection(self, path):
95         sp = path.split("/", 1)
96         p = sp[0]
97         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
98                                       arvados.util.collection_uuid_pattern.match(p[5:])):
99             locator = p[5:]
100             return (self.collection_cache.get(locator), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
101         else:
102             return (None, path)
103
104     def _match(self, collection, patternsegments, parent):
105         if not patternsegments:
106             return []
107
108         if not isinstance(collection, arvados.collection.RichCollectionBase):
109             return []
110
111         ret = []
112         # iterate over the files and subcollections in 'collection'
113         for filename in collection:
114             if patternsegments[0] == '.':
115                 # Pattern contains something like "./foo" so just shift
116                 # past the "./"
117                 ret.extend(self._match(collection, patternsegments[1:], parent))
118             elif fnmatch.fnmatch(filename, patternsegments[0]):
119                 cur = os.path.join(parent, filename)
120                 if len(patternsegments) == 1:
121                     ret.append(cur)
122                 else:
123                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
124         return ret
125
126     def glob(self, pattern):
127         collection, rest = self.get_collection(pattern)
128         if collection is not None and not rest:
129             return [pattern]
130         patternsegments = rest.split("/")
131         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
132
133     def open(self, fn, mode):
134         collection, rest = self.get_collection(fn)
135         if collection is not None:
136             return collection.open(rest, mode)
137         else:
138             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
139
140     def exists(self, fn):
141         try:
142             collection, rest = self.get_collection(fn)
143         except HttpError as err:
144             if err.resp.status == 404:
145                 return False
146             else:
147                 raise
148         if collection is not None:
149             if rest:
150                 return collection.exists(rest)
151             else:
152                 return True
153         else:
154             return super(CollectionFsAccess, self).exists(fn)
155
156     def size(self, fn):  # type: (unicode) -> bool
157         collection, rest = self.get_collection(fn)
158         if collection is not None:
159             if rest:
160                 arvfile = collection.find(rest)
161                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
162                     return arvfile.size()
163             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
164         else:
165             return super(CollectionFsAccess, self).size(fn)
166
167     def isfile(self, fn):  # type: (unicode) -> bool
168         collection, rest = self.get_collection(fn)
169         if collection is not None:
170             if rest:
171                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
172             else:
173                 return False
174         else:
175             return super(CollectionFsAccess, self).isfile(fn)
176
177     def isdir(self, fn):  # type: (unicode) -> bool
178         collection, rest = self.get_collection(fn)
179         if collection is not None:
180             if rest:
181                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
182             else:
183                 return True
184         else:
185             return super(CollectionFsAccess, self).isdir(fn)
186
187     def listdir(self, fn):  # type: (unicode) -> List[unicode]
188         collection, rest = self.get_collection(fn)
189         if collection is not None:
190             if rest:
191                 dir = collection.find(rest)
192             else:
193                 dir = collection
194             if dir is None:
195                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
196             if not isinstance(dir, arvados.collection.RichCollectionBase):
197                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
198             return [abspath(l, fn) for l in list(dir.keys())]
199         else:
200             return super(CollectionFsAccess, self).listdir(fn)
201
202     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
203         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
204             return paths[-1]
205         return os.path.join(path, *paths)
206
207     def realpath(self, path):
208         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
209             return path
210         collection, rest = self.get_collection(path)
211         if collection is not None:
212             return path
213         else:
214             return os.path.realpath(path)
215
216 class CollectionFetcher(DefaultFetcher):
217     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
218         super(CollectionFetcher, self).__init__(cache, session)
219         self.api_client = api_client
220         self.fsaccess = fs_access
221         self.num_retries = num_retries
222
223     def fetch_text(self, url):
224         if url.startswith("keep:"):
225             with self.fsaccess.open(url, "r") as f:
226                 return f.read()
227         if url.startswith("arvwf:"):
228             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
229             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
230             return definition
231         return super(CollectionFetcher, self).fetch_text(url)
232
233     def check_exists(self, url):
234         try:
235             if url.startswith("http://arvados.org/cwl"):
236                 return True
237             if url.startswith("keep:"):
238                 return self.fsaccess.exists(url)
239             if url.startswith("arvwf:"):
240                 if self.fetch_text(url):
241                     return True
242         except arvados.errors.NotFoundError:
243             return False
244         except Exception:
245             logger.exception("Got unexpected exception checking if file exists")
246             return False
247         return super(CollectionFetcher, self).check_exists(url)
248
249     def urljoin(self, base_url, url):
250         if not url:
251             return base_url
252
253         urlsp = urllib.parse.urlsplit(url)
254         if urlsp.scheme or not base_url:
255             return url
256
257         basesp = urllib.parse.urlsplit(base_url)
258         if basesp.scheme in ("keep", "arvwf"):
259             if not basesp.path:
260                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
261
262             baseparts = basesp.path.split("/")
263             urlparts = urlsp.path.split("/") if urlsp.path else []
264
265             locator = baseparts.pop(0)
266
267             if (basesp.scheme == "keep" and
268                 (not arvados.util.keep_locator_pattern.match(pdh)) and
269                 (not arvados.util.collection_uuid_pattern.match(pdh))):
270                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
271
272             if urlsp.path.startswith("/"):
273                 baseparts = []
274                 urlparts.pop(0)
275
276             if baseparts and urlsp.path:
277                 baseparts.pop()
278
279             path = "/".join([locator] + baseparts + urlparts)
280             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
281
282         return super(CollectionFetcher, self).urljoin(base_url, url)
283
284     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
285
286     def supported_schemes(self):  # type: () -> List[Text]
287         return self.schemes
288
289
290 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
291 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
292
293 def collectionResolver(api_client, document_loader, uri, num_retries=4):
294     if uri.startswith("keep:") or uri.startswith("arvwf:"):
295         return str(uri)
296
297     if workflow_uuid_pattern.match(uri):
298         return u"arvwf:%s#main" % (uri)
299
300     if pipeline_template_uuid_pattern.match(uri):
301         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
302         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
303
304     p = uri.split("/")
305     if arvados.util.keep_locator_pattern.match(p[0]):
306         return u"keep:%s" % (uri)
307
308     if arvados.util.collection_uuid_pattern.match(p[0]):
309         return u"keep:%s%s" % (api_client.collections().
310                               get(uuid=p[0]).execute()["portable_data_hash"],
311                               uri[len(p[0]):])
312
313     return cwltool.resolver.tool_resolver(document_loader, uri)