13306: Fixed formatting for newly introduced imports
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8
9 import fnmatch
10 import os
11 import errno
12 import urllib.parse
13 import re
14 import logging
15 import threading
16 from collections import OrderedDict
17
18 import ruamel.yaml as yaml
19
20 import cwltool.stdfsaccess
21 from cwltool.pathmapper import abspath
22 import cwltool.resolver
23
24 import arvados.util
25 import arvados.collection
26 import arvados.arvfile
27 import arvados.errors
28
29 from googleapiclient.errors import HttpError
30
31 from schema_salad.ref_resolver import DefaultFetcher
32
33 logger = logging.getLogger('arvados.cwl-runner')
34
35 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
36
37 class CollectionCache(object):
38     def __init__(self, api_client, keep_client, num_retries,
39                  cap=256*1024*1024,
40                  min_entries=2):
41         self.api_client = api_client
42         self.keep_client = keep_client
43         self.num_retries = num_retries
44         self.collections = OrderedDict()
45         self.lock = threading.Lock()
46         self.total = 0
47         self.cap = cap
48         self.min_entries = min_entries
49
50     def set_cap(self, cap):
51         self.cap = cap
52
53     def cap_cache(self, required):
54         # ordered dict iterates from oldest to newest
55         for pdh, v in list(self.collections.items()):
56             available = self.cap - self.total
57             if available >= required or len(self.collections) < self.min_entries:
58                 return
59             # cut it loose
60             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
61             del self.collections[pdh]
62             self.total -= v[1]
63
64     def get(self, pdh):
65         with self.lock:
66             if pdh not in self.collections:
67                 m = pdh_size.match(pdh)
68                 if m:
69                     self.cap_cache(int(m.group(2)) * 128)
70                 logger.debug("Creating collection reader for %s", pdh)
71                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
72                                                          keep_client=self.keep_client,
73                                                          num_retries=self.num_retries)
74                 sz = len(cr.manifest_text()) * 128
75                 self.collections[pdh] = (cr, sz)
76                 self.total += sz
77             else:
78                 cr, sz = self.collections[pdh]
79                 # bump it to the back
80                 del self.collections[pdh]
81                 self.collections[pdh] = (cr, sz)
82             return cr
83
84
85 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
86     """Implement the cwltool FsAccess interface for Arvados Collections."""
87
88     def __init__(self, basedir, collection_cache=None):
89         super(CollectionFsAccess, self).__init__(basedir)
90         self.collection_cache = collection_cache
91
92     def get_collection(self, path):
93         sp = path.split("/", 1)
94         p = sp[0]
95         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
96             pdh = p[5:]
97             return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
98         else:
99             return (None, path)
100
101     def _match(self, collection, patternsegments, parent):
102         if not patternsegments:
103             return []
104
105         if not isinstance(collection, arvados.collection.RichCollectionBase):
106             return []
107
108         ret = []
109         # iterate over the files and subcollections in 'collection'
110         for filename in collection:
111             if patternsegments[0] == '.':
112                 # Pattern contains something like "./foo" so just shift
113                 # past the "./"
114                 ret.extend(self._match(collection, patternsegments[1:], parent))
115             elif fnmatch.fnmatch(filename, patternsegments[0]):
116                 cur = os.path.join(parent, filename)
117                 if len(patternsegments) == 1:
118                     ret.append(cur)
119                 else:
120                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
121         return ret
122
123     def glob(self, pattern):
124         collection, rest = self.get_collection(pattern)
125         if collection is not None and not rest:
126             return [pattern]
127         patternsegments = rest.split("/")
128         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
129
130     def open(self, fn, mode):
131         collection, rest = self.get_collection(fn)
132         if collection is not None:
133             return collection.open(rest, mode)
134         else:
135             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
136
137     def exists(self, fn):
138         try:
139             collection, rest = self.get_collection(fn)
140         except HttpError as err:
141             if err.resp.status == 404:
142                 return False
143             else:
144                 raise
145         if collection is not None:
146             if rest:
147                 return collection.exists(rest)
148             else:
149                 return True
150         else:
151             return super(CollectionFsAccess, self).exists(fn)
152
153     def size(self, fn):  # type: (unicode) -> bool
154         collection, rest = self.get_collection(fn)
155         if collection is not None:
156             if rest:
157                 arvfile = collection.find(rest)
158                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
159                     return arvfile.size()
160             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
161         else:
162             return super(CollectionFsAccess, self).size(fn)
163
164     def isfile(self, fn):  # type: (unicode) -> bool
165         collection, rest = self.get_collection(fn)
166         if collection is not None:
167             if rest:
168                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
169             else:
170                 return False
171         else:
172             return super(CollectionFsAccess, self).isfile(fn)
173
174     def isdir(self, fn):  # type: (unicode) -> bool
175         collection, rest = self.get_collection(fn)
176         if collection is not None:
177             if rest:
178                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
179             else:
180                 return True
181         else:
182             return super(CollectionFsAccess, self).isdir(fn)
183
184     def listdir(self, fn):  # type: (unicode) -> List[unicode]
185         collection, rest = self.get_collection(fn)
186         if collection is not None:
187             if rest:
188                 dir = collection.find(rest)
189             else:
190                 dir = collection
191             if dir is None:
192                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
193             if not isinstance(dir, arvados.collection.RichCollectionBase):
194                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
195             return [abspath(l, fn) for l in list(dir.keys())]
196         else:
197             return super(CollectionFsAccess, self).listdir(fn)
198
199     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
200         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
201             return paths[-1]
202         return os.path.join(path, *paths)
203
204     def realpath(self, path):
205         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
206             return path
207         collection, rest = self.get_collection(path)
208         if collection is not None:
209             return path
210         else:
211             return os.path.realpath(path)
212
213 class CollectionFetcher(DefaultFetcher):
214     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
215         super(CollectionFetcher, self).__init__(cache, session)
216         self.api_client = api_client
217         self.fsaccess = fs_access
218         self.num_retries = num_retries
219
220     def fetch_text(self, url):
221         if url.startswith("keep:"):
222             with self.fsaccess.open(url, "r") as f:
223                 return f.read()
224         if url.startswith("arvwf:"):
225             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
226             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
227             return definition
228         return super(CollectionFetcher, self).fetch_text(url)
229
230     def check_exists(self, url):
231         try:
232             if url.startswith("http://arvados.org/cwl"):
233                 return True
234             if url.startswith("keep:"):
235                 return self.fsaccess.exists(url)
236             if url.startswith("arvwf:"):
237                 if self.fetch_text(url):
238                     return True
239         except arvados.errors.NotFoundError:
240             return False
241         except:
242             logger.exception("Got unexpected exception checking if file exists:")
243             return False
244         return super(CollectionFetcher, self).check_exists(url)
245
246     def urljoin(self, base_url, url):
247         if not url:
248             return base_url
249
250         urlsp = urllib.parse.urlsplit(url)
251         if urlsp.scheme or not base_url:
252             return url
253
254         basesp = urllib.parse.urlsplit(base_url)
255         if basesp.scheme in ("keep", "arvwf"):
256             if not basesp.path:
257                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
258
259             baseparts = basesp.path.split("/")
260             urlparts = urlsp.path.split("/") if urlsp.path else []
261
262             pdh = baseparts.pop(0)
263
264             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
265                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
266
267             if urlsp.path.startswith("/"):
268                 baseparts = []
269                 urlparts.pop(0)
270
271             if baseparts and urlsp.path:
272                 baseparts.pop()
273
274             path = "/".join([pdh] + baseparts + urlparts)
275             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
276
277         return super(CollectionFetcher, self).urljoin(base_url, url)
278
279     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
280
281     def supported_schemes(self):  # type: () -> List[Text]
282         return self.schemes
283
284
285 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
286 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
287
288 def collectionResolver(api_client, document_loader, uri, num_retries=4):
289     if uri.startswith("keep:") or uri.startswith("arvwf:"):
290         return uri
291
292     if workflow_uuid_pattern.match(uri):
293         return "arvwf:%s#main" % (uri)
294
295     if pipeline_template_uuid_pattern.match(uri):
296         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
297         return "keep:" + list(pt["components"].values())[0]["script_parameters"]["cwl:tool"]
298
299     p = uri.split("/")
300     if arvados.util.keep_locator_pattern.match(p[0]):
301         return "keep:%s" % (uri)
302
303     if arvados.util.collection_uuid_pattern.match(p[0]):
304         return "keep:%s%s" % (api_client.collections().
305                               get(uuid=p[0]).execute()["portable_data_hash"],
306                               uri[len(p[0]):])
307
308     return cwltool.resolver.tool_resolver(document_loader, uri)