def main(arguments=None, stdout=sys.stdout, stderr=sys.stderr):
global api_client
+ if stdout is sys.stdout and hasattr(stdout, 'buffer'):
+ # in Python 3, write to stdout as binary
+ stdout = stdout.buffer
+
args = parse_arguments(arguments, stdout, stderr)
if api_client is None:
api_client = arvados.api('v1')
open_flags |= os.O_EXCL
try:
if args.destination == "-":
- stdout.write(reader.manifest_text())
+ stdout.write(reader.manifest_text().encode())
else:
out_fd = os.open(args.destination, open_flags)
with os.fdopen(out_fd, 'wb') as out_file:
- out_file.write(reader.manifest_text())
+ out_file.write(reader.manifest_text().encode())
except (IOError, OSError) as error:
logger.error("can't write to '{}': {}".format(args.destination, error))
return 1
if args.hash:
digestor = hashlib.new(args.hash)
try:
- with s.open(f.name, 'r') as file_reader:
+ with s.open(f.name, 'rb') as file_reader:
for data in file_reader.readall():
if outfile:
outfile.write(data)
import arvados.commands.get as arv_get
from . import run_test_server
-from .arvados_testutil import redirected_streams
+from . import arvados_testutil as tutil
-class ArvadosGetTestCase(run_test_server.TestCaseWithServers):
+class ArvadosGetTestCase(run_test_server.TestCaseWithServers,
+ tutil.VersionChecker):
MAIN_SERVER = {}
KEEP_SERVER = {}
}):
c = collection.Collection()
for path, data in listitems(contents):
- with c.open(path, 'w') as f:
+ with c.open(path, 'wb') as f:
f.write(data)
c.save_new()
return (c.manifest_locator(), c.portable_data_hash(), c.manifest_text())
def run_get(self, args):
- self.stdout = io.BytesIO()
- self.stderr = io.BytesIO()
+ self.stdout = tutil.BytesIO()
+ self.stderr = tutil.StringIO()
return arv_get.main(args, self.stdout, self.stderr)
def test_version_argument(self):
- err = io.BytesIO()
- out = io.BytesIO()
- with redirected_streams(stdout=out, stderr=err):
+ with tutil.redirected_streams(
+ stdout=tutil.StringIO, stderr=tutil.StringIO) as (out, err):
with self.assertRaises(SystemExit):
self.run_get(['--version'])
- self.assertEqual(out.getvalue(), '')
- self.assertRegexpMatches(err.getvalue(), "[0-9]+\.[0-9]+\.[0-9]+")
+ self.assertVersionOutput(out, err)
def test_get_single_file(self):
# Get the file using the collection's locator
r = self.run_get(["{}/subdir/baz.txt".format(self.col_loc), '-'])
self.assertEqual(0, r)
- self.assertEqual('baz', self.stdout.getvalue())
+ self.assertEqual(b'baz', self.stdout.getvalue())
# Then, try by PDH
r = self.run_get(["{}/subdir/baz.txt".format(self.col_pdh), '-'])
self.assertEqual(0, r)
- self.assertEqual('baz', self.stdout.getvalue())
+ self.assertEqual(b'baz', self.stdout.getvalue())
def test_get_multiple_files(self):
# Download the entire collection to the temp directory