ff054a960a90917edf97ac68e2950d9c1a408ba0
[arvados.git] / examples / prediction / prediction.rb
1 #!/usr/bin/ruby1.8 
2 # -*- coding: utf-8 -*-
3
4 # Copyright:: Copyright 2011 Google Inc.
5 # License:: All Rights Reserved.
6 # Original Author:: Bob Aman, Winton Davies, Robert Kaplow
7 # Maintainer:: Robert Kaplow (mailto:rkaplow@google.com)
8
9 require 'rubygems'
10 require 'sinatra'
11 require 'datamapper'
12 require 'google/api_client'
13 require 'yaml'
14
15 use Rack::Session::Pool, :expire_after => 86400 # 1 day
16
17 # Set up our token store
18 DataMapper.setup(:default, 'sqlite::memory:')
19 class TokenPair
20   include DataMapper::Resource
21
22   property :id, Serial
23   property :refresh_token, String
24   property :access_token, String
25   property :expires_in, Integer
26   property :issued_at, Integer
27
28   def update_token!(object)
29     self.refresh_token = object.refresh_token
30     self.access_token = object.access_token
31     self.expires_in = object.expires_in
32     self.issued_at = object.issued_at
33   end
34
35   def to_hash
36     return {
37       :refresh_token => refresh_token,
38       :access_token => access_token,
39       :expires_in => expires_in,
40       :issued_at => Time.at(issued_at)
41     }
42   end
43 end
44 TokenPair.auto_migrate!
45
46 before do
47
48   # FILL IN THIS SECTION
49   # This will work if your yaml file is stored as ./google-api.yaml
50   # ------------------------
51   oauth_yaml = YAML.load_file('.google-api.yaml')
52   @client = Google::APIClient.new
53   @client.authorization.client_id = oauth_yaml["client_id"]
54   @client.authorization.client_secret = oauth_yaml["client_secret"]
55   @client.authorization.scope = oauth_yaml["scope"]
56   @client.authorization.refresh_token = oauth_yaml["refresh_token"]
57   @client.authorization.access_token = oauth_yaml["access_token"]
58   # -----------------------
59
60   @client.authorization.redirect_uri = to('/oauth2callback')
61
62   # Workaround for now as expires_in may be nil, but when converted to int it becomes 0.
63   @client.authorization.expires_in = 1800 if @client.authorization.expires_in.to_i == 0
64
65   if session[:token_id]
66     # Load the access token here if it's available
67     token_pair = TokenPair.get(session[:token_id])
68     @client.authorization.update_token!(token_pair.to_hash)
69   end
70   if @client.authorization.refresh_token && @client.authorization.expired?
71     @client.authorization.fetch_access_token!
72   end
73
74
75   @prediction = @client.discovered_api('prediction', 'v1.3')
76   unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/
77     redirect to('/oauth2authorize')
78   end
79 end
80
81 get '/oauth2authorize' do
82   redirect @client.authorization.authorization_uri.to_s, 303
83 end
84
85 get '/oauth2callback' do
86   @client.authorization.fetch_access_token!
87   # Persist the token here
88   token_pair = if session[:token_id]
89     TokenPair.get(session[:token_id])
90   else
91     TokenPair.new
92   end
93   token_pair.update_token!(@client.authorization)
94   token_pair.save()
95   session[:token_id] = token_pair.id
96   redirect to('/')
97 end
98
99 get '/' do
100   # FILL IN DATAFILE:
101   # ----------------------------------------
102   datafile = "BUCKET/OBJECT"
103   # ----------------------------------------
104   # Train a predictive model.
105   train(datafile)
106   # Check to make sure the training has completed.
107   if (is_done?(datafile))
108     # Do a prediction.
109     # FILL IN DESIRED INPUT:
110     # -------------------------------------------------------------------------------
111     # Note, the input features should match the features of the dataset.
112     prediction,score = get_prediction(datafile, ["Alice noticed with some surprise."])
113     # -------------------------------------------------------------------------------
114
115     # We currently just dump the results to output, but you can display them on the page if desired.
116     puts prediction
117     puts score    
118   end
119 end
120
121 ##
122 # Trains a predictive model.
123 #
124 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
125 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
126 #                 then the correct string is "bucket/object"
127 def train(datafile)
128   input = "{\"id\" : \"#{datafile}\"}"
129   puts "training input: #{input}"
130   result = @client.execute(:api_method => @prediction.training.insert,
131                            :merged_body => input,
132                            :headers => {'Content-Type' => 'application/json'}
133                            )
134   status, headers, body = result.response
135 end
136
137 ##
138 # Returns the current training status
139 #
140 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
141 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
142 #                 then the correct string is "bucket/object"
143 # @return [Integer] status The HTTP status code of the training job.
144 def get_training_status(datafile)
145   result = @client.execute(:api_method => @prediction.training.get,
146                            :parameters => {'data' => datafile})
147   status, headers, body = result.response
148   return status
149 end
150
151
152 ##
153 # Checks the training status until a model exists (will loop forever).
154 #
155 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
156 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
157 #                 then the correct string is "bucket/object"
158 # @return [Bool] exists True if model exists and can be used for predictions.
159
160 def is_done?(datafile)
161   status = get_training_status(datafile)
162   # We use an exponential backoff approach here.
163   test_counter = 0
164   while test_counter < 10 do
165     puts "Attempting to check model #{datafile} - Status: #{status} "
166     return true if status == 200
167     sleep 5 * (test_counter + 1)
168     status = get_training_status(datafile)
169     test_counter += 1
170   end
171   return false
172 end
173
174
175
176 ##
177 # Returns the prediction and most most likely class score if categorization.
178 #
179 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
180 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
181 #                 then the correct string is "bucket/object"
182 # @param [List] input_features A list of input features.
183 #
184 # @return [String or Double] prediction The returned prediction, String if categorization,
185 #                            Double if regression
186 # @return [Double] trueclass_score The numeric score of the most likely label. (Categorical only).
187
188 def get_prediction(datafile,input_features)
189   # We take the input features and put it in the right input (json) format.
190   input="{\"input\" : { \"csvInstance\" :  #{input_features}}}"
191   puts "Prediction Input: #{input}"
192   result = @client.execute(:api_method => @prediction.training.predict,
193                            :parameters => {'data' => datafile},
194                            :merged_body => input,
195                            :headers => {'Content-Type' => 'application/json'})
196   status, headers, body = result.response
197   prediction_data = result.data
198   puts status
199   puts body
200   puts prediction_data
201   # Categorical
202   if prediction_data["outputLabel"] != nil
203     # Pull the most likely label.
204     prediction = prediction_data["outputLabel"]
205     # Pull the class probabilities.
206     probs = prediction_data["outputMulti"]
207     puts probs
208     # Verify we are getting a value result.
209     puts ["ERROR", input_features].join("\t")  if probs.nil?
210     return "error", -1.0 if probs.nil?
211
212     # Extract the score for the most likely class.
213     trueclass_score = probs.select{|hash|
214       hash["label"] ==  prediction
215     }[0]["score"]
216
217     # Regression.
218   else
219     prediction = prediction_data["outputValue"]
220     # Class core unused.
221     trueclass_score = -1
222   end
223
224   puts [prediction,trueclass_score,input_features].join("\t") 
225   return prediction,trueclass_score
226 end
227