Merge branch '11162-wes-support' refs #11162
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12 from collections import OrderedDict
13
14 import ruamel.yaml as yaml
15
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
19
20 import arvados.util
21 import arvados.collection
22 import arvados.arvfile
23 import arvados.errors
24
25 from googleapiclient.errors import HttpError
26
27 from schema_salad.ref_resolver import DefaultFetcher
28
29 logger = logging.getLogger('arvados.cwl-runner')
30
31 class CollectionCache(object):
32     def __init__(self, api_client, keep_client, num_retries,
33                  cap=256*1024*1024,
34                  min_entries=2):
35         self.api_client = api_client
36         self.keep_client = keep_client
37         self.num_retries = num_retries
38         self.collections = OrderedDict()
39         self.lock = threading.Lock()
40         self.total = 0
41         self.cap = cap
42         self.min_entries = min_entries
43
44     def cap_cache(self):
45         if self.total > self.cap:
46             # ordered list iterates from oldest to newest
47             for pdh, v in self.collections.items():
48                 if self.total < self.cap or len(self.collections) < self.min_entries:
49                     break
50                 # cut it loose
51                 logger.debug("Evicting collection reader %s from cache", pdh)
52                 del self.collections[pdh]
53                 self.total -= v[1]
54
55     def get(self, pdh):
56         with self.lock:
57             if pdh not in self.collections:
58                 logger.debug("Creating collection reader for %s", pdh)
59                 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
60                                                          keep_client=self.keep_client,
61                                                          num_retries=self.num_retries)
62                 sz = len(cr.manifest_text()) * 128
63                 self.collections[pdh] = (cr, sz)
64                 self.total += sz
65                 self.cap_cache()
66             else:
67                 cr, sz = self.collections[pdh]
68                 # bump it to the back
69                 del self.collections[pdh]
70                 self.collections[pdh] = (cr, sz)
71             return cr
72
73
74 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
75     """Implement the cwltool FsAccess interface for Arvados Collections."""
76
77     def __init__(self, basedir, collection_cache=None):
78         super(CollectionFsAccess, self).__init__(basedir)
79         self.collection_cache = collection_cache
80
81     def get_collection(self, path):
82         sp = path.split("/", 1)
83         p = sp[0]
84         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
85             pdh = p[5:]
86             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
87         else:
88             return (None, path)
89
90     def _match(self, collection, patternsegments, parent):
91         if not patternsegments:
92             return []
93
94         if not isinstance(collection, arvados.collection.RichCollectionBase):
95             return []
96
97         ret = []
98         # iterate over the files and subcollections in 'collection'
99         for filename in collection:
100             if patternsegments[0] == '.':
101                 # Pattern contains something like "./foo" so just shift
102                 # past the "./"
103                 ret.extend(self._match(collection, patternsegments[1:], parent))
104             elif fnmatch.fnmatch(filename, patternsegments[0]):
105                 cur = os.path.join(parent, filename)
106                 if len(patternsegments) == 1:
107                     ret.append(cur)
108                 else:
109                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
110         return ret
111
112     def glob(self, pattern):
113         collection, rest = self.get_collection(pattern)
114         if collection is not None and not rest:
115             return [pattern]
116         patternsegments = rest.split("/")
117         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
118
119     def open(self, fn, mode):
120         collection, rest = self.get_collection(fn)
121         if collection is not None:
122             return collection.open(rest, mode)
123         else:
124             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
125
126     def exists(self, fn):
127         try:
128             collection, rest = self.get_collection(fn)
129         except HttpError as err:
130             if err.resp.status == 404:
131                 return False
132             else:
133                 raise
134         if collection is not None:
135             if rest:
136                 return collection.exists(rest)
137             else:
138                 return True
139         else:
140             return super(CollectionFsAccess, self).exists(fn)
141
142     def isfile(self, fn):  # type: (unicode) -> bool
143         collection, rest = self.get_collection(fn)
144         if collection is not None:
145             if rest:
146                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
147             else:
148                 return False
149         else:
150             return super(CollectionFsAccess, self).isfile(fn)
151
152     def isdir(self, fn):  # type: (unicode) -> bool
153         collection, rest = self.get_collection(fn)
154         if collection is not None:
155             if rest:
156                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
157             else:
158                 return True
159         else:
160             return super(CollectionFsAccess, self).isdir(fn)
161
162     def listdir(self, fn):  # type: (unicode) -> List[unicode]
163         collection, rest = self.get_collection(fn)
164         if collection is not None:
165             if rest:
166                 dir = collection.find(rest)
167             else:
168                 dir = collection
169             if dir is None:
170                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
171             if not isinstance(dir, arvados.collection.RichCollectionBase):
172                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
173             return [abspath(l, fn) for l in dir.keys()]
174         else:
175             return super(CollectionFsAccess, self).listdir(fn)
176
177     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
178         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
179             return paths[-1]
180         return os.path.join(path, *paths)
181
182     def realpath(self, path):
183         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
184             return path
185         collection, rest = self.get_collection(path)
186         if collection is not None:
187             return path
188         else:
189             return os.path.realpath(path)
190
191 class CollectionFetcher(DefaultFetcher):
192     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
193         super(CollectionFetcher, self).__init__(cache, session)
194         self.api_client = api_client
195         self.fsaccess = fs_access
196         self.num_retries = num_retries
197
198     def fetch_text(self, url):
199         if url.startswith("keep:"):
200             with self.fsaccess.open(url, "r") as f:
201                 return f.read()
202         if url.startswith("arvwf:"):
203             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
204             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
205             return definition
206         return super(CollectionFetcher, self).fetch_text(url)
207
208     def check_exists(self, url):
209         try:
210             if url.startswith("http://arvados.org/cwl"):
211                 return True
212             if url.startswith("keep:"):
213                 return self.fsaccess.exists(url)
214             if url.startswith("arvwf:"):
215                 if self.fetch_text(url):
216                     return True
217         except arvados.errors.NotFoundError:
218             return False
219         except:
220             logger.exception("Got unexpected exception checking if file exists:")
221             return False
222         return super(CollectionFetcher, self).check_exists(url)
223
224     def urljoin(self, base_url, url):
225         if not url:
226             return base_url
227
228         urlsp = urlparse.urlsplit(url)
229         if urlsp.scheme or not base_url:
230             return url
231
232         basesp = urlparse.urlsplit(base_url)
233         if basesp.scheme in ("keep", "arvwf"):
234             if not basesp.path:
235                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
236
237             baseparts = basesp.path.split("/")
238             urlparts = urlsp.path.split("/") if urlsp.path else []
239
240             pdh = baseparts.pop(0)
241
242             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
243                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
244
245             if urlsp.path.startswith("/"):
246                 baseparts = []
247                 urlparts.pop(0)
248
249             if baseparts and urlsp.path:
250                 baseparts.pop()
251
252             path = "/".join([pdh] + baseparts + urlparts)
253             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
254
255         return super(CollectionFetcher, self).urljoin(base_url, url)
256
257 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
258 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
259
260 def collectionResolver(api_client, document_loader, uri, num_retries=4):
261     if uri.startswith("keep:") or uri.startswith("arvwf:"):
262         return uri
263
264     if workflow_uuid_pattern.match(uri):
265         return "arvwf:%s#main" % (uri)
266
267     if pipeline_template_uuid_pattern.match(uri):
268         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
269         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
270
271     p = uri.split("/")
272     if arvados.util.keep_locator_pattern.match(p[0]):
273         return "keep:%s" % (uri)
274
275     if arvados.util.collection_uuid_pattern.match(p[0]):
276         return "keep:%s%s" % (api_client.collections().
277                               get(uuid=p[0]).execute()["portable_data_hash"],
278                               uri[len(p[0]):])
279
280     return cwltool.resolver.tool_resolver(document_loader, uri)