13306: Fixing previous commit to use viewvalues instead of listvalues
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from future.utils import viewvalues
9
10 import fnmatch
11 import os
12 import errno
13 import urllib.parse
14 import re
15 import logging
16 import threading
17 from collections import OrderedDict
18
19 import ruamel.yaml as yaml
20
21 import cwltool.stdfsaccess
22 from cwltool.pathmapper import abspath
23 import cwltool.resolver
24
25 import arvados.util
26 import arvados.collection
27 import arvados.arvfile
28 import arvados.errors
29
30 from googleapiclient.errors import HttpError
31
32 from schema_salad.ref_resolver import DefaultFetcher
33
34 logger = logging.getLogger('arvados.cwl-runner')
35
36 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
37
38 class CollectionCache(object):
39     def __init__(self, api_client, keep_client, num_retries,
40                  cap=256*1024*1024,
41                  min_entries=2):
42         self.api_client = api_client
43         self.keep_client = keep_client
44         self.num_retries = num_retries
45         self.collections = OrderedDict()
46         self.lock = threading.Lock()
47         self.total = 0
48         self.cap = cap
49         self.min_entries = min_entries
50
51     def set_cap(self, cap):
52         self.cap = cap
53
54     def cap_cache(self, required):
55         # ordered dict iterates from oldest to newest
56         for pdh, v in list(self.collections.items()):
57             available = self.cap - self.total
58             if available >= required or len(self.collections) < self.min_entries:
59                 return
60             # cut it loose
61             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
62             del self.collections[pdh]
63             self.total -= v[1]
64
65     def get(self, pdh):
66         with self.lock:
67             if pdh not in self.collections:
68                 m = pdh_size.match(pdh)
69                 if m:
70                     self.cap_cache(int(m.group(2)) * 128)
71                 logger.debug("Creating collection reader for %s", pdh)
72                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
73                                                          keep_client=self.keep_client,
74                                                          num_retries=self.num_retries)
75                 sz = len(cr.manifest_text()) * 128
76                 self.collections[pdh] = (cr, sz)
77                 self.total += sz
78             else:
79                 cr, sz = self.collections[pdh]
80                 # bump it to the back
81                 del self.collections[pdh]
82                 self.collections[pdh] = (cr, sz)
83             return cr
84
85
86 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
87     """Implement the cwltool FsAccess interface for Arvados Collections."""
88
89     def __init__(self, basedir, collection_cache=None):
90         super(CollectionFsAccess, self).__init__(basedir)
91         self.collection_cache = collection_cache
92
93     def get_collection(self, path):
94         sp = path.split("/", 1)
95         p = sp[0]
96         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
97             pdh = p[5:]
98             return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
99         else:
100             return (None, path)
101
102     def _match(self, collection, patternsegments, parent):
103         if not patternsegments:
104             return []
105
106         if not isinstance(collection, arvados.collection.RichCollectionBase):
107             return []
108
109         ret = []
110         # iterate over the files and subcollections in 'collection'
111         for filename in collection:
112             if patternsegments[0] == '.':
113                 # Pattern contains something like "./foo" so just shift
114                 # past the "./"
115                 ret.extend(self._match(collection, patternsegments[1:], parent))
116             elif fnmatch.fnmatch(filename, patternsegments[0]):
117                 cur = os.path.join(parent, filename)
118                 if len(patternsegments) == 1:
119                     ret.append(cur)
120                 else:
121                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
122         return ret
123
124     def glob(self, pattern):
125         collection, rest = self.get_collection(pattern)
126         if collection is not None and not rest:
127             return [pattern]
128         patternsegments = rest.split("/")
129         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
130
131     def open(self, fn, mode):
132         collection, rest = self.get_collection(fn)
133         if collection is not None:
134             return collection.open(rest, mode)
135         else:
136             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
137
138     def exists(self, fn):
139         try:
140             collection, rest = self.get_collection(fn)
141         except HttpError as err:
142             if err.resp.status == 404:
143                 return False
144             else:
145                 raise
146         if collection is not None:
147             if rest:
148                 return collection.exists(rest)
149             else:
150                 return True
151         else:
152             return super(CollectionFsAccess, self).exists(fn)
153
154     def size(self, fn):  # type: (unicode) -> bool
155         collection, rest = self.get_collection(fn)
156         if collection is not None:
157             if rest:
158                 arvfile = collection.find(rest)
159                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
160                     return arvfile.size()
161             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
162         else:
163             return super(CollectionFsAccess, self).size(fn)
164
165     def isfile(self, fn):  # type: (unicode) -> bool
166         collection, rest = self.get_collection(fn)
167         if collection is not None:
168             if rest:
169                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
170             else:
171                 return False
172         else:
173             return super(CollectionFsAccess, self).isfile(fn)
174
175     def isdir(self, fn):  # type: (unicode) -> bool
176         collection, rest = self.get_collection(fn)
177         if collection is not None:
178             if rest:
179                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
180             else:
181                 return True
182         else:
183             return super(CollectionFsAccess, self).isdir(fn)
184
185     def listdir(self, fn):  # type: (unicode) -> List[unicode]
186         collection, rest = self.get_collection(fn)
187         if collection is not None:
188             if rest:
189                 dir = collection.find(rest)
190             else:
191                 dir = collection
192             if dir is None:
193                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
194             if not isinstance(dir, arvados.collection.RichCollectionBase):
195                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
196             return [abspath(l, fn) for l in list(dir.keys())]
197         else:
198             return super(CollectionFsAccess, self).listdir(fn)
199
200     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
201         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
202             return paths[-1]
203         return os.path.join(path, *paths)
204
205     def realpath(self, path):
206         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
207             return path
208         collection, rest = self.get_collection(path)
209         if collection is not None:
210             return path
211         else:
212             return os.path.realpath(path)
213
214 class CollectionFetcher(DefaultFetcher):
215     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
216         super(CollectionFetcher, self).__init__(cache, session)
217         self.api_client = api_client
218         self.fsaccess = fs_access
219         self.num_retries = num_retries
220
221     def fetch_text(self, url):
222         if url.startswith("keep:"):
223             with self.fsaccess.open(url, "r") as f:
224                 return f.read()
225         if url.startswith("arvwf:"):
226             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
227             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
228             return definition
229         return super(CollectionFetcher, self).fetch_text(url)
230
231     def check_exists(self, url):
232         try:
233             if url.startswith("http://arvados.org/cwl"):
234                 return True
235             if url.startswith("keep:"):
236                 return self.fsaccess.exists(url)
237             if url.startswith("arvwf:"):
238                 if self.fetch_text(url):
239                     return True
240         except arvados.errors.NotFoundError:
241             return False
242         except:
243             logger.exception("Got unexpected exception checking if file exists:")
244             return False
245         return super(CollectionFetcher, self).check_exists(url)
246
247     def urljoin(self, base_url, url):
248         if not url:
249             return base_url
250
251         urlsp = urllib.parse.urlsplit(url)
252         if urlsp.scheme or not base_url:
253             return url
254
255         basesp = urllib.parse.urlsplit(base_url)
256         if basesp.scheme in ("keep", "arvwf"):
257             if not basesp.path:
258                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
259
260             baseparts = basesp.path.split("/")
261             urlparts = urlsp.path.split("/") if urlsp.path else []
262
263             pdh = baseparts.pop(0)
264
265             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
266                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
267
268             if urlsp.path.startswith("/"):
269                 baseparts = []
270                 urlparts.pop(0)
271
272             if baseparts and urlsp.path:
273                 baseparts.pop()
274
275             path = "/".join([pdh] + baseparts + urlparts)
276             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
277
278         return super(CollectionFetcher, self).urljoin(base_url, url)
279
280     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
281
282     def supported_schemes(self):  # type: () -> List[Text]
283         return self.schemes
284
285
286 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
287 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
288
289 def collectionResolver(api_client, document_loader, uri, num_retries=4):
290     if uri.startswith("keep:") or uri.startswith("arvwf:"):
291         return uri
292
293     if workflow_uuid_pattern.match(uri):
294         return "arvwf:%s#main" % (uri)
295
296     if pipeline_template_uuid_pattern.match(uri):
297         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
298         return "keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
299
300     p = uri.split("/")
301     if arvados.util.keep_locator_pattern.match(p[0]):
302         return "keep:%s" % (uri)
303
304     if arvados.util.collection_uuid_pattern.match(p[0]):
305         return "keep:%s%s" % (api_client.collections().
306                               get(uuid=p[0]).execute()["portable_data_hash"],
307                               uri[len(p[0]):])
308
309     return cwltool.resolver.tool_resolver(document_loader, uri)