Merge branch '11950-stretch'
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import fnmatch
6 import os
7 import errno
8 import urlparse
9 import re
10 import logging
11 import threading
12
13 import ruamel.yaml as yaml
14
15 import cwltool.stdfsaccess
16 from cwltool.pathmapper import abspath
17 import cwltool.resolver
18
19 import arvados.util
20 import arvados.collection
21 import arvados.arvfile
22 import arvados.errors
23
24 from schema_salad.ref_resolver import DefaultFetcher
25
26 logger = logging.getLogger('arvados.cwl-runner')
27
28 class CollectionCache(object):
29     def __init__(self, api_client, keep_client, num_retries):
30         self.api_client = api_client
31         self.keep_client = keep_client
32         self.collections = {}
33         self.lock = threading.Lock()
34
35     def get(self, pdh):
36         with self.lock:
37             if pdh not in self.collections:
38                 logger.debug("Creating collection reader for %s", pdh)
39                 self.collections[pdh] = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
40                                                                             keep_client=self.keep_client)
41             return self.collections[pdh]
42
43
44 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
45     """Implement the cwltool FsAccess interface for Arvados Collections."""
46
47     def __init__(self, basedir, collection_cache=None):
48         super(CollectionFsAccess, self).__init__(basedir)
49         self.collection_cache = collection_cache
50
51     def get_collection(self, path):
52         sp = path.split("/", 1)
53         p = sp[0]
54         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
55             pdh = p[5:]
56             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
57         else:
58             return (None, path)
59
60     def _match(self, collection, patternsegments, parent):
61         if not patternsegments:
62             return []
63
64         if not isinstance(collection, arvados.collection.RichCollectionBase):
65             return []
66
67         ret = []
68         # iterate over the files and subcollections in 'collection'
69         for filename in collection:
70             if patternsegments[0] == '.':
71                 # Pattern contains something like "./foo" so just shift
72                 # past the "./"
73                 ret.extend(self._match(collection, patternsegments[1:], parent))
74             elif fnmatch.fnmatch(filename, patternsegments[0]):
75                 cur = os.path.join(parent, filename)
76                 if len(patternsegments) == 1:
77                     ret.append(cur)
78                 else:
79                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
80         return ret
81
82     def glob(self, pattern):
83         collection, rest = self.get_collection(pattern)
84         if collection and not rest:
85             return [pattern]
86         patternsegments = rest.split("/")
87         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
88
89     def open(self, fn, mode):
90         collection, rest = self.get_collection(fn)
91         if collection:
92             return collection.open(rest, mode)
93         else:
94             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
95
96     def exists(self, fn):
97         collection, rest = self.get_collection(fn)
98         if collection:
99             if rest:
100                 return collection.exists(rest)
101             else:
102                 return True
103         else:
104             return super(CollectionFsAccess, self).exists(fn)
105
106     def isfile(self, fn):  # type: (unicode) -> bool
107         collection, rest = self.get_collection(fn)
108         if collection:
109             if rest:
110                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
111             else:
112                 return False
113         else:
114             return super(CollectionFsAccess, self).isfile(fn)
115
116     def isdir(self, fn):  # type: (unicode) -> bool
117         collection, rest = self.get_collection(fn)
118         if collection:
119             if rest:
120                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
121             else:
122                 return True
123         else:
124             return super(CollectionFsAccess, self).isdir(fn)
125
126     def listdir(self, fn):  # type: (unicode) -> List[unicode]
127         collection, rest = self.get_collection(fn)
128         if collection:
129             if rest:
130                 dir = collection.find(rest)
131             else:
132                 dir = collection
133             if dir is None:
134                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
135             if not isinstance(dir, arvados.collection.RichCollectionBase):
136                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
137             return [abspath(l, fn) for l in dir.keys()]
138         else:
139             return super(CollectionFsAccess, self).listdir(fn)
140
141     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
142         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
143             return paths[-1]
144         return os.path.join(path, *paths)
145
146     def realpath(self, path):
147         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
148             return path
149         collection, rest = self.get_collection(path)
150         if collection:
151             return path
152         else:
153             return os.path.realpath(path)
154
155 class CollectionFetcher(DefaultFetcher):
156     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4, overrides=None):
157         super(CollectionFetcher, self).__init__(cache, session)
158         self.api_client = api_client
159         self.fsaccess = fs_access
160         self.num_retries = num_retries
161         self.overrides = overrides if overrides else {}
162
163     def fetch_text(self, url):
164         if url in self.overrides:
165             return self.overrides[url]
166         if url.startswith("keep:"):
167             with self.fsaccess.open(url, "r") as f:
168                 return f.read()
169         if url.startswith("arvwf:"):
170             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
171             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
172             return definition
173         return super(CollectionFetcher, self).fetch_text(url)
174
175     def check_exists(self, url):
176         if url in self.overrides:
177             return True
178         try:
179             if url.startswith("http://arvados.org/cwl"):
180                 return True
181             if url.startswith("keep:"):
182                 return self.fsaccess.exists(url)
183             if url.startswith("arvwf:"):
184                 if self.fetch_text(url):
185                     return True
186         except arvados.errors.NotFoundError:
187             return False
188         except:
189             logger.exception("Got unexpected exception checking if file exists:")
190             return False
191         return super(CollectionFetcher, self).check_exists(url)
192
193     def urljoin(self, base_url, url):
194         if not url:
195             return base_url
196
197         urlsp = urlparse.urlsplit(url)
198         if urlsp.scheme or not base_url:
199             return url
200
201         basesp = urlparse.urlsplit(base_url)
202         if basesp.scheme in ("keep", "arvwf"):
203             if not basesp.path:
204                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
205
206             baseparts = basesp.path.split("/")
207             urlparts = urlsp.path.split("/") if urlsp.path else []
208
209             pdh = baseparts.pop(0)
210
211             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
212                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
213
214             if urlsp.path.startswith("/"):
215                 baseparts = []
216                 urlparts.pop(0)
217
218             if baseparts and urlsp.path:
219                 baseparts.pop()
220
221             path = "/".join([pdh] + baseparts + urlparts)
222             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
223
224         return super(CollectionFetcher, self).urljoin(base_url, url)
225
226 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
227 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
228
229 def collectionResolver(api_client, document_loader, uri, num_retries=4):
230     if workflow_uuid_pattern.match(uri):
231         return "arvwf:%s#main" % (uri)
232
233     if pipeline_template_uuid_pattern.match(uri):
234         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
235         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
236
237     p = uri.split("/")
238     if arvados.util.keep_locator_pattern.match(p[0]):
239         return "keep:%s" % (uri)
240
241     if arvados.util.collection_uuid_pattern.match(p[0]):
242         return "keep:%s%s" % (api_client.collections().
243                               get(uuid=p[0]).execute()["portable_data_hash"],
244                               uri[len(p[0]):])
245
246     return cwltool.resolver.tool_resolver(document_loader, uri)