Merge branch '16265-security-updates' into dependabot/bundler/apps/workbench/loofah...
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19
20 import ruamel.yaml as yaml
21
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
25
26 import arvados.util
27 import arvados.collection
28 import arvados.arvfile
29 import arvados.errors
30
31 from googleapiclient.errors import HttpError
32
33 from schema_salad.ref_resolver import DefaultFetcher
34
35 logger = logging.getLogger('arvados.cwl-runner')
36
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38
39 class CollectionCache(object):
40     def __init__(self, api_client, keep_client, num_retries,
41                  cap=256*1024*1024,
42                  min_entries=2):
43         self.api_client = api_client
44         self.keep_client = keep_client
45         self.num_retries = num_retries
46         self.collections = OrderedDict()
47         self.lock = threading.Lock()
48         self.total = 0
49         self.cap = cap
50         self.min_entries = min_entries
51
52     def set_cap(self, cap):
53         self.cap = cap
54
55     def cap_cache(self, required):
56         # ordered dict iterates from oldest to newest
57         for pdh, v in list(self.collections.items()):
58             available = self.cap - self.total
59             if available >= required or len(self.collections) < self.min_entries:
60                 return
61             # cut it loose
62             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63             del self.collections[pdh]
64             self.total -= v[1]
65
66     def get(self, locator):
67         with self.lock:
68             if locator not in self.collections:
69                 m = pdh_size.match(locator)
70                 if m:
71                     self.cap_cache(int(m.group(2)) * 128)
72                 logger.debug("Creating collection reader for %s", locator)
73                 try:
74                     cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75                                                              keep_client=self.keep_client,
76                                                              num_retries=self.num_retries)
77                 except arvados.errors.ApiError as ap:
78                     raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79                 sz = len(cr.manifest_text()) * 128
80                 self.collections[locator] = (cr, sz)
81                 self.total += sz
82             else:
83                 cr, sz = self.collections[locator]
84                 # bump it to the back
85                 del self.collections[locator]
86                 self.collections[locator] = (cr, sz)
87             return cr
88
89
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91     """Implement the cwltool FsAccess interface for Arvados Collections."""
92
93     def __init__(self, basedir, collection_cache=None):
94         super(CollectionFsAccess, self).__init__(basedir)
95         self.collection_cache = collection_cache
96
97     def get_collection(self, path):
98         sp = path.split("/", 1)
99         p = sp[0]
100         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101                                       arvados.util.collection_uuid_pattern.match(p[5:])):
102             locator = p[5:]
103             return (self.collection_cache.get(locator), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
104         else:
105             return (None, path)
106
107     def _match(self, collection, patternsegments, parent):
108         if not patternsegments:
109             return []
110
111         if not isinstance(collection, arvados.collection.RichCollectionBase):
112             return []
113
114         ret = []
115         # iterate over the files and subcollections in 'collection'
116         for filename in collection:
117             if patternsegments[0] == '.':
118                 # Pattern contains something like "./foo" so just shift
119                 # past the "./"
120                 ret.extend(self._match(collection, patternsegments[1:], parent))
121             elif fnmatch.fnmatch(filename, patternsegments[0]):
122                 cur = os.path.join(parent, filename)
123                 if len(patternsegments) == 1:
124                     ret.append(cur)
125                 else:
126                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
127         return ret
128
129     def glob(self, pattern):
130         collection, rest = self.get_collection(pattern)
131         if collection is not None and not rest:
132             return [pattern]
133         patternsegments = rest.split("/")
134         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
135
136     def open(self, fn, mode, encoding=None):
137         collection, rest = self.get_collection(fn)
138         if collection is not None:
139             return collection.open(rest, mode, encoding=encoding)
140         else:
141             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
142
143     def exists(self, fn):
144         try:
145             collection, rest = self.get_collection(fn)
146         except HttpError as err:
147             if err.resp.status == 404:
148                 return False
149             else:
150                 raise
151         if collection is not None:
152             if rest:
153                 return collection.exists(rest)
154             else:
155                 return True
156         else:
157             return super(CollectionFsAccess, self).exists(fn)
158
159     def size(self, fn):  # type: (unicode) -> bool
160         collection, rest = self.get_collection(fn)
161         if collection is not None:
162             if rest:
163                 arvfile = collection.find(rest)
164                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
165                     return arvfile.size()
166             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
167         else:
168             return super(CollectionFsAccess, self).size(fn)
169
170     def isfile(self, fn):  # type: (unicode) -> bool
171         collection, rest = self.get_collection(fn)
172         if collection is not None:
173             if rest:
174                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
175             else:
176                 return False
177         else:
178             return super(CollectionFsAccess, self).isfile(fn)
179
180     def isdir(self, fn):  # type: (unicode) -> bool
181         collection, rest = self.get_collection(fn)
182         if collection is not None:
183             if rest:
184                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
185             else:
186                 return True
187         else:
188             return super(CollectionFsAccess, self).isdir(fn)
189
190     def listdir(self, fn):  # type: (unicode) -> List[unicode]
191         collection, rest = self.get_collection(fn)
192         if collection is not None:
193             if rest:
194                 dir = collection.find(rest)
195             else:
196                 dir = collection
197             if dir is None:
198                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
199             if not isinstance(dir, arvados.collection.RichCollectionBase):
200                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
201             return [abspath(l, fn) for l in list(dir.keys())]
202         else:
203             return super(CollectionFsAccess, self).listdir(fn)
204
205     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
206         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
207             return paths[-1]
208         return os.path.join(path, *paths)
209
210     def realpath(self, path):
211         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
212             return path
213         collection, rest = self.get_collection(path)
214         if collection is not None:
215             return path
216         else:
217             return os.path.realpath(path)
218
219 class CollectionFetcher(DefaultFetcher):
220     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
221         super(CollectionFetcher, self).__init__(cache, session)
222         self.api_client = api_client
223         self.fsaccess = fs_access
224         self.num_retries = num_retries
225
226     def fetch_text(self, url):
227         if url.startswith("keep:"):
228             with self.fsaccess.open(url, "r", encoding="utf-8") as f:
229                 return f.read()
230         if url.startswith("arvwf:"):
231             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
232             definition = yaml.round_trip_load(record["definition"])
233             definition["label"] = record["name"]
234             return yaml.round_trip_dump(definition)
235         return super(CollectionFetcher, self).fetch_text(url)
236
237     def check_exists(self, url):
238         try:
239             if url.startswith("http://arvados.org/cwl"):
240                 return True
241             if url.startswith("keep:"):
242                 return self.fsaccess.exists(url)
243             if url.startswith("arvwf:"):
244                 if self.fetch_text(url):
245                     return True
246         except arvados.errors.NotFoundError:
247             return False
248         except Exception:
249             logger.exception("Got unexpected exception checking if file exists")
250             return False
251         return super(CollectionFetcher, self).check_exists(url)
252
253     def urljoin(self, base_url, url):
254         if not url:
255             return base_url
256
257         urlsp = urllib.parse.urlsplit(url)
258         if urlsp.scheme or not base_url:
259             return url
260
261         basesp = urllib.parse.urlsplit(base_url)
262         if basesp.scheme in ("keep", "arvwf"):
263             if not basesp.path:
264                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
265
266             baseparts = basesp.path.split("/")
267             urlparts = urlsp.path.split("/") if urlsp.path else []
268
269             locator = baseparts.pop(0)
270
271             if (basesp.scheme == "keep" and
272                 (not arvados.util.keep_locator_pattern.match(locator)) and
273                 (not arvados.util.collection_uuid_pattern.match(locator))):
274                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
275
276             if urlsp.path.startswith("/"):
277                 baseparts = []
278                 urlparts.pop(0)
279
280             if baseparts and urlsp.path:
281                 baseparts.pop()
282
283             path = "/".join([locator] + baseparts + urlparts)
284             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
285
286         return super(CollectionFetcher, self).urljoin(base_url, url)
287
288     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
289
290     def supported_schemes(self):  # type: () -> List[Text]
291         return self.schemes
292
293
294 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
295 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
296
297 def collectionResolver(api_client, document_loader, uri, num_retries=4):
298     if uri.startswith("keep:") or uri.startswith("arvwf:"):
299         return str(uri)
300
301     if workflow_uuid_pattern.match(uri):
302         return u"arvwf:%s#main" % (uri)
303
304     if pipeline_template_uuid_pattern.match(uri):
305         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
306         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
307
308     p = uri.split("/")
309     if arvados.util.keep_locator_pattern.match(p[0]):
310         return u"keep:%s" % (uri)
311
312     if arvados.util.collection_uuid_pattern.match(p[0]):
313         return u"keep:%s%s" % (api_client.collections().
314                               get(uuid=p[0]).execute()["portable_data_hash"],
315                               uri[len(p[0]):])
316
317     return cwltool.resolver.tool_resolver(document_loader, uri)