X-Git-Url: https://git.arvados.org/arvados.git/blobdiff_plain/fee8873d0c5eeec1bd838161357679de1a3fe0cb..65622f423c2ee35250856657b06118481d53edc8:/sdk/python/tests/test_arv_ls.py diff --git a/sdk/python/tests/test_arv_ls.py b/sdk/python/tests/test_arv_ls.py index ed03c124f4..e3f6c128aa 100644 --- a/sdk/python/tests/test_arv_ls.py +++ b/sdk/python/tests/test_arv_ls.py @@ -1,21 +1,20 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import io +from __future__ import absolute_import +from builtins import str +from builtins import range import os import random import sys import mock import tempfile -import multiprocessing import arvados.errors as arv_error import arvados.commands.ls as arv_ls -import run_test_server +from . import run_test_server -from arvados_testutil import str_keep_locator +from . import arvados_testutil as tutil +from .arvados_testutil import str_keep_locator, redirected_streams, StringIO -class ArvLsTestCase(run_test_server.TestCaseWithServers): +class ArvLsTestCase(run_test_server.TestCaseWithServers, tutil.VersionChecker): FAKE_UUID = 'zzzzz-4zz18-12345abcde12345' def newline_join(self, seq): @@ -36,29 +35,10 @@ class ArvLsTestCase(run_test_server.TestCaseWithServers): api_client.collections().get().execute.return_value = coll_info return coll_info, api_client - def run_ls(self, args, api_client): - self.stdout = io.BytesIO() - self.stderr = io.BytesIO() - return arv_ls.main(args, self.stdout, self.stderr, api_client) - - def run_ls_process(self, args=[], api_client=None): - _, stdout_path = tempfile.mkstemp() - _, stderr_path = tempfile.mkstemp() - def wrap(): - def wrapper(*args, **kwargs): - sys.stdout = open(stdout_path, 'w') - sys.stderr = open(stderr_path, 'w') - arv_ls.main(*args, **kwargs) - return wrapper - p = multiprocessing.Process(target=wrap(), - args=(args, sys.stdout, sys.stderr, api_client)) - p.start() - p.join() - out = open(stdout_path, 'r').read() - err = open(stderr_path, 'r').read() - os.unlink(stdout_path) - os.unlink(stderr_path) - return p.exitcode, out, err + def run_ls(self, args, api_client, logger=None): + self.stdout = StringIO() + self.stderr = StringIO() + return arv_ls.main(args, self.stdout, self.stderr, api_client, logger) def test_plain_listing(self): collection, api_client = self.mock_api_for_manifest( @@ -96,15 +76,16 @@ class ArvLsTestCase(run_test_server.TestCaseWithServers): def test_locator_failure(self): api_client = mock.MagicMock(name='mock_api_client') + error_mock = mock.MagicMock() + logger = mock.MagicMock() + logger.error = error_mock api_client.collections().get().execute.side_effect = ( arv_error.NotFoundError) - self.assertNotEqual(0, self.run_ls([self.FAKE_UUID], api_client)) - self.assertNotEqual('', self.stderr.getvalue()) + self.assertNotEqual(0, self.run_ls([self.FAKE_UUID], api_client, logger)) + self.assertEqual(1, error_mock.call_count) def test_version_argument(self): - _, api_client = self.mock_api_for_manifest(['']) - exitcode, out, err = self.run_ls_process(['--version']) - self.assertEqual(0, exitcode) - self.assertEqual('', out) - self.assertNotEqual('', err) - self.assertRegexpMatches(err, "[0-9]+\.[0-9]+\.[0-9]+") + with redirected_streams(stdout=StringIO, stderr=StringIO) as (out, err): + with self.assertRaises(SystemExit): + self.run_ls(['--version'], None) + self.assertVersionOutput(out, err)