fixes and docs for testing crunch jobs locally
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16
17 from apiclient import errors
18 from apiclient.discovery import build
19
20 class CredentialsFromEnv:
21     @staticmethod
22     def http_request(self, uri, **kwargs):
23         from httplib import BadStatusLine
24         if 'headers' not in kwargs:
25             kwargs['headers'] = {}
26         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
27         try:
28             return self.orig_http_request(uri, **kwargs)
29         except BadStatusLine:
30             # This is how httplib tells us that it tried to reuse an
31             # existing connection but it was already closed by the
32             # server. In that case, yes, we would like to retry.
33             # Unfortunately, we are not absolutely certain that the
34             # previous call did not succeed, so this is slightly
35             # risky.
36             return self.orig_http_request(uri, **kwargs)
37     def authorize(self, http):
38         http.orig_http_request = http.request
39         http.request = types.MethodType(self.http_request, http)
40         return http
41
42 url = ('https://%s/discovery/v1/apis/'
43        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
44 credentials = CredentialsFromEnv()
45 http = httplib2.Http()
46 http = credentials.authorize(http)
47 http.disable_ssl_certificate_validation=True
48 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
49
50 def task_set_output(self,s):
51     service.job_tasks().update(uuid=self['uuid'],
52                                job_task=json.dumps({
53                 'output':s,
54                 'success':True,
55                 'progress':1.0
56                 })).execute()
57
58 _current_task = None
59 def current_task():
60     global _current_task
61     if _current_task:
62         return _current_task
63     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
64     t = UserDict.UserDict(t)
65     t.set_output = types.MethodType(task_set_output, t)
66     t.tmpdir = os.environ['TASK_TMPDIR']
67     _current_task = t
68     return t
69
70 _current_job = None
71 def current_job():
72     global _current_job
73     if _current_job:
74         return _current_job
75     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
76     t = UserDict.UserDict(t)
77     t.tmpdir = os.environ['CRUNCH_WORK']
78     _current_job = t
79     return t
80
81 def api():
82     return service
83
84 class JobTask:
85     def __init__(self, parameters=dict(), resource_limits=dict()):
86         print "init jobtask %s %s" % (parameters, resource_limits)
87
88 class job_setup:
89     @staticmethod
90     def one_task_per_input_file(if_sequence=0, and_end_task=True):
91         if if_sequence != current_task()['sequence']:
92             return
93         job_input = current_job()['script_parameters']['input']
94         cr = CollectionReader(job_input)
95         for s in cr.all_streams():
96             for f in s.all_files():
97                 task_input = f.as_manifest()
98                 new_task_attrs = {
99                     'job_uuid': current_job()['uuid'],
100                     'created_by_job_task': current_task()['uuid'],
101                     'sequence': if_sequence + 1,
102                     'parameters': {
103                         'input':task_input
104                         }
105                     }
106                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
107         if and_end_task:
108             service.job_tasks().update(uuid=current_task()['uuid'],
109                                        job_task=json.dumps({'success':True})
110                                        ).execute()
111             exit(0)
112
113 class util:
114     @staticmethod
115     def run_command(execargs, **kwargs):
116         p = subprocess.Popen(execargs, close_fds=True, shell=False,
117                              stdin=subprocess.PIPE,
118                              stdout=subprocess.PIPE,
119                              stderr=subprocess.PIPE,
120                              **kwargs)
121         stdoutdata, stderrdata = p.communicate(None)
122         if p.returncode != 0:
123             raise Exception("run_command %s exit %d:\n%s" %
124                             (execargs, p.returncode, stderrdata))
125         return stdoutdata, stderrdata
126
127     @staticmethod
128     def git_checkout(url, version, path):
129         if not re.search('^/', path):
130             path = os.path.join(current_job().tmpdir, path)
131         if not os.path.exists(path):
132             util.run_command(["git", "clone", url, path],
133                              cwd=os.path.dirname(path))
134         util.run_command(["git", "checkout", version],
135                          cwd=path)
136         return path
137
138 class DataReader:
139     def __init__(self, data_locator):
140         self.data_locator = data_locator
141         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
142                                   stdout=subprocess.PIPE,
143                                   stdin=None, stderr=subprocess.PIPE,
144                                   shell=False, close_fds=True)
145     def __enter__(self):
146         pass
147     def __exit__(self):
148         self.close()
149     def read(self, size, **kwargs):
150         return self.p.stdout.read(size, **kwargs)
151     def close(self):
152         self.p.stdout.close()
153         if not self.p.stderr.closed:
154             for err in self.p.stderr:
155                 print >> sys.stderr, err
156             self.p.stderr.close()
157         self.p.wait()
158         if self.p.returncode != 0:
159             raise Exception("whget subprocess exited %d" % self.p.returncode)
160
161 class StreamFileReader:
162     def __init__(self, stream, pos, size, name):
163         self._stream = stream
164         self._pos = pos
165         self._size = size
166         self._name = name
167         self._filepos = 0
168     def name(self):
169         return self._name
170     def decompressed_name(self):
171         return re.sub('\.(bz2|gz)$', '', self._name)
172     def size(self):
173         return self._size
174     def stream_name(self):
175         return self._stream.name()
176     def read(self, size, **kwargs):
177         self._stream.seek(self._pos + self._filepos)
178         data = self._stream.read(min(size, self._size - self._filepos))
179         self._filepos += len(data)
180         return data
181     def readall(self, size, **kwargs):
182         while True:
183             data = self.read(size, **kwargs)
184             if data == '':
185                 break
186             yield data
187     def bunzip2(self, size):
188         decompressor = bz2.BZ2Decompressor()
189         for chunk in self.readall(size):
190             data = decompressor.decompress(chunk)
191             if data and data != '':
192                 yield data
193     def gunzip(self, size):
194         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
195         for chunk in self.readall(size):
196             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
197             if data and data != '':
198                 yield data
199     def readlines(self, decompress=True):
200         self._stream.seek(self._pos + self._filepos)
201         if decompress and re.search('\.bz2$', self._name):
202             datasource = self.bunzip2(2**10)
203         elif decompress and re.search('\.gz$', self._name):
204             datasource = self.gunzip(2**10)
205         else:
206             datasource = self.readall(2**10)
207         data = ''
208         for newdata in datasource:
209             data += newdata
210             sol = 0
211             while True:
212                 eol = string.find(data, "\n", sol)
213                 if eol < 0:
214                     break
215                 yield data[sol:eol+1]
216                 sol = eol+1
217             data = data[sol:]
218         if data != '':
219             yield data
220     def as_manifest(self):
221         if self.size() == 0:
222             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
223                     % (self._stream.name(), self.name()))
224         return string.join(self._stream.tokens_for_range(self._pos, self._size),
225                            " ") + "\n"
226
227 class StreamReader:
228     def __init__(self, tokens):
229         self._tokens = tokens
230         self._current_datablock_data = None
231         self._current_datablock_pos = 0
232         self._current_datablock_index = -1
233         self._pos = 0
234
235         self._stream_name = None
236         self.data_locators = []
237         self.files = []
238
239         for tok in self._tokens:
240             if self._stream_name == None:
241                 self._stream_name = tok
242             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
243                 self.data_locators += [tok]
244             elif re.search(r'^\d+:\d+:\S+', tok):
245                 pos, size, name = tok.split(':',2)
246                 self.files += [[int(pos), int(size), name]]
247             else:
248                 raise Exception("Invalid manifest format")
249     def tokens_for_range(self, range_start, range_size):
250         resp = [self._stream_name]
251         return_all_tokens = False
252         block_start = 0
253         token_bytes_skipped = 0
254         for locator in self.data_locators:
255             sizehint = re.search(r'\+(\d+)', locator)
256             if not sizehint:
257                 return_all_tokens = True
258             if return_all_tokens:
259                 resp += [locator]
260                 next
261             blocksize = int(sizehint.group(0))
262             if range_start + range_size <= block_start:
263                 break
264             if range_start < block_start + blocksize:
265                 resp += [locator]
266             else:
267                 token_bytes_skipped += blocksize
268             block_start += blocksize
269         for f in self.files:
270             if ((f[0] < range_start + range_size)
271                 and
272                 (f[0] + f[1] > range_start)
273                 and
274                 f[1] > 0):
275                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
276         return resp
277     def name(self):
278         return self._stream_name
279     def all_files(self):
280         for f in self.files:
281             pos, size, name = f
282             yield StreamFileReader(self, pos, size, name)
283     def nextdatablock(self):
284         if self._current_datablock_index < 0:
285             self._current_datablock_pos = 0
286             self._current_datablock_index = 0
287         else:
288             self._current_datablock_pos += self.current_datablock_size()
289             self._current_datablock_index += 1
290         self._current_datablock_data = None
291     def current_datablock_data(self):
292         if self._current_datablock_data == None:
293             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
294         return self._current_datablock_data
295     def current_datablock_size(self):
296         if self._current_datablock_index < 0:
297             self.nextdatablock()
298         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
299         if sizehint:
300             return int(sizehint.group(0))
301         return len(self.current_datablock_data())
302     def seek(self, pos):
303         """Set the position of the next read operation."""
304         self._pos = pos
305     def really_seek(self):
306         """Find and load the appropriate data block, so the byte at
307         _pos is in memory.
308         """
309         if self._pos == self._current_datablock_pos:
310             return True
311         if (self._current_datablock_pos != None and
312             self._pos >= self._current_datablock_pos and
313             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
314             return True
315         if self._pos < self._current_datablock_pos:
316             self._current_datablock_index = -1
317             self.nextdatablock()
318         while (self._pos > self._current_datablock_pos and
319                self._pos > self._current_datablock_pos + self.current_datablock_size()):
320             self.nextdatablock()
321     def read(self, size):
322         """Read no more than size bytes -- but at least one byte,
323         unless _pos is already at the end of the stream.
324         """
325         if size == 0:
326             return ''
327         self.really_seek()
328         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
329             self.nextdatablock()
330             if self._current_datablock_index >= len(self.data_locators):
331                 return None
332         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
333         self._pos += len(data)
334         return data
335
336 class CollectionReader:
337     def __init__(self, manifest_locator_or_text):
338         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
339             self._manifest_text = manifest_locator_or_text
340             self._manifest_locator = None
341         else:
342             self._manifest_locator = manifest_locator_or_text
343             self._manifest_text = None
344         self._streams = None
345     def __enter__(self):
346         pass
347     def __exit__(self):
348         pass
349     def _populate(self):
350         if self._streams != None:
351             return
352         if not self._manifest_text:
353             self._manifest_text = Keep.get(self._manifest_locator)
354         self._streams = []
355         for stream_line in self._manifest_text.split("\n"):
356             stream_tokens = stream_line.split()
357             self._streams += [stream_tokens]
358     def all_streams(self):
359         self._populate()
360         resp = []
361         for s in self._streams:
362             resp += [StreamReader(s)]
363         return resp
364     def all_files(self):
365         for s in self.all_streams():
366             for f in s.all_files():
367                 yield f
368
369 class CollectionWriter:
370     KEEP_BLOCK_SIZE = 2**26
371     def __init__(self):
372         self._data_buffer = []
373         self._data_buffer_len = 0
374         self._current_stream_files = []
375         self._current_stream_length = 0
376         self._current_stream_locators = []
377         self._current_stream_name = '.'
378         self._current_file_name = None
379         self._current_file_pos = 0
380         self._finished_streams = []
381     def __enter__(self):
382         pass
383     def __exit__(self):
384         self.finish()
385     def write(self, newdata):
386         self._data_buffer += [newdata]
387         self._data_buffer_len += len(newdata)
388         self._current_stream_length += len(newdata)
389         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
390             self.flush_data()
391     def flush_data(self):
392         data_buffer = ''.join(self._data_buffer)
393         if data_buffer != '':
394             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
395             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
396     def start_new_file(self, newfilename=None):
397         self.finish_current_file()
398         self.set_current_file_name(newfilename)
399     def set_current_file_name(self, newfilename):
400         if re.search(r'[ \t\n]', newfilename):
401             raise AssertionError("Manifest filenames cannot contain whitespace")
402         self._current_file_name = newfilename
403     def current_file_name(self):
404         return self._current_file_name
405     def finish_current_file(self):
406         if self._current_file_name == None:
407             if self._current_file_pos == self._current_stream_length:
408                 return
409             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
410         self._current_stream_files += [[self._current_file_pos,
411                                        self._current_stream_length - self._current_file_pos,
412                                        self._current_file_name]]
413         self._current_file_pos = self._current_stream_length
414     def start_new_stream(self, newstreamname=None):
415         self.finish_current_stream()
416         self.set_current_stream_name(newstreamname)
417     def set_current_stream_name(self, newstreamname):
418         if re.search(r'[ \t\n]', newstreamname):
419             raise AssertionError("Manifest stream names cannot contain whitespace")
420         self._current_stream_name = newstreamname
421     def current_stream_name(self):
422         return self._current_stream_name
423     def finish_current_stream(self):
424         self.finish_current_file()
425         self.flush_data()
426         if len(self._current_stream_files) == 0:
427             pass
428         elif self._current_stream_name == None:
429             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
430         else:
431             self._finished_streams += [[self._current_stream_name,
432                                        self._current_stream_locators,
433                                        self._current_stream_files]]
434         self._current_stream_files = []
435         self._current_stream_length = 0
436         self._current_stream_locators = []
437         self._current_stream_name = None
438         self._current_file_pos = 0
439         self._current_file_name = None
440     def finish(self):
441         return Keep.put(self.manifest_text())
442     def manifest_text(self):
443         self.finish_current_stream()
444         manifest = ''
445         for stream in self._finished_streams:
446             manifest += stream[0]
447             if len(stream[1]) == 0:
448                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
449             else:
450                 for locator in stream[1]:
451                     manifest += " %s" % locator
452             for sfile in stream[2]:
453                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
454             manifest += "\n"
455         return manifest
456
457 class Keep:
458     @staticmethod
459     def put(data):
460         if 'KEEP_LOCAL_STORE' in os.environ:
461             return Keep.local_store_put(data)
462         p = subprocess.Popen(["whput", "-"],
463                              stdout=subprocess.PIPE,
464                              stdin=subprocess.PIPE,
465                              stderr=subprocess.PIPE,
466                              shell=False, close_fds=True)
467         stdoutdata, stderrdata = p.communicate(data)
468         if p.returncode != 0:
469             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
470         return stdoutdata.rstrip()
471     @staticmethod
472     def get(locator):
473         if 'KEEP_LOCAL_STORE' in os.environ:
474             return Keep.local_store_get(locator)
475         p = subprocess.Popen(["whget", locator, "-"],
476                              stdout=subprocess.PIPE,
477                              stdin=None,
478                              stderr=subprocess.PIPE,
479                              shell=False, close_fds=True)
480         stdoutdata, stderrdata = p.communicate(None)
481         if p.returncode != 0:
482             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
483         m = hashlib.new('md5')
484         m.update(stdoutdata)
485         try:
486             if locator.index(m.hexdigest()) == 0:
487                 return stdoutdata
488         except ValueError:
489             pass
490         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
491     @staticmethod
492     def local_store_put(data):
493         m = hashlib.new('md5')
494         m.update(data)
495         md5 = m.hexdigest()
496         locator = '%s+%d' % (md5, len(data))
497         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
498             f.write(data)
499         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
500                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
501         return locator
502     @staticmethod
503     def local_store_get(locator):
504         r = re.search('^([0-9a-f]{32,})', locator)
505         if not r:
506             raise Exception("Keep.get: invalid data locator '%s'" % locator)
507         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
508             return ''
509         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
510             return f.read()