Merge branch 'main' into 18842-arv-mount-disk-config
[arvados.git] / sdk / cwl / arvados_cwl / fsaccess.py
index fc0bbb80b9ea05715770c2d3142e7208ca19dcb2..5c09e671fa21eac1952c417e10580d332e3612be 100644 (file)
@@ -5,6 +5,7 @@
 from future import standard_library
 standard_library.install_aliases()
 from builtins import object
+from builtins import str
 from future.utils import viewvalues
 
 import fnmatch
@@ -62,24 +63,27 @@ class CollectionCache(object):
             del self.collections[pdh]
             self.total -= v[1]
 
-    def get(self, pdh):
+    def get(self, locator):
         with self.lock:
-            if pdh not in self.collections:
-                m = pdh_size.match(pdh)
+            if locator not in self.collections:
+                m = pdh_size.match(locator)
                 if m:
                     self.cap_cache(int(m.group(2)) * 128)
-                logger.debug("Creating collection reader for %s", pdh)
-                cr = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
-                                                         keep_client=self.keep_client,
-                                                         num_retries=self.num_retries)
+                logger.debug("Creating collection reader for %s", locator)
+                try:
+                    cr = arvados.collection.CollectionReader(locator, api_client=self.api_client,
+                                                             keep_client=self.keep_client,
+                                                             num_retries=self.num_retries)
+                except arvados.errors.ApiError as ap:
+                    raise IOError(errno.ENOENT, "Could not access collection '%s': %s" % (locator, str(ap._get_reason())))
                 sz = len(cr.manifest_text()) * 128
-                self.collections[pdh] = (cr, sz)
+                self.collections[locator] = (cr, sz)
                 self.total += sz
             else:
-                cr, sz = self.collections[pdh]
+                cr, sz = self.collections[locator]
                 # bump it to the back
-                del self.collections[pdh]
-                self.collections[pdh] = (cr, sz)
+                del self.collections[locator]
+                self.collections[locator] = (cr, sz)
             return cr
 
 
@@ -93,9 +97,11 @@ class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
     def get_collection(self, path):
         sp = path.split("/", 1)
         p = sp[0]
-        if p.startswith("keep:") and arvados.util.keep_locator_pattern.match(p[5:]):
-            pdh = p[5:]
-            return (self.collection_cache.get(pdh), urllib.parse.unquote(sp[1]) if len(sp) == 2 else None)
+        if p.startswith("keep:") and (arvados.util.keep_locator_pattern.match(p[5:]) or
+                                      arvados.util.collection_uuid_pattern.match(p[5:])):
+            locator = p[5:]
+            rest = os.path.normpath(urllib.parse.unquote(sp[1])) if len(sp) == 2 else None
+            return (self.collection_cache.get(locator), rest)
         else:
             return (None, path)
 
@@ -123,15 +129,15 @@ class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
 
     def glob(self, pattern):
         collection, rest = self.get_collection(pattern)
-        if collection is not None and not rest:
+        if collection is not None and rest in (None, "", "."):
             return [pattern]
         patternsegments = rest.split("/")
         return sorted(self._match(collection, patternsegments, "keep:" + collection.manifest_locator()))
 
-    def open(self, fn, mode):
+    def open(self, fn, mode, encoding=None):
         collection, rest = self.get_collection(fn)
         if collection is not None:
-            return collection.open(rest, mode)
+            return collection.open(rest, mode, encoding=encoding)
         else:
             return super(CollectionFsAccess, self).open(self._abs(fn), mode)
 
@@ -143,6 +149,11 @@ class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
                 return False
             else:
                 raise
+        except IOError as err:
+            if err.errno == errno.ENOENT:
+                return False
+            else:
+                raise
         if collection is not None:
             if rest:
                 return collection.exists(rest)
@@ -218,29 +229,31 @@ class CollectionFetcher(DefaultFetcher):
         self.fsaccess = fs_access
         self.num_retries = num_retries
 
-    def fetch_text(self, url):
+    def fetch_text(self, url, content_types=None):
         if url.startswith("keep:"):
-            with self.fsaccess.open(url, "r") as f:
+            with self.fsaccess.open(url, "r", encoding="utf-8") as f:
                 return f.read()
         if url.startswith("arvwf:"):
             record = self.api_client.workflows().get(uuid=url[6:]).execute(num_retries=self.num_retries)
-            definition = record["definition"] + ('\nlabel: "%s"\n' % record["name"].replace('"', '\\"'))
-            return definition
+            definition = yaml.round_trip_load(record["definition"])
+            definition["label"] = record["name"]
+            return yaml.round_trip_dump(definition)
         return super(CollectionFetcher, self).fetch_text(url)
 
     def check_exists(self, url):
         try:
             if url.startswith("http://arvados.org/cwl"):
                 return True
-            if url.startswith("keep:"):
-                return self.fsaccess.exists(url)
-            if url.startswith("arvwf:"):
-                if self.fetch_text(url):
+            urld, _ = urllib.parse.urldefrag(url)
+            if urld.startswith("keep:"):
+                return self.fsaccess.exists(urld)
+            if urld.startswith("arvwf:"):
+                if self.fetch_text(urld):
                     return True
         except arvados.errors.NotFoundError:
             return False
-        except:
-            logger.exception("Got unexpected exception checking if file exists:")
+        except Exception:
+            logger.exception("Got unexpected exception checking if file exists")
             return False
         return super(CollectionFetcher, self).check_exists(url)
 
@@ -260,9 +273,11 @@ class CollectionFetcher(DefaultFetcher):
             baseparts = basesp.path.split("/")
             urlparts = urlsp.path.split("/") if urlsp.path else []
 
-            pdh = baseparts.pop(0)
+            locator = baseparts.pop(0)
 
-            if basesp.scheme == "keep" and not arvados.util.keep_locator_pattern.match(pdh):
+            if (basesp.scheme == "keep" and
+                (not arvados.util.keep_locator_pattern.match(locator)) and
+                (not arvados.util.collection_uuid_pattern.match(locator))):
                 raise IOError(errno.EINVAL, "Invalid Keep locator", base_url)
 
             if urlsp.path.startswith("/"):
@@ -272,7 +287,7 @@ class CollectionFetcher(DefaultFetcher):
             if baseparts and urlsp.path:
                 baseparts.pop()
 
-            path = "/".join([pdh] + baseparts + urlparts)
+            path = "/".join([locator] + baseparts + urlparts)
             return urllib.parse.urlunsplit((basesp.scheme, "", path, "", urlsp.fragment))
 
         return super(CollectionFetcher, self).urljoin(base_url, url)
@@ -288,7 +303,7 @@ pipeline_template_uuid_pattern = re.compile(r'[a-z0-9]{5}-p5p6p-[a-z0-9]{15}')
 
 def collectionResolver(api_client, document_loader, uri, num_retries=4):
     if uri.startswith("keep:") or uri.startswith("arvwf:"):
-        return uri.encode("utf-8").decode()
+        return str(uri)
 
     if workflow_uuid_pattern.match(uri):
         return u"arvwf:%s#main" % (uri)