1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
18 from collections import OrderedDict
20 import ruamel.yaml as yaml
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
27 import arvados.collection
28 import arvados.arvfile
31 from googleapiclient.errors import HttpError
33 from schema_salad.ref_resolver import DefaultFetcher
35 logger = logging.getLogger('arvados.cwl-runner')
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
39 class CollectionCache(object):
40 def __init__(self, api_client, keep_client, num_retries,
43 self.api_client = api_client
44 self.keep_client = keep_client
45 self.num_retries = num_retries
46 self.collections = OrderedDict()
47 self.lock = threading.Lock()
50 self.min_entries = min_entries
52 def set_cap(self, cap):
55 def cap_cache(self, required):
56 # ordered dict iterates from oldest to newest
57 for pdh, v in list(self.collections.items()):
58 available = self.cap - self.total
59 if available >= required or len(self.collections) < self.min_entries:
62 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63 del self.collections[pdh]
66 def get(self, locator):
68 if locator not in self.collections:
69 m = pdh_size.match(locator)
71 self.cap_cache(int(m.group(2)) * 128)
72 logger.debug("Creating collection reader for %s", locator)
74 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75 keep_client=self.keep_client,
76 num_retries=self.num_retries)
77 except arvados.errors.ApiError as ap:
78 raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79 sz = len(cr.manifest_text()) * 128
80 self.collections[locator] = (cr, sz)
83 cr, sz = self.collections[locator]
85 del self.collections[locator]
86 self.collections[locator] = (cr, sz)
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91 """Implement the cwltool FsAccess interface for Arvados Collections."""
93 def __init__(self, basedir, collection_cache=None):
94 super(CollectionFsAccess, self).__init__(basedir)
95 self.collection_cache = collection_cache
97 def get_collection(self, path):
98 sp = path.split("/", 1)
100 if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101 arvados.util.collection_uuid_pattern.match(p[5:])):
103 rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
104 return (self.collection_cache.get(locator), rest)
108 def _match(self, collection, patternsegments, parent):
109 if not patternsegments:
112 if not isinstance(collection, arvados.collection.RichCollectionBase):
116 # iterate over the files and subcollections in 'collection'
117 for filename in collection:
118 if patternsegments[0] == '.':
119 # Pattern contains something like "./foo" so just shift
121 ret.extend(self._match(collection, patternsegments[1:], parent))
122 elif fnmatch.fnmatch(filename, patternsegments[0]):
123 cur = os.path.join(parent, filename)
124 if len(patternsegments) == 1:
127 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
130 def glob(self, pattern):
131 collection, rest = self.get_collection(pattern)
132 if collection is not None and rest in (None, "", "."):
134 patternsegments = rest.split("/")
135 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
137 def open(self, fn, mode, encoding=None):
138 collection, rest = self.get_collection(fn)
139 if collection is not None:
140 return collection.open(rest, mode, encoding=encoding)
142 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
144 def exists(self, fn):
146 collection, rest = self.get_collection(fn)
147 except HttpError as err:
148 if err.resp.status == 404:
152 except IOError as err:
153 if err.errno == errno.ENOENT:
157 if collection is not None:
159 return collection.exists(rest)
163 return super(CollectionFsAccess, self).exists(fn)
165 def size(self, fn): # type: (unicode) -> bool
166 collection, rest = self.get_collection(fn)
167 if collection is not None:
169 arvfile = collection.find(rest)
170 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
171 return arvfile.size()
172 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
174 return super(CollectionFsAccess, self).size(fn)
176 def isfile(self, fn): # type: (unicode) -> bool
177 collection, rest = self.get_collection(fn)
178 if collection is not None:
180 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
184 return super(CollectionFsAccess, self).isfile(fn)
186 def isdir(self, fn): # type: (unicode) -> bool
187 collection, rest = self.get_collection(fn)
188 if collection is not None:
190 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
194 return super(CollectionFsAccess, self).isdir(fn)
196 def listdir(self, fn): # type: (unicode) -> List[unicode]
197 collection, rest = self.get_collection(fn)
198 if collection is not None:
200 dir = collection.find(rest)
204 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
205 if not isinstance(dir, arvados.collection.RichCollectionBase):
206 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
207 return [abspath(l, fn) for l in list(dir.keys())]
209 return super(CollectionFsAccess, self).listdir(fn)
211 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
212 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
214 return os.path.join(path, *paths)
216 def realpath(self, path):
217 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
219 collection, rest = self.get_collection(path)
220 if collection is not None:
223 return os.path.realpath(path)
225 class CollectionFetcher(DefaultFetcher):
226 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
227 super(CollectionFetcher, self).__init__(cache, session)
228 self.api_client = api_client
229 self.fsaccess = fs_access
230 self.num_retries = num_retries
232 def fetch_text(self, url, content_types=None):
233 if url.startswith("keep:"):
234 with self.fsaccess.open(url, "r", encoding="utf-8") as f:
236 if url.startswith("arvwf:"):
237 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
238 definition = yaml.round_trip_load(record["definition"])
239 definition["label"] = record["name"]
240 return yaml.round_trip_dump(definition)
241 return super(CollectionFetcher, self).fetch_text(url)
243 def check_exists(self, url):
245 if url.startswith("http://arvados.org/cwl"):
247 urld, _ = urllib.parse.urldefrag(url)
248 if urld.startswith("keep:"):
249 return self.fsaccess.exists(urld)
250 if urld.startswith("arvwf:"):
251 if self.fetch_text(urld):
253 except arvados.errors.NotFoundError:
256 logger.exception("Got unexpected exception checking if file exists")
258 return super(CollectionFetcher, self).check_exists(url)
260 def urljoin(self, base_url, url):
264 urlsp = urllib.parse.urlsplit(url)
265 if urlsp.scheme or not base_url:
268 basesp = urllib.parse.urlsplit(base_url)
269 if basesp.scheme in ("keep", "arvwf"):
271 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
273 baseparts = basesp.path.split("/")
274 urlparts = urlsp.path.split("/") if urlsp.path else []
276 locator = baseparts.pop(0)
278 if (basesp.scheme == "keep" and
279 (not arvados.util.keep_locator_pattern.match(locator)) and
280 (not arvados.util.collection_uuid_pattern.match(locator))):
281 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
283 if urlsp.path.startswith("/"):
287 if baseparts and urlsp.path:
290 path = "/".join([locator] + baseparts + urlparts)
291 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
293 return super(CollectionFetcher, self).urljoin(base_url, url)
295 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
297 def supported_schemes(self): # type: () -> List[Text]
301 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
302 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
304 def collectionResolver(api_client, document_loader, uri, num_retries=4):
305 if uri.startswith("keep:") or uri.startswith("arvwf:"):
308 if workflow_uuid_pattern.match(uri):
309 return u"arvwf:%s#main" % (uri)
311 if pipeline_template_uuid_pattern.match(uri):
312 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
313 return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
316 if arvados.util.keep_locator_pattern.match(p[0]):
317 return u"keep:%s" % (uri)
319 if arvados.util.collection_uuid_pattern.match(p[0]):
320 return u"keep:%s%s" % (api_client.collections().
321 get(uuid=p[0]).execute()["portable_data_hash"],
324 return cwltool.resolver.tool_resolver(document_loader, uri)