1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
12 from collections import OrderedDict
13 from io import StringIO
17 import cwltool.stdfsaccess
18 from cwltool.pathmapper import abspath
19 import cwltool.resolver
22 import arvados.collection
23 import arvados.arvfile
26 from googleapiclient.errors import HttpError
28 from schema_salad.ref_resolver import DefaultFetcher
30 logger = logging.getLogger('arvados.cwl-runner')
32 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
34 class CollectionCache(object):
35 def __init__(self, api_client, keep_client, num_retries,
38 self.api_client = api_client
39 self.keep_client = keep_client
40 self.num_retries = num_retries
41 self.collections = OrderedDict()
42 self.lock = threading.Lock()
45 self.min_entries = min_entries
47 def set_cap(self, cap):
50 def cap_cache(self, required):
51 # ordered dict iterates from oldest to newest
52 for pdh, v in list(self.collections.items()):
53 available = self.cap - self.total
54 if available >= required or len(self.collections) < self.min_entries:
57 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
58 del self.collections[pdh]
61 def get(self, locator):
63 if locator not in self.collections:
64 m = pdh_size.match(locator)
66 self.cap_cache(int(m.group(2)) * 128)
67 logger.debug("Creating collection reader for %s", locator)
69 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
70 keep_client=self.keep_client,
71 num_retries=self.num_retries)
72 except arvados.errors.ApiError as ap:
73 raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
74 sz = len(cr.manifest_text()) * 128
75 self.collections[locator] = (cr, sz)
78 cr, sz = self.collections[locator]
80 del self.collections[locator]
81 self.collections[locator] = (cr, sz)
85 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
86 """Implement the cwltool FsAccess interface for Arvados Collections."""
88 def __init__(self, basedir, collection_cache=None):
89 super(CollectionFsAccess, self).__init__(basedir)
90 self.collection_cache = collection_cache
92 def get_collection(self, path):
93 sp = path.split("/", 1)
95 if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
96 arvados.util.collection_uuid_pattern.match(p[5:])):
98 rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
99 return (self.collection_cache.get(locator), rest)
103 def _match(self, collection, patternsegments, parent):
104 if not patternsegments:
107 if not isinstance(collection, arvados.collection.RichCollectionBase):
111 # iterate over the files and subcollections in 'collection'
112 for filename in collection:
113 if patternsegments[0] == '.':
114 # Pattern contains something like "./foo" so just shift
116 ret.extend(self._match(collection, patternsegments[1:], parent))
117 elif fnmatch.fnmatch(filename, patternsegments[0]):
118 cur = os.path.join(parent, filename)
119 if len(patternsegments) == 1:
122 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
125 def glob(self, pattern):
126 collection, rest = self.get_collection(pattern)
127 if collection is not None and rest in (None, "", "."):
129 patternsegments = rest.split("/")
130 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
132 def open(self, fn, mode, encoding=None):
133 collection, rest = self.get_collection(fn)
134 if collection is not None:
135 return collection.open(rest, mode, encoding=encoding)
137 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
139 def exists(self, fn):
141 collection, rest = self.get_collection(fn)
142 except HttpError as err:
143 if err.resp.status == 404:
147 except IOError as err:
148 if err.errno == errno.ENOENT:
152 if collection is not None:
154 return collection.exists(rest)
158 return super(CollectionFsAccess, self).exists(fn)
160 def size(self, fn): # type: (unicode) -> bool
161 collection, rest = self.get_collection(fn)
162 if collection is not None:
164 arvfile = collection.find(rest)
165 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
166 return arvfile.size()
167 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
169 return super(CollectionFsAccess, self).size(fn)
171 def isfile(self, fn): # type: (unicode) -> bool
172 collection, rest = self.get_collection(fn)
173 if collection is not None:
175 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
179 return super(CollectionFsAccess, self).isfile(fn)
181 def isdir(self, fn): # type: (unicode) -> bool
182 collection, rest = self.get_collection(fn)
183 if collection is not None:
185 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
189 return super(CollectionFsAccess, self).isdir(fn)
191 def listdir(self, fn): # type: (unicode) -> List[unicode]
192 collection, rest = self.get_collection(fn)
193 if collection is not None:
195 dir = collection.find(rest)
199 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
200 if not isinstance(dir, arvados.collection.RichCollectionBase):
201 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
202 return [abspath(l, fn) for l in list(dir.keys())]
204 return super(CollectionFsAccess, self).listdir(fn)
206 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
207 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
209 return os.path.join(path, *paths)
211 def realpath(self, path):
212 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
214 collection, rest = self.get_collection(path)
215 if collection is not None:
218 return os.path.realpath(path)
220 class CollectionFetcher(DefaultFetcher):
221 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
222 super(CollectionFetcher, self).__init__(cache, session)
223 self.api_client = api_client
224 self.fsaccess = fs_access
225 self.num_retries = num_retries
227 def fetch_text(self, url, content_types=None):
228 if url.startswith("keep:"):
229 with self.fsaccess.open(url, "r", encoding="utf-8") as f:
231 if url.startswith("arvwf:"):
232 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
233 yaml = ruamel.yaml.YAML(typ='rt', pure=True)
234 definition = yaml.load(record["definition"])
235 definition["label"] = record["name"]
237 yaml.dump(definition, stream)
238 return stream.getvalue()
239 return super(CollectionFetcher, self).fetch_text(url)
241 def check_exists(self, url):
243 if url.startswith("http://arvados.org/cwl"):
245 urld, _ = urllib.parse.urldefrag(url)
246 if urld.startswith("keep:"):
247 return self.fsaccess.exists(urld)
248 if urld.startswith("arvwf:"):
249 if self.fetch_text(urld):
251 except arvados.errors.NotFoundError:
254 logger.exception("Got unexpected exception checking if file exists")
256 return super(CollectionFetcher, self).check_exists(url)
258 def urljoin(self, base_url, url):
262 urlsp = urllib.parse.urlsplit(url)
263 if urlsp.scheme or not base_url:
266 basesp = urllib.parse.urlsplit(base_url)
267 if basesp.scheme in ("keep", "arvwf"):
269 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
271 baseparts = basesp.path.split("/")
272 urlparts = urlsp.path.split("/") if urlsp.path else []
274 locator = baseparts.pop(0)
276 if (basesp.scheme == "keep" and
277 (not arvados.util.keep_locator_pattern.match(locator)) and
278 (not arvados.util.collection_uuid_pattern.match(locator))):
279 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
281 if urlsp.path.startswith("/"):
285 if baseparts and urlsp.path:
288 path = "/".join([locator] + baseparts + urlparts)
289 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
291 return super(CollectionFetcher, self).urljoin(base_url, url)
293 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
295 def supported_schemes(self): # type: () -> List[Text]
299 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
300 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
302 def collectionResolver(api_client, document_loader, uri, num_retries=4):
303 if uri.startswith("keep:") or uri.startswith("arvwf:"):
306 if workflow_uuid_pattern.match(uri):
307 return u"arvwf:%s#main" % (uri)
309 if pipeline_template_uuid_pattern.match(uri):
310 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
311 return u"keep:" + next(pt["components"].values())["script_parameters"]["cwl:tool"]
314 if arvados.util.keep_locator_pattern.match(p[0]):
315 return u"keep:%s" % (uri)
317 if arvados.util.collection_uuid_pattern.match(p[0]):
318 return u"keep:%s%s" % (api_client.collections().
319 get(uuid=p[0]).execute()["portable_data_hash"],
322 return cwltool.resolver.tool_resolver(document_loader, uri)