add CollectionReader.manifest_text()
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114 class util:
115     @staticmethod
116     def run_command(execargs, **kwargs):
117         p = subprocess.Popen(execargs, close_fds=True, shell=False,
118                              stdin=subprocess.PIPE,
119                              stdout=subprocess.PIPE,
120                              stderr=subprocess.PIPE,
121                              **kwargs)
122         stdoutdata, stderrdata = p.communicate(None)
123         if p.returncode != 0:
124             raise Exception("run_command %s exit %d:\n%s" %
125                             (execargs, p.returncode, stderrdata))
126         return stdoutdata, stderrdata
127
128     @staticmethod
129     def git_checkout(url, version, path):
130         if not re.search('^/', path):
131             path = os.path.join(current_job().tmpdir, path)
132         if not os.path.exists(path):
133             util.run_command(["git", "clone", url, path],
134                              cwd=os.path.dirname(path))
135         util.run_command(["git", "checkout", version],
136                          cwd=path)
137         return path
138
139     @staticmethod
140     def tarball_extract(tarball, path):
141         """Retrieve a tarball from Keep and extract it to a local
142         directory.  Return the absolute path where the tarball was
143         extracted. If the top level of the tarball contained just one
144         file or directory, return the absolute path of that single
145         item.
146
147         tarball -- collection locator
148         path -- where to extract the tarball: absolute, or relative to job tmp
149         """
150         if not re.search('^/', path):
151             path = os.path.join(current_job().tmpdir, path)
152         lockfile = open(path + '.lock', 'w')
153         fcntl.flock(lockfile, fcntl.LOCK_EX)
154         try:
155             os.stat(path)
156         except OSError:
157             os.mkdir(path)
158         already_have_it = False
159         try:
160             if os.readlink(os.path.join(path, '.locator')) == tarball:
161                 already_have_it = True
162         except OSError:
163             pass
164         if not already_have_it:
165
166             # emulate "rm -f" (i.e., if the file does not exist, we win)
167             try:
168                 os.unlink(os.path.join(path, '.locator'))
169             except OSError:
170                 if os.path.exists(os.path.join(path, '.locator')):
171                     os.unlink(os.path.join(path, '.locator'))
172
173             for f in CollectionReader(tarball).all_files():
174                 decompress_flag = ''
175                 if re.search('\.(tbz|tar.bz2)$', f.name()):
176                     decompress_flag = 'j'
177                 elif re.search('\.(tgz|tar.gz)$', f.name()):
178                     decompress_flag = 'z'
179                 p = subprocess.Popen(["tar",
180                                       "-C", path,
181                                       ("-x%sf" % decompress_flag),
182                                       "-"],
183                                      stdout=None,
184                                      stdin=subprocess.PIPE, stderr=sys.stderr,
185                                      shell=False, close_fds=True)
186                 while True:
187                     buf = f.read(2**20)
188                     if len(buf) == 0:
189                         break
190                     p.stdin.write(buf)
191                 p.stdin.close()
192                 p.wait()
193                 if p.returncode != 0:
194                     lockfile.close()
195                     raise Exception("tar exited %d" % p.returncode)
196             os.symlink(tarball, os.path.join(path, '.locator'))
197         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
198         lockfile.close()
199         if len(tld_extracts) == 1:
200             return os.path.join(path, tld_extracts[0])
201         return path
202
203     @staticmethod
204     def collection_extract(collection, path, files=[], decompress=True):
205         """Retrieve a collection from Keep and extract it to a local
206         directory.  Return the absolute path where the collection was
207         extracted.
208
209         collection -- collection locator
210         path -- where to extract: absolute, or relative to job tmp
211         """
212         if not re.search('^/', path):
213             path = os.path.join(current_job().tmpdir, path)
214         lockfile = open(path + '.lock', 'w')
215         fcntl.flock(lockfile, fcntl.LOCK_EX)
216         try:
217             os.stat(path)
218         except OSError:
219             os.mkdir(path)
220         already_have_it = False
221         try:
222             if os.readlink(os.path.join(path, '.locator')) == collection:
223                 already_have_it = True
224         except OSError:
225             pass
226
227         # emulate "rm -f" (i.e., if the file does not exist, we win)
228         files_got = []
229         try:
230             os.unlink(os.path.join(path, '.locator'))
231         except OSError:
232             if os.path.exists(os.path.join(path, '.locator')):
233                 os.unlink(os.path.join(path, '.locator'))
234
235         for f in CollectionReader(collection).all_files():
236             if (files == [] or
237                 ((f.name() not in files_got) and
238                  (f.name() in files or
239                   (decompress and f.decompressed_name() in files)))):
240                 outname = f.decompressed_name() if decompress else f.name()
241                 files_got += [outname]
242                 if os.path.exists(os.path.join(path, outname)):
243                     continue
244                 outfile = open(os.path.join(path, outname), 'w')
245                 for buf in (f.readall_decompressed() if decompress
246                             else f.readall()):
247                     outfile.write(buf)
248                 outfile.close()
249         if len(files_got) < len(files):
250             raise Exception("Wanted files %s but only got %s from %s" % (files, files_got, map(lambda z: z.name(), list(CollectionReader(collection).all_files()))))
251         os.symlink(collection, os.path.join(path, '.locator'))
252
253         lockfile.close()
254         return path
255
256 class DataReader:
257     def __init__(self, data_locator):
258         self.data_locator = data_locator
259         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
260                                   stdout=subprocess.PIPE,
261                                   stdin=None, stderr=subprocess.PIPE,
262                                   shell=False, close_fds=True)
263     def __enter__(self):
264         pass
265     def __exit__(self):
266         self.close()
267     def read(self, size, **kwargs):
268         return self.p.stdout.read(size, **kwargs)
269     def close(self):
270         self.p.stdout.close()
271         if not self.p.stderr.closed:
272             for err in self.p.stderr:
273                 print >> sys.stderr, err
274             self.p.stderr.close()
275         self.p.wait()
276         if self.p.returncode != 0:
277             raise Exception("whget subprocess exited %d" % self.p.returncode)
278
279 class StreamFileReader:
280     def __init__(self, stream, pos, size, name):
281         self._stream = stream
282         self._pos = pos
283         self._size = size
284         self._name = name
285         self._filepos = 0
286     def name(self):
287         return self._name
288     def decompressed_name(self):
289         return re.sub('\.(bz2|gz)$', '', self._name)
290     def size(self):
291         return self._size
292     def stream_name(self):
293         return self._stream.name()
294     def read(self, size, **kwargs):
295         self._stream.seek(self._pos + self._filepos)
296         data = self._stream.read(min(size, self._size - self._filepos))
297         self._filepos += len(data)
298         return data
299     def readall(self, size=2**20, **kwargs):
300         while True:
301             data = self.read(size, **kwargs)
302             if data == '':
303                 break
304             yield data
305     def bunzip2(self, size):
306         decompressor = bz2.BZ2Decompressor()
307         for chunk in self.readall(size):
308             data = decompressor.decompress(chunk)
309             if data and data != '':
310                 yield data
311     def gunzip(self, size):
312         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
313         for chunk in self.readall(size):
314             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
315             if data and data != '':
316                 yield data
317     def readall_decompressed(self, size=2**20):
318         self._stream.seek(self._pos + self._filepos)
319         if re.search('\.bz2$', self._name):
320             return self.bunzip2(size)
321         elif re.search('\.gz$', self._name):
322             return self.gunzip(size)
323         else:
324             return self.readall(size)
325     def readlines(self, decompress=True):
326         if decompress:
327             datasource = self.readall_decompressed()
328         else:
329             self._stream.seek(self._pos + self._filepos)
330             datasource = self.readall()
331         data = ''
332         for newdata in datasource:
333             data += newdata
334             sol = 0
335             while True:
336                 eol = string.find(data, "\n", sol)
337                 if eol < 0:
338                     break
339                 yield data[sol:eol+1]
340                 sol = eol+1
341             data = data[sol:]
342         if data != '':
343             yield data
344     def as_manifest(self):
345         if self.size() == 0:
346             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
347                     % (self._stream.name(), self.name()))
348         return string.join(self._stream.tokens_for_range(self._pos, self._size),
349                            " ") + "\n"
350
351 class StreamReader:
352     def __init__(self, tokens):
353         self._tokens = tokens
354         self._current_datablock_data = None
355         self._current_datablock_pos = 0
356         self._current_datablock_index = -1
357         self._pos = 0
358
359         self._stream_name = None
360         self.data_locators = []
361         self.files = []
362
363         for tok in self._tokens:
364             if self._stream_name == None:
365                 self._stream_name = tok
366             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
367                 self.data_locators += [tok]
368             elif re.search(r'^\d+:\d+:\S+', tok):
369                 pos, size, name = tok.split(':',2)
370                 self.files += [[int(pos), int(size), name]]
371             else:
372                 raise Exception("Invalid manifest format")
373     def tokens_for_range(self, range_start, range_size):
374         resp = [self._stream_name]
375         return_all_tokens = False
376         block_start = 0
377         token_bytes_skipped = 0
378         for locator in self.data_locators:
379             sizehint = re.search(r'\+(\d+)', locator)
380             if not sizehint:
381                 return_all_tokens = True
382             if return_all_tokens:
383                 resp += [locator]
384                 next
385             blocksize = int(sizehint.group(0))
386             if range_start + range_size <= block_start:
387                 break
388             if range_start < block_start + blocksize:
389                 resp += [locator]
390             else:
391                 token_bytes_skipped += blocksize
392             block_start += blocksize
393         for f in self.files:
394             if ((f[0] < range_start + range_size)
395                 and
396                 (f[0] + f[1] > range_start)
397                 and
398                 f[1] > 0):
399                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
400         return resp
401     def name(self):
402         return self._stream_name
403     def all_files(self):
404         for f in self.files:
405             pos, size, name = f
406             yield StreamFileReader(self, pos, size, name)
407     def nextdatablock(self):
408         if self._current_datablock_index < 0:
409             self._current_datablock_pos = 0
410             self._current_datablock_index = 0
411         else:
412             self._current_datablock_pos += self.current_datablock_size()
413             self._current_datablock_index += 1
414         self._current_datablock_data = None
415     def current_datablock_data(self):
416         if self._current_datablock_data == None:
417             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
418         return self._current_datablock_data
419     def current_datablock_size(self):
420         if self._current_datablock_index < 0:
421             self.nextdatablock()
422         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
423         if sizehint:
424             return int(sizehint.group(0))
425         return len(self.current_datablock_data())
426     def seek(self, pos):
427         """Set the position of the next read operation."""
428         self._pos = pos
429     def really_seek(self):
430         """Find and load the appropriate data block, so the byte at
431         _pos is in memory.
432         """
433         if self._pos == self._current_datablock_pos:
434             return True
435         if (self._current_datablock_pos != None and
436             self._pos >= self._current_datablock_pos and
437             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
438             return True
439         if self._pos < self._current_datablock_pos:
440             self._current_datablock_index = -1
441             self.nextdatablock()
442         while (self._pos > self._current_datablock_pos and
443                self._pos > self._current_datablock_pos + self.current_datablock_size()):
444             self.nextdatablock()
445     def read(self, size):
446         """Read no more than size bytes -- but at least one byte,
447         unless _pos is already at the end of the stream.
448         """
449         if size == 0:
450             return ''
451         self.really_seek()
452         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
453             self.nextdatablock()
454             if self._current_datablock_index >= len(self.data_locators):
455                 return None
456         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
457         self._pos += len(data)
458         return data
459
460 class CollectionReader:
461     def __init__(self, manifest_locator_or_text):
462         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
463             self._manifest_text = manifest_locator_or_text
464             self._manifest_locator = None
465         else:
466             self._manifest_locator = manifest_locator_or_text
467             self._manifest_text = None
468         self._streams = None
469     def __enter__(self):
470         pass
471     def __exit__(self):
472         pass
473     def _populate(self):
474         if self._streams != None:
475             return
476         if not self._manifest_text:
477             self._manifest_text = Keep.get(self._manifest_locator)
478         self._streams = []
479         for stream_line in self._manifest_text.split("\n"):
480             stream_tokens = stream_line.split()
481             self._streams += [stream_tokens]
482     def all_streams(self):
483         self._populate()
484         resp = []
485         for s in self._streams:
486             resp += [StreamReader(s)]
487         return resp
488     def all_files(self):
489         for s in self.all_streams():
490             for f in s.all_files():
491                 yield f
492     def manifest_text(self):
493         self._populate()
494         return self._manifest_text
495
496 class CollectionWriter:
497     KEEP_BLOCK_SIZE = 2**26
498     def __init__(self):
499         self._data_buffer = []
500         self._data_buffer_len = 0
501         self._current_stream_files = []
502         self._current_stream_length = 0
503         self._current_stream_locators = []
504         self._current_stream_name = '.'
505         self._current_file_name = None
506         self._current_file_pos = 0
507         self._finished_streams = []
508     def __enter__(self):
509         pass
510     def __exit__(self):
511         self.finish()
512     def write(self, newdata):
513         self._data_buffer += [newdata]
514         self._data_buffer_len += len(newdata)
515         self._current_stream_length += len(newdata)
516         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
517             self.flush_data()
518     def flush_data(self):
519         data_buffer = ''.join(self._data_buffer)
520         if data_buffer != '':
521             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
522             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
523     def start_new_file(self, newfilename=None):
524         self.finish_current_file()
525         self.set_current_file_name(newfilename)
526     def set_current_file_name(self, newfilename):
527         if re.search(r'[ \t\n]', newfilename):
528             raise AssertionError("Manifest filenames cannot contain whitespace")
529         self._current_file_name = newfilename
530     def current_file_name(self):
531         return self._current_file_name
532     def finish_current_file(self):
533         if self._current_file_name == None:
534             if self._current_file_pos == self._current_stream_length:
535                 return
536             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
537         self._current_stream_files += [[self._current_file_pos,
538                                        self._current_stream_length - self._current_file_pos,
539                                        self._current_file_name]]
540         self._current_file_pos = self._current_stream_length
541     def start_new_stream(self, newstreamname=None):
542         self.finish_current_stream()
543         self.set_current_stream_name(newstreamname)
544     def set_current_stream_name(self, newstreamname):
545         if re.search(r'[ \t\n]', newstreamname):
546             raise AssertionError("Manifest stream names cannot contain whitespace")
547         self._current_stream_name = newstreamname
548     def current_stream_name(self):
549         return self._current_stream_name
550     def finish_current_stream(self):
551         self.finish_current_file()
552         self.flush_data()
553         if len(self._current_stream_files) == 0:
554             pass
555         elif self._current_stream_name == None:
556             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
557         else:
558             self._finished_streams += [[self._current_stream_name,
559                                        self._current_stream_locators,
560                                        self._current_stream_files]]
561         self._current_stream_files = []
562         self._current_stream_length = 0
563         self._current_stream_locators = []
564         self._current_stream_name = None
565         self._current_file_pos = 0
566         self._current_file_name = None
567     def finish(self):
568         return Keep.put(self.manifest_text())
569     def manifest_text(self):
570         self.finish_current_stream()
571         manifest = ''
572         for stream in self._finished_streams:
573             manifest += stream[0]
574             if len(stream[1]) == 0:
575                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
576             else:
577                 for locator in stream[1]:
578                     manifest += " %s" % locator
579             for sfile in stream[2]:
580                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
581             manifest += "\n"
582         return manifest
583
584 class Keep:
585     @staticmethod
586     def put(data):
587         if 'KEEP_LOCAL_STORE' in os.environ:
588             return Keep.local_store_put(data)
589         p = subprocess.Popen(["whput", "-"],
590                              stdout=subprocess.PIPE,
591                              stdin=subprocess.PIPE,
592                              stderr=subprocess.PIPE,
593                              shell=False, close_fds=True)
594         stdoutdata, stderrdata = p.communicate(data)
595         if p.returncode != 0:
596             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
597         return stdoutdata.rstrip()
598     @staticmethod
599     def get(locator):
600         if 'KEEP_LOCAL_STORE' in os.environ:
601             return Keep.local_store_get(locator)
602         p = subprocess.Popen(["whget", locator, "-"],
603                              stdout=subprocess.PIPE,
604                              stdin=None,
605                              stderr=subprocess.PIPE,
606                              shell=False, close_fds=True)
607         stdoutdata, stderrdata = p.communicate(None)
608         if p.returncode != 0:
609             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
610         m = hashlib.new('md5')
611         m.update(stdoutdata)
612         try:
613             if locator.index(m.hexdigest()) == 0:
614                 return stdoutdata
615         except ValueError:
616             pass
617         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
618     @staticmethod
619     def local_store_put(data):
620         m = hashlib.new('md5')
621         m.update(data)
622         md5 = m.hexdigest()
623         locator = '%s+%d' % (md5, len(data))
624         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
625             f.write(data)
626         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
627                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
628         return locator
629     @staticmethod
630     def local_store_get(locator):
631         r = re.search('^([0-9a-f]{32,})', locator)
632         if not r:
633             raise Exception("Keep.get: invalid data locator '%s'" % locator)
634         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
635             return ''
636         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
637             return f.read()