add CollectionWriter and Keep classes to Python SDK
[arvados.git] / sdk / python / arvados.py
1 import gflags
2 import httplib2
3 import logging
4 import os
5 import pprint
6 import sys
7 import types
8 import subprocess
9 import json
10 import UserDict
11 import re
12 import hashlib
13
14 from apiclient import errors
15 from apiclient.discovery import build
16
17 class CredentialsFromEnv:
18     @staticmethod
19     def http_request(self, uri, **kwargs):
20         from httplib import BadStatusLine
21         if 'headers' not in kwargs:
22             kwargs['headers'] = {}
23         kwargs['headers']['Authorization'] = 'OAuth2 %s' % os.environ['ARVADOS_API_TOKEN']
24         try:
25             return self.orig_http_request(uri, **kwargs)
26         except BadStatusLine:
27             # This is how httplib tells us that it tried to reuse an
28             # existing connection but it was already closed by the
29             # server. In that case, yes, we would like to retry.
30             # Unfortunately, we are not absolutely certain that the
31             # previous call did not succeed, so this is slightly
32             # risky.
33             return self.orig_http_request(uri, **kwargs)
34     def authorize(self, http):
35         http.orig_http_request = http.request
36         http.request = types.MethodType(self.http_request, http)
37         return http
38
39 url = ('https://%s/discovery/v1/apis/'
40        '{api}/{apiVersion}/rest' % os.environ['ARVADOS_API_HOST'])
41 credentials = CredentialsFromEnv()
42 http = httplib2.Http()
43 http = credentials.authorize(http)
44 http.disable_ssl_certificate_validation=True
45 service = build("arvados", "v1", http=http, discoveryServiceUrl=url)
46
47 def task_set_output(self,s):
48     service.job_tasks().update(uuid=self['uuid'],
49                                job_task=json.dumps({
50                 'output':s,
51                 'success':True,
52                 'progress':1.0
53                 })).execute()
54
55 _current_task = None
56 def current_task():
57     global _current_task
58     if _current_task:
59         return _current_task
60     t = service.job_tasks().get(uuid=os.environ['TASK_UUID']).execute()
61     t = UserDict.UserDict(t)
62     t.set_output = types.MethodType(task_set_output, t)
63     _current_task = t
64     return t
65
66 _current_job = None
67 def current_job():
68     global _current_job
69     if _current_job:
70         return _current_job
71     t = service.jobs().get(uuid=os.environ['JOB_UUID']).execute()
72     _current_job = t
73     return t
74
75 def api():
76     return service
77
78 class JobTask:
79     def __init__(self, parameters=dict(), resource_limits=dict()):
80         print "init jobtask %s %s" % (parameters, resource_limits)
81
82 class job_setup:
83     @staticmethod
84     def one_task_per_input_file(if_sequence=0, and_end_task=True):
85         if if_sequence != current_task()['sequence']:
86             return
87         job_input = current_job()['script_parameters']['input']
88         p = subprocess.Popen(["whls", job_input],
89                              stdout=subprocess.PIPE,
90                              stdin=None, stderr=None,
91                              shell=False, close_fds=True)
92         for f in p.stdout.read().split("\n"):
93             if f != '':
94                 task_input = job_input + '/' + re.sub(r'^\./', '', f)
95                 new_task_attrs = {
96                     'job_uuid': current_job()['uuid'],
97                     'created_by_job_task': current_task()['uuid'],
98                     'sequence': if_sequence + 1,
99                     'parameters': {
100                         'input':task_input
101                         }
102                     }
103                 service.job_tasks().create(job_task=json.dumps(new_task_attrs)).execute()
104         p.stdout.close()
105         p.wait()
106         if p.returncode != 0:
107             raise Exception("whls exited %d" % p.returncode)
108         if and_end_task:
109             service.job_tasks().update(uuid=current_task()['uuid'],
110                                        job_task=json.dumps({'success':True})
111                                        ).execute()
112             exit(0)
113
114 class DataReader:
115     def __init__(self, data_locator):
116         self.data_locator = data_locator
117         self.p = subprocess.Popen(["whget", "-r", self.data_locator, "-"],
118                                   stdout=subprocess.PIPE,
119                                   stdin=None, stderr=subprocess.PIPE,
120                                   shell=False, close_fds=True)
121     def __enter__(self):
122         pass
123     def __exit__(self):
124         self.close()
125     def read(self, size, **kwargs):
126         return self.p.stdout.read(size, **kwargs)
127     def close(self):
128         self.p.stdout.close()
129         if not self.p.stderr.closed:
130             for err in self.p.stderr:
131                 print >> sys.stderr, err
132             self.p.stderr.close()
133         self.p.wait()
134         if self.p.returncode != 0:
135             raise Exception("whget subprocess exited %d" % self.p.returncode)
136
137 class CollectionWriter:
138     KEEP_BLOCK_SIZE = 2**26
139     def __init__(self):
140         self.data_buffer = ''
141         self.current_stream_files = []
142         self.current_stream_length = 0
143         self.current_stream_locators = []
144         self.current_stream_name = '.'
145         self.current_file_name = None
146         self.current_file_pos = 0
147         self.finished_streams = []
148     def __enter__(self):
149         pass
150     def __exit__(self):
151         self.commit()
152     def write(self, newdata):
153         self.data_buffer += newdata
154         self.current_stream_length += len(newdata)
155         while len(self.data_buffer) >= self.KEEP_BLOCK_SIZE:
156             self.flush_data()
157     def flush_data(self):
158         if self.data_buffer != '':
159             self.current_stream_locators += [Keep.put(self.data_buffer[0:self.KEEP_BLOCK_SIZE])]
160             self.data_buffer = self.data_buffer[self.KEEP_BLOCK_SIZE:]
161     def start_new_file(self, newfilename=None):
162         self.finish_current_file()
163         self.current_file_name = newfilename
164     def set_current_file_name(self, newfilename):
165         self.current_file_name = newfilename
166     def finish_current_file(self):
167         if self.current_file_name == None:
168             if self.current_file_pos == self.current_stream_length:
169                 return
170             raise Exception("cannot finish an unnamed file (%d bytes at offset %d in '%s' stream)" % (self.current_stream_length - self.current_file_pos, self.current_file_pos, self.current_stream_name))
171         self.current_stream_files += [[self.current_file_pos,
172                                        self.current_stream_length - self.current_file_pos,
173                                        self.current_file_name]]
174         self.current_file_pos = self.current_stream_length
175     def start_new_stream(self, newstreamname=None):
176         self.finish_current_stream()
177         self.current_stream_name = newstreamname
178     def set_current_stream_name(self, newstreamname):
179         self.current_stream_name = newstreamname
180     def finish_current_stream(self):
181         self.finish_current_file()
182         self.flush_data()
183         if len(self.current_stream_files) == 0:
184             pass
185         elif self.current_stream_name == None:
186             raise Exception("cannot finish an unnamed stream (%d bytes in %d files)" % (self.current_stream_length, len(self.current_stream_files)))
187         else:
188             self.finished_streams += [[self.current_stream_name,
189                                        self.current_stream_locators,
190                                        self.current_stream_files]]
191         self.current_stream_files = []
192         self.current_stream_length = 0
193         self.current_stream_locators = []
194         self.current_stream_name = None
195         self.current_file_pos = 0
196         self.current_file_name = None
197     def finish(self):
198         return Keep.put(self.manifest_text())
199     def manifest_text(self):
200         self.finish_current_stream()
201         manifest = ''
202         for stream in self.finished_streams:
203             manifest += stream[0]
204             if len(stream[1]) == 0:
205                 manifest += " d41d8cd98f00b204e9800998ecf8427e+0"
206             else:
207                 for locator in stream[1]:
208                     manifest += " %s" % locator
209             for sfile in stream[2]:
210                 manifest += " %d:%d:%s" % (sfile[0], sfile[1], sfile[2])
211             manifest += "\n"
212         return manifest
213
214 class Keep:
215     @staticmethod
216     def put(data):
217         p = subprocess.Popen(["whput", "-"],
218                              stdout=subprocess.PIPE,
219                              stdin=subprocess.PIPE,
220                              stderr=subprocess.PIPE,
221                              shell=False, close_fds=True)
222         stdoutdata, stderrdata = p.communicate(data)
223         if p.returncode != 0:
224             raise Exception("whput subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
225         return stdoutdata.rstrip()
226     @staticmethod
227     def get(locator):
228         p = subprocess.Popen(["whget", locator, "-"],
229                              stdout=subprocess.PIPE,
230                              stdin=None,
231                              stderr=subprocess.PIPE,
232                              shell=False, close_fds=True)
233         stdoutdata, stderrdata = p.communicate(None)
234         if p.returncode != 0:
235             raise Exception("whget subprocess exited %d - stderr:\n%s" % (p.returncode, stderrdata))
236         m = hashlib.new('md5')
237         m.update(stdoutdata)
238         try:
239             if locator.index(m.hexdigest()) == 0:
240                 return stdoutdata
241         except ValueError:
242             pass
243         raise Exception("md5 checksum mismatch: md5(get(%s)) == %s" % (locator, m.hexdigest()))