raise Exception("Output source is not in keep or a literal")
sp = k.split("/")
srccollection = sp[0][5:]
- reader = self.collection_cache.get(srccollection)
try:
+ reader = self.collection_cache.get(srccollection)
srcpath = "/".join(sp[1:]) if len(sp) > 1 else "."
final.copy(srcpath, v.target, source_collection=reader, overwrite=False)
+ except arvados.errors.ArgumentError as e:
+ logger.error("Creating CollectionReader for '%s' '%s': %s", k, v, e)
+ raise
except IOError as e:
logger.warn("While preparing output collection: %s", e)
import urlparse
import re
import logging
+import threading
import ruamel.yaml as yaml
self.api_client = api_client
self.keep_client = keep_client
self.collections = {}
+ self.lock = threading.Lock()
def get(self, pdh):
- if pdh not in self.collections:
- logger.debug("Creating collection reader for %s", pdh)
- self.collections[pdh] = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
- keep_client=self.keep_client)
- return self.collections[pdh]
+ with self.lock:
+ if pdh not in self.collections:
+ logger.debug("Creating collection reader for %s", pdh)
+ self.collections[pdh] = arvados.collection.CollectionReader(pdh, api_client=self.api_client,
+ keep_client=self.keep_client)
+ return self.collections[pdh]
class CollectionFsAccess(cwltool.stdfsaccess.StdFsAccess):
--- /dev/null
+import functools
+import mock
+import sys
+import unittest
+import json
+import logging
+import os
+
+import arvados
+import arvados.keep
+import arvados.collection
+import arvados_cwl
+
+from cwltool.pathmapper import MapperEnt
+from .mock_discovery import get_rootDesc
+
+from arvados_cwl.fsaccess import CollectionCache
+
+class TestFsAccess(unittest.TestCase):
+ @mock.patch("arvados.collection.CollectionReader")
+ def test_collection_cache(self, cr):
+ cache = CollectionCache(mock.MagicMock(), mock.MagicMock(), 4)
+ c1 = cache.get("99999999999999999999999999999991+99")
+ c2 = cache.get("99999999999999999999999999999991+99")
+ self.assertIs(c1, c2)
+ self.assertEqual(1, cr.call_count)
+ c3 = cache.get("99999999999999999999999999999992+99")
+ self.assertEqual(2, cr.call_count)