1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from future.utils import viewvalues
17 from collections import OrderedDict
19 import ruamel.yaml as yaml
21 import cwltool.stdfsaccess
22 from cwltool.pathmapper import abspath
23 import cwltool.resolver
26 import arvados.collection
27 import arvados.arvfile
30 from googleapiclient.errors import HttpError
32 from schema_salad.ref_resolver import DefaultFetcher
34 logger = logging.getLogger('arvados.cwl-runner')
36 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
38 class CollectionCache(object):
39 def __init__(self, api_client, keep_client, num_retries,
42 self.api_client = api_client
43 self.keep_client = keep_client
44 self.num_retries = num_retries
45 self.collections = OrderedDict()
46 self.lock = threading.Lock()
49 self.min_entries = min_entries
51 def set_cap(self, cap):
54 def cap_cache(self, required):
55 # ordered dict iterates from oldest to newest
56 for pdh, v in list(self.collections.items()):
57 available = self.cap - self.total
58 if available >= required or len(self.collections) < self.min_entries:
61 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
62 del self.collections[pdh]
67 if pdh not in self.collections:
68 m = pdh_size.match(pdh)
70 self.cap_cache(int(m.group(2)) * 128)
71 logger.debug("Creating collection reader for %s", pdh)
72 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
73 keep_client=self.keep_client,
74 num_retries=self.num_retries)
75 sz = len(cr.manifest_text()) * 128
76 self.collections[pdh] = (cr, sz)
79 cr, sz = self.collections[pdh]
81 del self.collections[pdh]
82 self.collections[pdh] = (cr, sz)
86 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
87 """Implement the cwltool FsAccess interface for Arvados Collections."""
89 def __init__(self, basedir, collection_cache=None):
90 super(CollectionFsAccess, self).__init__(basedir)
91 self.collection_cache = collection_cache
93 def get_collection(self, path):
94 sp = path.split("/", 1)
96 if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
98 return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
102 def _match(self, collection, patternsegments, parent):
103 if not patternsegments:
106 if not isinstance(collection, arvados.collection.RichCollectionBase):
110 # iterate over the files and subcollections in 'collection'
111 for filename in collection:
112 if patternsegments[0] == '.':
113 # Pattern contains something like "./foo" so just shift
115 ret.extend(self._match(collection, patternsegments[1:], parent))
116 elif fnmatch.fnmatch(filename, patternsegments[0]):
117 cur = os.path.join(parent, filename)
118 if len(patternsegments) == 1:
121 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
124 def glob(self, pattern):
125 collection, rest = self.get_collection(pattern)
126 if collection is not None and not rest:
128 patternsegments = rest.split("/")
129 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
131 def open(self, fn, mode):
132 collection, rest = self.get_collection(fn)
133 if collection is not None:
134 return collection.open(rest, mode)
136 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
138 def exists(self, fn):
140 collection, rest = self.get_collection(fn)
141 except HttpError as err:
142 if err.resp.status == 404:
146 if collection is not None:
148 return collection.exists(rest)
152 return super(CollectionFsAccess, self).exists(fn)
154 def size(self, fn): # type: (unicode) -> bool
155 collection, rest = self.get_collection(fn)
156 if collection is not None:
158 arvfile = collection.find(rest)
159 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
160 return arvfile.size()
161 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
163 return super(CollectionFsAccess, self).size(fn)
165 def isfile(self, fn): # type: (unicode) -> bool
166 collection, rest = self.get_collection(fn)
167 if collection is not None:
169 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
173 return super(CollectionFsAccess, self).isfile(fn)
175 def isdir(self, fn): # type: (unicode) -> bool
176 collection, rest = self.get_collection(fn)
177 if collection is not None:
179 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
183 return super(CollectionFsAccess, self).isdir(fn)
185 def listdir(self, fn): # type: (unicode) -> List[unicode]
186 collection, rest = self.get_collection(fn)
187 if collection is not None:
189 dir = collection.find(rest)
193 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
194 if not isinstance(dir, arvados.collection.RichCollectionBase):
195 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
196 return [abspath(l, fn) for l in list(dir.keys())]
198 return super(CollectionFsAccess, self).listdir(fn)
200 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
201 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
203 return os.path.join(path, *paths)
205 def realpath(self, path):
206 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
208 collection, rest = self.get_collection(path)
209 if collection is not None:
212 return os.path.realpath(path)
214 class CollectionFetcher(DefaultFetcher):
215 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
216 super(CollectionFetcher, self).__init__(cache, session)
217 self.api_client = api_client
218 self.fsaccess = fs_access
219 self.num_retries = num_retries
221 def fetch_text(self, url):
222 if url.startswith("keep:"):
223 with self.fsaccess.open(url, "r") as f:
225 if url.startswith("arvwf:"):
226 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
227 definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
229 return super(CollectionFetcher, self).fetch_text(url)
231 def check_exists(self, url):
233 if url.startswith("http://arvados.org/cwl"):
235 if url.startswith("keep:"):
236 return self.fsaccess.exists(url)
237 if url.startswith("arvwf:"):
238 if self.fetch_text(url):
240 except arvados.errors.NotFoundError:
243 logger.exception("Got unexpected exception checking if file exists:")
245 return super(CollectionFetcher, self).check_exists(url)
247 def urljoin(self, base_url, url):
251 urlsp = urllib.parse.urlsplit(url)
252 if urlsp.scheme or not base_url:
255 basesp = urllib.parse.urlsplit(base_url)
256 if basesp.scheme in ("keep", "arvwf"):
258 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
260 baseparts = basesp.path.split("/")
261 urlparts = urlsp.path.split("/") if urlsp.path else []
263 pdh = baseparts.pop(0)
265 if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
266 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
268 if urlsp.path.startswith("/"):
272 if baseparts and urlsp.path:
275 path = "/".join([pdh] + baseparts + urlparts)
276 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
278 return super(CollectionFetcher, self).urljoin(base_url, url)
280 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
282 def supported_schemes(self): # type: () -> List[Text]
286 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
287 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
289 def collectionResolver(api_client, document_loader, uri, num_retries=4):
290 if uri.startswith("keep:") or uri.startswith("arvwf:"):
291 return uri.encode("utf-8").decode()
293 if workflow_uuid_pattern.match(uri):
294 return u"arvwf:%s#main" % (uri)
296 if pipeline_template_uuid_pattern.match(uri):
297 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
298 return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
301 if arvados.util.keep_locator_pattern.match(p[0]):
302 return u"keep:%s" % (uri)
304 if arvados.util.collection_uuid_pattern.match(p[0]):
305 return u"keep:%s%s" % (api_client.collections().
306 get(uuid=p[0]).execute()["portable_data_hash"],
309 return cwltool.resolver.tool_resolver(document_loader, uri)