Merge branch 'master' into 11840-unique-constraint-untrash-coll
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
1 import fnmatch
2 import os
3 import errno
4 import urlparse
5 import re
6 import logging
7 import threading
8
9 import ruamel.yaml as yaml
10
11 import cwltool.stdfsaccess
12 from cwltool.pathmapper import abspath
13 import cwltool.resolver
14
15 import arvados.util
16 import arvados.collection
17 import arvados.arvfile
18 import arvados.errors
19
20 from schema_salad.ref_resolver import DefaultFetcher
21
22 logger = logging.getLogger('arvados.cwl-runner')
23
24 class CollectionCache(object):
25     def __init__(self, api_client, keep_client, num_retries):
26         self.api_client = api_client
27         self.keep_client = keep_client
28         self.collections = {}
29         self.lock = threading.Lock()
30
31     def get(self, pdh):
32         with self.lock:
33             if pdh not in self.collections:
34                 logger.debug("Creating collection reader for %s", pdh)
35                 self.collections[pdh] = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
36                                                                             keep_client=self.keep_client)
37             return self.collections[pdh]
38
39
40 class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
41     """Implement the cwltool FsAccess interface for Arvados Collections."""
42
43     def __init__(self, basedir, collection_cache=None):
44         super(CollectionFsAccess, self).__init__(basedir)
45         self.collection_cache = collection_cache
46
47     def get_collection(self, path):
48         sp = path.split("/", 1)
49         p = sp[0]
50         if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
51             pdh = p[5:]
52             return (self.collection_cache.get(pdh), sp[1] if len(sp) == 2 else None)
53         else:
54             return (None, path)
55
56     def _match(self, collection, patternsegments, parent):
57         if not patternsegments:
58             return []
59
60         if not isinstance(collection, arvados.collection.RichCollectionBase):
61             return []
62
63         ret = []
64         # iterate over the files and subcollections in 'collection'
65         for filename in collection:
66             if patternsegments[0] == '.':
67                 # Pattern contains something like "./foo" so just shift
68                 # past the "./"
69                 ret.extend(self._match(collection, patternsegments[1:], parent))
70             elif fnmatch.fnmatch(filename, patternsegments[0]):
71                 cur = os.path.join(parent, filename)
72                 if len(patternsegments) == 1:
73                     ret.append(cur)
74                 else:
75                     ret.extend(self._match(collection[filename], patternsegments[1:], cur))
76         return ret
77
78     def glob(self, pattern):
79         collection, rest = self.get_collection(pattern)
80         if collection and not rest:
81             return [pattern]
82         patternsegments = rest.split("/")
83         return self._match(collection, patternsegments, "keep:" + collection.manifest_locator())
84
85     def open(self, fn, mode):
86         collection, rest = self.get_collection(fn)
87         if collection:
88             return collection.open(rest, mode)
89         else:
90             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
91
92     def exists(self, fn):
93         collection, rest = self.get_collection(fn)
94         if collection:
95             if rest:
96                 return collection.exists(rest)
97             else:
98                 return True
99         else:
100             return super(CollectionFsAccess, self).exists(fn)
101
102     def isfile(self, fn):  # type: (unicode) -> bool
103         collection, rest = self.get_collection(fn)
104         if collection:
105             if rest:
106                 return isinstance(collection.find(rest), arvados.arvfile.ArvadosFile)
107             else:
108                 return False
109         else:
110             return super(CollectionFsAccess, self).isfile(fn)
111
112     def isdir(self, fn):  # type: (unicode) -> bool
113         collection, rest = self.get_collection(fn)
114         if collection:
115             if rest:
116                 return isinstance(collection.find(rest), arvados.collection.RichCollectionBase)
117             else:
118                 return True
119         else:
120             return super(CollectionFsAccess, self).isdir(fn)
121
122     def listdir(self, fn):  # type: (unicode) -> List[unicode]
123         collection, rest = self.get_collection(fn)
124         if collection:
125             if rest:
126                 dir = collection.find(rest)
127             else:
128                 dir = collection
129             if dir is None:
130                 raise IOError(errno.ENOENT, "Directory '%s' in '%s' not found" % (rest, collection.portable_data_hash()))
131             if not isinstance(dir, arvados.collection.RichCollectionBase):
132                 raise IOError(errno.ENOENT, "Path '%s' in '%s' is not a Directory" % (rest, collection.portable_data_hash()))
133             return [abspath(l, fn) for l in dir.keys()]
134         else:
135             return super(CollectionFsAccess, self).listdir(fn)
136
137     def join(self, path, *paths): # type: (unicode, *unicode) -> unicode
138         if paths and paths[-1].startswith("keep:") and arvados.util.keep_locator_pattern.match(paths[-1][5:]):
139             return paths[-1]
140         return os.path.join(path, *paths)
141
142     def realpath(self, path):
143         if path.startswith("$(task.tmpdir)") or path.startswith("$(task.outdir)"):
144             return path
145         collection, rest = self.get_collection(path)
146         if collection:
147             return path
148         else:
149             return os.path.realpath(path)
150
151 class CollectionFetcher(DefaultFetcher):
152     def __init__(self, cache, session, api_client=None, fs_access=None, num_retries=4, overrides=None):
153         super(CollectionFetcher, self).__init__(cache, session)
154         self.api_client = api_client
155         self.fsaccess = fs_access
156         self.num_retries = num_retries
157         self.overrides = overrides if overrides else {}
158
159     def fetch_text(self, url):
160         if url in self.overrides:
161             return self.overrides[url]
162         if url.startswith("keep:"):
163             with self.fsaccess.open(url, "r") as f:
164                 return f.read()
165         if url.startswith("arvwf:"):
166             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
167             definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
168             return definition
169         return super(CollectionFetcher, self).fetch_text(url)
170
171     def check_exists(self, url):
172         if url in self.overrides:
173             return True
174         try:
175             if url.startswith("http://arvados.org/cwl"):
176                 return True
177             if url.startswith("keep:"):
178                 return self.fsaccess.exists(url)
179             if url.startswith("arvwf:"):
180                 if self.fetch_text(url):
181                     return True
182         except arvados.errors.NotFoundError:
183             return False
184         except:
185             logger.exception("Got unexpected exception checking if file exists:")
186             return False
187         return super(CollectionFetcher, self).check_exists(url)
188
189     def urljoin(self, base_url, url):
190         if not url:
191             return base_url
192
193         urlsp = urlparse.urlsplit(url)
194         if urlsp.scheme or not base_url:
195             return url
196
197         basesp = urlparse.urlsplit(base_url)
198         if basesp.scheme in ("keep", "arvwf"):
199             if not basesp.path:
200                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
201
202             baseparts = basesp.path.split("/")
203             urlparts = urlsp.path.split("/") if urlsp.path else []
204
205             pdh = baseparts.pop(0)
206
207             if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
208                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
209
210             if urlsp.path.startswith("/"):
211                 baseparts = []
212                 urlparts.pop(0)
213
214             if baseparts and urlsp.path:
215                 baseparts.pop()
216
217             path = "/".join([pdh] + baseparts + urlparts)
218             return urlparse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
219
220         return super(CollectionFetcher, self).urljoin(base_url, url)
221
222 workflow_uuid_pattern = re.compile(r'[a-z0-9]{5}-7fd4e-[a-z0-9]{15}')
223 pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
224
225 def collectionResolver(api_client, document_loader, uri, num_retries=4):
226     if workflow_uuid_pattern.match(uri):
227         return "arvwf:%s#main" % (uri)
228
229     if pipeline_template_uuid_pattern.match(uri):
230         pt = api_client.pipeline_templates().get(uuid=uri).execute(num_retries=num_retries)
231         return "keep:" + pt["components"].values()[0]["script_parameters"]["cwl:tool"]
232
233     p = uri.split("/")
234     if arvados.util.keep_locator_pattern.match(p[0]):
235         return "keep:%s" % (uri)
236
237     if arvados.util.collection_uuid_pattern.match(p[0]):
238         return "keep:%s%s" % (api_client.collections().
239                               get(uuid=p[0]).execute()["portable_data_hash"],
240                               uri[len(p[0]):])
241
242     return cwltool.resolver.tool_resolver(document_loader, uri)