20613: Use GoogleHTTPClientFilter in a-c-r
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
10
11 import fnmatch
12 import os
13 import errno
14 import urllib.parse
15 import re
16 import logging
17 import threading
18 from collections import OrderedDict
19 from io import StringIO
20
21 import ruamel.yaml
22
23 import cwltool.stdfsaccess
24 from cwltool.pathmapper import abspath
25 import cwltool.resolver
26
27 import arvados.util
28 import arvados.collection
29 import arvados.arvfile
30 import arvados.errors
31
32 from googleapiclient.errors import HttpError
33
34 from schema_salad.ref_resolver import DefaultFetcher
35
36 logger = logging.getLogger('arvados.cwl-runner')
37
38 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
39
40 class CollectionCache(object):
41     def __init__(self, api_client, keep_client, num_retries,
42                  cap=256*1024*1024,
43                  min_entries=2):
44         self.api_client = api_client
45         self.keep_client = keep_client
46         self.num_retries = num_retries
47         self.collections = OrderedDict()
48         self.lock = threading.Lock()
49         self.total = 0
50         self.cap = cap
51         self.min_entries = min_entries
52
53     def set_cap(self, cap):
54         self.cap = cap
55
56     def cap_cache(self, required):
57         # ordered dict iterates from oldest to newest
58         for pdh, v in list(self.collections.items()):
59             available = self.cap - self.total
60             if available >= required or len(self.collections) < self.min_entries:
61                 return
62             # cut it loose
63             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
64             del self.collections[pdh]
65             self.total -= v[1]
66
67     def get(self, locator):
68         with self.lock:
69             if locator not in self.collections:
70                 m = pdh_size.match(locator)
71                 if m:
72                     self.cap_cache(int(m.group(2)) * 128)
73                 logger.debug("Creating collection reader for %s", locator)
74                 try:
75                     cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
76                                                              keep_client=self.keep_client,
77                                                              num_retries=self.num_retries)
78                 except arvados.errors.ApiError as ap:
79                     raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
80                 sz = len(cr.manifest_text()) * 128
81                 self.collections[locator] = (cr, sz)
82                 self.total += sz
83             else:
84                 cr, sz = self.collections[locator]
85                 # bump it to the back
86                 del self.collections[locator]
87                 self.collections[locator] = (cr, sz)
88             return cr
89
90
91 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
92     """Implement the cwltool FsAccess interface for Arvados Collections."""
93
94     def __init__(self, basedir, collection_cache=None):
95         super(CollectionFsAccess, self).__init__(basedir)
96         self.collection_cache = collection_cache
97
98     def get_collection(self, path):
99         sp = path.split("/", 1)
100         p = sp[0]
101         if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
102                                       arvados.util.collection_uuid_pattern.match(p[5:])):
103             locator = p[5:]
104             rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
105             return (self.collection_cache.get(locator), rest)
106         else:
107             return (None, path)
108
109     def _match(self, collection, patternsegments, parent):
110         if not patternsegments:
111             return []
112
113         if not isinstance(collection, arvados.collection.RichCollectionBase):
114             return []
115
116         ret = []
117         # iterate over the files and subcollections in 'collection'
118         for filename in collection:
119             if patternsegments[0] == '.':
120                 # Pattern contains something like "./foo" so just shift
121                 # past the "./"
122                 ret.extend(self._match(collection, patternsegments[1:], parent))
123             elif fnmatch.fnmatch(filename, patternsegments[0]):
124                 cur = os.path.join(parent, filename)
125                 if len(patternsegments) == 1:
126                     ret.append(cur)
127                 else:
128                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
129         return ret
130
131     def glob(self, pattern):
132         collection, rest = self.get_collection(pattern)
133         if collection is not None and rest in (None, "", "."):
134             return [pattern]
135         patternsegments = rest.split("/")
136         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
137
138     def open(self, fn, mode, encoding=None):
139         collection, rest = self.get_collection(fn)
140         if collection is not None:
141             return collection.open(rest, mode, encoding=encoding)
142         else:
143             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
144
145     def exists(self, fn):
146         try:
147             collection, rest = self.get_collection(fn)
148         except HttpError as err:
149             if err.resp.status == 404:
150                 return False
151             else:
152                 raise
153         except IOError as err:
154             if err.errno == errno.ENOENT:
155                 return False
156             else:
157                 raise
158         if collection is not None:
159             if rest:
160                 return collection.exists(rest)
161             else:
162                 return True
163         else:
164             return super(CollectionFsAccess, self).exists(fn)
165
166     def size(self, fn):  # type: (unicode) -> bool
167         collection, rest = self.get_collection(fn)
168         if collection is not None:
169             if rest:
170                 arvfile = collection.find(rest)
171                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
172                     return arvfile.size()
173             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
174         else:
175             return super(CollectionFsAccess, self).size(fn)
176
177     def isfile(self, fn):  # type: (unicode) -> bool
178         collection, rest = self.get_collection(fn)
179         if collection is not None:
180             if rest:
181                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
182             else:
183                 return False
184         else:
185             return super(CollectionFsAccess, self).isfile(fn)
186
187     def isdir(self, fn):  # type: (unicode) -> bool
188         collection, rest = self.get_collection(fn)
189         if collection is not None:
190             if rest:
191                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
192             else:
193                 return True
194         else:
195             return super(CollectionFsAccess, self).isdir(fn)
196
197     def listdir(self, fn):  # type: (unicode) -> List[unicode]
198         collection, rest = self.get_collection(fn)
199         if collection is not None:
200             if rest:
201                 dir = collection.find(rest)
202             else:
203                 dir = collection
204             if dir is None:
205                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
206             if not isinstance(dir, arvados.collection.RichCollectionBase):
207                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
208             return [abspath(l, fn) for l in list(dir.keys())]
209         else:
210             return super(CollectionFsAccess, self).listdir(fn)
211
212     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
213         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
214             return paths[-1]
215         return os.path.join(path, *paths)
216
217     def realpath(self, path):
218         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
219             return path
220         collection, rest = self.get_collection(path)
221         if collection is not None:
222             return path
223         else:
224             return os.path.realpath(path)
225
226 class CollectionFetcher(DefaultFetcher):
227     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
228         super(CollectionFetcher, self).__init__(cache, session)
229         self.api_client = api_client
230         self.fsaccess = fs_access
231         self.num_retries = num_retries
232
233     def fetch_text(self, url, content_types=None):
234         if url.startswith("keep:"):
235             with self.fsaccess.open(url, "r", encoding="utf-8") as f:
236                 return f.read()
237         if url.startswith("arvwf:"):
238             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
239             yaml = ruamel.yaml.YAML(typ='rt', pure=True)
240             definition = yaml.load(record["definition"])
241             definition["label"] = record["name"]
242             stream = StringIO()
243             yaml.dump(definition, stream)
244             return stream.getvalue()
245         return super(CollectionFetcher, self).fetch_text(url)
246
247     def check_exists(self, url):
248         try:
249             if url.startswith("http://arvados.org/cwl"):
250                 return True
251             urld, _ = urllib.parse.urldefrag(url)
252             if urld.startswith("keep:"):
253                 return self.fsaccess.exists(urld)
254             if urld.startswith("arvwf:"):
255                 if self.fetch_text(urld):
256                     return True
257         except arvados.errors.NotFoundError:
258             return False
259         except Exception:
260             logger.exception("Got unexpected exception checking if file exists")
261             return False
262         return super(CollectionFetcher, self).check_exists(url)
263
264     def urljoin(self, base_url, url):
265         if not url:
266             return base_url
267
268         urlsp = urllib.parse.urlsplit(url)
269         if urlsp.scheme or not base_url:
270             return url
271
272         basesp = urllib.parse.urlsplit(base_url)
273         if basesp.scheme in ("keep", "arvwf"):
274             if not basesp.path:
275                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
276
277             baseparts = basesp.path.split("/")
278             urlparts = urlsp.path.split("/") if urlsp.path else []
279
280             locator = baseparts.pop(0)
281
282             if (basesp.scheme == "keep" and
283                 (not arvados.util.keep_locator_pattern.match(locator)) and
284                 (not arvados.util.collection_uuid_pattern.match(locator))):
285                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
286
287             if urlsp.path.startswith("/"):
288                 baseparts = []
289                 urlparts.pop(0)
290
291             if baseparts and urlsp.path:
292                 baseparts.pop()
293
294             path = "/".join([locator] + baseparts + urlparts)
295             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
296
297         return super(CollectionFetcher, self).urljoin(base_url, url)
298
299     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
300
301     def supported_schemes(self):  # type: () -> List[Text]
302         return self.schemes
303
304
305 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
306 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
307
308 def collectionResolver(api_client, document_loader, uri, num_retries=4):
309     if uri.startswith("keep:") or uri.startswith("arvwf:"):
310         return str(uri)
311
312     if workflow_uuid_pattern.match(uri):
313         return u"arvwf:%s#main" % (uri)
314
315     if pipeline_template_uuid_pattern.match(uri):
316         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
317         return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
318
319     p = uri.split("/")
320     if arvados.util.keep_locator_pattern.match(p[0]):
321         return u"keep:%s" % (uri)
322
323     if arvados.util.collection_uuid_pattern.match(p[0]):
324         return u"keep:%s%s" % (api_client.collections().
325                               get(uuid=p[0]).execute()["portable_data_hash"],
326                               uri[len(p[0]):])
327
328     return cwltool.resolver.tool_resolver(document_loader, uri)