1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
18 from collections import OrderedDict
20 import ruamel.yaml as yaml
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
27 import arvados.collection
28 import arvados.arvfile
31 from googleapiclient.errors import HttpError
33 from schema_salad.ref_resolver import DefaultFetcher
35 logger = logging.getLogger('arvados.cwl-runner')
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
39 class CollectionCache(object):
40 def __init__(self, api_client, keep_client, num_retries,
43 self.api_client = api_client
44 self.keep_client = keep_client
45 self.num_retries = num_retries
46 self.collections = OrderedDict()
47 self.lock = threading.Lock()
50 self.min_entries = min_entries
52 def set_cap(self, cap):
55 def cap_cache(self, required):
56 # ordered dict iterates from oldest to newest
57 for pdh, v in list(self.collections.items()):
58 available = self.cap - self.total
59 if available >= required or len(self.collections) < self.min_entries:
62 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63 del self.collections[pdh]
66 def get(self, locator):
68 if locator not in self.collections:
69 m = pdh_size.match(locator)
71 self.cap_cache(int(m.group(2)) * 128)
72 logger.debug("Creating collection reader for %s", locator)
74 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75 keep_client=self.keep_client,
76 num_retries=self.num_retries)
77 except arvados.errors.ApiError as ap:
78 raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79 sz = len(cr.manifest_text()) * 128
80 self.collections[locator] = (cr, sz)
83 cr, sz = self.collections[locator]
85 del self.collections[locator]
86 self.collections[locator] = (cr, sz)
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91 """Implement the cwltool FsAccess interface for Arvados Collections."""
93 def __init__(self, basedir, collection_cache=None):
94 super(CollectionFsAccess, self).__init__(basedir)
95 self.collection_cache = collection_cache
97 def get_collection(self, path):
98 sp = path.split("/", 1)
100 if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101 arvados.util.collection_uuid_pattern.match(p[5:])):
103 return (self.collection_cache.get(locator), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
107 def _match(self, collection, patternsegments, parent):
108 if not patternsegments:
111 if not isinstance(collection, arvados.collection.RichCollectionBase):
115 # iterate over the files and subcollections in 'collection'
116 for filename in collection:
117 if patternsegments[0] == '.':
118 # Pattern contains something like "./foo" so just shift
120 ret.extend(self._match(collection, patternsegments[1:], parent))
121 elif fnmatch.fnmatch(filename, patternsegments[0]):
122 cur = os.path.join(parent, filename)
123 if len(patternsegments) == 1:
126 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
129 def glob(self, pattern):
130 collection, rest = self.get_collection(pattern)
131 if collection is not None and not rest:
133 patternsegments = rest.split("/")
134 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
136 def open(self, fn, mode, encoding=None):
137 collection, rest = self.get_collection(fn)
138 if collection is not None:
139 return collection.open(rest, mode, encoding=encoding)
141 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
143 def exists(self, fn):
145 collection, rest = self.get_collection(fn)
146 except HttpError as err:
147 if err.resp.status == 404:
151 except IOError as err:
152 if err.errno == errno.ENOENT:
156 if collection is not None:
158 return collection.exists(rest)
162 return super(CollectionFsAccess, self).exists(fn)
164 def size(self, fn): # type: (unicode) -> bool
165 collection, rest = self.get_collection(fn)
166 if collection is not None:
168 arvfile = collection.find(rest)
169 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
170 return arvfile.size()
171 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
173 return super(CollectionFsAccess, self).size(fn)
175 def isfile(self, fn): # type: (unicode) -> bool
176 collection, rest = self.get_collection(fn)
177 if collection is not None:
179 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
183 return super(CollectionFsAccess, self).isfile(fn)
185 def isdir(self, fn): # type: (unicode) -> bool
186 collection, rest = self.get_collection(fn)
187 if collection is not None:
189 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
193 return super(CollectionFsAccess, self).isdir(fn)
195 def listdir(self, fn): # type: (unicode) -> List[unicode]
196 collection, rest = self.get_collection(fn)
197 if collection is not None:
199 dir = collection.find(rest)
203 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
204 if not isinstance(dir, arvados.collection.RichCollectionBase):
205 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
206 return [abspath(l, fn) for l in list(dir.keys())]
208 return super(CollectionFsAccess, self).listdir(fn)
210 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
211 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
213 return os.path.join(path, *paths)
215 def realpath(self, path):
216 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
218 collection, rest = self.get_collection(path)
219 if collection is not None:
222 return os.path.realpath(path)
224 class CollectionFetcher(DefaultFetcher):
225 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
226 super(CollectionFetcher, self).__init__(cache, session)
227 self.api_client = api_client
228 self.fsaccess = fs_access
229 self.num_retries = num_retries
231 def fetch_text(self, url):
232 if url.startswith("keep:"):
233 with self.fsaccess.open(url, "r", encoding="utf-8") as f:
235 if url.startswith("arvwf:"):
236 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
237 definition = yaml.round_trip_load(record["definition"])
238 definition["label"] = record["name"]
239 return yaml.round_trip_dump(definition)
240 return super(CollectionFetcher, self).fetch_text(url)
242 def check_exists(self, url):
244 if url.startswith("http://arvados.org/cwl"):
246 if url.startswith("keep:"):
247 return self.fsaccess.exists(url)
248 if url.startswith("arvwf:"):
249 if self.fetch_text(url):
251 except arvados.errors.NotFoundError:
254 logger.exception("Got unexpected exception checking if file exists")
256 return super(CollectionFetcher, self).check_exists(url)
258 def urljoin(self, base_url, url):
262 urlsp = urllib.parse.urlsplit(url)
263 if urlsp.scheme or not base_url:
266 basesp = urllib.parse.urlsplit(base_url)
267 if basesp.scheme in ("keep", "arvwf"):
269 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
271 baseparts = basesp.path.split("/")
272 urlparts = urlsp.path.split("/") if urlsp.path else []
274 locator = baseparts.pop(0)
276 if (basesp.scheme == "keep" and
277 (not arvados.util.keep_locator_pattern.match(locator)) and
278 (not arvados.util.collection_uuid_pattern.match(locator))):
279 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
281 if urlsp.path.startswith("/"):
285 if baseparts and urlsp.path:
288 path = "/".join([locator] + baseparts + urlparts)
289 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
291 return super(CollectionFetcher, self).urljoin(base_url, url)
293 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
295 def supported_schemes(self): # type: () -> List[Text]
299 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
300 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
302 def collectionResolver(api_client, document_loader, uri, num_retries=4):
303 if uri.startswith("keep:") or uri.startswith("arvwf:"):
306 if workflow_uuid_pattern.match(uri):
307 return u"arvwf:%s#main" % (uri)
309 if pipeline_template_uuid_pattern.match(uri):
310 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
311 return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
314 if arvados.util.keep_locator_pattern.match(p[0]):
315 return u"keep:%s" % (uri)
317 if arvados.util.collection_uuid_pattern.match(p[0]):
318 return u"keep:%s%s" % (api_client.collections().
319 get(uuid=p[0]).execute()["portable_data_hash"],
322 return cwltool.resolver.tool_resolver(document_loader, uri)