Added pip requirements.txt to automatically install python dependencies for Arvados...
[arvados.git] / sdk / python / arvados / __init__.py
1 import gflags
2 import httplib
3 import httplib2
4 import logging
5 import os
6 import pprint
7 import sys
8 import types
9 import subprocess
10 import json
11 import UserDict
12 import re
13 import hashlib
14 import string
15 import bz2
16 import zlib
17 import fcntl
18 import time
19 import threading
20
21 from apiclient import errors
22 from apiclient.discovery import build
23
24 if 'ARVADOS_DEBUG' in os.environ:
25     logging.basicConfig(level=logging.DEBUG)
26
27 EMPTY_BLOCK_LOCATOR = 'd41d8cd98f00b204e9800998ecf8427e+0'
28
29 services = {}
30
31 class errors:
32     class SyntaxError(Exception):
33         pass
34     class AssertionError(Exception):
35         pass
36     class NotFoundError(Exception):
37         pass
38     class CommandFailedError(Exception):
39         pass
40     class KeepWriteError(Exception):
41         pass
42     class NotImplementedError(Exception):
43         pass
44
45 class CredentialsFromEnv(object):
46     @staticmethod
47     def http_request(self, uri, **kwargs):
48         from httplib import BadStatusLine
49         if 'headers' not in kwargs:
50             kwargs['headers'] = {}
51         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
52         try:
53             return self.orig_http_request(uri, **kwargs)
54         except BadStatusLine:
55             # This is how httplib tells us that it tried to reuse an
56             # existing connection but it was already closed by the
57             # server. In that case, yes, we would like to retry.
58             # Unfortunately, we are not absolutely certain that the
59             # previous call did not succeed, so this is slightly
60             # risky.
61             return self.orig_http_request(uri, **kwargs)
62     def authorize(self, http):
63         http.orig_http_request = http.request
64         http.request = types.MethodType(self.http_request, http)
65         return http
66
67 def task_set_output(self,s):
68     api('v1').job_tasks().update(uuid=self['uuid'],
69                                  body={
70             'output':s,
71             'success':True,
72             'progress':1.0
73             }).execute()
74
75 _current_task = None
76 def current_task():
77     global _current_task
78     if _current_task:
79         return _current_task
80     t = api('v1').job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
81     t = UserDict.UserDict(t)
82     t.set_output = types.MethodType(task_set_output, t)
83     t.tmpdir = os.environ['TASK_WORK']
84     _current_task = t
85     return t
86
87 _current_job = None
88 def current_job():
89     global _current_job
90     if _current_job:
91         return _current_job
92     t = api('v1').jobs().get(uuid=os.environ['JOB_UUID']).execute()
93     t = UserDict.UserDict(t)
94     t.tmpdir = os.environ['JOB_WORK']
95     _current_job = t
96     return t
97
98 def getjobparam(*args):
99     return current_job()['script_parameters'].get(*args)
100
101 def api(version=None):
102     global services
103     if not services.get(version):
104         apiVersion = version
105         if not version:
106             apiVersion = 'v1'
107             logging.info("Using default API version. " +
108                          "Call arvados.api('%s') instead." %
109                          apiVersion)
110         if 'ARVADOS_API_HOST' not in os.environ:
111             raise Exception("ARVADOS_API_HOST is not set. Aborting.")
112         url = ('https://%s/discovery/v1/apis/{api}/{apiVersion}/rest' %
113                os.environ['ARVADOS_API_HOST'])
114         credentials = CredentialsFromEnv()
115
116         # Use system's CA certificates (if we find them) instead of httplib2's
117         ca_certs = '/etc/ssl/certs/ca-certificates.crt'
118         if not os.path.exists(ca_certs):
119             ca_certs = None             # use httplib2 default
120
121         http = httplib2.Http(ca_certs=ca_certs)
122         http = credentials.authorize(http)
123         if re.match(r'(?i)^(true|1|yes)$',
124                     os.environ.get('ARVADOS_API_HOST_INSECURE', '')):
125             http.disable_ssl_certificate_validation=True
126         services[version] = build(
127             'arvados', apiVersion, http=http, discoveryServiceUrl=url)
128     return services[version]
129
130 class JobTask(object):
131     def __init__(self, parameters=dict(), runtime_constraints=dict()):
132         print "init jobtask %s %s" % (parameters, runtime_constraints)
133
134 class job_setup:
135     @staticmethod
136     def one_task_per_input_file(if_sequence=0, and_end_task=True):
137         if if_sequence != current_task()['sequence']:
138             return
139         job_input = current_job()['script_parameters']['input']
140         cr = CollectionReader(job_input)
141         for s in cr.all_streams():
142             for f in s.all_files():
143                 task_input = f.as_manifest()
144                 new_task_attrs = {
145                     'job_uuid': current_job()['uuid'],
146                     'created_by_job_task_uuid': current_task()['uuid'],
147                     'sequence': if_sequence + 1,
148                     'parameters': {
149                         'input':task_input
150                         }
151                     }
152                 api('v1').job_tasks().create(body=new_task_attrs).execute()
153         if and_end_task:
154             api('v1').job_tasks().update(uuid=current_task()['uuid'],
155                                        body={'success':True}
156                                        ).execute()
157             exit(0)
158
159     @staticmethod
160     def one_task_per_input_stream(if_sequence=0, and_end_task=True):
161         if if_sequence != current_task()['sequence']:
162             return
163         job_input = current_job()['script_parameters']['input']
164         cr = CollectionReader(job_input)
165         for s in cr.all_streams():
166             task_input = s.tokens()
167             new_task_attrs = {
168                 'job_uuid': current_job()['uuid'],
169                 'created_by_job_task_uuid': current_task()['uuid'],
170                 'sequence': if_sequence + 1,
171                 'parameters': {
172                     'input':task_input
173                     }
174                 }
175             api('v1').job_tasks().create(body=new_task_attrs).execute()
176         if and_end_task:
177             api('v1').job_tasks().update(uuid=current_task()['uuid'],
178                                        body={'success':True}
179                                        ).execute()
180             exit(0)
181
182 class util:
183     @staticmethod
184     def clear_tmpdir(path=None):
185         """
186         Ensure the given directory (or TASK_TMPDIR if none given)
187         exists and is empty.
188         """
189         if path == None:
190             path = current_task().tmpdir
191         if os.path.exists(path):
192             p = subprocess.Popen(['rm', '-rf', path])
193             stdout, stderr = p.communicate(None)
194             if p.returncode != 0:
195                 raise Exception('rm -rf %s: %s' % (path, stderr))
196         os.mkdir(path)
197
198     @staticmethod
199     def run_command(execargs, **kwargs):
200         kwargs.setdefault('stdin', subprocess.PIPE)
201         kwargs.setdefault('stdout', subprocess.PIPE)
202         kwargs.setdefault('stderr', sys.stderr)
203         kwargs.setdefault('close_fds', True)
204         kwargs.setdefault('shell', False)
205         p = subprocess.Popen(execargs, **kwargs)
206         stdoutdata, stderrdata = p.communicate(None)
207         if p.returncode != 0:
208             raise errors.CommandFailedError(
209                 "run_command %s exit %d:\n%s" %
210                 (execargs, p.returncode, stderrdata))
211         return stdoutdata, stderrdata
212
213     @staticmethod
214     def git_checkout(url, version, path):
215         if not re.search('^/', path):
216             path = os.path.join(current_job().tmpdir, path)
217         if not os.path.exists(path):
218             util.run_command(["git", "clone", url, path],
219                              cwd=os.path.dirname(path))
220         util.run_command(["git", "checkout", version],
221                          cwd=path)
222         return path
223
224     @staticmethod
225     def tar_extractor(path, decompress_flag):
226         return subprocess.Popen(["tar",
227                                  "-C", path,
228                                  ("-x%sf" % decompress_flag),
229                                  "-"],
230                                 stdout=None,
231                                 stdin=subprocess.PIPE, stderr=sys.stderr,
232                                 shell=False, close_fds=True)
233
234     @staticmethod
235     def tarball_extract(tarball, path):
236         """Retrieve a tarball from Keep and extract it to a local
237         directory.  Return the absolute path where the tarball was
238         extracted. If the top level of the tarball contained just one
239         file or directory, return the absolute path of that single
240         item.
241
242         tarball -- collection locator
243         path -- where to extract the tarball: absolute, or relative to job tmp
244         """
245         if not re.search('^/', path):
246             path = os.path.join(current_job().tmpdir, path)
247         lockfile = open(path + '.lock', 'w')
248         fcntl.flock(lockfile, fcntl.LOCK_EX)
249         try:
250             os.stat(path)
251         except OSError:
252             os.mkdir(path)
253         already_have_it = False
254         try:
255             if os.readlink(os.path.join(path, '.locator')) == tarball:
256                 already_have_it = True
257         except OSError:
258             pass
259         if not already_have_it:
260
261             # emulate "rm -f" (i.e., if the file does not exist, we win)
262             try:
263                 os.unlink(os.path.join(path, '.locator'))
264             except OSError:
265                 if os.path.exists(os.path.join(path, '.locator')):
266                     os.unlink(os.path.join(path, '.locator'))
267
268             for f in CollectionReader(tarball).all_files():
269                 if re.search('\.(tbz|tar.bz2)$', f.name()):
270                     p = util.tar_extractor(path, 'j')
271                 elif re.search('\.(tgz|tar.gz)$', f.name()):
272                     p = util.tar_extractor(path, 'z')
273                 elif re.search('\.tar$', f.name()):
274                     p = util.tar_extractor(path, '')
275                 else:
276                     raise errors.AssertionError(
277                         "tarball_extract cannot handle filename %s" % f.name())
278                 while True:
279                     buf = f.read(2**20)
280                     if len(buf) == 0:
281                         break
282                     p.stdin.write(buf)
283                 p.stdin.close()
284                 p.wait()
285                 if p.returncode != 0:
286                     lockfile.close()
287                     raise errors.CommandFailedError(
288                         "tar exited %d" % p.returncode)
289             os.symlink(tarball, os.path.join(path, '.locator'))
290         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
291         lockfile.close()
292         if len(tld_extracts) == 1:
293             return os.path.join(path, tld_extracts[0])
294         return path
295
296     @staticmethod
297     def zipball_extract(zipball, path):
298         """Retrieve a zip archive from Keep and extract it to a local
299         directory.  Return the absolute path where the archive was
300         extracted. If the top level of the archive contained just one
301         file or directory, return the absolute path of that single
302         item.
303
304         zipball -- collection locator
305         path -- where to extract the archive: absolute, or relative to job tmp
306         """
307         if not re.search('^/', path):
308             path = os.path.join(current_job().tmpdir, path)
309         lockfile = open(path + '.lock', 'w')
310         fcntl.flock(lockfile, fcntl.LOCK_EX)
311         try:
312             os.stat(path)
313         except OSError:
314             os.mkdir(path)
315         already_have_it = False
316         try:
317             if os.readlink(os.path.join(path, '.locator')) == zipball:
318                 already_have_it = True
319         except OSError:
320             pass
321         if not already_have_it:
322
323             # emulate "rm -f" (i.e., if the file does not exist, we win)
324             try:
325                 os.unlink(os.path.join(path, '.locator'))
326             except OSError:
327                 if os.path.exists(os.path.join(path, '.locator')):
328                     os.unlink(os.path.join(path, '.locator'))
329
330             for f in CollectionReader(zipball).all_files():
331                 if not re.search('\.zip$', f.name()):
332                     raise errors.NotImplementedError(
333                         "zipball_extract cannot handle filename %s" % f.name())
334                 zip_filename = os.path.join(path, os.path.basename(f.name()))
335                 zip_file = open(zip_filename, 'wb')
336                 while True:
337                     buf = f.read(2**20)
338                     if len(buf) == 0:
339                         break
340                     zip_file.write(buf)
341                 zip_file.close()
342                 
343                 p = subprocess.Popen(["unzip",
344                                       "-q", "-o",
345                                       "-d", path,
346                                       zip_filename],
347                                      stdout=None,
348                                      stdin=None, stderr=sys.stderr,
349                                      shell=False, close_fds=True)
350                 p.wait()
351                 if p.returncode != 0:
352                     lockfile.close()
353                     raise errors.CommandFailedError(
354                         "unzip exited %d" % p.returncode)
355                 os.unlink(zip_filename)
356             os.symlink(zipball, os.path.join(path, '.locator'))
357         tld_extracts = filter(lambda f: f != '.locator', os.listdir(path))
358         lockfile.close()
359         if len(tld_extracts) == 1:
360             return os.path.join(path, tld_extracts[0])
361         return path
362
363     @staticmethod
364     def collection_extract(collection, path, files=[], decompress=True):
365         """Retrieve a collection from Keep and extract it to a local
366         directory.  Return the absolute path where the collection was
367         extracted.
368
369         collection -- collection locator
370         path -- where to extract: absolute, or relative to job tmp
371         """
372         matches = re.search(r'^([0-9a-f]+)(\+[\w@]+)*$', collection)
373         if matches:
374             collection_hash = matches.group(1)
375         else:
376             collection_hash = hashlib.md5(collection).hexdigest()
377         if not re.search('^/', path):
378             path = os.path.join(current_job().tmpdir, path)
379         lockfile = open(path + '.lock', 'w')
380         fcntl.flock(lockfile, fcntl.LOCK_EX)
381         try:
382             os.stat(path)
383         except OSError:
384             os.mkdir(path)
385         already_have_it = False
386         try:
387             if os.readlink(os.path.join(path, '.locator')) == collection_hash:
388                 already_have_it = True
389         except OSError:
390             pass
391
392         # emulate "rm -f" (i.e., if the file does not exist, we win)
393         try:
394             os.unlink(os.path.join(path, '.locator'))
395         except OSError:
396             if os.path.exists(os.path.join(path, '.locator')):
397                 os.unlink(os.path.join(path, '.locator'))
398
399         files_got = []
400         for s in CollectionReader(collection).all_streams():
401             stream_name = s.name()
402             for f in s.all_files():
403                 if (files == [] or
404                     ((f.name() not in files_got) and
405                      (f.name() in files or
406                       (decompress and f.decompressed_name() in files)))):
407                     outname = f.decompressed_name() if decompress else f.name()
408                     files_got += [outname]
409                     if os.path.exists(os.path.join(path, stream_name, outname)):
410                         continue
411                     util.mkdir_dash_p(os.path.dirname(os.path.join(path, stream_name, outname)))
412                     outfile = open(os.path.join(path, stream_name, outname), 'wb')
413                     for buf in (f.readall_decompressed() if decompress
414                                 else f.readall()):
415                         outfile.write(buf)
416                     outfile.close()
417         if len(files_got) < len(files):
418             raise errors.AssertionError(
419                 "Wanted files %s but only got %s from %s" %
420                 (files, files_got,
421                  [z.name() for z in CollectionReader(collection).all_files()]))
422         os.symlink(collection_hash, os.path.join(path, '.locator'))
423
424         lockfile.close()
425         return path
426
427     @staticmethod
428     def mkdir_dash_p(path):
429         if not os.path.exists(path):
430             util.mkdir_dash_p(os.path.dirname(path))
431             try:
432                 os.mkdir(path)
433             except OSError:
434                 if not os.path.exists(path):
435                     os.mkdir(path)
436
437     @staticmethod
438     def stream_extract(stream, path, files=[], decompress=True):
439         """Retrieve a stream from Keep and extract it to a local
440         directory.  Return the absolute path where the stream was
441         extracted.
442
443         stream -- StreamReader object
444         path -- where to extract: absolute, or relative to job tmp
445         """
446         if not re.search('^/', path):
447             path = os.path.join(current_job().tmpdir, path)
448         lockfile = open(path + '.lock', 'w')
449         fcntl.flock(lockfile, fcntl.LOCK_EX)
450         try:
451             os.stat(path)
452         except OSError:
453             os.mkdir(path)
454
455         files_got = []
456         for f in stream.all_files():
457             if (files == [] or
458                 ((f.name() not in files_got) and
459                  (f.name() in files or
460                   (decompress and f.decompressed_name() in files)))):
461                 outname = f.decompressed_name() if decompress else f.name()
462                 files_got += [outname]
463                 if os.path.exists(os.path.join(path, outname)):
464                     os.unlink(os.path.join(path, outname))
465                 util.mkdir_dash_p(os.path.dirname(os.path.join(path, outname)))
466                 outfile = open(os.path.join(path, outname), 'wb')
467                 for buf in (f.readall_decompressed() if decompress
468                             else f.readall()):
469                     outfile.write(buf)
470                 outfile.close()
471         if len(files_got) < len(files):
472             raise errors.AssertionError(
473                 "Wanted files %s but only got %s from %s" %
474                 (files, files_got, [z.name() for z in stream.all_files()]))
475         lockfile.close()
476         return path
477
478     @staticmethod
479     def listdir_recursive(dirname, base=None):
480         allfiles = []
481         for ent in sorted(os.listdir(dirname)):
482             ent_path = os.path.join(dirname, ent)
483             ent_base = os.path.join(base, ent) if base else ent
484             if os.path.isdir(ent_path):
485                 allfiles += util.listdir_recursive(ent_path, ent_base)
486             else:
487                 allfiles += [ent_base]
488         return allfiles
489
490 class StreamFileReader(object):
491     def __init__(self, stream, pos, size, name):
492         self._stream = stream
493         self._pos = pos
494         self._size = size
495         self._name = name
496         self._filepos = 0
497     def name(self):
498         return self._name
499     def decompressed_name(self):
500         return re.sub('\.(bz2|gz)$', '', self._name)
501     def size(self):
502         return self._size
503     def stream_name(self):
504         return self._stream.name()
505     def read(self, size, **kwargs):
506         self._stream.seek(self._pos + self._filepos)
507         data = self._stream.read(min(size, self._size - self._filepos))
508         self._filepos += len(data)
509         return data
510     def readall(self, size=2**20, **kwargs):
511         while True:
512             data = self.read(size, **kwargs)
513             if data == '':
514                 break
515             yield data
516     def bunzip2(self, size):
517         decompressor = bz2.BZ2Decompressor()
518         for chunk in self.readall(size):
519             data = decompressor.decompress(chunk)
520             if data and data != '':
521                 yield data
522     def gunzip(self, size):
523         decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
524         for chunk in self.readall(size):
525             data = decompressor.decompress(decompressor.unconsumed_tail + chunk)
526             if data and data != '':
527                 yield data
528     def readall_decompressed(self, size=2**20):
529         self._stream.seek(self._pos + self._filepos)
530         if re.search('\.bz2$', self._name):
531             return self.bunzip2(size)
532         elif re.search('\.gz$', self._name):
533             return self.gunzip(size)
534         else:
535             return self.readall(size)
536     def readlines(self, decompress=True):
537         if decompress:
538             datasource = self.readall_decompressed()
539         else:
540             self._stream.seek(self._pos + self._filepos)
541             datasource = self.readall()
542         data = ''
543         for newdata in datasource:
544             data += newdata
545             sol = 0
546             while True:
547                 eol = string.find(data, "\n", sol)
548                 if eol < 0:
549                     break
550                 yield data[sol:eol+1]
551                 sol = eol+1
552             data = data[sol:]
553         if data != '':
554             yield data
555     def as_manifest(self):
556         if self.size() == 0:
557             return ("%s %s 0:0:%s\n"
558                     % (self._stream.name(), EMPTY_BLOCK_LOCATOR, self.name()))
559         return string.join(self._stream.tokens_for_range(self._pos, self._size),
560                            " ") + "\n"
561
562 class StreamReader(object):
563     def __init__(self, tokens):
564         self._tokens = tokens
565         self._current_datablock_data = None
566         self._current_datablock_pos = 0
567         self._current_datablock_index = -1
568         self._pos = 0
569
570         self._stream_name = None
571         self.data_locators = []
572         self.files = []
573
574         for tok in self._tokens:
575             if self._stream_name == None:
576                 self._stream_name = tok
577             elif re.search(r'^[0-9a-f]{32}(\+\S+)*$', tok):
578                 self.data_locators += [tok]
579             elif re.search(r'^\d+:\d+:\S+', tok):
580                 pos, size, name = tok.split(':',2)
581                 self.files += [[int(pos), int(size), name]]
582             else:
583                 raise errors.SyntaxError("Invalid manifest format")
584
585     def tokens(self):
586         return self._tokens
587     def tokens_for_range(self, range_start, range_size):
588         resp = [self._stream_name]
589         return_all_tokens = False
590         block_start = 0
591         token_bytes_skipped = 0
592         for locator in self.data_locators:
593             sizehint = re.search(r'\+(\d+)', locator)
594             if not sizehint:
595                 return_all_tokens = True
596             if return_all_tokens:
597                 resp += [locator]
598                 next
599             blocksize = int(sizehint.group(0))
600             if range_start + range_size <= block_start:
601                 break
602             if range_start < block_start + blocksize:
603                 resp += [locator]
604             else:
605                 token_bytes_skipped += blocksize
606             block_start += blocksize
607         for f in self.files:
608             if ((f[0] < range_start + range_size)
609                 and
610                 (f[0] + f[1] > range_start)
611                 and
612                 f[1] > 0):
613                 resp += ["%d:%d:%s" % (f[0] - token_bytes_skipped, f[1], f[2])]
614         return resp
615     def name(self):
616         return self._stream_name
617     def all_files(self):
618         for f in self.files:
619             pos, size, name = f
620             yield StreamFileReader(self, pos, size, name)
621     def nextdatablock(self):
622         if self._current_datablock_index < 0:
623             self._current_datablock_pos = 0
624             self._current_datablock_index = 0
625         else:
626             self._current_datablock_pos += self.current_datablock_size()
627             self._current_datablock_index += 1
628         self._current_datablock_data = None
629     def current_datablock_data(self):
630         if self._current_datablock_data == None:
631             self._current_datablock_data = Keep.get(self.data_locators[self._current_datablock_index])
632         return self._current_datablock_data
633     def current_datablock_size(self):
634         if self._current_datablock_index < 0:
635             self.nextdatablock()
636         sizehint = re.search('\+(\d+)', self.data_locators[self._current_datablock_index])
637         if sizehint:
638             return int(sizehint.group(0))
639         return len(self.current_datablock_data())
640     def seek(self, pos):
641         """Set the position of the next read operation."""
642         self._pos = pos
643     def really_seek(self):
644         """Find and load the appropriate data block, so the byte at
645         _pos is in memory.
646         """
647         if self._pos == self._current_datablock_pos:
648             return True
649         if (self._current_datablock_pos != None and
650             self._pos >= self._current_datablock_pos and
651             self._pos <= self._current_datablock_pos + self.current_datablock_size()):
652             return True
653         if self._pos < self._current_datablock_pos:
654             self._current_datablock_index = -1
655             self.nextdatablock()
656         while (self._pos > self._current_datablock_pos and
657                self._pos > self._current_datablock_pos + self.current_datablock_size()):
658             self.nextdatablock()
659     def read(self, size):
660         """Read no more than size bytes -- but at least one byte,
661         unless _pos is already at the end of the stream.
662         """
663         if size == 0:
664             return ''
665         self.really_seek()
666         while self._pos >= self._current_datablock_pos + self.current_datablock_size():
667             self.nextdatablock()
668             if self._current_datablock_index >= len(self.data_locators):
669                 return None
670         data = self.current_datablock_data()[self._pos - self._current_datablock_pos : self._pos - self._current_datablock_pos + size]
671         self._pos += len(data)
672         return data
673
674 class CollectionReader(object):
675     def __init__(self, manifest_locator_or_text):
676         if re.search(r'^\S+( [a-f0-9]{32,}(\+\S+)*)+( \d+:\d+:\S+)+\n', manifest_locator_or_text):
677             self._manifest_text = manifest_locator_or_text
678             self._manifest_locator = None
679         else:
680             self._manifest_locator = manifest_locator_or_text
681             self._manifest_text = None
682         self._streams = None
683     def __enter__(self):
684         pass
685     def __exit__(self):
686         pass
687     def _populate(self):
688         if self._streams != None:
689             return
690         if not self._manifest_text:
691             self._manifest_text = Keep.get(self._manifest_locator)
692         self._streams = []
693         for stream_line in self._manifest_text.split("\n"):
694             if stream_line != '':
695                 stream_tokens = stream_line.split()
696                 self._streams += [stream_tokens]
697     def all_streams(self):
698         self._populate()
699         resp = []
700         for s in self._streams:
701             resp += [StreamReader(s)]
702         return resp
703     def all_files(self):
704         for s in self.all_streams():
705             for f in s.all_files():
706                 yield f
707     def manifest_text(self):
708         self._populate()
709         return self._manifest_text
710
711 class CollectionWriter(object):
712     KEEP_BLOCK_SIZE = 2**26
713     def __init__(self):
714         self._data_buffer = []
715         self._data_buffer_len = 0
716         self._current_stream_files = []
717         self._current_stream_length = 0
718         self._current_stream_locators = []
719         self._current_stream_name = '.'
720         self._current_file_name = None
721         self._current_file_pos = 0
722         self._finished_streams = []
723     def __enter__(self):
724         pass
725     def __exit__(self):
726         self.finish()
727     def write_directory_tree(self,
728                              path, stream_name='.', max_manifest_depth=-1):
729         self.start_new_stream(stream_name)
730         todo = []
731         if max_manifest_depth == 0:
732             dirents = sorted(util.listdir_recursive(path))
733         else:
734             dirents = sorted(os.listdir(path))
735         for dirent in dirents:
736             target = os.path.join(path, dirent)
737             if os.path.isdir(target):
738                 todo += [[target,
739                           os.path.join(stream_name, dirent),
740                           max_manifest_depth-1]]
741             else:
742                 self.start_new_file(dirent)
743                 with open(target, 'rb') as f:
744                     while True:
745                         buf = f.read(2**26)
746                         if len(buf) == 0:
747                             break
748                         self.write(buf)
749         self.finish_current_stream()
750         map(lambda x: self.write_directory_tree(*x), todo)
751
752     def write(self, newdata):
753         if hasattr(newdata, '__iter__'):
754             for s in newdata:
755                 self.write(s)
756             return
757         self._data_buffer += [newdata]
758         self._data_buffer_len += len(newdata)
759         self._current_stream_length += len(newdata)
760         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
761             self.flush_data()
762     def flush_data(self):
763         data_buffer = ''.join(self._data_buffer)
764         if data_buffer != '':
765             self._current_stream_locators += [Keep.put(data_buffer[0:self.KEEP_BLOCK_SIZE])]
766             self._data_buffer = [data_buffer[self.KEEP_BLOCK_SIZE:]]
767             self._data_buffer_len = len(self._data_buffer[0])
768     def start_new_file(self, newfilename=None):
769         self.finish_current_file()
770         self.set_current_file_name(newfilename)
771     def set_current_file_name(self, newfilename):
772         newfilename = re.sub(r' ', '\\\\040', newfilename)
773         if re.search(r'[ \t\n]', newfilename):
774             raise errors.AssertionError(
775                 "Manifest filenames cannot contain whitespace: %s" %
776                 newfilename)
777         self._current_file_name = newfilename
778     def current_file_name(self):
779         return self._current_file_name
780     def finish_current_file(self):
781         if self._current_file_name == None:
782             if self._current_file_pos == self._current_stream_length:
783                 return
784             raise errors.AssertionError(
785                 "Cannot finish an unnamed file " +
786                 "(%d bytes at offset %d in '%s' stream)" %
787                 (self._current_stream_length - self._current_file_pos,
788                  self._current_file_pos,
789                  self._current_stream_name))
790         self._current_stream_files += [[self._current_file_pos,
791                                        self._current_stream_length - self._current_file_pos,
792                                        self._current_file_name]]
793         self._current_file_pos = self._current_stream_length
794     def start_new_stream(self, newstreamname='.'):
795         self.finish_current_stream()
796         self.set_current_stream_name(newstreamname)
797     def set_current_stream_name(self, newstreamname):
798         if re.search(r'[ \t\n]', newstreamname):
799             raise errors.AssertionError(
800                 "Manifest stream names cannot contain whitespace")
801         self._current_stream_name = '.' if newstreamname=='' else newstreamname
802     def current_stream_name(self):
803         return self._current_stream_name
804     def finish_current_stream(self):
805         self.finish_current_file()
806         self.flush_data()
807         if len(self._current_stream_files) == 0:
808             pass
809         elif self._current_stream_name == None:
810             raise errors.AssertionError(
811                 "Cannot finish an unnamed stream (%d bytes in %d files)" %
812                 (self._current_stream_length, len(self._current_stream_files)))
813         else:
814             if len(self._current_stream_locators) == 0:
815                 self._current_stream_locators += [EMPTY_BLOCK_LOCATOR]
816             self._finished_streams += [[self._current_stream_name,
817                                        self._current_stream_locators,
818                                        self._current_stream_files]]
819         self._current_stream_files = []
820         self._current_stream_length = 0
821         self._current_stream_locators = []
822         self._current_stream_name = None
823         self._current_file_pos = 0
824         self._current_file_name = None
825     def finish(self):
826         return Keep.put(self.manifest_text())
827     def manifest_text(self):
828         self.finish_current_stream()
829         manifest = ''
830         for stream in self._finished_streams:
831             if not re.search(r'^\.(/.*)?$', stream[0]):
832                 manifest += './'
833             manifest += stream[0]
834             for locator in stream[1]:
835                 manifest += " %s" % locator
836             for sfile in stream[2]:
837                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
838             manifest += "\n"
839         return manifest
840     def data_locators(self):
841         ret = []
842         for name, locators, files in self._finished_streams:
843             ret += locators
844         return ret
845
846 global_client_object = None
847
848 class Keep:
849     @staticmethod
850     def global_client_object():
851         global global_client_object
852         if global_client_object == None:
853             global_client_object = KeepClient()
854         return global_client_object
855
856     @staticmethod
857     def get(locator, **kwargs):
858         return Keep.global_client_object().get(locator, **kwargs)
859
860     @staticmethod
861     def put(data, **kwargs):
862         return Keep.global_client_object().put(data, **kwargs)
863
864 class KeepClient(object):
865
866     class ThreadLimiter(object):
867         """
868         Limit the number of threads running at a given time to
869         {desired successes} minus {successes reported}. When successes
870         reported == desired, wake up the remaining threads and tell
871         them to quit.
872
873         Should be used in a "with" block.
874         """
875         def __init__(self, todo):
876             self._todo = todo
877             self._done = 0
878             self._todo_lock = threading.Semaphore(todo)
879             self._done_lock = threading.Lock()
880         def __enter__(self):
881             self._todo_lock.acquire()
882             return self
883         def __exit__(self, type, value, traceback):
884             self._todo_lock.release()
885         def shall_i_proceed(self):
886             """
887             Return true if the current thread should do stuff. Return
888             false if the current thread should just stop.
889             """
890             with self._done_lock:
891                 return (self._done < self._todo)
892         def increment_done(self):
893             """
894             Report that the current thread was successful.
895             """
896             with self._done_lock:
897                 self._done += 1
898         def done(self):
899             """
900             Return how many successes were reported.
901             """
902             with self._done_lock:
903                 return self._done
904
905     class KeepWriterThread(threading.Thread):
906         """
907         Write a blob of data to the given Keep server. Call
908         increment_done() of the given ThreadLimiter if the write
909         succeeds.
910         """
911         def __init__(self, **kwargs):
912             super(KeepClient.KeepWriterThread, self).__init__()
913             self.args = kwargs
914         def run(self):
915             with self.args['thread_limiter'] as limiter:
916                 if not limiter.shall_i_proceed():
917                     # My turn arrived, but the job has been done without
918                     # me.
919                     return
920                 logging.debug("KeepWriterThread %s proceeding %s %s" %
921                               (str(threading.current_thread()),
922                                self.args['data_hash'],
923                                self.args['service_root']))
924                 h = httplib2.Http()
925                 url = self.args['service_root'] + self.args['data_hash']
926                 api_token = os.environ['ARVADOS_API_TOKEN']
927                 headers = {'Authorization': "OAuth2 %s" % api_token}
928                 try:
929                     resp, content = h.request(url.encode('utf-8'), 'PUT',
930                                               headers=headers,
931                                               body=self.args['data'])
932                     if (resp['status'] == '401' and
933                         re.match(r'Timestamp verification failed', content)):
934                         body = KeepClient.sign_for_old_server(
935                             self.args['data_hash'],
936                             self.args['data'])
937                         h = httplib2.Http()
938                         resp, content = h.request(url.encode('utf-8'), 'PUT',
939                                                   headers=headers,
940                                                   body=body)
941                     if re.match(r'^2\d\d$', resp['status']):
942                         logging.debug("KeepWriterThread %s succeeded %s %s" %
943                                       (str(threading.current_thread()),
944                                        self.args['data_hash'],
945                                        self.args['service_root']))
946                         return limiter.increment_done()
947                     logging.warning("Request fail: PUT %s => %s %s" %
948                                     (url, resp['status'], content))
949                 except (httplib2.HttpLib2Error, httplib.HTTPException) as e:
950                     logging.warning("Request fail: PUT %s => %s: %s" %
951                                     (url, type(e), str(e)))
952
953     def __init__(self):
954         self.lock = threading.Lock()
955         self.service_roots = None
956
957     def shuffled_service_roots(self, hash):
958         if self.service_roots == None:
959             self.lock.acquire()
960             keep_disks = api().keep_disks().list().execute()['items']
961             roots = (("http%s://%s:%d/" %
962                       ('s' if f['service_ssl_flag'] else '',
963                        f['service_host'],
964                        f['service_port']))
965                      for f in keep_disks)
966             self.service_roots = sorted(set(roots))
967             logging.debug(str(self.service_roots))
968             self.lock.release()
969         seed = hash
970         pool = self.service_roots[:]
971         pseq = []
972         while len(pool) > 0:
973             if len(seed) < 8:
974                 if len(pseq) < len(hash) / 4: # first time around
975                     seed = hash[-4:] + hash
976                 else:
977                     seed += hash
978             probe = int(seed[0:8], 16) % len(pool)
979             pseq += [pool[probe]]
980             pool = pool[:probe] + pool[probe+1:]
981             seed = seed[8:]
982         logging.debug(str(pseq))
983         return pseq
984
985     def get(self, locator):
986         if 'KEEP_LOCAL_STORE' in os.environ:
987             return KeepClient.local_store_get(locator)
988         expect_hash = re.sub(r'\+.*', '', locator)
989         for service_root in self.shuffled_service_roots(expect_hash):
990             h = httplib2.Http()
991             url = service_root + expect_hash
992             api_token = os.environ['ARVADOS_API_TOKEN']
993             headers = {'Authorization': "OAuth2 %s" % api_token,
994                        'Accept': 'application/octet-stream'}
995             try:
996                 resp, content = h.request(url.encode('utf-8'), 'GET',
997                                           headers=headers)
998                 if re.match(r'^2\d\d$', resp['status']):
999                     m = hashlib.new('md5')
1000                     m.update(content)
1001                     md5 = m.hexdigest()
1002                     if md5 == expect_hash:
1003                         return content
1004                     logging.warning("Checksum fail: md5(%s) = %s" % (url, md5))
1005             except (httplib2.HttpLib2Error, httplib.ResponseNotReady) as e:
1006                 logging.info("Request fail: GET %s => %s: %s" %
1007                              (url, type(e), str(e)))
1008         raise errors.NotFoundError("Block not found: %s" % expect_hash)
1009
1010     def put(self, data, **kwargs):
1011         if 'KEEP_LOCAL_STORE' in os.environ:
1012             return KeepClient.local_store_put(data)
1013         m = hashlib.new('md5')
1014         m.update(data)
1015         data_hash = m.hexdigest()
1016         have_copies = 0
1017         want_copies = kwargs.get('copies', 2)
1018         if not (want_copies > 0):
1019             return data_hash
1020         threads = []
1021         thread_limiter = KeepClient.ThreadLimiter(want_copies)
1022         for service_root in self.shuffled_service_roots(data_hash):
1023             t = KeepClient.KeepWriterThread(data=data,
1024                                             data_hash=data_hash,
1025                                             service_root=service_root,
1026                                             thread_limiter=thread_limiter)
1027             t.start()
1028             threads += [t]
1029         for t in threads:
1030             t.join()
1031         have_copies = thread_limiter.done()
1032         if have_copies == want_copies:
1033             return (data_hash + '+' + str(len(data)))
1034         raise errors.KeepWriteError(
1035             "Write fail for %s: wanted %d but wrote %d" %
1036             (data_hash, want_copies, have_copies))
1037
1038     @staticmethod
1039     def sign_for_old_server(data_hash, data):
1040         return (("-----BEGIN PGP SIGNED MESSAGE-----\n\n\n%d %s\n-----BEGIN PGP SIGNATURE-----\n\n-----END PGP SIGNATURE-----\n" % (int(time.time()), data_hash)) + data)
1041
1042
1043     @staticmethod
1044     def local_store_put(data):
1045         m = hashlib.new('md5')
1046         m.update(data)
1047         md5 = m.hexdigest()
1048         locator = '%s+%d' % (md5, len(data))
1049         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'), 'w') as f:
1050             f.write(data)
1051         os.rename(os.path.join(os.environ['KEEP_LOCAL_STORE'], md5 + '.tmp'),
1052                   os.path.join(os.environ['KEEP_LOCAL_STORE'], md5))
1053         return locator
1054     @staticmethod
1055     def local_store_get(locator):
1056         r = re.search('^([0-9a-f]{32,})', locator)
1057         if not r:
1058             raise errors.NotFoundError(
1059                 "Invalid data locator: '%s'" % locator)
1060         if r.group(0) == EMPTY_BLOCK_LOCATOR.split('+')[0]:
1061             return ''
1062         with open(os.path.join(os.environ['KEEP_LOCAL_STORE'], r.group(0)), 'r') as f:
1063             return f.read()
1064
1065 # We really shouldn't do this but some clients still use
1066 # arvados.service.* directly instead of arvados.api().*
1067 service = api()