14008: Merge branch 'master' into 14008-containers-index
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from googleapiclient.errors import HttpError
26
27 from schema_salad.ref_resolver import DefaultFetcher
28
29 logger = logging.getLogger('arvados.cwl-runner')
30
31 class CollectionCache(object):
32     def __init__(self, api_client, keep_client, num_retries,
33                  cap=256*1024*1024,
34                  min_entries=2):
35         self.api_client = api_client
36         self.keep_client = keep_client
37         self.num_retries = num_retries
38         self.collections = OrderedDict()
39         self.lock = threading.Lock()
40         self.total = 0
41         self.cap = cap
42         self.min_entries = min_entries
43
44     def cap_cache(self):
45         if self.total > self.cap:
46             # ordered list iterates from oldest to newest
47             for pdh, v in self.collections.items():
48                 if self.total < self.cap or len(self.collections) < self.min_entries:
49                     break
50                 # cut it loose
51                 logger.debug("Evicting collection reader %s from cache", pdh)
52                 del self.collections[pdh]
53                 self.total -= v[1]
54
55     def get(self, pdh):
56         with self.lock:
57             if pdh not in self.collections:
58                 logger.debug("Creating collection reader for %s", pdh)
59                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
60                                                          keep_client=self.keep_client,
61                                                          num_retries=self.num_retries)
62                 sz = len(cr.manifest_text()) * 128
63                 self.collections[pdh] = (cr, sz)
64                 self.total += sz
65                 self.cap_cache()
66             else:
67                 cr, sz = self.collections[pdh]
68                 # bump it to the back
69                 del self.collections[pdh]
70                 self.collections[pdh] = (cr, sz)
71             return cr
72
73
74 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
75     """Implement the cwltool FsAccess interface for Arvados Collections."""
76
77     def __init__(self, basedir, collection_cache=None):
78         super(CollectionFsAccess, self).__init__(basedir)
79         self.collection_cache = collection_cache
80
81     def get_collection(self, path):
82         sp = path.split("/", 1)
83         p = sp[0]
84         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
85             pdh = p[5:]
86             return (self.collection_cache.get(pdh), urlparse.unquote(sp[1]) if len(sp) == 2 else None)
87         else:
88             return (None, path)
89
90     def _match(self, collection, patternsegments, parent):
91         if not patternsegments:
92             return []
93
94         if not isinstance(collection, arvados.collection.RichCollectionBase):
95             return []
96
97         ret = []
98         # iterate over the files and subcollections in 'collection'
99         for filename in collection:
100             if patternsegments[0] == '.':
101                 # Pattern contains something like "./foo" so just shift
102                 # past the "./"
103                 ret.extend(self._match(collection, patternsegments[1:], parent))
104             elif fnmatch.fnmatch(filename, patternsegments[0]):
105                 cur = os.path.join(parent, filename)
106                 if len(patternsegments) == 1:
107                     ret.append(cur)
108                 else:
109                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
110         return ret
111
112     def glob(self, pattern):
113         collection, rest = self.get_collection(pattern)
114         if collection is not None and not rest:
115             return [pattern]
116         patternsegments = rest.split("/")
117         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
118
119     def open(self, fn, mode):
120         collection, rest = self.get_collection(fn)
121         if collection is not None:
122             return collection.open(rest, mode)
123         else:
124             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
125
126     def exists(self, fn):
127         try:
128             collection, rest = self.get_collection(fn)
129         except HttpError as err:
130             if err.resp.status == 404:
131                 return False
132             else:
133                 raise
134         if collection is not None:
135             if rest:
136                 return collection.exists(rest)
137             else:
138                 return True
139         else:
140             return super(CollectionFsAccess, self).exists(fn)
141
142     def size(self, fn):  # type: (unicode) -> bool
143         collection, rest = self.get_collection(fn)
144         if collection is not None:
145             if rest:
146                 arvfile = collection.find(rest)
147                 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
148                     return arvfile.size()
149             raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
150         else:
151             return super(CollectionFsAccess, self).size(fn)
152
153     def isfile(self, fn):  # type: (unicode) -> bool
154         collection, rest = self.get_collection(fn)
155         if collection is not None:
156             if rest:
157                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
158             else:
159                 return False
160         else:
161             return super(CollectionFsAccess, self).isfile(fn)
162
163     def isdir(self, fn):  # type: (unicode) -> bool
164         collection, rest = self.get_collection(fn)
165         if collection is not None:
166             if rest:
167                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
168             else:
169                 return True
170         else:
171             return super(CollectionFsAccess, self).isdir(fn)
172
173     def listdir(self, fn):  # type: (unicode) -> List[unicode]
174         collection, rest = self.get_collection(fn)
175         if collection is not None:
176             if rest:
177                 dir = collection.find(rest)
178             else:
179                 dir = collection
180             if dir is None:
181                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
182             if not isinstance(dir, arvados.collection.RichCollectionBase):
183                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
184             return [abspath(l, fn) for l in dir.keys()]
185         else:
186             return super(CollectionFsAccess, self).listdir(fn)
187
188     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
189         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
190             return paths[-1]
191         return os.path.join(path, *paths)
192
193     def realpath(self, path):
194         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
195             return path
196         collection, rest = self.get_collection(path)
197         if collection is not None:
198             return path
199         else:
200             return os.path.realpath(path)
201
202 class CollectionFetcher(DefaultFetcher):
203     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
204         super(CollectionFetcher, self).__init__(cache, session)
205         self.api_client = api_client
206         self.fsaccess = fs_access
207         self.num_retries = num_retries
208
209     def fetch_text(self, url):
210         if url.startswith("keep:"):
211             with self.fsaccess.open(url, "r") as f:
212                 return f.read()
213         if url.startswith("arvwf:"):
214             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
215             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
216             return definition
217         return super(CollectionFetcher, self).fetch_text(url)
218
219     def check_exists(self, url):
220         try:
221             if url.startswith("http://arvados.org/cwl"):
222                 return True
223             if url.startswith("keep:"):
224                 return self.fsaccess.exists(url)
225             if url.startswith("arvwf:"):
226                 if self.fetch_text(url):
227                     return True
228         except arvados.errors.NotFoundError:
229             return False
230         except:
231             logger.exception("Got unexpected exception checking if file exists:")
232             return False
233         return super(CollectionFetcher, self).check_exists(url)
234
235     def urljoin(self, base_url, url):
236         if not url:
237             return base_url
238
239         urlsp = urlparse.urlsplit(url)
240         if urlsp.scheme or not base_url:
241             return url
242
243         basesp = urlparse.urlsplit(base_url)
244         if basesp.scheme in ("keep", "arvwf"):
245             if not basesp.path:
246                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
247
248             baseparts = basesp.path.split("/")
249             urlparts = urlsp.path.split("/") if urlsp.path else []
250
251             pdh = baseparts.pop(0)
252
253             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
254                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
255
256             if urlsp.path.startswith("/"):
257                 baseparts = []
258                 urlparts.pop(0)
259
260             if baseparts and urlsp.path:
261                 baseparts.pop()
262
263             path = "/".join([pdh] + baseparts + urlparts)
264             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
265
266         return super(CollectionFetcher, self).urljoin(base_url, url)
267
268 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
269 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
270
271 def collectionResolver(api_client, document_loader, uri, num_retries=4):
272     if uri.startswith("keep:") or uri.startswith("arvwf:"):
273         return uri
274
275     if workflow_uuid_pattern.match(uri):
276         return "arvwf:%s#main" % (uri)
277
278     if pipeline_template_uuid_pattern.match(uri):
279         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
280         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
281
282     p = uri.split("/")
283     if arvados.util.keep_locator_pattern.match(p[0]):
284         return "keep:%s" % (uri)
285
286     if arvados.util.collection_uuid_pattern.match(p[0]):
287         return "keep:%s%s" % (api_client.collections().
288                               get(uuid=p[0]).execute()["portable_data_hash"],
289                               uri[len(p[0]):])
290
291     return cwltool.resolver.tool_resolver(document_loader, uri)