add transparent bz2 decompression, some tests and fixes
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15
16 from apiclient import errors
17 from apiclient.discovery import build
18
19 class CredentialsFromEnv:
20     @staticmethod
21     def http_request(self, uri, **kwargs):
22         from httplib import BadStatusLine
23         if 'headers' not in kwargs:
24             kwargs['headers'] = {}
25         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
26         try:
27             return self.orig_http_request(uri, **kwargs)
28         except BadStatusLine:
29             # This is how httplib tells us that it tried to reuse an
30             # existing connection but it was already closed by the
31             # server. In that case, yes, we would like to retry.
32             # Unfortunately, we are not absolutely certain that the
33             # previous call did not succeed, so this is slightly
34             # risky.
35             return self.orig_http_request(uri, **kwargs)
36     def authorize(self, http):
37         http.orig_http_request = http.request
38         http.request = types.MethodType(self.http_request, http)
39         return http
40
41 url = ('https://%s/discovery/v1/apis/'
42        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
43 credentials = CredentialsFromEnv()
44 http = httplib2.Http()
45 http = credentials.authorize(http)
46 http.disable_ssl_certificate_validation=True
47 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
48
49 def task_set_output(self,s):
50     service.job_tasks().update(uuid=self['uuid'],
51                                job_task=json.dumps({
52                 'output':s,
53                 'success':True,
54                 'progress':1.0
55                 })).execute()
56
57 _current_task = None
58 def current_task():
59     global _current_task
60     if _current_task:
61         return _current_task
62     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
63     t = UserDict.UserDict(t)
64     t.set_output = types.MethodType(task_set_output, t)
65     _current_task = t
66     return t
67
68 _current_job = None
69 def current_job():
70     global _current_job
71     if _current_job:
72         return _current_job
73     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
74     _current_job = t
75     return t
76
77 def api():
78     return service
79
80 class JobTask:
81     def __init__(self, parameters=dict(), resource_limits=dict()):
82         print "init jobtask %s %s" % (parameters, resource_limits)
83
84 class job_setup:
85     @staticmethod
86     def one_task_per_input_file(if_sequence=0, and_end_task=True):
87         if if_sequence != current_task()['sequence']:
88             return
89         job_input = current_job()['script_parameters']['input']
90         cr = CollectionReader(job_input)
91         for s in cr.all_streams():
92             for f in s.all_files():
93                 task_input = f.as_manifest()
94                 new_task_attrs = {
95                     'job_uuid': current_job()['uuid'],
96                     'created_by_job_task': current_task()['uuid'],
97                     'sequence': if_sequence + 1,
98                     'parameters': {
99                         'input':task_input
100                         }
101                     }
102                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
103         if and_end_task:
104             service.job_tasks().update(uuid=current_task()['uuid'],
105                                        job_task=json.dumps({'success':True})
106                                        ).execute()
107             exit(0)
108
109 class DataReader:
110     def __init__(self, data_locator):
111         self.data_locator = data_locator
112         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
113                                   stdout=subprocess.PIPE,
114                                   stdin=None, stderr=subprocess.PIPE,
115                                   shell=False, close_fds=True)
116     def __enter__(self):
117         pass
118     def __exit__(self):
119         self.close()
120     def read(self, size, **kwargs):
121         return self.p.stdout.read(size, **kwargs)
122     def close(self):
123         self.p.stdout.close()
124         if not self.p.stderr.closed:
125             for err in self.p.stderr:
126                 print >> sys.stderr, err
127             self.p.stderr.close()
128         self.p.wait()
129         if self.p.returncode != 0:
130             raise Exception("whget subprocess exited %d" % self.p.returncode)
131
132 class StreamFileReader:
133     def __init__(self, stream, pos, size, name):
134         self._stream = stream
135         self._pos = pos
136         self._size = size
137         self._name = name
138         self._filepos = 0
139     def name(self):
140         return self._name
141     def size(self):
142         return self._size
143     def stream_name(self):
144         return self._stream.name()
145     def read(self, size, **kwargs):
146         self._stream.seek(self._pos + self._filepos)
147         data = self._stream.read(min(size, self._size - self._filepos))
148         self._filepos += len(data)
149         return data
150     def readall(self, size, **kwargs):
151         while True:
152             data = self.read(size, **kwargs)
153             if data == '':
154                 break
155             yield data
156     def bunzip2(self, size):
157         decompressor = bz2.BZ2Decompressor()
158         for chunk in self.readall(size):
159             data = decompressor.decompress(chunk)
160             if data and data != '':
161                 yield data
162     def readlines(self, decompress=True):
163         self._stream.seek(self._pos + self._filepos)
164         if decompress and re.search('\.bz2$', self._name):
165             datasource = self.bunzip2(2**10)
166         else:
167             datasource = self.readall(2**10)
168         data = ''
169         for newdata in datasource:
170             data += newdata
171             sol = 0
172             while True:
173                 eol = string.find(data, "\n", sol)
174                 if eol < 0:
175                     break
176                 yield data[sol:eol+1]
177                 sol = eol+1
178             data = data[sol:]
179         if data != '':
180             yield data
181     def as_manifest(self):
182         if self.size() == 0:
183             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
184                     % (self._stream.name(), self.name()))
185         return string.join(self._stream.tokens_for_range(self._pos, self._size),
186                            " ") + "\n"
187
188 class StreamReader:
189     def __init__(self, tokens):
190         self._tokens = tokens
191         self._current_datablock_data = None
192         self._current_datablock_pos = 0
193         self._current_datablock_index = -1
194         self._pos = 0
195
196         self._stream_name = None
197         self.data_locators = []
198         self.files = []
199
200         for tok in self._tokens:
201             if self._stream_name == None:
202                 self._stream_name = tok
203             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
204                 self.data_locators += [tok]
205             elif re.search(r'^\d+:\d+:\S+', tok):
206                 pos, size, name = tok.split(':',2)
207                 self.files += [[int(pos), int(size), name]]
208             else:
209                 raise Exception("Invalid manifest format")
210     def tokens_for_range(self, range_start, range_size):
211         resp = [self._stream_name]
212         return_all_tokens = False
213         block_start = 0
214         token_bytes_skipped = 0
215         for locator in self.data_locators:
216             sizehint = re.search(r'\+(\d+)', locator)
217             if not sizehint:
218                 return_all_tokens = True
219             if return_all_tokens:
220                 resp += [locator]
221                 next
222             blocksize = int(sizehint.group(0))
223             if range_start + range_size <= block_start:
224                 break
225             if range_start < block_start + blocksize:
226                 resp += [locator]
227             else:
228                 token_bytes_skipped += blocksize
229             block_start += blocksize
230         for f in self.files:
231             if ((f[0] < range_start + range_size)
232                 and
233                 (f[0] + f[1] > range_start)
234                 and
235                 f[1] > 0):
236                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
237         return resp
238     def name(self):
239         return self._stream_name
240     def all_files(self):
241         for f in self.files:
242             pos, size, name = f
243             yield StreamFileReader(self, pos, size, name)
244     def nextdatablock(self):
245         if self._current_datablock_index < 0:
246             self._current_datablock_pos = 0
247             self._current_datablock_index = 0
248         else:
249             self._current_datablock_pos += self.current_datablock_size()
250             self._current_datablock_index += 1
251         self._current_datablock_data = None
252     def current_datablock_data(self):
253         if self._current_datablock_data == None:
254             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
255         return self._current_datablock_data
256     def current_datablock_size(self):
257         if self._current_datablock_index < 0:
258             self.nextdatablock()
259         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
260         if sizehint:
261             return int(sizehint.group(0))
262         return len(self.current_datablock_data())
263     def seek(self, pos):
264         """Set the position of the next read operation."""
265         self._pos = pos
266     def really_seek(self):
267         """Find and load the appropriate data block, so the byte at
268         _pos is in memory.
269         """
270         if self._pos == self._current_datablock_pos:
271             return True
272         if (self._current_datablock_pos != None and
273             self._pos >= self._current_datablock_pos and
274             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
275             return True
276         if self._pos < self._current_datablock_pos:
277             self._current_datablock_index = -1
278             self.nextdatablock()
279         while (self._pos > self._current_datablock_pos and
280                self._pos > self._current_datablock_pos + self.current_datablock_size()):
281             self.nextdatablock()
282     def read(self, size):
283         """Read no more than size bytes -- but at least one byte,
284         unless _pos is already at the end of the stream.
285         """
286         if size == 0:
287             return ''
288         self.really_seek()
289         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
290             self.nextdatablock()
291             if self._current_datablock_index >= len(self.data_locators):
292                 return None
293         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
294         self._pos += len(data)
295         return data
296
297 class CollectionReader:
298     def __init__(self, manifest_locator_or_text):
299         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
300             self._manifest_text = manifest_locator_or_text
301             self._manifest_locator = None
302         else:
303             self._manifest_locator = manifest_locator_or_text
304             self._manifest_text = None
305         self._streams = None
306     def __enter__(self):
307         pass
308     def __exit__(self):
309         pass
310     def _populate(self):
311         if self._streams != None:
312             return
313         if not self._manifest_text:
314             self._manifest_text = Keep.get(self._manifest_locator)
315         self._streams = []
316         for stream_line in self._manifest_text.split("\n"):
317             stream_tokens = stream_line.split()
318             self._streams += [stream_tokens]
319     def all_streams(self):
320         self._populate()
321         resp = []
322         for s in self._streams:
323             resp += [StreamReader(s)]
324         return resp
325     def all_files(self):
326         for s in self.all_streams():
327             for f in s.all_files():
328                 yield f
329
330 class CollectionWriter:
331     KEEP_BLOCK_SIZE = 2**26
332     def __init__(self):
333         self._data_buffer = ''
334         self._current_stream_files = []
335         self._current_stream_length = 0
336         self._current_stream_locators = []
337         self._current_stream_name = '.'
338         self._current_file_name = None
339         self._current_file_pos = 0
340         self._finished_streams = []
341     def __enter__(self):
342         pass
343     def __exit__(self):
344         self.finish()
345     def write(self, newdata):
346         self._data_buffer += newdata
347         self._current_stream_length += len(newdata)
348         while len(self._data_buffer) >= self.KEEP_BLOCK_SIZE:
349             self.flush_data()
350     def flush_data(self):
351         if self._data_buffer != '':
352             self._current_stream_locators += [Keep.put(self._data_buffer[0:self.KEEP_BLOCK_SIZE])]
353             self._data_buffer = self._data_buffer[self.KEEP_BLOCK_SIZE:]
354     def start_new_file(self, newfilename=None):
355         self.finish_current_file()
356         self.set_current_file_name(newfilename)
357     def set_current_file_name(self, newfilename):
358         if re.search(r'[ \t\n]', newfilename):
359             raise AssertionError("Manifest filenames cannot contain whitespace")
360         self._current_file_name = newfilename
361     def current_file_name(self):
362         return self._current_file_name
363     def finish_current_file(self):
364         if self._current_file_name == None:
365             if self._current_file_pos == self._current_stream_length:
366                 return
367             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
368         self._current_stream_files += [[self._current_file_pos,
369                                        self._current_stream_length - self._current_file_pos,
370                                        self._current_file_name]]
371         self._current_file_pos = self._current_stream_length
372     def start_new_stream(self, newstreamname=None):
373         self.finish_current_stream()
374         self.set_current_stream_name(newstreamname)
375     def set_current_stream_name(self, newstreamname):
376         if re.search(r'[ \t\n]', newstreamname):
377             raise AssertionError("Manifest stream names cannot contain whitespace")
378         self._current_stream_name = newstreamname
379     def current_stream_name(self):
380         return self._current_stream_name
381     def finish_current_stream(self):
382         self.finish_current_file()
383         self.flush_data()
384         if len(self._current_stream_files) == 0:
385             pass
386         elif self._current_stream_name == None:
387             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
388         else:
389             self._finished_streams += [[self._current_stream_name,
390                                        self._current_stream_locators,
391                                        self._current_stream_files]]
392         self._current_stream_files = []
393         self._current_stream_length = 0
394         self._current_stream_locators = []
395         self._current_stream_name = None
396         self._current_file_pos = 0
397         self._current_file_name = None
398     def finish(self):
399         return Keep.put(self.manifest_text())
400     def manifest_text(self):
401         self.finish_current_stream()
402         manifest = ''
403         for stream in self._finished_streams:
404             manifest += stream[0]
405             if len(stream[1]) == 0:
406                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
407             else:
408                 for locator in stream[1]:
409                     manifest += " %s" % locator
410             for sfile in stream[2]:
411                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
412             manifest += "\n"
413         return manifest
414
415 class Keep:
416     @staticmethod
417     def put(data):
418         if 'KEEP_LOCAL_STORE' in os.environ:
419             return Keep.local_store_put(data)
420         p = subprocess.Popen(["whput", "-"],
421                              stdout=subprocess.PIPE,
422                              stdin=subprocess.PIPE,
423                              stderr=subprocess.PIPE,
424                              shell=False, close_fds=True)
425         stdoutdata, stderrdata = p.communicate(data)
426         if p.returncode != 0:
427             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
428         return stdoutdata.rstrip()
429     @staticmethod
430     def get(locator):
431         if 'KEEP_LOCAL_STORE' in os.environ:
432             return Keep.local_store_get(locator)
433         p = subprocess.Popen(["whget", locator, "-"],
434                              stdout=subprocess.PIPE,
435                              stdin=None,
436                              stderr=subprocess.PIPE,
437                              shell=False, close_fds=True)
438         stdoutdata, stderrdata = p.communicate(None)
439         if p.returncode != 0:
440             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
441         m = hashlib.new('md5')
442         m.update(stdoutdata)
443         try:
444             if locator.index(m.hexdigest()) == 0:
445                 return stdoutdata
446         except ValueError:
447             pass
448         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
449     @staticmethod
450     def local_store_put(data):
451         m = hashlib.new('md5')
452         m.update(data)
453         md5 = m.hexdigest()
454         locator = '%s+%d' % (md5, len(data))
455         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
456             f.write(data)
457         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
458                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
459         return locator
460     @staticmethod
461     def local_store_get(locator):
462         r = re.search('^([0-9a-f]{32,})', locator)
463         if not r:
464             raise Exception("Keep.get: invalid data locator '%s'" % locator)
465         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
466             return ''
467         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
468             return f.read()