Merge branch '14383-java-sdk-double-slash'. Fixes #14383.
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19
20 import ruamel.yaml as yaml
21
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
25
26 import arvados.util
27 import arvados.collection
28 import arvados.arvfile
29 import arvados.errors
30
31 from googleapiclient.errors import HttpError
32
33 from schema_salad.ref_resolver import DefaultFetcher
34
35 logger = logging.getLogger('arvados.cwl-runner')
36
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38
39 class CollectionCache(object):
40     def __init__(self, api_client, keep_client, num_retries,
41                  cap=256*1024*1024,
42                  min_entries=2):
43         self.api_client = api_client
44         self.keep_client = keep_client
45         self.num_retries = num_retries
46         self.collections = OrderedDict()
47         self.lock = threading.Lock()
48         self.total = 0
49         self.cap = cap
50         self.min_entries = min_entries
51
52     def set_cap(self, cap):
53         self.cap = cap
54
55     def cap_cache(self, required):
56         # ordered dict iterates from oldest to newest
57         for pdh, v in list(self.collections.items()):
58             available = self.cap - self.total
59             if available >= required or len(self.collections) < self.min_entries:
60                 return
61             # cut it loose
62             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63             del self.collections[pdh]
64             self.total -= v[1]
65
66     def get(self, pdh):
67         with self.lock:
68             if pdh not in self.collections:
69                 m = pdh_size.match(pdh)
70                 if m:
71                     self.cap_cache(int(m.group(2)) * 128)
72                 logger.debug("Creating collection reader for %s", pdh)
73                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
74                                                          keep_client=self.keep_client,
75                                                          num_retries=self.num_retries)
76                 sz = len(cr.manifest_text()) * 128
77                 self.collections[pdh] = (cr, sz)
78                 self.total += sz
79             else:
80                 cr, sz = self.collections[pdh]
81                 # bump it to the back
82                 del self.collections[pdh]
83                 self.collections[pdh] = (cr, sz)
84             return cr
85
86
87 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
88     """Implement the cwltool FsAccess interface for Arvados Collections."""
89
90     def __init__(self, basedir, collection_cache=None):
91         super(CollectionFsAccess, self).__init__(basedir)
92         self.collection_cache = collection_cache
93
94     def get_collection(self, path):
95         sp = path.split("/", 1)
96         p = sp[0]
97         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
98             pdh = p[5:]
99             return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
100         else:
101             return (None, path)
102
103     def _match(self, collection, patternsegments, parent):
104         if not patternsegments:
105             return []
106
107         if not isinstance(collection, arvados.collection.RichCollectionBase):
108             return []
109
110         ret = []
111         # iterate over the files and subcollections in 'collection'
112         for filename in collection:
113             if patternsegments[0] == '.':
114                 # Pattern contains something like "./foo" so just shift
115                 # past the "./"
116                 ret.extend(self._match(collection, patternsegments[1:], parent))
117             elif fnmatch.fnmatch(filename, patternsegments[0]):
118                 cur = os.path.join(parent, filename)
119                 if len(patternsegments) == 1:
120                     ret.append(cur)
121                 else:
122                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
123         return ret
124
125     def glob(self, pattern):
126         collection, rest = self.get_collection(pattern)
127         if collection is not None and not rest:
128             return [pattern]
129         patternsegments = rest.split("/")
130         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
131
132     def open(self, fn, mode):
133         collection, rest = self.get_collection(fn)
134         if collection is not None:
135             return collection.open(rest, mode)
136         else:
137             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
138
139     def exists(self, fn):
140         try:
141             collection, rest = self.get_collection(fn)
142         except HttpError as err:
143             if err.resp.status == 404:
144                 return False
145             else:
146                 raise
147         if collection is not None:
148             if rest:
149                 return collection.exists(rest)
150             else:
151                 return True
152         else:
153             return super(CollectionFsAccess, self).exists(fn)
154
155     def size(self, fn):  # type: (unicode) -> bool
156         collection, rest = self.get_collection(fn)
157         if collection is not None:
158             if rest:
159                 arvfile = collection.find(rest)
160                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
161                     return arvfile.size()
162             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
163         else:
164             return super(CollectionFsAccess, self).size(fn)
165
166     def isfile(self, fn):  # type: (unicode) -> bool
167         collection, rest = self.get_collection(fn)
168         if collection is not None:
169             if rest:
170                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
171             else:
172                 return False
173         else:
174             return super(CollectionFsAccess, self).isfile(fn)
175
176     def isdir(self, fn):  # type: (unicode) -> bool
177         collection, rest = self.get_collection(fn)
178         if collection is not None:
179             if rest:
180                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
181             else:
182                 return True
183         else:
184             return super(CollectionFsAccess, self).isdir(fn)
185
186     def listdir(self, fn):  # type: (unicode) -> List[unicode]
187         collection, rest = self.get_collection(fn)
188         if collection is not None:
189             if rest:
190                 dir = collection.find(rest)
191             else:
192                 dir = collection
193             if dir is None:
194                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
195             if not isinstance(dir, arvados.collection.RichCollectionBase):
196                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
197             return [abspath(l, fn) for l in list(dir.keys())]
198         else:
199             return super(CollectionFsAccess, self).listdir(fn)
200
201     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
202         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
203             return paths[-1]
204         return os.path.join(path, *paths)
205
206     def realpath(self, path):
207         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
208             return path
209         collection, rest = self.get_collection(path)
210         if collection is not None:
211             return path
212         else:
213             return os.path.realpath(path)
214
215 class CollectionFetcher(DefaultFetcher):
216     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
217         super(CollectionFetcher, self).__init__(cache, session)
218         self.api_client = api_client
219         self.fsaccess = fs_access
220         self.num_retries = num_retries
221
222     def fetch_text(self, url):
223         if url.startswith("keep:"):
224             with self.fsaccess.open(url, "r") as f:
225                 return f.read()
226         if url.startswith("arvwf:"):
227             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
228             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
229             return definition
230         return super(CollectionFetcher, self).fetch_text(url)
231
232     def check_exists(self, url):
233         try:
234             if url.startswith("http://arvados.org/cwl"):
235                 return True
236             if url.startswith("keep:"):
237                 return self.fsaccess.exists(url)
238             if url.startswith("arvwf:"):
239                 if self.fetch_text(url):
240                     return True
241         except arvados.errors.NotFoundError:
242             return False
243         except:
244             logger.exception("Got unexpected exception checking if file exists:")
245             return False
246         return super(CollectionFetcher, self).check_exists(url)
247
248     def urljoin(self, base_url, url):
249         if not url:
250             return base_url
251
252         urlsp = urllib.parse.urlsplit(url)
253         if urlsp.scheme or not base_url:
254             return url
255
256         basesp = urllib.parse.urlsplit(base_url)
257         if basesp.scheme in ("keep", "arvwf"):
258             if not basesp.path:
259                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
260
261             baseparts = basesp.path.split("/")
262             urlparts = urlsp.path.split("/") if urlsp.path else []
263
264             pdh = baseparts.pop(0)
265
266             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
267                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
268
269             if urlsp.path.startswith("/"):
270                 baseparts = []
271                 urlparts.pop(0)
272
273             if baseparts and urlsp.path:
274                 baseparts.pop()
275
276             path = "/".join([pdh] + baseparts + urlparts)
277             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
278
279         return super(CollectionFetcher, self).urljoin(base_url, url)
280
281     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
282
283     def supported_schemes(self):  # type: () -> List[Text]
284         return self.schemes
285
286
287 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
288 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
289
290 def collectionResolver(api_client, document_loader, uri, num_retries=4):
291     if uri.startswith("keep:") or uri.startswith("arvwf:"):
292         return str(uri)
293
294     if workflow_uuid_pattern.match(uri):
295         return u"arvwf:%s#main" % (uri)
296
297     if pipeline_template_uuid_pattern.match(uri):
298         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
299         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
300
301     p = uri.split("/")
302     if arvados.util.keep_locator_pattern.match(p[0]):
303         return u"keep:%s" % (uri)
304
305     if arvados.util.collection_uuid_pattern.match(p[0]):
306         return u"keep:%s%s" % (api_client.collections().
307                               get(uuid=p[0]).execute()["portable_data_hash"],
308                               uri[len(p[0]):])
309
310     return cwltool.resolver.tool_resolver(document_loader, uri)