14360: Merge branch 'master' into 14360-dispatch-cloud
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from googleapiclient.errors import HttpError
26
27 from schema_salad.ref_resolver import DefaultFetcher
28
29 logger = logging.getLogger('arvados.cwl-runner')
30
31 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
32
33 class CollectionCache(object):
34     def __init__(self, api_client, keep_client, num_retries,
35                  cap=256*1024*1024,
36                  min_entries=2):
37         self.api_client = api_client
38         self.keep_client = keep_client
39         self.num_retries = num_retries
40         self.collections = OrderedDict()
41         self.lock = threading.Lock()
42         self.total = 0
43         self.cap = cap
44         self.min_entries = min_entries
45
46     def set_cap(self, cap):
47         self.cap = cap
48
49     def cap_cache(self, required):
50         # ordered dict iterates from oldest to newest
51         for pdh, v in self.collections.items():
52             available = self.cap - self.total
53             if available >= required or len(self.collections) < self.min_entries:
54                 return
55             # cut it loose
56             logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
57             del self.collections[pdh]
58             self.total -= v[1]
59
60     def get(self, pdh):
61         with self.lock:
62             if pdh not in self.collections:
63                 m = pdh_size.match(pdh)
64                 if m:
65                     self.cap_cache(int(m.group(2)) * 128)
66                 logger.debug("Creating collection reader for %s", pdh)
67                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
68                                                          keep_client=self.keep_client,
69                                                          num_retries=self.num_retries)
70                 sz = len(cr.manifest_text()) * 128
71                 self.collections[pdh] = (cr, sz)
72                 self.total += sz
73             else:
74                 cr, sz = self.collections[pdh]
75                 # bump it to the back
76                 del self.collections[pdh]
77                 self.collections[pdh] = (cr, sz)
78             return cr
79
80
81 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
82     """Implement the cwltool FsAccess interface for Arvados Collections."""
83
84     def __init__(self, basedir, collection_cache=None):
85         super(CollectionFsAccess, self).__init__(basedir)
86         self.collection_cache = collection_cache
87
88     def get_collection(self, path):
89         sp = path.split("/", 1)
90         p = sp[0]
91         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
92             pdh = p[5:]
93             return (self.collection_cache.get(pdh), urlparse.unquote(sp[1]) if len(sp) == 2 else None)
94         else:
95             return (None, path)
96
97     def _match(self, collection, patternsegments, parent):
98         if not patternsegments:
99             return []
100
101         if not isinstance(collection, arvados.collection.RichCollectionBase):
102             return []
103
104         ret = []
105         # iterate over the files and subcollections in 'collection'
106         for filename in collection:
107             if patternsegments[0] == '.':
108                 # Pattern contains something like "./foo" so just shift
109                 # past the "./"
110                 ret.extend(self._match(collection, patternsegments[1:], parent))
111             elif fnmatch.fnmatch(filename, patternsegments[0]):
112                 cur = os.path.join(parent, filename)
113                 if len(patternsegments) == 1:
114                     ret.append(cur)
115                 else:
116                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
117         return ret
118
119     def glob(self, pattern):
120         collection, rest = self.get_collection(pattern)
121         if collection is not None and not rest:
122             return [pattern]
123         patternsegments = rest.split("/")
124         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
125
126     def open(self, fn, mode):
127         collection, rest = self.get_collection(fn)
128         if collection is not None:
129             return collection.open(rest, mode)
130         else:
131             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
132
133     def exists(self, fn):
134         try:
135             collection, rest = self.get_collection(fn)
136         except HttpError as err:
137             if err.resp.status == 404:
138                 return False
139             else:
140                 raise
141         if collection is not None:
142             if rest:
143                 return collection.exists(rest)
144             else:
145                 return True
146         else:
147             return super(CollectionFsAccess, self).exists(fn)
148
149     def size(self, fn):  # type: (unicode) -> bool
150         collection, rest = self.get_collection(fn)
151         if collection is not None:
152             if rest:
153                 arvfile = collection.find(rest)
154                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
155                     return arvfile.size()
156             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
157         else:
158             return super(CollectionFsAccess, self).size(fn)
159
160     def isfile(self, fn):  # type: (unicode) -> bool
161         collection, rest = self.get_collection(fn)
162         if collection is not None:
163             if rest:
164                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
165             else:
166                 return False
167         else:
168             return super(CollectionFsAccess, self).isfile(fn)
169
170     def isdir(self, fn):  # type: (unicode) -> bool
171         collection, rest = self.get_collection(fn)
172         if collection is not None:
173             if rest:
174                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
175             else:
176                 return True
177         else:
178             return super(CollectionFsAccess, self).isdir(fn)
179
180     def listdir(self, fn):  # type: (unicode) -> List[unicode]
181         collection, rest = self.get_collection(fn)
182         if collection is not None:
183             if rest:
184                 dir = collection.find(rest)
185             else:
186                 dir = collection
187             if dir is None:
188                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
189             if not isinstance(dir, arvados.collection.RichCollectionBase):
190                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
191             return [abspath(l, fn) for l in dir.keys()]
192         else:
193             return super(CollectionFsAccess, self).listdir(fn)
194
195     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
196         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
197             return paths[-1]
198         return os.path.join(path, *paths)
199
200     def realpath(self, path):
201         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
202             return path
203         collection, rest = self.get_collection(path)
204         if collection is not None:
205             return path
206         else:
207             return os.path.realpath(path)
208
209 class CollectionFetcher(DefaultFetcher):
210     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
211         super(CollectionFetcher, self).__init__(cache, session)
212         self.api_client = api_client
213         self.fsaccess = fs_access
214         self.num_retries = num_retries
215
216     def fetch_text(self, url):
217         if url.startswith("keep:"):
218             with self.fsaccess.open(url, "r") as f:
219                 return f.read()
220         if url.startswith("arvwf:"):
221             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
222             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
223             return definition
224         return super(CollectionFetcher, self).fetch_text(url)
225
226     def check_exists(self, url):
227         try:
228             if url.startswith("http://arvados.org/cwl"):
229                 return True
230             if url.startswith("keep:"):
231                 return self.fsaccess.exists(url)
232             if url.startswith("arvwf:"):
233                 if self.fetch_text(url):
234                     return True
235         except arvados.errors.NotFoundError:
236             return False
237         except:
238             logger.exception("Got unexpected exception checking if file exists:")
239             return False
240         return super(CollectionFetcher, self).check_exists(url)
241
242     def urljoin(self, base_url, url):
243         if not url:
244             return base_url
245
246         urlsp = urlparse.urlsplit(url)
247         if urlsp.scheme or not base_url:
248             return url
249
250         basesp = urlparse.urlsplit(base_url)
251         if basesp.scheme in ("keep", "arvwf"):
252             if not basesp.path:
253                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
254
255             baseparts = basesp.path.split("/")
256             urlparts = urlsp.path.split("/") if urlsp.path else []
257
258             pdh = baseparts.pop(0)
259
260             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
261                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
262
263             if urlsp.path.startswith("/"):
264                 baseparts = []
265                 urlparts.pop(0)
266
267             if baseparts and urlsp.path:
268                 baseparts.pop()
269
270             path = "/".join([pdh] + baseparts + urlparts)
271             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
272
273         return super(CollectionFetcher, self).urljoin(base_url, url)
274
275     schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
276
277     def supported_schemes(self):  # type: () -> List[Text]
278         return self.schemes
279
280
281 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
282 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
283
284 def collectionResolver(api_client, document_loader, uri, num_retries=4):
285     if uri.startswith("keep:") or uri.startswith("arvwf:"):
286         return uri
287
288     if workflow_uuid_pattern.match(uri):
289         return "arvwf:%s#main" % (uri)
290
291     if pipeline_template_uuid_pattern.match(uri):
292         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
293         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
294
295     p = uri.split("/")
296     if arvados.util.keep_locator_pattern.match(p[0]):
297         return "keep:%s" % (uri)
298
299     if arvados.util.collection_uuid_pattern.match(p[0]):
300         return "keep:%s%s" % (api_client.collections().
301                               get(uuid=p[0]).execute()["portable_data_hash"],
302                               uri[len(p[0]):])
303
304     return cwltool.resolver.tool_resolver(document_loader, uri)