Merge branch '4019-query-properties' closes #4019
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from schema_salad.ref_resolver import DefaultFetcher
26
27 logger = logging.getLogger('arvados.cwl-runner')
28
29 class CollectionCache(object):
30     def __init__(self, api_client, keep_client, num_retries,
31                  cap=256*1024*1024,
32                  min_entries=2):
33         self.api_client = api_client
34         self.keep_client = keep_client
35         self.collections = OrderedDict()
36         self.lock = threading.Lock()
37         self.total = 0
38         self.cap = cap
39         self.min_entries = min_entries
40
41     def cap_cache(self):
42         if self.total > self.cap:
43             # ordered list iterates from oldest to newest
44             for pdh, v in self.collections.items():
45                 if self.total < self.cap or len(self.collections) < self.min_entries:
46                     break
47                 # cut it loose
48                 logger.debug("Evicting collection reader %s from cache", pdh)
49                 del self.collections[pdh]
50                 self.total -= v[1]
51
52     def get(self, pdh):
53         with self.lock:
54             if pdh not in self.collections:
55                 logger.debug("Creating collection reader for %s", pdh)
56                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
57                                                          keep_client=self.keep_client)
58                 sz = len(cr.manifest_text()) * 128
59                 self.collections[pdh] = (cr, sz)
60                 self.total += sz
61                 self.cap_cache()
62             else:
63                 cr, sz = self.collections[pdh]
64                 # bump it to the back
65                 del self.collections[pdh]
66                 self.collections[pdh] = (cr, sz)
67             return cr
68
69
70 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
71     """Implement the cwltool FsAccess interface for Arvados Collections."""
72
73     def __init__(self, basedir, collection_cache=None):
74         super(CollectionFsAccess, self).__init__(basedir)
75         self.collection_cache = collection_cache
76
77     def get_collection(self, path):
78         sp = path.split("/", 1)
79         p = sp[0]
80         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
81             pdh = p[5:]
82             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
83         else:
84             return (None, path)
85
86     def _match(self, collection, patternsegments, parent):
87         if not patternsegments:
88             return []
89
90         if not isinstance(collection, arvados.collection.RichCollectionBase):
91             return []
92
93         ret = []
94         # iterate over the files and subcollections in 'collection'
95         for filename in collection:
96             if patternsegments[0] == '.':
97                 # Pattern contains something like "./foo" so just shift
98                 # past the "./"
99                 ret.extend(self._match(collection, patternsegments[1:], parent))
100             elif fnmatch.fnmatch(filename, patternsegments[0]):
101                 cur = os.path.join(parent, filename)
102                 if len(patternsegments) == 1:
103                     ret.append(cur)
104                 else:
105                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
106         return ret
107
108     def glob(self, pattern):
109         collection, rest = self.get_collection(pattern)
110         if collection is not None and not rest:
111             return [pattern]
112         patternsegments = rest.split("/")
113         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
114
115     def open(self, fn, mode):
116         collection, rest = self.get_collection(fn)
117         if collection is not None:
118             return collection.open(rest, mode)
119         else:
120             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
121
122     def exists(self, fn):
123         collection, rest = self.get_collection(fn)
124         if collection is not None:
125             if rest:
126                 return collection.exists(rest)
127             else:
128                 return True
129         else:
130             return super(CollectionFsAccess, self).exists(fn)
131
132     def isfile(self, fn):  # type: (unicode) -> bool
133         collection, rest = self.get_collection(fn)
134         if collection is not None:
135             if rest:
136                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
137             else:
138                 return False
139         else:
140             return super(CollectionFsAccess, self).isfile(fn)
141
142     def isdir(self, fn):  # type: (unicode) -> bool
143         collection, rest = self.get_collection(fn)
144         if collection is not None:
145             if rest:
146                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
147             else:
148                 return True
149         else:
150             return super(CollectionFsAccess, self).isdir(fn)
151
152     def listdir(self, fn):  # type: (unicode) -> List[unicode]
153         collection, rest = self.get_collection(fn)
154         if collection is not None:
155             if rest:
156                 dir = collection.find(rest)
157             else:
158                 dir = collection
159             if dir is None:
160                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
161             if not isinstance(dir, arvados.collection.RichCollectionBase):
162                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
163             return [abspath(l, fn) for l in dir.keys()]
164         else:
165             return super(CollectionFsAccess, self).listdir(fn)
166
167     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
168         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
169             return paths[-1]
170         return os.path.join(path, *paths)
171
172     def realpath(self, path):
173         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
174             return path
175         collection, rest = self.get_collection(path)
176         if collection is not None:
177             return path
178         else:
179             return os.path.realpath(path)
180
181 class CollectionFetcher(DefaultFetcher):
182     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4, overrides=None):
183         super(CollectionFetcher, self).__init__(cache, session)
184         self.api_client = api_client
185         self.fsaccess = fs_access
186         self.num_retries = num_retries
187         self.overrides = overrides if overrides else {}
188
189     def fetch_text(self, url):
190         if url in self.overrides:
191             return self.overrides[url]
192         if url.startswith("keep:"):
193             with self.fsaccess.open(url, "r") as f:
194                 return f.read()
195         if url.startswith("arvwf:"):
196             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
197             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
198             return definition
199         return super(CollectionFetcher, self).fetch_text(url)
200
201     def check_exists(self, url):
202         if url in self.overrides:
203             return True
204         try:
205             if url.startswith("http://arvados.org/cwl"):
206                 return True
207             if url.startswith("keep:"):
208                 return self.fsaccess.exists(url)
209             if url.startswith("arvwf:"):
210                 if self.fetch_text(url):
211                     return True
212         except arvados.errors.NotFoundError:
213             return False
214         except:
215             logger.exception("Got unexpected exception checking if file exists:")
216             return False
217         return super(CollectionFetcher, self).check_exists(url)
218
219     def urljoin(self, base_url, url):
220         if not url:
221             return base_url
222
223         urlsp = urlparse.urlsplit(url)
224         if urlsp.scheme or not base_url:
225             return url
226
227         basesp = urlparse.urlsplit(base_url)
228         if basesp.scheme in ("keep", "arvwf"):
229             if not basesp.path:
230                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
231
232             baseparts = basesp.path.split("/")
233             urlparts = urlsp.path.split("/") if urlsp.path else []
234
235             pdh = baseparts.pop(0)
236
237             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
238                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
239
240             if urlsp.path.startswith("/"):
241                 baseparts = []
242                 urlparts.pop(0)
243
244             if baseparts and urlsp.path:
245                 baseparts.pop()
246
247             path = "/".join([pdh] + baseparts + urlparts)
248             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
249
250         return super(CollectionFetcher, self).urljoin(base_url, url)
251
252 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
253 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
254
255 def collectionResolver(api_client, document_loader, uri, num_retries=4):
256     if uri.startswith("keep:") or uri.startswith("arvwf:"):
257         return uri
258
259     if workflow_uuid_pattern.match(uri):
260         return "arvwf:%s#main" % (uri)
261
262     if pipeline_template_uuid_pattern.match(uri):
263         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
264         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
265
266     p = uri.split("/")
267     if arvados.util.keep_locator_pattern.match(p[0]):
268         return "keep:%s" % (uri)
269
270     if arvados.util.collection_uuid_pattern.match(p[0]):
271         return "keep:%s%s" % (api_client.collections().
272                               get(uuid=p[0]).execute()["portable_data_hash"],
273                               uri[len(p[0]):])
274
275     return cwltool.resolver.tool_resolver(document_loader, uri)