2 # -*- coding: utf-8 -*-
4 # Copyright:: Copyright 2011 Google Inc.
5 # License:: All Rights Reserved.
6 # Original Author:: Bob Aman, Winton Davies, Robert Kaplow
7 # Maintainer:: Robert Kaplow (mailto:rkaplow@google.com)
12 require 'google/api_client'
15 use Rack::Session::Pool, :expire_after => 86400 # 1 day
17 # Set up our token store
18 DataMapper.setup(:default, 'sqlite::memory:')
20 include DataMapper::Resource
23 property :refresh_token, String
24 property :access_token, String
25 property :expires_in, Integer
26 property :issued_at, Integer
28 def update_token!(object)
29 self.refresh_token = object.refresh_token
30 self.access_token = object.access_token
31 self.expires_in = object.expires_in
32 self.issued_at = object.issued_at
37 :refresh_token => refresh_token,
38 :access_token => access_token,
39 :expires_in => expires_in,
40 :issued_at => Time.at(issued_at)
44 TokenPair.auto_migrate!
48 # FILL IN THIS SECTION
49 # This will work if your yaml file is stored as ./google-api.yaml
50 # ------------------------
51 oauth_yaml = YAML.load_file('.google-api.yaml')
52 @client = Google::APIClient.new
53 @client.authorization.client_id = oauth_yaml["client_id"]
54 @client.authorization.client_secret = oauth_yaml["client_secret"]
55 @client.authorization.scope = oauth_yaml["scope"]
56 @client.authorization.refresh_token = oauth_yaml["refresh_token"]
57 @client.authorization.access_token = oauth_yaml["access_token"]
58 # -----------------------
60 @client.authorization.redirect_uri = to('/oauth2callback')
62 # Workaround for now as expires_in may be nil, but when converted to int it becomes 0.
63 @client.authorization.expires_in = 1800 if @client.authorization.expires_in.to_i == 0
66 # Load the access token here if it's available
67 token_pair = TokenPair.get(session[:token_id])
68 @client.authorization.update_token!(token_pair.to_hash)
70 if @client.authorization.refresh_token && @client.authorization.expired?
71 @client.authorization.fetch_access_token!
75 @prediction = @client.discovered_api('prediction', 'v1.3')
76 unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/
77 redirect to('/oauth2authorize')
81 get '/oauth2authorize' do
82 redirect @client.authorization.authorization_uri.to_s, 303
85 get '/oauth2callback' do
86 @client.authorization.fetch_access_token!
87 # Persist the token here
88 token_pair = if session[:token_id]
89 TokenPair.get(session[:token_id])
93 token_pair.update_token!(@client.authorization)
95 session[:token_id] = token_pair.id
101 # ----------------------------------------
102 datafile = "BUCKET/OBJECT"
103 # ----------------------------------------
104 # Train a predictive model.
106 # Check to make sure the training has completed.
107 if (is_done?(datafile))
109 # FILL IN DESIRED INPUT:
110 # -------------------------------------------------------------------------------
111 # Note, the input features should match the features of the dataset.
112 prediction,score = get_prediction(datafile, ["Alice noticed with some surprise."])
113 # -------------------------------------------------------------------------------
115 # We currently just dump the results to output, but you can display them on the page if desired.
122 # Trains a predictive model.
124 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
125 # include the gs:// part. If the Google Storage path is gs://bucket/object,
126 # then the correct string is "bucket/object"
128 input = "{\"id\" : \"#{datafile}\"}"
129 puts "training input: #{input}"
130 result = @client.execute(:api_method => @prediction.training.insert,
131 :merged_body => input,
132 :headers => {'Content-Type' => 'application/json'}
134 status, headers, body = result.response
138 # Returns the current training status
140 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
141 # include the gs:// part. If the Google Storage path is gs://bucket/object,
142 # then the correct string is "bucket/object"
143 # @return [Integer] status The HTTP status code of the training job.
144 def get_training_status(datafile)
145 result = @client.execute(:api_method => @prediction.training.get,
146 :parameters => {'data' => datafile})
147 status, headers, body = result.response
153 # Checks the training status until a model exists (will loop forever).
155 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
156 # include the gs:// part. If the Google Storage path is gs://bucket/object,
157 # then the correct string is "bucket/object"
158 # @return [Bool] exists True if model exists and can be used for predictions.
160 def is_done?(datafile)
161 status = get_training_status(datafile)
162 # We use an exponential backoff approach here.
164 while test_counter < 10 do
165 puts "Attempting to check model #{datafile} - Status: #{status} "
166 return true if status == 200
167 sleep 5 * (test_counter + 1)
168 status = get_training_status(datafile)
177 # Returns the prediction and most most likely class score if categorization.
179 # @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
180 # include the gs:// part. If the Google Storage path is gs://bucket/object,
181 # then the correct string is "bucket/object"
182 # @param [List] input_features A list of input features.
184 # @return [String or Double] prediction The returned prediction, String if categorization,
185 # Double if regression
186 # @return [Double] trueclass_score The numeric score of the most likely label. (Categorical only).
188 def get_prediction(datafile,input_features)
189 # We take the input features and put it in the right input (json) format.
190 input="{\"input\" : { \"csvInstance\" : #{input_features}}}"
191 puts "Prediction Input: #{input}"
192 result = @client.execute(:api_method => @prediction.training.predict,
193 :parameters => {'data' => datafile},
194 :merged_body => input,
195 :headers => {'Content-Type' => 'application/json'})
196 status, headers, body = result.response
197 prediction_data = result.data
202 if prediction_data["outputLabel"] != nil
203 # Pull the most likely label.
204 prediction = prediction_data["outputLabel"]
205 # Pull the class probabilities.
206 probs = prediction_data["outputMulti"]
208 # Verify we are getting a value result.
209 puts ["ERROR", input_features].join("\t") if probs.nil?
210 return "error", -1.0 if probs.nil?
212 # Extract the score for the most likely class.
213 trueclass_score = probs.select{|hash|
214 hash["label"] == prediction
219 prediction = prediction_data["outputValue"]
224 puts [prediction,trueclass_score,input_features].join("\t")
225 return prediction,trueclass_score