Merge branch 'master' of git.curoverse.com:arvados into 11876-r-sdk
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from schema_salad.ref_resolver import DefaultFetcher
26
27 logger = logging.getLogger('arvados.cwl-runner')
28
29 class CollectionCache(object):
30     def __init__(self, api_client, keep_client, num_retries,
31                  cap=256*1024*1024,
32                  min_entries=2):
33         self.api_client = api_client
34         self.keep_client = keep_client
35         self.collections = OrderedDict()
36         self.lock = threading.Lock()
37         self.total = 0
38         self.cap = cap
39         self.min_entries = min_entries
40
41     def cap_cache(self):
42         if self.total > self.cap:
43             # ordered list iterates from oldest to newest
44             for pdh, v in self.collections.items():
45                 if self.total < self.cap or len(self.collections) < self.min_entries:
46                     break
47                 # cut it loose
48                 logger.debug("Evicting collection reader %s from cache", pdh)
49                 del self.collections[pdh]
50                 self.total -= v[1]
51
52     def get(self, pdh):
53         with self.lock:
54             if pdh not in self.collections:
55                 logger.debug("Creating collection reader for %s", pdh)
56                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
57                                                          keep_client=self.keep_client)
58                 sz = len(cr.manifest_text()) * 128
59                 self.collections[pdh] = (cr, sz)
60                 self.total += sz
61                 self.cap_cache()
62             else:
63                 cr, sz = self.collections[pdh]
64                 # bump it to the back
65                 del self.collections[pdh]
66                 self.collections[pdh] = (cr, sz)
67             return cr
68
69
70 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
71     """Implement the cwltool FsAccess interface for Arvados Collections."""
72
73     def __init__(self, basedir, collection_cache=None):
74         super(CollectionFsAccess, self).__init__(basedir)
75         self.collection_cache = collection_cache
76
77     def get_collection(self, path):
78         sp = path.split("/", 1)
79         p = sp[0]
80         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
81             pdh = p[5:]
82             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
83         else:
84             return (None, path)
85
86     def _match(self, collection, patternsegments, parent):
87         if not patternsegments:
88             return []
89
90         if not isinstance(collection, arvados.collection.RichCollectionBase):
91             return []
92
93         ret = []
94         # iterate over the files and subcollections in 'collection'
95         for filename in collection:
96             if patternsegments[0] == '.':
97                 # Pattern contains something like "./foo" so just shift
98                 # past the "./"
99                 ret.extend(self._match(collection, patternsegments[1:], parent))
100             elif fnmatch.fnmatch(filename, patternsegments[0]):
101                 cur = os.path.join(parent, filename)
102                 if len(patternsegments) == 1:
103                     ret.append(cur)
104                 else:
105                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
106         return ret
107
108     def glob(self, pattern):
109         collection, rest = self.get_collection(pattern)
110         if collection is not None and not rest:
111             return [pattern]
112         patternsegments = rest.split("/")
113         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
114
115     def open(self, fn, mode):
116         collection, rest = self.get_collection(fn)
117         if collection is not None:
118             return collection.open(rest, mode)
119         else:
120             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
121
122     def exists(self, fn):
123         collection, rest = self.get_collection(fn)
124         if collection is not None:
125             if rest:
126                 return collection.exists(rest)
127             else:
128                 return True
129         else:
130             return super(CollectionFsAccess, self).exists(fn)
131
132     def isfile(self, fn):  # type: (unicode) -> bool
133         collection, rest = self.get_collection(fn)
134         if collection is not None:
135             if rest:
136                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
137             else:
138                 return False
139         else:
140             return super(CollectionFsAccess, self).isfile(fn)
141
142     def isdir(self, fn):  # type: (unicode) -> bool
143         collection, rest = self.get_collection(fn)
144         if collection is not None:
145             if rest:
146                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
147             else:
148                 return True
149         else:
150             return super(CollectionFsAccess, self).isdir(fn)
151
152     def listdir(self, fn):  # type: (unicode) -> List[unicode]
153         collection, rest = self.get_collection(fn)
154         if collection is not None:
155             if rest:
156                 dir = collection.find(rest)
157             else:
158                 dir = collection
159             if dir is None:
160                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
161             if not isinstance(dir, arvados.collection.RichCollectionBase):
162                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
163             return [abspath(l, fn) for l in dir.keys()]
164         else:
165             return super(CollectionFsAccess, self).listdir(fn)
166
167     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
168         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
169             return paths[-1]
170         return os.path.join(path, *paths)
171
172     def realpath(self, path):
173         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
174             return path
175         collection, rest = self.get_collection(path)
176         if collection is not None:
177             return path
178         else:
179             return os.path.realpath(path)
180
181 class CollectionFetcher(DefaultFetcher):
182     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
183         super(CollectionFetcher, self).__init__(cache, session)
184         self.api_client = api_client
185         self.fsaccess = fs_access
186         self.num_retries = num_retries
187
188     def fetch_text(self, url):
189         if url.startswith("keep:"):
190             with self.fsaccess.open(url, "r") as f:
191                 return f.read()
192         if url.startswith("arvwf:"):
193             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
194             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
195             return definition
196         return super(CollectionFetcher, self).fetch_text(url)
197
198     def check_exists(self, url):
199         try:
200             if url.startswith("http://arvados.org/cwl"):
201                 return True
202             if url.startswith("keep:"):
203                 return self.fsaccess.exists(url)
204             if url.startswith("arvwf:"):
205                 if self.fetch_text(url):
206                     return True
207         except arvados.errors.NotFoundError:
208             return False
209         except:
210             logger.exception("Got unexpected exception checking if file exists:")
211             return False
212         return super(CollectionFetcher, self).check_exists(url)
213
214     def urljoin(self, base_url, url):
215         if not url:
216             return base_url
217
218         urlsp = urlparse.urlsplit(url)
219         if urlsp.scheme or not base_url:
220             return url
221
222         basesp = urlparse.urlsplit(base_url)
223         if basesp.scheme in ("keep", "arvwf"):
224             if not basesp.path:
225                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
226
227             baseparts = basesp.path.split("/")
228             urlparts = urlsp.path.split("/") if urlsp.path else []
229
230             pdh = baseparts.pop(0)
231
232             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
233                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
234
235             if urlsp.path.startswith("/"):
236                 baseparts = []
237                 urlparts.pop(0)
238
239             if baseparts and urlsp.path:
240                 baseparts.pop()
241
242             path = "/".join([pdh] + baseparts + urlparts)
243             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
244
245         return super(CollectionFetcher, self).urljoin(base_url, url)
246
247 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
248 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
249
250 def collectionResolver(api_client, document_loader, uri, num_retries=4):
251     if uri.startswith("keep:") or uri.startswith("arvwf:"):
252         return uri
253
254     if workflow_uuid_pattern.match(uri):
255         return "arvwf:%s#main" % (uri)
256
257     if pipeline_template_uuid_pattern.match(uri):
258         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
259         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
260
261     p = uri.split("/")
262     if arvados.util.keep_locator_pattern.match(p[0]):
263         return "keep:%s" % (uri)
264
265     if arvados.util.collection_uuid_pattern.match(p[0]):
266         return "keep:%s%s" % (api_client.collections().
267                               get(uuid=p[0]).execute()["portable_data_hash"],
268                               uri[len(p[0]):])
269
270     return cwltool.resolver.tool_resolver(document_loader, uri)