1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
12 from collections import OrderedDict
14 import ruamel.yaml as yaml
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
21 import arvados.collection
22 import arvados.arvfile
25 from googleapiclient.errors import HttpError
27 from schema_salad.ref_resolver import DefaultFetcher
29 logger = logging.getLogger('arvados.cwl-runner')
31 pdh_size = re.compile(r'([0-9a-f]{32})\+(\d+)(\+\S+)*')
33 class CollectionCache(object):
34 def __init__(self, api_client, keep_client, num_retries,
37 self.api_client = api_client
38 self.keep_client = keep_client
39 self.num_retries = num_retries
40 self.collections = OrderedDict()
41 self.lock = threading.Lock()
44 self.min_entries = min_entries
46 def set_cap(self, cap):
49 def cap_cache(self, required):
50 # ordered dict iterates from oldest to newest
51 for pdh, v in self.collections.items():
52 available = self.cap - self.total
53 if available >= required or len(self.collections) < self.min_entries:
56 logger.debug("Evicting collection reader %s from cache (cap %s total %s required %s)", pdh, self.cap, self.total, required)
57 del self.collections[pdh]
62 if pdh not in self.collections:
63 m = pdh_size.match(pdh)
65 self.cap_cache(int(m.group(2)) * 128)
66 logger.debug("Creating collection reader for %s", pdh)
67 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
68 keep_client=self.keep_client,
69 num_retries=self.num_retries)
70 sz = len(cr.manifest_text()) * 128
71 self.collections[pdh] = (cr, sz)
74 cr, sz = self.collections[pdh]
76 del self.collections[pdh]
77 self.collections[pdh] = (cr, sz)
81 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
82 """Implement the cwltool FsAccess interface for Arvados Collections."""
84 def __init__(self, basedir, collection_cache=None):
85 super(CollectionFsAccess, self).__init__(basedir)
86 self.collection_cache = collection_cache
88 def get_collection(self, path):
89 sp = path.split("/", 1)
91 if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
93 return (self.collection_cache.get(pdh), urlparse.unquote(sp[1]) if len(sp) == 2 else None)
97 def _match(self, collection, patternsegments, parent):
98 if not patternsegments:
101 if not isinstance(collection, arvados.collection.RichCollectionBase):
105 # iterate over the files and subcollections in 'collection'
106 for filename in collection:
107 if patternsegments[0] == '.':
108 # Pattern contains something like "./foo" so just shift
110 ret.extend(self._match(collection, patternsegments[1:], parent))
111 elif fnmatch.fnmatch(filename, patternsegments[0]):
112 cur = os.path.join(parent, filename)
113 if len(patternsegments) == 1:
116 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
119 def glob(self, pattern):
120 collection, rest = self.get_collection(pattern)
121 if collection is not None and not rest:
123 patternsegments = rest.split("/")
124 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
126 def open(self, fn, mode):
127 collection, rest = self.get_collection(fn)
128 if collection is not None:
129 return collection.open(rest, mode)
131 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
133 def exists(self, fn):
135 collection, rest = self.get_collection(fn)
136 except HttpError as err:
137 if err.resp.status == 404:
141 if collection is not None:
143 return collection.exists(rest)
147 return super(CollectionFsAccess, self).exists(fn)
149 def size(self, fn): # type: (unicode) -> bool
150 collection, rest = self.get_collection(fn)
151 if collection is not None:
153 arvfile = collection.find(rest)
154 if isinstance(arvfile, arvados.arvfile.ArvadosFile):
155 return arvfile.size()
156 raise IOError(errno.EINVAL, "Not a path to a file %s" % (fn))
158 return super(CollectionFsAccess, self).size(fn)
160 def isfile(self, fn): # type: (unicode) -> bool
161 collection, rest = self.get_collection(fn)
162 if collection is not None:
164 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
168 return super(CollectionFsAccess, self).isfile(fn)
170 def isdir(self, fn): # type: (unicode) -> bool
171 collection, rest = self.get_collection(fn)
172 if collection is not None:
174 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
178 return super(CollectionFsAccess, self).isdir(fn)
180 def listdir(self, fn): # type: (unicode) -> List[unicode]
181 collection, rest = self.get_collection(fn)
182 if collection is not None:
184 dir = collection.find(rest)
188 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
189 if not isinstance(dir, arvados.collection.RichCollectionBase):
190 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
191 return [abspath(l, fn) for l in dir.keys()]
193 return super(CollectionFsAccess, self).listdir(fn)
195 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
196 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
198 return os.path.join(path, *paths)
200 def realpath(self, path):
201 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
203 collection, rest = self.get_collection(path)
204 if collection is not None:
207 return os.path.realpath(path)
209 class CollectionFetcher(DefaultFetcher):
210 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
211 super(CollectionFetcher, self).__init__(cache, session)
212 self.api_client = api_client
213 self.fsaccess = fs_access
214 self.num_retries = num_retries
216 def fetch_text(self, url):
217 if url.startswith("keep:"):
218 with self.fsaccess.open(url, "r") as f:
220 if url.startswith("arvwf:"):
221 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
222 definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
224 return super(CollectionFetcher, self).fetch_text(url)
226 def check_exists(self, url):
228 if url.startswith("http://arvados.org/cwl"):
230 if url.startswith("keep:"):
231 return self.fsaccess.exists(url)
232 if url.startswith("arvwf:"):
233 if self.fetch_text(url):
235 except arvados.errors.NotFoundError:
238 logger.exception("Got unexpected exception checking if file exists:")
240 return super(CollectionFetcher, self).check_exists(url)
242 def urljoin(self, base_url, url):
246 urlsp = urlparse.urlsplit(url)
247 if urlsp.scheme or not base_url:
250 basesp = urlparse.urlsplit(base_url)
251 if basesp.scheme in ("keep", "arvwf"):
253 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
255 baseparts = basesp.path.split("/")
256 urlparts = urlsp.path.split("/") if urlsp.path else []
258 pdh = baseparts.pop(0)
260 if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
261 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
263 if urlsp.path.startswith("/"):
267 if baseparts and urlsp.path:
270 path = "/".join([pdh] + baseparts + urlparts)
271 return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
273 return super(CollectionFetcher, self).urljoin(base_url, url)
275 schemes = [u"file", u"http", u"https", u"mailto", u"keep", u"arvwf"]
277 def supported_schemes(self): # type: () -> List[Text]
281 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
282 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
284 def collectionResolver(api_client, document_loader, uri, num_retries=4):
285 if uri.startswith("keep:") or uri.startswith("arvwf:"):
288 if workflow_uuid_pattern.match(uri):
289 return "arvwf:%s#main" % (uri)
291 if pipeline_template_uuid_pattern.match(uri):
292 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
293 return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
296 if arvados.util.keep_locator_pattern.match(p[0]):
297 return "keep:%s" % (uri)
299 if arvados.util.collection_uuid_pattern.match(p[0]):
300 return "keep:%s%s" % (api_client.collections().
301 get(uuid=p[0]).execute()["portable_data_hash"],
304 return cwltool.resolver.tool_resolver(document_loader, uri)