13306: Changes to arvados-cwl-runner code after running futurize --stage2
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 from future import standard_library
2 standard_library.install_aliases()
3 from builtins import object
4 # Copyright (C) The Arvados Authors. All rights reserved.
5 #
6 # SPDX-License-Identifier: Apache-2.0
7
8 import fnmatch
9 import os
10 import errno
11 import urllib.parse
12 import re
13 import logging
14 import threading
15 from collections import OrderedDict
16
17 import ruamel.yaml as yaml
18
19 import cwltool.stdfsaccess
20 from cwltool.pathmapper import abspath
21 import cwltool.resolver
22
23 import arvados.util
24 import arvados.collection
25 import arvados.arvfile
26 import arvados.errors
27
28 from googleapiclient.errors import HttpError
29
30 from schema_salad.ref_resolver import DefaultFetcher
31
32 logger = logging.getLogger('arvados.cwl-runner')
33
34 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
35
36 class CollectionCache(object):
37     def __init__(self, api_client, keep_client, num_retries,
38                  cap=256*1024*1024,
39                  min_entries=2):
40         self.api_client = api_client
41         self.keep_client = keep_client
42         self.num_retries = num_retries
43         self.collections = OrderedDict()
44         self.lock = threading.Lock()
45         self.total = 0
46         self.cap = cap
47         self.min_entries = min_entries
48
49     def set_cap(self, cap):
50         self.cap = cap
51
52     def cap_cache(self, required):
53         # ordered dict iterates from oldest to newest
54         for pdh, v in list(self.collections.items()):
55             available = self.cap - self.total
56             if available >= required or len(self.collections) < self.min_entries:
57                 return
58             # cut it loose
59             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
60             del self.collections[pdh]
61             self.total -= v[1]
62
63     def get(self, pdh):
64         with self.lock:
65             if pdh not in self.collections:
66                 m = pdh_size.match(pdh)
67                 if m:
68                     self.cap_cache(int(m.group(2)) * 128)
69                 logger.debug("Creating collection reader for %s", pdh)
70                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
71                                                          keep_client=self.keep_client,
72                                                          num_retries=self.num_retries)
73                 sz = len(cr.manifest_text()) * 128
74                 self.collections[pdh] = (cr, sz)
75                 self.total += sz
76             else:
77                 cr, sz = self.collections[pdh]
78                 # bump it to the back
79                 del self.collections[pdh]
80                 self.collections[pdh] = (cr, sz)
81             return cr
82
83
84 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
85     """Implement the cwltool FsAccess interface for Arvados Collections."""
86
87     def __init__(self, basedir, collection_cache=None):
88         super(CollectionFsAccess, self).__init__(basedir)
89         self.collection_cache = collection_cache
90
91     def get_collection(self, path):
92         sp = path.split("/", 1)
93         p = sp[0]
94         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
95             pdh = p[5:]
96             return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
97         else:
98             return (None, path)
99
100     def _match(self, collection, patternsegments, parent):
101         if not patternsegments:
102             return []
103
104         if not isinstance(collection, arvados.collection.RichCollectionBase):
105             return []
106
107         ret = []
108         # iterate over the files and subcollections in 'collection'
109         for filename in collection:
110             if patternsegments[0] == '.':
111                 # Pattern contains something like "./foo" so just shift
112                 # past the "./"
113                 ret.extend(self._match(collection, patternsegments[1:], parent))
114             elif fnmatch.fnmatch(filename, patternsegments[0]):
115                 cur = os.path.join(parent, filename)
116                 if len(patternsegments) == 1:
117                     ret.append(cur)
118                 else:
119                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
120         return ret
121
122     def glob(self, pattern):
123         collection, rest = self.get_collection(pattern)
124         if collection is not None and not rest:
125             return [pattern]
126         patternsegments = rest.split("/")
127         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
128
129     def open(self, fn, mode):
130         collection, rest = self.get_collection(fn)
131         if collection is not None:
132             return collection.open(rest, mode)
133         else:
134             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
135
136     def exists(self, fn):
137         try:
138             collection, rest = self.get_collection(fn)
139         except HttpError as err:
140             if err.resp.status == 404:
141                 return False
142             else:
143                 raise
144         if collection is not None:
145             if rest:
146                 return collection.exists(rest)
147             else:
148                 return True
149         else:
150             return super(CollectionFsAccess, self).exists(fn)
151
152     def size(self, fn):  # type: (unicode) -> bool
153         collection, rest = self.get_collection(fn)
154         if collection is not None:
155             if rest:
156                 arvfile = collection.find(rest)
157                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
158                     return arvfile.size()
159             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
160         else:
161             return super(CollectionFsAccess, self).size(fn)
162
163     def isfile(self, fn):  # type: (unicode) -> bool
164         collection, rest = self.get_collection(fn)
165         if collection is not None:
166             if rest:
167                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
168             else:
169                 return False
170         else:
171             return super(CollectionFsAccess, self).isfile(fn)
172
173     def isdir(self, fn):  # type: (unicode) -> bool
174         collection, rest = self.get_collection(fn)
175         if collection is not None:
176             if rest:
177                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
178             else:
179                 return True
180         else:
181             return super(CollectionFsAccess, self).isdir(fn)
182
183     def listdir(self, fn):  # type: (unicode) -> List[unicode]
184         collection, rest = self.get_collection(fn)
185         if collection is not None:
186             if rest:
187                 dir = collection.find(rest)
188             else:
189                 dir = collection
190             if dir is None:
191                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
192             if not isinstance(dir, arvados.collection.RichCollectionBase):
193                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
194             return [abspath(l, fn) for l in list(dir.keys())]
195         else:
196             return super(CollectionFsAccess, self).listdir(fn)
197
198     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
199         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
200             return paths[-1]
201         return os.path.join(path, *paths)
202
203     def realpath(self, path):
204         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
205             return path
206         collection, rest = self.get_collection(path)
207         if collection is not None:
208             return path
209         else:
210             return os.path.realpath(path)
211
212 class CollectionFetcher(DefaultFetcher):
213     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
214         super(CollectionFetcher, self).__init__(cache, session)
215         self.api_client = api_client
216         self.fsaccess = fs_access
217         self.num_retries = num_retries
218
219     def fetch_text(self, url):
220         if url.startswith("keep:"):
221             with self.fsaccess.open(url, "r") as f:
222                 return f.read()
223         if url.startswith("arvwf:"):
224             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
225             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
226             return definition
227         return super(CollectionFetcher, self).fetch_text(url)
228
229     def check_exists(self, url):
230         try:
231             if url.startswith("http://arvados.org/cwl"):
232                 return True
233             if url.startswith("keep:"):
234                 return self.fsaccess.exists(url)
235             if url.startswith("arvwf:"):
236                 if self.fetch_text(url):
237                     return True
238         except arvados.errors.NotFoundError:
239             return False
240         except:
241             logger.exception("Got unexpected exception checking if file exists:")
242             return False
243         return super(CollectionFetcher, self).check_exists(url)
244
245     def urljoin(self, base_url, url):
246         if not url:
247             return base_url
248
249         urlsp = urllib.parse.urlsplit(url)
250         if urlsp.scheme or not base_url:
251             return url
252
253         basesp = urllib.parse.urlsplit(base_url)
254         if basesp.scheme in ("keep", "arvwf"):
255             if not basesp.path:
256                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
257
258             baseparts = basesp.path.split("/")
259             urlparts = urlsp.path.split("/") if urlsp.path else []
260
261             pdh = baseparts.pop(0)
262
263             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
264                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
265
266             if urlsp.path.startswith("/"):
267                 baseparts = []
268                 urlparts.pop(0)
269
270             if baseparts and urlsp.path:
271                 baseparts.pop()
272
273             path = "/".join([pdh] + baseparts + urlparts)
274             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
275
276         return super(CollectionFetcher, self).urljoin(base_url, url)
277
278     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
279
280     def supported_schemes(self):  # type: () -> List[Text]
281         return self.schemes
282
283
284 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
285 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
286
287 def collectionResolver(api_client, document_loader, uri, num_retries=4):
288     if uri.startswith("keep:") or uri.startswith("arvwf:"):
289         return uri
290
291     if workflow_uuid_pattern.match(uri):
292         return "arvwf:%s#main" % (uri)
293
294     if pipeline_template_uuid_pattern.match(uri):
295         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
296         return "keep:" + list(pt["components"].values())[0]["script_parameters"]["cwl:tool"]
297
298     p = uri.split("/")
299     if arvados.util.keep_locator_pattern.match(p[0]):
300         return "keep:%s" % (uri)
301
302     if arvados.util.collection_uuid_pattern.match(p[0]):
303         return "keep:%s%s" % (api_client.collections().
304                               get(uuid=p[0]).execute()["portable_data_hash"],
305                               uri[len(p[0]):])
306
307     return cwltool.resolver.tool_resolver(document_loader, uri)