1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
18 from collections import OrderedDict
19 from io import StringIO
23 import cwltool.stdfsaccess
24 from cwltool.pathmapper import abspath
25 import cwltool.resolver
28 import arvados.collection
29 import arvados.arvfile
32 from googleapiclient.errors import HttpError
34 from schema_salad.ref_resolver import DefaultFetcher
36 logger = logging.getLogger('arvados.cwl-runner')
38 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
40 class CollectionCache(object):
41 def __init__(self, api_client, keep_client, num_retries,
44 self.api_client = api_client
45 self.keep_client = keep_client
46 self.num_retries = num_retries
47 self.collections = OrderedDict()
48 self.lock = threading.Lock()
51 self.min_entries = min_entries
53 def set_cap(self, cap):
56 def cap_cache(self, required):
57 # ordered dict iterates from oldest to newest
58 for pdh, v in list(self.collections.items()):
59 available = self.cap - self.total
60 if available >= required or len(self.collections) < self.min_entries:
63 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
64 del self.collections[pdh]
67 def get(self, locator):
69 if locator not in self.collections:
70 m = pdh_size.match(locator)
72 self.cap_cache(int(m.group(2)) * 128)
73 logger.debug("Creating collection reader for %s", locator)
75 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
76 keep_client=self.keep_client,
77 num_retries=self.num_retries)
78 except arvados.errors.ApiError as ap:
79 raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
80 sz = len(cr.manifest_text()) * 128
81 self.collections[locator] = (cr, sz)
84 cr, sz = self.collections[locator]
86 del self.collections[locator]
87 self.collections[locator] = (cr, sz)
91 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
92 """Implement the cwltool FsAccess interface for Arvados Collections."""
94 def __init__(self, basedir, collection_cache=None):
95 super(CollectionFsAccess, self).__init__(basedir)
96 self.collection_cache = collection_cache
98 def get_collection(self, path):
99 sp = path.split("/", 1)
101 if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
102 arvados.util.collection_uuid_pattern.match(p[5:])):
104 rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
105 return (self.collection_cache.get(locator), rest)
109 def _match(self, collection, patternsegments, parent):
110 if not patternsegments:
113 if not isinstance(collection, arvados.collection.RichCollectionBase):
117 # iterate over the files and subcollections in 'collection'
118 for filename in collection:
119 if patternsegments[0] == '.':
120 # Pattern contains something like "./foo" so just shift
122 ret.extend(self._match(collection, patternsegments[1:], parent))
123 elif fnmatch.fnmatch(filename, patternsegments[0]):
124 cur = os.path.join(parent, filename)
125 if len(patternsegments) == 1:
128 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
131 def glob(self, pattern):
132 collection, rest = self.get_collection(pattern)
133 if collection is not None and rest in (None, "", "."):
135 patternsegments = rest.split("/")
136 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
138 def open(self, fn, mode, encoding=None):
139 collection, rest = self.get_collection(fn)
140 if collection is not None:
141 return collection.open(rest, mode, encoding=encoding)
143 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
145 def exists(self, fn):
147 collection, rest = self.get_collection(fn)
148 except HttpError as err:
149 if err.resp.status == 404:
153 except IOError as err:
154 if err.errno == errno.ENOENT:
158 if collection is not None:
160 return collection.exists(rest)
164 return super(CollectionFsAccess, self).exists(fn)
166 def size(self, fn): # type: (unicode) -> bool
167 collection, rest = self.get_collection(fn)
168 if collection is not None:
170 arvfile = collection.find(rest)
171 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
172 return arvfile.size()
173 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
175 return super(CollectionFsAccess, self).size(fn)
177 def isfile(self, fn): # type: (unicode) -> bool
178 collection, rest = self.get_collection(fn)
179 if collection is not None:
181 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
185 return super(CollectionFsAccess, self).isfile(fn)
187 def isdir(self, fn): # type: (unicode) -> bool
188 collection, rest = self.get_collection(fn)
189 if collection is not None:
191 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
195 return super(CollectionFsAccess, self).isdir(fn)
197 def listdir(self, fn): # type: (unicode) -> List[unicode]
198 collection, rest = self.get_collection(fn)
199 if collection is not None:
201 dir = collection.find(rest)
205 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
206 if not isinstance(dir, arvados.collection.RichCollectionBase):
207 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
208 return [abspath(l, fn) for l in list(dir.keys())]
210 return super(CollectionFsAccess, self).listdir(fn)
212 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
213 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
215 return os.path.join(path, *paths)
217 def realpath(self, path):
218 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
220 collection, rest = self.get_collection(path)
221 if collection is not None:
224 return os.path.realpath(path)
226 class CollectionFetcher(DefaultFetcher):
227 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
228 super(CollectionFetcher, self).__init__(cache, session)
229 self.api_client = api_client
230 self.fsaccess = fs_access
231 self.num_retries = num_retries
233 def fetch_text(self, url, content_types=None):
234 if url.startswith("keep:"):
235 with self.fsaccess.open(url, "r", encoding="utf-8") as f:
237 if url.startswith("arvwf:"):
238 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
239 yaml = ruamel.yaml.YAML(typ='rt', pure=True)
240 definition = yaml.load(record["definition"])
241 definition["label"] = record["name"]
243 yaml.dump(definition, stream)
244 return stream.getvalue()
245 return super(CollectionFetcher, self).fetch_text(url)
247 def check_exists(self, url):
249 if url.startswith("http://arvados.org/cwl"):
251 urld, _ = urllib.parse.urldefrag(url)
252 if urld.startswith("keep:"):
253 return self.fsaccess.exists(urld)
254 if urld.startswith("arvwf:"):
255 if self.fetch_text(urld):
257 except arvados.errors.NotFoundError:
260 logger.exception("Got unexpected exception checking if file exists")
262 return super(CollectionFetcher, self).check_exists(url)
264 def urljoin(self, base_url, url):
268 urlsp = urllib.parse.urlsplit(url)
269 if urlsp.scheme or not base_url:
272 basesp = urllib.parse.urlsplit(base_url)
273 if basesp.scheme in ("keep", "arvwf"):
275 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
277 baseparts = basesp.path.split("/")
278 urlparts = urlsp.path.split("/") if urlsp.path else []
280 locator = baseparts.pop(0)
282 if (basesp.scheme == "keep" and
283 (not arvados.util.keep_locator_pattern.match(locator)) and
284 (not arvados.util.collection_uuid_pattern.match(locator))):
285 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
287 if urlsp.path.startswith("/"):
291 if baseparts and urlsp.path:
294 path = "/".join([locator] + baseparts + urlparts)
295 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
297 return super(CollectionFetcher, self).urljoin(base_url, url)
299 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
301 def supported_schemes(self): # type: () -> List[Text]
305 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
306 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
308 def collectionResolver(api_client, document_loader, uri, num_retries=4):
309 if uri.startswith("keep:") or uri.startswith("arvwf:"):
312 if workflow_uuid_pattern.match(uri):
313 return u"arvwf:%s#main" % (uri)
315 if pipeline_template_uuid_pattern.match(uri):
316 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
317 return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
320 if arvados.util.keep_locator_pattern.match(p[0]):
321 return u"keep:%s" % (uri)
323 if arvados.util.collection_uuid_pattern.match(p[0]):
324 return u"keep:%s%s" % (api_client.collections().
325 get(uuid=p[0]).execute()["portable_data_hash"],
328 return cwltool.resolver.tool_resolver(document_loader, uri)