17801: Handle glob capture of keep URI with a trailing slash
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19
20 import ruamel.yaml as yaml
21
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
25
26 import arvados.util
27 import arvados.collection
28 import arvados.arvfile
29 import arvados.errors
30
31 from googleapiclient.errors import HttpError
32
33 from schema_salad.ref_resolver import DefaultFetcher
34
35 logger = logging.getLogger('arvados.cwl-runner')
36
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38
39 class CollectionCache(object):
40     def __init__(self, api_client, keep_client, num_retries,
41                  cap=256*1024*1024,
42                  min_entries=2):
43         self.api_client = api_client
44         self.keep_client = keep_client
45         self.num_retries = num_retries
46         self.collections = OrderedDict()
47         self.lock = threading.Lock()
48         self.total = 0
49         self.cap = cap
50         self.min_entries = min_entries
51
52     def set_cap(self, cap):
53         self.cap = cap
54
55     def cap_cache(self, required):
56         # ordered dict iterates from oldest to newest
57         for pdh, v in list(self.collections.items()):
58             available = self.cap - self.total
59             if available >= required or len(self.collections) < self.min_entries:
60                 return
61             # cut it loose
62             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63             del self.collections[pdh]
64             self.total -= v[1]
65
66     def get(self, locator):
67         with self.lock:
68             if locator not in self.collections:
69                 m = pdh_size.match(locator)
70                 if m:
71                     self.cap_cache(int(m.group(2)) * 128)
72                 logger.debug("Creating collection reader for %s", locator)
73                 try:
74                     cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75                                                              keep_client=self.keep_client,
76                                                              num_retries=self.num_retries)
77                 except arvados.errors.ApiError as ap:
78                     raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79                 sz = len(cr.manifest_text()) * 128
80                 self.collections[locator] = (cr, sz)
81                 self.total += sz
82             else:
83                 cr, sz = self.collections[locator]
84                 # bump it to the back
85                 del self.collections[locator]
86                 self.collections[locator] = (cr, sz)
87             return cr
88
89
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91     """Implement the cwltool FsAccess interface for Arvados Collections."""
92
93     def __init__(self, basedir, collection_cache=None):
94         super(CollectionFsAccess, self).__init__(basedir)
95         self.collection_cache = collection_cache
96
97     def get_collection(self, path):
98         sp = path.split("/", 1)
99         p = sp[0]
100         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101                                       arvados.util.collection_uuid_pattern.match(p[5:])):
102             locator = p[5:]
103             rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
104             return (self.collection_cache.get(locator), rest)
105         else:
106             return (None, path)
107
108     def _match(self, collection, patternsegments, parent):
109         if not patternsegments:
110             return []
111
112         if not isinstance(collection, arvados.collection.RichCollectionBase):
113             return []
114
115         ret = []
116         # iterate over the files and subcollections in 'collection'
117         for filename in collection:
118             if patternsegments[0] == '.':
119                 # Pattern contains something like "./foo" so just shift
120                 # past the "./"
121                 ret.extend(self._match(collection, patternsegments[1:], parent))
122             elif fnmatch.fnmatch(filename, patternsegments[0]):
123                 cur = os.path.join(parent, filename)
124                 if len(patternsegments) == 1:
125                     ret.append(cur)
126                 else:
127                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
128         return ret
129
130     def glob(self, pattern):
131         collection, rest = self.get_collection(pattern)
132         if collection is not None and rest in (None, "", "."):
133             return [pattern]
134         patternsegments = rest.split("/")
135         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
136
137     def open(self, fn, mode, encoding=None):
138         collection, rest = self.get_collection(fn)
139         if collection is not None:
140             return collection.open(rest, mode, encoding=encoding)
141         else:
142             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
143
144     def exists(self, fn):
145         try:
146             collection, rest = self.get_collection(fn)
147         except HttpError as err:
148             if err.resp.status == 404:
149                 return False
150             else:
151                 raise
152         except IOError as err:
153             if err.errno == errno.ENOENT:
154                 return False
155             else:
156                 raise
157         if collection is not None:
158             if rest:
159                 return collection.exists(rest)
160             else:
161                 return True
162         else:
163             return super(CollectionFsAccess, self).exists(fn)
164
165     def size(self, fn):  # type: (unicode) -> bool
166         collection, rest = self.get_collection(fn)
167         if collection is not None:
168             if rest:
169                 arvfile = collection.find(rest)
170                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
171                     return arvfile.size()
172             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
173         else:
174             return super(CollectionFsAccess, self).size(fn)
175
176     def isfile(self, fn):  # type: (unicode) -> bool
177         collection, rest = self.get_collection(fn)
178         if collection is not None:
179             if rest:
180                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
181             else:
182                 return False
183         else:
184             return super(CollectionFsAccess, self).isfile(fn)
185
186     def isdir(self, fn):  # type: (unicode) -> bool
187         collection, rest = self.get_collection(fn)
188         if collection is not None:
189             if rest:
190                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
191             else:
192                 return True
193         else:
194             return super(CollectionFsAccess, self).isdir(fn)
195
196     def listdir(self, fn):  # type: (unicode) -> List[unicode]
197         collection, rest = self.get_collection(fn)
198         if collection is not None:
199             if rest:
200                 dir = collection.find(rest)
201             else:
202                 dir = collection
203             if dir is None:
204                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
205             if not isinstance(dir, arvados.collection.RichCollectionBase):
206                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
207             return [abspath(l, fn) for l in list(dir.keys())]
208         else:
209             return super(CollectionFsAccess, self).listdir(fn)
210
211     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
212         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
213             return paths[-1]
214         return os.path.join(path, *paths)
215
216     def realpath(self, path):
217         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
218             return path
219         collection, rest = self.get_collection(path)
220         if collection is not None:
221             return path
222         else:
223             return os.path.realpath(path)
224
225 class CollectionFetcher(DefaultFetcher):
226     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
227         super(CollectionFetcher, self).__init__(cache, session)
228         self.api_client = api_client
229         self.fsaccess = fs_access
230         self.num_retries = num_retries
231
232     def fetch_text(self, url, content_types=None):
233         if url.startswith("keep:"):
234             with self.fsaccess.open(url, "r", encoding="utf-8") as f:
235                 return f.read()
236         if url.startswith("arvwf:"):
237             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
238             definition = yaml.round_trip_load(record["definition"])
239             definition["label"] = record["name"]
240             return yaml.round_trip_dump(definition)
241         return super(CollectionFetcher, self).fetch_text(url)
242
243     def check_exists(self, url):
244         try:
245             if url.startswith("http://arvados.org/cwl"):
246                 return True
247             if url.startswith("keep:"):
248                 return self.fsaccess.exists(url)
249             if url.startswith("arvwf:"):
250                 if self.fetch_text(url):
251                     return True
252         except arvados.errors.NotFoundError:
253             return False
254         except Exception:
255             logger.exception("Got unexpected exception checking if file exists")
256             return False
257         return super(CollectionFetcher, self).check_exists(url)
258
259     def urljoin(self, base_url, url):
260         if not url:
261             return base_url
262
263         urlsp = urllib.parse.urlsplit(url)
264         if urlsp.scheme or not base_url:
265             return url
266
267         basesp = urllib.parse.urlsplit(base_url)
268         if basesp.scheme in ("keep", "arvwf"):
269             if not basesp.path:
270                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
271
272             baseparts = basesp.path.split("/")
273             urlparts = urlsp.path.split("/") if urlsp.path else []
274
275             locator = baseparts.pop(0)
276
277             if (basesp.scheme == "keep" and
278                 (not arvados.util.keep_locator_pattern.match(locator)) and
279                 (not arvados.util.collection_uuid_pattern.match(locator))):
280                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
281
282             if urlsp.path.startswith("/"):
283                 baseparts = []
284                 urlparts.pop(0)
285
286             if baseparts and urlsp.path:
287                 baseparts.pop()
288
289             path = "/".join([locator] + baseparts + urlparts)
290             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
291
292         return super(CollectionFetcher, self).urljoin(base_url, url)
293
294     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
295
296     def supported_schemes(self):  # type: () -> List[Text]
297         return self.schemes
298
299
300 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
301 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
302
303 def collectionResolver(api_client, document_loader, uri, num_retries=4):
304     if uri.startswith("keep:") or uri.startswith("arvwf:"):
305         return str(uri)
306
307     if workflow_uuid_pattern.match(uri):
308         return u"arvwf:%s#main" % (uri)
309
310     if pipeline_template_uuid_pattern.match(uri):
311         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
312         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
313
314     p = uri.split("/")
315     if arvados.util.keep_locator_pattern.match(p[0]):
316         return u"keep:%s" % (uri)
317
318     if arvados.util.collection_uuid_pattern.match(p[0]):
319         return u"keep:%s%s" % (api_client.collections().
320                               get(uuid=p[0]).execute()["portable_data_hash"],
321                               uri[len(p[0]):])
322
323     return cwltool.resolver.tool_resolver(document_loader, uri)