Merge branch 'wtsi/13113-acr-collectioncache-retries' refs #13113
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from schema_salad.ref_resolver import DefaultFetcher
26
27 logger = logging.getLogger('arvados.cwl-runner')
28
29 class CollectionCache(object):
30     def __init__(self, api_client, keep_client, num_retries,
31                  cap=256*1024*1024,
32                  min_entries=2):
33         self.api_client = api_client
34         self.keep_client = keep_client
35         self.num_retries = num_retries
36         self.collections = OrderedDict()
37         self.lock = threading.Lock()
38         self.total = 0
39         self.cap = cap
40         self.min_entries = min_entries
41
42     def cap_cache(self):
43         if self.total > self.cap:
44             # ordered list iterates from oldest to newest
45             for pdh, v in self.collections.items():
46                 if self.total < self.cap or len(self.collections) < self.min_entries:
47                     break
48                 # cut it loose
49                 logger.debug("Evicting collection reader %s from cache", pdh)
50                 del self.collections[pdh]
51                 self.total -= v[1]
52
53     def get(self, pdh):
54         with self.lock:
55             if pdh not in self.collections:
56                 logger.debug("Creating collection reader for %s", pdh)
57                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
58                                                          keep_client=self.keep_client,
59                                                          num_retries=self.num_retries)
60                 sz = len(cr.manifest_text()) * 128
61                 self.collections[pdh] = (cr, sz)
62                 self.total += sz
63                 self.cap_cache()
64             else:
65                 cr, sz = self.collections[pdh]
66                 # bump it to the back
67                 del self.collections[pdh]
68                 self.collections[pdh] = (cr, sz)
69             return cr
70
71
72 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
73     """Implement the cwltool FsAccess interface for Arvados Collections."""
74
75     def __init__(self, basedir, collection_cache=None):
76         super(CollectionFsAccess, self).__init__(basedir)
77         self.collection_cache = collection_cache
78
79     def get_collection(self, path):
80         sp = path.split("/", 1)
81         p = sp[0]
82         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
83             pdh = p[5:]
84             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
85         else:
86             return (None, path)
87
88     def _match(self, collection, patternsegments, parent):
89         if not patternsegments:
90             return []
91
92         if not isinstance(collection, arvados.collection.RichCollectionBase):
93             return []
94
95         ret = []
96         # iterate over the files and subcollections in 'collection'
97         for filename in collection:
98             if patternsegments[0] == '.':
99                 # Pattern contains something like "./foo" so just shift
100                 # past the "./"
101                 ret.extend(self._match(collection, patternsegments[1:], parent))
102             elif fnmatch.fnmatch(filename, patternsegments[0]):
103                 cur = os.path.join(parent, filename)
104                 if len(patternsegments) == 1:
105                     ret.append(cur)
106                 else:
107                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
108         return ret
109
110     def glob(self, pattern):
111         collection, rest = self.get_collection(pattern)
112         if collection is not None and not rest:
113             return [pattern]
114         patternsegments = rest.split("/")
115         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
116
117     def open(self, fn, mode):
118         collection, rest = self.get_collection(fn)
119         if collection is not None:
120             return collection.open(rest, mode)
121         else:
122             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
123
124     def exists(self, fn):
125         collection, rest = self.get_collection(fn)
126         if collection is not None:
127             if rest:
128                 return collection.exists(rest)
129             else:
130                 return True
131         else:
132             return super(CollectionFsAccess, self).exists(fn)
133
134     def isfile(self, fn):  # type: (unicode) -> bool
135         collection, rest = self.get_collection(fn)
136         if collection is not None:
137             if rest:
138                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
139             else:
140                 return False
141         else:
142             return super(CollectionFsAccess, self).isfile(fn)
143
144     def isdir(self, fn):  # type: (unicode) -> bool
145         collection, rest = self.get_collection(fn)
146         if collection is not None:
147             if rest:
148                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
149             else:
150                 return True
151         else:
152             return super(CollectionFsAccess, self).isdir(fn)
153
154     def listdir(self, fn):  # type: (unicode) -> List[unicode]
155         collection, rest = self.get_collection(fn)
156         if collection is not None:
157             if rest:
158                 dir = collection.find(rest)
159             else:
160                 dir = collection
161             if dir is None:
162                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
163             if not isinstance(dir, arvados.collection.RichCollectionBase):
164                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
165             return [abspath(l, fn) for l in dir.keys()]
166         else:
167             return super(CollectionFsAccess, self).listdir(fn)
168
169     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
170         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
171             return paths[-1]
172         return os.path.join(path, *paths)
173
174     def realpath(self, path):
175         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
176             return path
177         collection, rest = self.get_collection(path)
178         if collection is not None:
179             return path
180         else:
181             return os.path.realpath(path)
182
183 class CollectionFetcher(DefaultFetcher):
184     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
185         super(CollectionFetcher, self).__init__(cache, session)
186         self.api_client = api_client
187         self.fsaccess = fs_access
188         self.num_retries = num_retries
189
190     def fetch_text(self, url):
191         if url.startswith("keep:"):
192             with self.fsaccess.open(url, "r") as f:
193                 return f.read()
194         if url.startswith("arvwf:"):
195             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
196             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
197             return definition
198         return super(CollectionFetcher, self).fetch_text(url)
199
200     def check_exists(self, url):
201         try:
202             if url.startswith("http://arvados.org/cwl"):
203                 return True
204             if url.startswith("keep:"):
205                 return self.fsaccess.exists(url)
206             if url.startswith("arvwf:"):
207                 if self.fetch_text(url):
208                     return True
209         except arvados.errors.NotFoundError:
210             return False
211         except:
212             logger.exception("Got unexpected exception checking if file exists:")
213             return False
214         return super(CollectionFetcher, self).check_exists(url)
215
216     def urljoin(self, base_url, url):
217         if not url:
218             return base_url
219
220         urlsp = urlparse.urlsplit(url)
221         if urlsp.scheme or not base_url:
222             return url
223
224         basesp = urlparse.urlsplit(base_url)
225         if basesp.scheme in ("keep", "arvwf"):
226             if not basesp.path:
227                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
228
229             baseparts = basesp.path.split("/")
230             urlparts = urlsp.path.split("/") if urlsp.path else []
231
232             pdh = baseparts.pop(0)
233
234             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
235                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
236
237             if urlsp.path.startswith("/"):
238                 baseparts = []
239                 urlparts.pop(0)
240
241             if baseparts and urlsp.path:
242                 baseparts.pop()
243
244             path = "/".join([pdh] + baseparts + urlparts)
245             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
246
247         return super(CollectionFetcher, self).urljoin(base_url, url)
248
249 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
250 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
251
252 def collectionResolver(api_client, document_loader, uri, num_retries=4):
253     if uri.startswith("keep:") or uri.startswith("arvwf:"):
254         return uri
255
256     if workflow_uuid_pattern.match(uri):
257         return "arvwf:%s#main" % (uri)
258
259     if pipeline_template_uuid_pattern.match(uri):
260         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
261         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
262
263     p = uri.split("/")
264     if arvados.util.keep_locator_pattern.match(p[0]):
265         return "keep:%s" % (uri)
266
267     if arvados.util.collection_uuid_pattern.match(p[0]):
268         return "keep:%s%s" % (api_client.collections().
269                               get(uuid=p[0]).execute()["portable_data_hash"],
270                               uri[len(p[0]):])
271
272     return cwltool.resolver.tool_resolver(document_loader, uri)