Merge branch '19982-spot-instance' refs #19982
[arvados.git] / sdk / python / tests / test_cmd_util.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 import contextlib
6 import copy
7 import itertools
8 import json
9 import os
10 import tempfile
11 import unittest
12
13 from pathlib import Path
14
15 from parameterized import parameterized
16
17 import arvados.commands._util as cmd_util
18
19 FILE_PATH = Path(__file__)
20
21 class ValidateFiltersTestCase(unittest.TestCase):
22     NON_FIELD_TYPES = [
23         None,
24         123,
25         ('name', '=', 'tuple'),
26         {'filters': ['name', '=', 'object']},
27     ]
28     NON_FILTER_TYPES = NON_FIELD_TYPES + ['string']
29     VALID_FILTERS = [
30         ['owner_uuid', '=', 'zzzzz-tpzed-12345abcde67890'],
31         ['name', 'in', ['foo', 'bar']],
32         '(replication_desired > replication_cofirmed)',
33         '(replication_confirmed>=replication_desired)',
34     ]
35
36     @parameterized.expand(itertools.combinations(VALID_FILTERS, 2))
37     def test_valid_filters(self, f1, f2):
38         expected = [f1, f2]
39         actual = cmd_util.validate_filters(copy.deepcopy(expected))
40         self.assertEqual(actual, expected)
41
42     @parameterized.expand([(t,) for t in NON_FILTER_TYPES])
43     def test_filters_wrong_type(self, value):
44         with self.assertRaisesRegex(ValueError, r'^filters are not a list\b'):
45             cmd_util.validate_filters(value)
46
47     @parameterized.expand([(t,) for t in NON_FIELD_TYPES])
48     def test_single_filter_wrong_type(self, value):
49         with self.assertRaisesRegex(ValueError, r'^filter at index 0 is not a string or list\b'):
50             cmd_util.validate_filters([value])
51
52     @parameterized.expand([
53         ([],),
54         (['owner_uuid'],),
55         (['owner_uuid', 'zzzzz-tpzed-12345abcde67890'],),
56         (['name', 'not in', 'foo', 'bar'],),
57         (['name', 'in', 'foo', 'bar', 'baz'],),
58     ])
59     def test_filters_wrong_arity(self, value):
60         with self.assertRaisesRegex(ValueError, r'^filter at index 0 does not have three items\b'):
61             cmd_util.validate_filters([value])
62
63     @parameterized.expand(itertools.product(
64         [0, 1],
65         NON_FIELD_TYPES,
66     ))
67     def test_filter_definition_wrong_type(self, index, bad_value):
68         value = ['owner_uuid', '=', 'zzzzz-tpzed-12345abcde67890']
69         value[index] = bad_value
70         name = ('field name', 'operator')[index]
71         with self.assertRaisesRegex(ValueError, rf'^filter at index 0 {name} is not a string\b'):
72             cmd_util.validate_filters([value])
73
74     @parameterized.expand([
75         # Not enclosed in parentheses
76         'foo = bar',
77         '(foo) < bar',
78         'foo > (bar)',
79         # Not exactly one operator
80         '(a >= b >= c)',
81         '(foo)',
82         '(file_count version)',
83         # Invalid field identifiers
84         '(version = 1)',
85         '(2 = file_count)',
86         '(replication.desired <= replication.confirmed)',
87         # Invalid whitespace
88         '(file_count\t=\tversion)',
89         '(file_count >= version\n)',
90     ])
91     def test_invalid_string_filter(self, value):
92         with self.assertRaisesRegex(ValueError, r'^filter at index 0 has invalid syntax\b'):
93             cmd_util.validate_filters([value])
94
95
96 class JSONArgumentTestCase(unittest.TestCase):
97     JSON_OBJECTS = [
98         None,
99         123,
100         456.789,
101         'string',
102         ['list', 1],
103         {'object': True, 'yaml': False},
104     ]
105
106     @classmethod
107     def setUpClass(cls):
108         cls.json_file = tempfile.NamedTemporaryFile(
109             'w+',
110             encoding='utf-8',
111             prefix='argtest',
112             suffix='.json',
113         )
114         cls.parser = cmd_util.JSONArgument()
115
116     @classmethod
117     def tearDownClass(cls):
118         cls.json_file.close()
119
120     def setUp(self):
121         self.json_file.seek(0)
122         self.json_file.truncate()
123
124     @parameterized.expand((obj,) for obj in JSON_OBJECTS)
125     def test_valid_argument_string(self, obj):
126         actual = self.parser(json.dumps(obj))
127         self.assertEqual(actual, obj)
128
129     @parameterized.expand((obj,) for obj in JSON_OBJECTS)
130     def test_valid_argument_path(self, obj):
131         json.dump(obj, self.json_file)
132         self.json_file.flush()
133         actual = self.parser(self.json_file.name)
134         self.assertEqual(actual, obj)
135
136     @parameterized.expand([
137         '',
138         '\0',
139         None,
140     ])
141     def test_argument_not_json_or_path(self, value):
142         if value is None:
143             with tempfile.NamedTemporaryFile() as gone_file:
144                 value = gone_file.name
145         with self.assertRaisesRegex(ValueError, r'\bnot a valid JSON string or file path\b'):
146             self.parser(value)
147
148     @parameterized.expand([
149         FILE_PATH.parent,
150         FILE_PATH / 'nonexistent.json',
151         None,
152     ])
153     def test_argument_path_unreadable(self, path):
154         if path is None:
155             bad_file = tempfile.NamedTemporaryFile()
156             os.chmod(bad_file.fileno(), 0o000)
157             path = bad_file.name
158             @contextlib.contextmanager
159             def ctx():
160                 try:
161                     yield
162                 finally:
163                     os.chmod(bad_file.fileno(), 0o600)
164         else:
165             ctx = contextlib.nullcontext
166         with self.assertRaisesRegex(ValueError, rf'^error reading JSON file path {str(path)!r}: '), ctx():
167             self.parser(str(path))
168
169     @parameterized.expand([
170         FILE_PATH,
171         None,
172     ])
173     def test_argument_path_not_json(self, path):
174         if path is None:
175             path = self.json_file.name
176         with self.assertRaisesRegex(ValueError, rf'^error decoding JSON from file {str(path)!r}'):
177             self.parser(str(path))
178
179
180 class JSONArgumentValidationTestCase(unittest.TestCase):
181     @parameterized.expand((obj,) for obj in JSONArgumentTestCase.JSON_OBJECTS)
182     def test_object_returned_from_validator(self, value):
183         parser = cmd_util.JSONArgument(lambda _: copy.deepcopy(value))
184         self.assertEqual(parser('{}'), value)
185
186     @parameterized.expand((obj,) for obj in JSONArgumentTestCase.JSON_OBJECTS)
187     def test_exception_raised_from_validator(self, value):
188         json_value = json.dumps(value)
189         def raise_func(_):
190             raise ValueError(json_value)
191         parser = cmd_util.JSONArgument(raise_func)
192         with self.assertRaises(ValueError) as exc_check:
193             parser(json_value)
194         self.assertEqual(exc_check.exception.args, (json_value,))