Merge pull request #2 from robertkaplow/master
[arvados.git] / examples / prediction / prediction.rb
1 #!/usr/bin/ruby1.8 
2 # -*- coding: utf-8 -*-
3
4 # Copyright:: Copyright 2011 Google Inc.
5 # License:: All Rights Reserved.
6 # Original Author:: Bob Aman, Winton Davies, Robert Kaplow
7 # Maintainer:: Robert Kaplow (mailto:rkaplow@google.com)
8
9 $:.unshift('lib')
10 require 'rubygems'
11 require 'sinatra'
12 require 'datamapper'
13 require 'google/api_client'
14 require 'yaml'
15
16 use Rack::Session::Pool, :expire_after => 86400 # 1 day
17
18 # Set up our token store
19 DataMapper.setup(:default, 'sqlite::memory:')
20 class TokenPair
21   include DataMapper::Resource
22
23   property :id, Serial
24   property :refresh_token, String
25   property :access_token, String
26   property :expires_in, Integer
27   property :issued_at, Integer
28
29   def update_token!(object)
30     self.refresh_token = object.refresh_token
31     self.access_token = object.access_token
32     self.expires_in = object.expires_in
33     self.issued_at = object.issued_at
34   end
35
36   def to_hash
37     return {
38       :refresh_token => refresh_token,
39       :access_token => access_token,
40       :expires_in => expires_in,
41       :issued_at => Time.at(issued_at)
42     }
43   end
44 end
45 TokenPair.auto_migrate!
46
47 before do
48
49   # FILL IN THIS SECTION
50   # This will work if your yaml file is stored as ./google-api.yaml
51   # ------------------------
52   oauth_yaml = YAML.load_file('.google-api.yaml')
53   @client = Google::APIClient.new
54   @client.authorization.client_id = oauth_yaml["client_id"]
55   @client.authorization.client_secret = oauth_yaml["client_secret"]
56   @client.authorization.scope = oauth_yaml["scope"]
57   @client.authorization.refresh_token = oauth_yaml["refresh_token"]
58   @client.authorization.access_token = oauth_yaml["access_token"]
59   # -----------------------
60
61   @client.authorization.redirect_uri = to('/oauth2callback')
62
63   # Workaround for now as expires_in may be nil, but when converted to int it becomes 0.
64   @client.authorization.expires_in = Time.now + 1800 if @client.authorization.expires_in.to_i == 0
65
66   if session[:token_id]
67     # Load the access token here if it's available
68     token_pair = TokenPair.get(session[:token_id])
69     @client.authorization.update_token!(token_pair.to_hash)
70   end
71   if @client.authorization.refresh_token && @client.authorization.expired?
72     @client.authorization.fetch_access_token!
73   end
74
75
76   @prediction = @client.discovered_api('prediction', 'v1.3')
77   unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/
78     redirect to('/oauth2authorize')
79   end
80 end
81
82 get '/oauth2authorize' do
83   redirect @client.authorization.authorization_uri.to_s, 303
84 end
85
86 get '/oauth2callback' do
87   @client.authorization.fetch_access_token!
88   # Persist the token here
89   token_pair = if session[:token_id]
90     TokenPair.get(session[:token_id])
91   else
92     TokenPair.new
93   end
94   token_pair.update_token!(@client.authorization)
95   token_pair.save()
96   session[:token_id] = token_pair.id
97   redirect to('/')
98 end
99
100 get '/' do
101   # FILL IN DATAFILE:
102   # ----------------------------------------
103   datafile = "BUCKET/OBJECT"
104   # ----------------------------------------
105   # Train a predictive model.
106   train(datafile)
107   # Check to make sure the training has completed.
108   if (is_done?(datafile))
109     # Do a prediction.
110     # FILL IN DESIRED INPUT:
111     # -------------------------------------------------------------------------------
112     prediction,score = get_prediction(datafile, ["Alice noticed with some surprise."])
113     # -------------------------------------------------------------------------------
114
115     # We currently just dump the results to output, but you can display them on the page if desired.
116     puts prediction
117     puts score    
118   end
119 end
120
121 ##
122 # Trains a predictive model.
123 #
124 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
125 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
126 #                 then the correct string is "bucket/object"
127 def train(datafile)
128   input = "{\"id\" : \"#{datafile}\"}"
129   puts "training input: #{input}"
130   status, headers, body = @client.execute(@prediction.training.insert,
131                                           {},
132                                           input,
133                                           {'Content-Type' => 'application/json'})
134 end
135
136 ##
137 # Returns the current training status
138 #
139 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
140 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
141 #                 then the correct string is "bucket/object"
142 # @return [Integer] status The HTTP status code of the training job.
143 def get_training_status(datafile)
144   status, headers, body = @client.execute(@prediction.training.get,
145                                           {'data' => datafile})
146   return status
147 end
148
149
150 ##
151 # Checks the training status until a model exists (will loop forever).
152 #
153 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
154 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
155 #                 then the correct string is "bucket/object"
156 # @return [Bool] exists True if model exists and can be used for predictions.
157
158 def is_done?(datafile)
159   status = get_training_status(datafile)
160   while true do
161     puts "Attempting to check model #{datafile} - Status: #{status} "
162     return true if status == 200
163     sleep 10
164     status = get_training_status(datafile)
165   end
166   return false
167 end
168
169
170
171 ##
172 # Returns the prediction and most most likely class score if categorization.
173 #
174 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
175 #                 include the gs:// part. If the Google Storage path is gs://bucket/object,
176 #                 then the correct string is "bucket/object"
177 # @param [List] input_features A list of input features.
178 #
179 # @return [String or Double] prediction The returned prediction, String if categorization,
180 #                            Double if regression
181 # @return [Double] trueclass_score The numeric score of the most likely label. (Categorical only).
182
183 def get_prediction(datafile,input_features)
184   # We take the input features and put it in the right input (json) format.
185   input="{\"input\" : { \"csvInstance\" :  #{input_features}}}"
186   puts "Prediction Input: #{input}"
187   status, headers, body = @client.execute(@prediction.training.predict,
188                                                      {'data' => datafile},
189                                                      input,
190                                                      {'Content-Type' => 'application/json'})
191   prediction_data = JSON.parse(body[0])
192   
193   # Categorical
194   if prediction_data["outputLabel"] != nil
195     # Pull the most likely label.
196     prediction = prediction_data["outputLabel"]
197     # Pull the class probabilities.
198     probs = prediction_data["outputMulti"]
199     puts probs
200     # Verify we are getting a value result.
201     puts ["ERROR", input_features].join("\t")  if probs.nil?
202     return "error", -1.0 if probs.nil?
203
204     # Extract the score for the most likely class.
205     trueclass_score = probs.select{|hash|
206       hash["label"] ==  prediction
207     }[0]["score"]
208
209     # Regression.
210   else
211     prediction = prediction_data["outputValue"]
212     # Class core unused.
213     trueclass_score = -1
214   end
215
216   puts [prediction,trueclass_score,input_features].join("\t") 
217   return prediction,trueclass_score
218 end
219