1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
5 from future import standard_library
6 standard_library.install_aliases()
7 from builtins import object
8 from builtins import str
9 from future.utils import viewvalues
18 from collections import OrderedDict
20 import ruamel.yaml as yaml
22 import cwltool.stdfsaccess
23 from cwltool.pathmapper import abspath
24 import cwltool.resolver
27 import arvados.collection
28 import arvados.arvfile
31 from googleapiclient.errors import HttpError
33 from schema_salad.ref_resolver import DefaultFetcher
35 logger = logging.getLogger('arvados.cwl-runner')
37 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
39 class CollectionCache(object):
40 def __init__(self, api_client, keep_client, num_retries,
43 self.api_client = api_client
44 self.keep_client = keep_client
45 self.num_retries = num_retries
46 self.collections = OrderedDict()
47 self.lock = threading.Lock()
50 self.min_entries = min_entries
52 def set_cap(self, cap):
55 def cap_cache(self, required):
56 # ordered dict iterates from oldest to newest
57 for pdh, v in list(self.collections.items()):
58 available = self.cap - self.total
59 if available >= required or len(self.collections) < self.min_entries:
62 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
63 del self.collections[pdh]
66 def get(self, locator):
68 if locator not in self.collections:
69 m = pdh_size.match(locator)
71 self.cap_cache(int(m.group(2)) * 128)
72 logger.debug("Creating collection reader for %s", locator)
74 cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
75 keep_client=self.keep_client,
76 num_retries=self.num_retries)
77 except arvados.errors.ApiError as ap:
78 raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
79 sz = len(cr.manifest_text()) * 128
80 self.collections[locator] = (cr, sz)
83 cr, sz = self.collections[locator]
85 del self.collections[locator]
86 self.collections[locator] = (cr, sz)
90 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
91 """Implement the cwltool FsAccess interface for Arvados Collections."""
93 def __init__(self, basedir, collection_cache=None):
94 super(CollectionFsAccess, self).__init__(basedir)
95 self.collection_cache = collection_cache
97 def get_collection(self, path):
98 sp = path.split("/", 1)
100 if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
101 arvados.util.collection_uuid_pattern.match(p[5:])):
103 rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
104 return (self.collection_cache.get(locator), rest)
108 def _match(self, collection, patternsegments, parent):
109 if not patternsegments:
112 if not isinstance(collection, arvados.collection.RichCollectionBase):
116 # iterate over the files and subcollections in 'collection'
117 for filename in collection:
118 if patternsegments[0] == '.':
119 # Pattern contains something like "./foo" so just shift
121 ret.extend(self._match(collection, patternsegments[1:], parent))
122 elif fnmatch.fnmatch(filename, patternsegments[0]):
123 cur = os.path.join(parent, filename)
124 if len(patternsegments) == 1:
127 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
130 def glob(self, pattern):
131 collection, rest = self.get_collection(pattern)
132 if collection is not None and rest in (None, "", "."):
134 patternsegments = rest.split("/")
135 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
137 def open(self, fn, mode, encoding=None):
138 collection, rest = self.get_collection(fn)
139 if collection is not None:
140 return collection.open(rest, mode, encoding=encoding)
142 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
144 def exists(self, fn):
146 collection, rest = self.get_collection(fn)
147 except HttpError as err:
148 if err.resp.status == 404:
152 except IOError as err:
153 if err.errno == errno.ENOENT:
157 if collection is not None:
159 return collection.exists(rest)
163 return super(CollectionFsAccess, self).exists(fn)
165 def size(self, fn): # type: (unicode) -> bool
166 collection, rest = self.get_collection(fn)
167 if collection is not None:
169 arvfile = collection.find(rest)
170 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
171 return arvfile.size()
172 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
174 return super(CollectionFsAccess, self).size(fn)
176 def isfile(self, fn): # type: (unicode) -> bool
177 collection, rest = self.get_collection(fn)
178 if collection is not None:
180 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
184 return super(CollectionFsAccess, self).isfile(fn)
186 def isdir(self, fn): # type: (unicode) -> bool
187 collection, rest = self.get_collection(fn)
188 if collection is not None:
190 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
194 return super(CollectionFsAccess, self).isdir(fn)
196 def listdir(self, fn): # type: (unicode) -> List[unicode]
197 collection, rest = self.get_collection(fn)
198 if collection is not None:
200 dir = collection.find(rest)
204 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
205 if not isinstance(dir, arvados.collection.RichCollectionBase):
206 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
207 return [abspath(l, fn) for l in list(dir.keys())]
209 return super(CollectionFsAccess, self).listdir(fn)
211 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
212 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
214 return os.path.join(path, *paths)
216 def realpath(self, path):
217 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
219 collection, rest = self.get_collection(path)
220 if collection is not None:
223 return os.path.realpath(path)
225 class CollectionFetcher(DefaultFetcher):
226 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
227 super(CollectionFetcher, self).__init__(cache, session)
228 self.api_client = api_client
229 self.fsaccess = fs_access
230 self.num_retries = num_retries
232 def fetch_text(self, url, content_types=None):
233 if url.startswith("keep:"):
234 with self.fsaccess.open(url, "r", encoding="utf-8") as f:
236 if url.startswith("arvwf:"):
237 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
238 definition = yaml.round_trip_load(record["definition"])
239 definition["label"] = record["name"]
240 return yaml.round_trip_dump(definition)
241 return super(CollectionFetcher, self).fetch_text(url)
243 def check_exists(self, url):
245 if url.startswith("http://arvados.org/cwl"):
247 if url.startswith("keep:"):
248 return self.fsaccess.exists(url)
249 if url.startswith("arvwf:"):
250 if self.fetch_text(url):
252 except arvados.errors.NotFoundError:
255 logger.exception("Got unexpected exception checking if file exists")
257 return super(CollectionFetcher, self).check_exists(url)
259 def urljoin(self, base_url, url):
263 urlsp = urllib.parse.urlsplit(url)
264 if urlsp.scheme or not base_url:
267 basesp = urllib.parse.urlsplit(base_url)
268 if basesp.scheme in ("keep", "arvwf"):
270 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
272 baseparts = basesp.path.split("/")
273 urlparts = urlsp.path.split("/") if urlsp.path else []
275 locator = baseparts.pop(0)
277 if (basesp.scheme == "keep" and
278 (not arvados.util.keep_locator_pattern.match(locator)) and
279 (not arvados.util.collection_uuid_pattern.match(locator))):
280 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
282 if urlsp.path.startswith("/"):
286 if baseparts and urlsp.path:
289 path = "/".join([locator] + baseparts + urlparts)
290 return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
292 return super(CollectionFetcher, self).urljoin(base_url, url)
294 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
296 def supported_schemes(self): # type: () -> List[Text]
300 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
301 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
303 def collectionResolver(api_client, document_loader, uri, num_retries=4):
304 if uri.startswith("keep:") or uri.startswith("arvwf:"):
307 if workflow_uuid_pattern.match(uri):
308 return u"arvwf:%s#main" % (uri)
310 if pipeline_template_uuid_pattern.match(uri):
311 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
312 return u"keep:" + viewvalues(pt["components"])[0]["script_parameters"]["cwl:tool"]
315 if arvados.util.keep_locator_pattern.match(p[0]):
316 return u"keep:%s" % (uri)
318 if arvados.util.collection_uuid_pattern.match(p[0]):
319 return u"keep:%s%s" % (api_client.collections().
320 get(uuid=p[0]).execute()["portable_data_hash"],
323 return cwltool.resolver.tool_resolver(document_loader, uri)