add util functions, fix up tmp dirs
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13 import string
14 import bz2
15 import zlib
16 import fcntl
17
18 from apiclient import errors
19 from apiclient.discovery import build
20
21 class CredentialsFromEnv:
22     @staticmethod
23     def http_request(self, uri, **kwargs):
24         from httplib import BadStatusLine
25         if 'headers' not in kwargs:
26             kwargs['headers'] = {}
27         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
28         try:
29             return self.orig_http_request(uri, **kwargs)
30         except BadStatusLine:
31             # This is how httplib tells us that it tried to reuse an
32             # existing connection but it was already closed by the
33             # server. In that case, yes, we would like to retry.
34             # Unfortunately, we are not absolutely certain that the
35             # previous call did not succeed, so this is slightly
36             # risky.
37             return self.orig_http_request(uri, **kwargs)
38     def authorize(self, http):
39         http.orig_http_request = http.request
40         http.request = types.MethodType(self.http_request, http)
41         return http
42
43 url = ('https://%s/discovery/v1/apis/'
44        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
45 credentials = CredentialsFromEnv()
46 http = httplib2.Http()
47 http = credentials.authorize(http)
48 http.disable_ssl_certificate_validation=True
49 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
50
51 def task_set_output(self,s):
52     service.job_tasks().update(uuid=self['uuid'],
53                                job_task=json.dumps({
54                 'output':s,
55                 'success':True,
56                 'progress':1.0
57                 })).execute()
58
59 _current_task = None
60 def current_task():
61     global _current_task
62     if _current_task:
63         return _current_task
64     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
65     t = UserDict.UserDict(t)
66     t.set_output = types.MethodType(task_set_output, t)
67     t.tmpdir = os.environ['TASK_WORK']
68     _current_task = t
69     return t
70
71 _current_job = None
72 def current_job():
73     global _current_job
74     if _current_job:
75         return _current_job
76     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
77     t = UserDict.UserDict(t)
78     t.tmpdir = os.environ['JOB_WORK']
79     _current_job = t
80     return t
81
82 def api():
83     return service
84
85 class JobTask:
86     def __init__(self, parameters=dict(), resource_limits=dict()):
87         print "init jobtask %s %s" % (parameters, resource_limits)
88
89 class job_setup:
90     @staticmethod
91     def one_task_per_input_file(if_sequence=0, and_end_task=True):
92         if if_sequence != current_task()['sequence']:
93             return
94         job_input = current_job()['script_parameters']['input']
95         cr = CollectionReader(job_input)
96         for s in cr.all_streams():
97             for f in s.all_files():
98                 task_input = f.as_manifest()
99                 new_task_attrs = {
100                     'job_uuid': current_job()['uuid'],
101                     'created_by_job_task_uuid': current_task()['uuid'],
102                     'sequence': if_sequence + 1,
103                     'parameters': {
104                         'input':task_input
105                         }
106                     }
107                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114 class util:
115     @staticmethod
116     def run_command(execargs, **kwargs):
117         p = subprocess.Popen(execargs, close_fds=True, shell=False,
118                              stdin=subprocess.PIPE,
119                              stdout=subprocess.PIPE,
120                              stderr=subprocess.PIPE,
121                              **kwargs)
122         stdoutdata, stderrdata = p.communicate(None)
123         if p.returncode != 0:
124             raise Exception("run_command %s exit %d:\n%s" %
125                             (execargs, p.returncode, stderrdata))
126         return stdoutdata, stderrdata
127
128     @staticmethod
129     def git_checkout(url, version, path):
130         if not re.search('^/', path):
131             path = os.path.join(current_job().tmpdir, path)
132         if not os.path.exists(path):
133             util.run_command(["git", "clone", url, path],
134                              cwd=os.path.dirname(path))
135         util.run_command(["git", "checkout", version],
136                          cwd=path)
137         return path
138
139     @staticmethod
140     def tarball_extract(tarball, path):
141         """Retrieve a tarball from Keep and extract it to a local
142         directory.  Return the absolute path where the tarball was
143         extracted. If the top level of the tarball contained just one
144         file or directory, return the absolute path of that single
145         item.
146
147         tarball -- collection locator
148         path -- where to extract the tarball: absolute, or relative to job tmp
149         """
150         if not re.search('^/', path):
151             path = os.path.join(current_job().tmpdir, path)
152         lockfile = open(path + '.lock', 'w')
153         fcntl.flock(lockfile, fcntl.LOCK_EX)
154         try:
155             os.stat(path)
156         except OSError:
157             os.mkdir(path)
158         already_have_it = False
159         try:
160             if os.readlink(os.path.join(path, '.locator')) == tarball:
161                 already_have_it = True
162         except OSError:
163             pass
164         if not already_have_it:
165
166             # emulate "rm -f" (i.e., if the file does not exist, we win)
167             try:
168                 os.unlink(os.path.join(path, '.locator'))
169             except OSError:
170                 if os.path.exists(os.path.join(path, '.locator')):
171                     os.unlink(os.path.join(path, '.locator'))
172
173             for f in CollectionReader(tarball).all_files():
174                 decompress_flag = ''
175                 if re.search('\.(tbz|tar.bz2)$', f.name()):
176                     decompress_flag = 'j'
177                 elif re.search('\.(tgz|tar.gz)$', f.name()):
178                     decompress_flag = 'z'
179                 p = subprocess.Popen(["tar",
180                                       "-C", path,
181                                       ("-x%sf" % decompress_flag),
182                                       "-"],
183                                      stdout=None,
184                                      stdin=subprocess.PIPE, stderr=sys.stderr,
185                                      shell=False, close_fds=True)
186                 while True:
187                     buf = f.read(2**20)
188                     if len(buf) == 0:
189                         break
190                     p.stdin.write(buf)
191                 p.stdin.close()
192                 p.wait()
193                 if p.returncode != 0:
194                     lockfile.close()
195                     raise Exception("tar exited %d" % p.returncode)
196             os.symlink(tarball, os.path.join(path, '.locator'))
197         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
198         lockfile.close()
199         if len(tld_extracts) == 1:
200             return os.path.join(path, tld_extracts[0])
201         return path
202
203     @staticmethod
204     def collection_extract(collection, path, files=[]):
205         """Retrieve a collection from Keep and extract it to a local
206         directory.  Return the absolute path where the collection was
207         extracted.
208
209         collection -- collection locator
210         path -- where to extract: absolute, or relative to job tmp
211         """
212         if not re.search('^/', path):
213             path = os.path.join(current_job().tmpdir, path)
214         lockfile = open(path + '.lock', 'w')
215         fcntl.flock(lockfile, fcntl.LOCK_EX)
216         try:
217             os.stat(path)
218         except OSError:
219             os.mkdir(path)
220         already_have_it = False
221         try:
222             if os.readlink(os.path.join(path, '.locator')) == collection:
223                 already_have_it = True
224         except OSError:
225             pass
226         if not already_have_it:
227             # emulate "rm -f" (i.e., if the file does not exist, we win)
228             try:
229                 os.unlink(os.path.join(path, '.locator'))
230             except OSError:
231                 if os.path.exists(os.path.join(path, '.locator')):
232                     os.unlink(os.path.join(path, '.locator'))
233
234             for f in CollectionReader(collection).all_files():
235                 if files == [] or f.name() in files:
236                     outfile = open(os.path.join(path, f.name()), 'w')
237                     while True:
238                         buf = f.read(2**20)
239                         if len(buf) == 0:
240                             break
241                         outfile.write(buf)
242                     outfile.close()
243             os.symlink(collection, os.path.join(path, '.locator'))
244         lockfile.close()
245         return path
246
247 class DataReader:
248     def __init__(self, data_locator):
249         self.data_locator = data_locator
250         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
251                                   stdout=subprocess.PIPE,
252                                   stdin=None, stderr=subprocess.PIPE,
253                                   shell=False, close_fds=True)
254     def __enter__(self):
255         pass
256     def __exit__(self):
257         self.close()
258     def read(self, size, **kwargs):
259         return self.p.stdout.read(size, **kwargs)
260     def close(self):
261         self.p.stdout.close()
262         if not self.p.stderr.closed:
263             for err in self.p.stderr:
264                 print >> sys.stderr, err
265             self.p.stderr.close()
266         self.p.wait()
267         if self.p.returncode != 0:
268             raise Exception("whget subprocess exited %d" % self.p.returncode)
269
270 class StreamFileReader:
271     def __init__(self, stream, pos, size, name):
272         self._stream = stream
273         self._pos = pos
274         self._size = size
275         self._name = name
276         self._filepos = 0
277     def name(self):
278         return self._name
279     def decompressed_name(self):
280         return re.sub('\.(bz2|gz)$', '', self._name)
281     def size(self):
282         return self._size
283     def stream_name(self):
284         return self._stream.name()
285     def read(self, size, **kwargs):
286         self._stream.seek(self._pos + self._filepos)
287         data = self._stream.read(min(size, self._size - self._filepos))
288         self._filepos += len(data)
289         return data
290     def readall(self, size, **kwargs):
291         while True:
292             data = self.read(size, **kwargs)
293             if data == '':
294                 break
295             yield data
296     def bunzip2(self, size):
297         decompressor = bz2.BZ2Decompressor()
298         for chunk in self.readall(size):
299             data = decompressor.decompress(chunk)
300             if data and data != '':
301                 yield data
302     def gunzip(self, size):
303         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
304         for chunk in self.readall(size):
305             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
306             if data and data != '':
307                 yield data
308     def readlines(self, decompress=True):
309         self._stream.seek(self._pos + self._filepos)
310         if decompress and re.search('\.bz2$', self._name):
311             datasource = self.bunzip2(2**10)
312         elif decompress and re.search('\.gz$', self._name):
313             datasource = self.gunzip(2**10)
314         else:
315             datasource = self.readall(2**10)
316         data = ''
317         for newdata in datasource:
318             data += newdata
319             sol = 0
320             while True:
321                 eol = string.find(data, "\n", sol)
322                 if eol < 0:
323                     break
324                 yield data[sol:eol+1]
325                 sol = eol+1
326             data = data[sol:]
327         if data != '':
328             yield data
329     def as_manifest(self):
330         if self.size() == 0:
331             return ("%s d41d8cd98f00b204e9800998ecf8427e+0 0:0:%s\n"
332                     % (self._stream.name(), self.name()))
333         return string.join(self._stream.tokens_for_range(self._pos, self._size),
334                            " ") + "\n"
335
336 class StreamReader:
337     def __init__(self, tokens):
338         self._tokens = tokens
339         self._current_datablock_data = None
340         self._current_datablock_pos = 0
341         self._current_datablock_index = -1
342         self._pos = 0
343
344         self._stream_name = None
345         self.data_locators = []
346         self.files = []
347
348         for tok in self._tokens:
349             if self._stream_name == None:
350                 self._stream_name = tok
351             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
352                 self.data_locators += [tok]
353             elif re.search(r'^\d+:\d+:\S+', tok):
354                 pos, size, name = tok.split(':',2)
355                 self.files += [[int(pos), int(size), name]]
356             else:
357                 raise Exception("Invalid manifest format")
358     def tokens_for_range(self, range_start, range_size):
359         resp = [self._stream_name]
360         return_all_tokens = False
361         block_start = 0
362         token_bytes_skipped = 0
363         for locator in self.data_locators:
364             sizehint = re.search(r'\+(\d+)', locator)
365             if not sizehint:
366                 return_all_tokens = True
367             if return_all_tokens:
368                 resp += [locator]
369                 next
370             blocksize = int(sizehint.group(0))
371             if range_start + range_size <= block_start:
372                 break
373             if range_start < block_start + blocksize:
374                 resp += [locator]
375             else:
376                 token_bytes_skipped += blocksize
377             block_start += blocksize
378         for f in self.files:
379             if ((f[0] < range_start + range_size)
380                 and
381                 (f[0] + f[1] > range_start)
382                 and
383                 f[1] > 0):
384                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
385         return resp
386     def name(self):
387         return self._stream_name
388     def all_files(self):
389         for f in self.files:
390             pos, size, name = f
391             yield StreamFileReader(self, pos, size, name)
392     def nextdatablock(self):
393         if self._current_datablock_index < 0:
394             self._current_datablock_pos = 0
395             self._current_datablock_index = 0
396         else:
397             self._current_datablock_pos += self.current_datablock_size()
398             self._current_datablock_index += 1
399         self._current_datablock_data = None
400     def current_datablock_data(self):
401         if self._current_datablock_data == None:
402             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
403         return self._current_datablock_data
404     def current_datablock_size(self):
405         if self._current_datablock_index < 0:
406             self.nextdatablock()
407         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
408         if sizehint:
409             return int(sizehint.group(0))
410         return len(self.current_datablock_data())
411     def seek(self, pos):
412         """Set the position of the next read operation."""
413         self._pos = pos
414     def really_seek(self):
415         """Find and load the appropriate data block, so the byte at
416         _pos is in memory.
417         """
418         if self._pos == self._current_datablock_pos:
419             return True
420         if (self._current_datablock_pos != None and
421             self._pos >= self._current_datablock_pos and
422             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
423             return True
424         if self._pos < self._current_datablock_pos:
425             self._current_datablock_index = -1
426             self.nextdatablock()
427         while (self._pos > self._current_datablock_pos and
428                self._pos > self._current_datablock_pos + self.current_datablock_size()):
429             self.nextdatablock()
430     def read(self, size):
431         """Read no more than size bytes -- but at least one byte,
432         unless _pos is already at the end of the stream.
433         """
434         if size == 0:
435             return ''
436         self.really_seek()
437         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
438             self.nextdatablock()
439             if self._current_datablock_index >= len(self.data_locators):
440                 return None
441         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
442         self._pos += len(data)
443         return data
444
445 class CollectionReader:
446     def __init__(self, manifest_locator_or_text):
447         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
448             self._manifest_text = manifest_locator_or_text
449             self._manifest_locator = None
450         else:
451             self._manifest_locator = manifest_locator_or_text
452             self._manifest_text = None
453         self._streams = None
454     def __enter__(self):
455         pass
456     def __exit__(self):
457         pass
458     def _populate(self):
459         if self._streams != None:
460             return
461         if not self._manifest_text:
462             self._manifest_text = Keep.get(self._manifest_locator)
463         self._streams = []
464         for stream_line in self._manifest_text.split("\n"):
465             stream_tokens = stream_line.split()
466             self._streams += [stream_tokens]
467     def all_streams(self):
468         self._populate()
469         resp = []
470         for s in self._streams:
471             resp += [StreamReader(s)]
472         return resp
473     def all_files(self):
474         for s in self.all_streams():
475             for f in s.all_files():
476                 yield f
477
478 class CollectionWriter:
479     KEEP_BLOCK_SIZE = 2**26
480     def __init__(self):
481         self._data_buffer = []
482         self._data_buffer_len = 0
483         self._current_stream_files = []
484         self._current_stream_length = 0
485         self._current_stream_locators = []
486         self._current_stream_name = '.'
487         self._current_file_name = None
488         self._current_file_pos = 0
489         self._finished_streams = []
490     def __enter__(self):
491         pass
492     def __exit__(self):
493         self.finish()
494     def write(self, newdata):
495         self._data_buffer += [newdata]
496         self._data_buffer_len += len(newdata)
497         self._current_stream_length += len(newdata)
498         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
499             self.flush_data()
500     def flush_data(self):
501         data_buffer = ''.join(self._data_buffer)
502         if data_buffer != '':
503             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
504             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
505     def start_new_file(self, newfilename=None):
506         self.finish_current_file()
507         self.set_current_file_name(newfilename)
508     def set_current_file_name(self, newfilename):
509         if re.search(r'[ \t\n]', newfilename):
510             raise AssertionError("Manifest filenames cannot contain whitespace")
511         self._current_file_name = newfilename
512     def current_file_name(self):
513         return self._current_file_name
514     def finish_current_file(self):
515         if self._current_file_name == None:
516             if self._current_file_pos == self._current_stream_length:
517                 return
518             raise Exception("Cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self._current_stream_length - self._current_file_pos, self._current_file_pos, self._current_stream_name))
519         self._current_stream_files += [[self._current_file_pos,
520                                        self._current_stream_length - self._current_file_pos,
521                                        self._current_file_name]]
522         self._current_file_pos = self._current_stream_length
523     def start_new_stream(self, newstreamname=None):
524         self.finish_current_stream()
525         self.set_current_stream_name(newstreamname)
526     def set_current_stream_name(self, newstreamname):
527         if re.search(r'[ \t\n]', newstreamname):
528             raise AssertionError("Manifest stream names cannot contain whitespace")
529         self._current_stream_name = newstreamname
530     def current_stream_name(self):
531         return self._current_stream_name
532     def finish_current_stream(self):
533         self.finish_current_file()
534         self.flush_data()
535         if len(self._current_stream_files) == 0:
536             pass
537         elif self._current_stream_name == None:
538             raise Exception("Cannot finish an unnamed stream (%d bytes in %d files)" % (self._current_stream_length, len(self._current_stream_files)))
539         else:
540             self._finished_streams += [[self._current_stream_name,
541                                        self._current_stream_locators,
542                                        self._current_stream_files]]
543         self._current_stream_files = []
544         self._current_stream_length = 0
545         self._current_stream_locators = []
546         self._current_stream_name = None
547         self._current_file_pos = 0
548         self._current_file_name = None
549     def finish(self):
550         return Keep.put(self.manifest_text())
551     def manifest_text(self):
552         self.finish_current_stream()
553         manifest = ''
554         for stream in self._finished_streams:
555             manifest += stream[0]
556             if len(stream[1]) == 0:
557                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
558             else:
559                 for locator in stream[1]:
560                     manifest += " %s" % locator
561             for sfile in stream[2]:
562                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
563             manifest += "\n"
564         return manifest
565
566 class Keep:
567     @staticmethod
568     def put(data):
569         if 'KEEP_LOCAL_STORE' in os.environ:
570             return Keep.local_store_put(data)
571         p = subprocess.Popen(["whput", "-"],
572                              stdout=subprocess.PIPE,
573                              stdin=subprocess.PIPE,
574                              stderr=subprocess.PIPE,
575                              shell=False, close_fds=True)
576         stdoutdata, stderrdata = p.communicate(data)
577         if p.returncode != 0:
578             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
579         return stdoutdata.rstrip()
580     @staticmethod
581     def get(locator):
582         if 'KEEP_LOCAL_STORE' in os.environ:
583             return Keep.local_store_get(locator)
584         p = subprocess.Popen(["whget", locator, "-"],
585                              stdout=subprocess.PIPE,
586                              stdin=None,
587                              stderr=subprocess.PIPE,
588                              shell=False, close_fds=True)
589         stdoutdata, stderrdata = p.communicate(None)
590         if p.returncode != 0:
591             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
592         m = hashlib.new('md5')
593         m.update(stdoutdata)
594         try:
595             if locator.index(m.hexdigest()) == 0:
596                 return stdoutdata
597         except ValueError:
598             pass
599         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))
600     @staticmethod
601     def local_store_put(data):
602         m = hashlib.new('md5')
603         m.update(data)
604         md5 = m.hexdigest()
605         locator = '%s+%d' % (md5, len(data))
606         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
607             f.write(data)
608         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
609                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
610         return locator
611     @staticmethod
612     def local_store_get(locator):
613         r = re.search('^([0-9a-f]{32,})', locator)
614         if not r:
615             raise Exception("Keep.get: invalid data locator '%s'" % locator)
616         if r.group(0) == 'd41d8cd98f00b204e9800998ecf8427e':
617             return ''
618         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
619             return f.read()