15397: Fix improperly quoted regexps.
[arvados.git] / sdk / python / tests / test_arv_get.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import io
6 import logging
7 import os
8 import re
9 import shutil
10 import tempfile
11
12 from unittest import mock
13
14 import arvados
15 import arvados.collection as collection
16 import arvados.commands.get as arv_get
17 from . import run_test_server
18
19 from . import arvados_testutil as tutil
20 from .arvados_testutil import ArvadosBaseTestCase
21
22 class ArvadosGetTestCase(run_test_server.TestCaseWithServers,
23                          tutil.VersionChecker,
24                          ArvadosBaseTestCase):
25     MAIN_SERVER = {}
26     KEEP_SERVER = {}
27
28     def setUp(self):
29         super(ArvadosGetTestCase, self).setUp()
30         self.tempdir = tempfile.mkdtemp()
31         self.col_loc, self.col_pdh, self.col_manifest = self.write_test_collection()
32
33         self.stdout = tutil.BytesIO()
34         self.stderr = tutil.StringIO()
35         self.loggingHandler = logging.StreamHandler(self.stderr)
36         self.loggingHandler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
37         logging.getLogger().addHandler(self.loggingHandler)
38
39     def tearDown(self):
40         logging.getLogger().removeHandler(self.loggingHandler)
41         super(ArvadosGetTestCase, self).tearDown()
42         shutil.rmtree(self.tempdir)
43
44     def write_test_collection(self,
45                               strip_manifest=False,
46                               contents = {
47                                   'foo.txt' : 'foo',
48                                   'bar.txt' : 'bar',
49                                   'subdir/baz.txt' : 'baz',
50                               }):
51         api = arvados.api()
52         c = collection.Collection(api_client=api)
53         for path, data in contents.items():
54             with c.open(path, 'wb') as f:
55                 f.write(data)
56         c.save_new()
57
58         api.close_connections()
59
60         return (c.manifest_locator(),
61                 c.portable_data_hash(),
62                 c.manifest_text(strip=strip_manifest))
63
64     def run_get(self, args):
65         self.stdout.seek(0, 0)
66         self.stdout.truncate(0)
67         self.stderr.seek(0, 0)
68         self.stderr.truncate(0)
69         return arv_get.main(args, self.stdout, self.stderr)
70
71     def test_version_argument(self):
72         with tutil.redirected_streams(
73                 stdout=tutil.StringIO, stderr=tutil.StringIO) as (out, err):
74             with self.assertRaises(SystemExit):
75                 self.run_get(['--version'])
76         self.assertVersionOutput(out, err)
77
78     def test_get_single_file(self):
79         # Get the file using the collection's locator
80         r = self.run_get(["{}/subdir/baz.txt".format(self.col_loc), '-'])
81         self.assertEqual(0, r)
82         self.assertEqual(b'baz', self.stdout.getvalue())
83         # Then, try by PDH
84         r = self.run_get(["{}/subdir/baz.txt".format(self.col_pdh), '-'])
85         self.assertEqual(0, r)
86         self.assertEqual(b'baz', self.stdout.getvalue())
87
88     def test_get_block(self):
89         # Get raw data using a block locator
90         blk = re.search(r' (acbd18\S+\+A\S+) ', self.col_manifest).group(1)
91         r = self.run_get([blk, '-'])
92         self.assertEqual(0, r)
93         self.assertEqual(b'foo', self.stdout.getvalue())
94
95     def test_get_multiple_files(self):
96         # Download the entire collection to the temp directory
97         r = self.run_get(["{}/".format(self.col_loc), self.tempdir])
98         self.assertEqual(0, r)
99         with open(os.path.join(self.tempdir, "foo.txt"), "r") as f:
100             self.assertEqual("foo", f.read())
101         with open(os.path.join(self.tempdir, "bar.txt"), "r") as f:
102             self.assertEqual("bar", f.read())
103         with open(os.path.join(self.tempdir, "subdir", "baz.txt"), "r") as f:
104             self.assertEqual("baz", f.read())
105
106     def test_get_collection_unstripped_manifest(self):
107         dummy_token = "+Axxxxxxx"
108         # Get the collection manifest by UUID
109         r = self.run_get([self.col_loc, self.tempdir])
110         self.assertEqual(0, r)
111         m_from_collection = re.sub(r"\+A[0-9a-f@]+", dummy_token, self.col_manifest)
112         with open(os.path.join(self.tempdir, self.col_loc), "r") as f:
113             # Replace manifest tokens before comparison to avoid races
114             m_from_file = re.sub(r"\+A[0-9a-f@]+", dummy_token, f.read())
115             self.assertEqual(m_from_collection, m_from_file)
116         # Get the collection manifest by PDH
117         r = self.run_get([self.col_pdh, self.tempdir])
118         self.assertEqual(0, r)
119         with open(os.path.join(self.tempdir, self.col_pdh), "r") as f:
120             # Replace manifest tokens before comparison to avoid races
121             m_from_file = re.sub(r"\+A[0-9a-f@]+", dummy_token, f.read())
122             self.assertEqual(m_from_collection, m_from_file)
123
124     def test_get_collection_stripped_manifest(self):
125         col_loc, col_pdh, col_manifest = self.write_test_collection(
126             strip_manifest=True)
127         # Get the collection manifest by UUID
128         r = self.run_get(['--strip-manifest', col_loc, self.tempdir])
129         self.assertEqual(0, r)
130         with open(os.path.join(self.tempdir, col_loc), "r") as f:
131             self.assertEqual(col_manifest, f.read())
132         # Get the collection manifest by PDH
133         r = self.run_get(['--strip-manifest', col_pdh, self.tempdir])
134         self.assertEqual(0, r)
135         with open(os.path.join(self.tempdir, col_pdh), "r") as f:
136             self.assertEqual(col_manifest, f.read())
137
138     def test_invalid_collection(self):
139         # Asking for an invalid collection should generate an error.
140         r = self.run_get(['this-uuid-seems-to-be-fake', self.tempdir])
141         self.assertNotEqual(0, r)
142
143     def test_invalid_file_request(self):
144         # Asking for an inexistant file within a collection should generate an error.
145         r = self.run_get(["{}/im-not-here.txt".format(self.col_loc), self.tempdir])
146         self.assertNotEqual(0, r)
147
148     def test_invalid_destination(self):
149         # Asking to place the collection's files on a non existant directory
150         # should generate an error.
151         r = self.run_get([self.col_loc, "/fake/subdir/"])
152         self.assertNotEqual(0, r)
153
154     def test_preexistent_destination(self):
155         # Asking to place a file with the same path as a local one should
156         # generate an error and avoid overwrites.
157         with open(os.path.join(self.tempdir, "foo.txt"), "w") as f:
158             f.write("another foo")
159         r = self.run_get(["{}/foo.txt".format(self.col_loc), self.tempdir])
160         self.assertNotEqual(0, r)
161         with open(os.path.join(self.tempdir, "foo.txt"), "r") as f:
162             self.assertEqual("another foo", f.read())
163
164     def test_no_progress_when_stderr_not_a_tty(self):
165         # Create a collection with a big file (>64MB) to force the progress
166         # to be printed
167         c = collection.Collection()
168         with c.open('bigfile.txt', 'wb') as f:
169             for _ in range(65):
170                 f.write("x" * 1024 * 1024)
171         c.save_new()
172         tmpdir = self.make_tmpdir()
173         # Simulate a TTY stderr
174         stderr = mock.MagicMock()
175         stdout = tutil.BytesIO()
176
177         # Confirm that progress is written to stderr when is a tty
178         stderr.isatty.return_value = True
179         r = arv_get.main(['{}/bigfile.txt'.format(c.manifest_locator()),
180                           '{}/bigfile.txt'.format(tmpdir)],
181                          stdout, stderr)
182         self.assertEqual(0, r)
183         self.assertEqual(b'', stdout.getvalue())
184         self.assertTrue(stderr.write.called)
185
186         # Clean up and reset stderr mock
187         os.remove('{}/bigfile.txt'.format(tmpdir))
188         stderr = mock.MagicMock()
189         stdout = tutil.BytesIO()
190
191         # Confirm that progress is not written to stderr when isn't a tty
192         stderr.isatty.return_value = False
193         r = arv_get.main(['{}/bigfile.txt'.format(c.manifest_locator()),
194                           '{}/bigfile.txt'.format(tmpdir)],
195                          stdout, stderr)
196         self.assertEqual(0, r)
197         self.assertEqual(b'', stdout.getvalue())
198         self.assertFalse(stderr.write.called)
199
200     request_id_regex = r'INFO: X-Request-Id: req-[a-z0-9]{20}\n'
201
202     def test_request_id_logging_on(self):
203         r = self.run_get(["-v", "{}/".format(self.col_loc), self.tempdir])
204         self.assertEqual(0, r)
205         self.assertRegex(self.stderr.getvalue(), self.request_id_regex)
206
207     def test_request_id_logging_off(self):
208         r = self.run_get(["{}/".format(self.col_loc), self.tempdir])
209         self.assertEqual(0, r)
210         self.assertNotRegex(self.stderr.getvalue(), self.request_id_regex)