1 # Copyright (C) The Arvados Authors. All rights reserved.
3 # SPDX-License-Identifier: Apache-2.0
12 from collections import OrderedDict
14 import ruamel.yaml as yaml
16 import cwltool.stdfsaccess
17 from cwltool.pathmapper import abspath
18 import cwltool.resolver
21 import arvados.collection
22 import arvados.arvfile
25 from googleapiclient.errors import HttpError
27 from schema_salad.ref_resolver import DefaultFetcher
29 logger = logging.getLogger('arvados.cwl-runner')
31 class CollectionCache(object):
32 def __init__(self, api_client, keep_client, num_retries,
35 self.api_client = api_client
36 self.keep_client = keep_client
37 self.num_retries = num_retries
38 self.collections = OrderedDict()
39 self.lock = threading.Lock()
42 self.min_entries = min_entries
45 if self.total > self.cap:
46 # ordered list iterates from oldest to newest
47 for pdh, v in self.collections.items():
48 if self.total < self.cap or len(self.collections) < self.min_entries:
51 logger.debug("Evicting collection reader %s from cache", pdh)
52 del self.collections[pdh]
57 if pdh not in self.collections:
58 logger.debug("Creating collection reader for %s", pdh)
59 cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
60 keep_client=self.keep_client,
61 num_retries=self.num_retries)
62 sz = len(cr.manifest_text()) * 128
63 self.collections[pdh] = (cr, sz)
67 cr, sz = self.collections[pdh]
69 del self.collections[pdh]
70 self.collections[pdh] = (cr, sz)
74 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
75 """Implement the cwltool FsAccess interface for Arvados Collections."""
77 def __init__(self, basedir, collection_cache=None):
78 super(CollectionFsAccess, self).__init__(basedir)
79 self.collection_cache = collection_cache
81 def get_collection(self, path):
82 sp = path.split("/", 1)
84 if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
86 return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
90 def _match(self, collection, patternsegments, parent):
91 if not patternsegments:
94 if not isinstance(collection, arvados.collection.RichCollectionBase):
98 # iterate over the files and subcollections in 'collection'
99 for filename in collection:
100 if patternsegments[0] == '.':
101 # Pattern contains something like "./foo" so just shift
103 ret.extend(self._match(collection, patternsegments[1:], parent))
104 elif fnmatch.fnmatch(filename, patternsegments[0]):
105 cur = os.path.join(parent, filename)
106 if len(patternsegments) == 1:
109 ret.extend(self._match(collection[filename], patternsegments[1:], cur))
112 def glob(self, pattern):
113 collection, rest = self.get_collection(pattern)
114 if collection is not None and not rest:
116 patternsegments = rest.split("/")
117 return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
119 def open(self, fn, mode):
120 collection, rest = self.get_collection(fn)
121 if collection is not None:
122 return collection.open(rest, mode)
124 return super(CollectionFsAccess, self).open(self._abs(fn), mode)
126 def exists(self, fn):
128 collection, rest = self.get_collection(fn)
129 except HttpError as err:
130 if err.resp.status == 404:
134 if collection is not None:
136 return collection.exists(rest)
140 return super(CollectionFsAccess, self).exists(fn)
142 def isfile(self, fn): # type: (unicode) -> bool
143 collection, rest = self.get_collection(fn)
144 if collection is not None:
146 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
150 return super(CollectionFsAccess, self).isfile(fn)
152 def isdir(self, fn): # type: (unicode) -> bool
153 collection, rest = self.get_collection(fn)
154 if collection is not None:
156 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
160 return super(CollectionFsAccess, self).isdir(fn)
162 def listdir(self, fn): # type: (unicode) -> List[unicode]
163 collection, rest = self.get_collection(fn)
164 if collection is not None:
166 dir = collection.find(rest)
170 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
171 if not isinstance(dir, arvados.collection.RichCollectionBase):
172 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
173 return [abspath(l, fn) for l in dir.keys()]
175 return super(CollectionFsAccess, self).listdir(fn)
177 def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
178 if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
180 return os.path.join(path, *paths)
182 def realpath(self, path):
183 if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
185 collection, rest = self.get_collection(path)
186 if collection is not None:
189 return os.path.realpath(path)
191 class CollectionFetcher(DefaultFetcher):
192 def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4):
193 super(CollectionFetcher, self).__init__(cache, session)
194 self.api_client = api_client
195 self.fsaccess = fs_access
196 self.num_retries = num_retries
198 def fetch_text(self, url):
199 if url.startswith("keep:"):
200 with self.fsaccess.open(url, "r") as f:
202 if url.startswith("arvwf:"):
203 record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
204 definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
206 return super(CollectionFetcher, self).fetch_text(url)
208 def check_exists(self, url):
210 if url.startswith("http://arvados.org/cwl"):
212 if url.startswith("keep:"):
213 return self.fsaccess.exists(url)
214 if url.startswith("arvwf:"):
215 if self.fetch_text(url):
217 except arvados.errors.NotFoundError:
220 logger.exception("Got unexpected exception checking if file exists:")
222 return super(CollectionFetcher, self).check_exists(url)
224 def urljoin(self, base_url, url):
228 urlsp = urlparse.urlsplit(url)
229 if urlsp.scheme or not base_url:
232 basesp = urlparse.urlsplit(base_url)
233 if basesp.scheme in ("keep", "arvwf"):
235 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
237 baseparts = basesp.path.split("/")
238 urlparts = urlsp.path.split("/") if urlsp.path else []
240 pdh = baseparts.pop(0)
242 if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
243 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
245 if urlsp.path.startswith("/"):
249 if baseparts and urlsp.path:
252 path = "/".join([pdh] + baseparts + urlparts)
253 return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
255 return super(CollectionFetcher, self).urljoin(base_url, url)
257 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
258 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
260 def collectionResolver(api_client, document_loader, uri, num_retries=4):
261 if uri.startswith("keep:") or uri.startswith("arvwf:"):
264 if workflow_uuid_pattern.match(uri):
265 return "arvwf:%s#main" % (uri)
267 if pipeline_template_uuid_pattern.match(uri):
268 pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
269 return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
272 if arvados.util.keep_locator_pattern.match(p[0]):
273 return "keep:%s" % (uri)
275 if arvados.util.collection_uuid_pattern.match(p[0]):
276 return "keep:%s%s" % (api_client.collections().
277 get(uuid=p[0]).execute()["portable_data_hash"],
280 return cwltool.resolver.tool_resolver(document_loader, uri)