Use config parameter to fetch api token in KeepClient
[arvados.git] / sdk / java-v2 / src / main / java / org / arvados / client / logic / keep / KeepClient.java
1 /*
2  * Copyright (C) The Arvados Authors. All rights reserved.
3  *
4  * SPDX-License-Identifier: AGPL-3.0 OR Apache-2.0
5  *
6  */
7
8 package org.arvados.client.logic.keep;
9
10 import com.google.common.collect.Lists;
11 import org.apache.commons.codec.digest.DigestUtils;
12 import org.apache.commons.io.FileUtils;
13 import org.arvados.client.api.client.KeepServicesApiClient;
14 import org.arvados.client.api.model.KeepService;
15 import org.arvados.client.api.model.KeepServiceList;
16 import org.arvados.client.common.Characters;
17 import org.arvados.client.common.Headers;
18 import org.arvados.client.config.ConfigProvider;
19 import org.arvados.client.exception.ArvadosApiException;
20 import org.arvados.client.exception.ArvadosClientException;
21 import org.slf4j.Logger;
22
23 import java.io.File;
24 import java.io.IOException;
25 import java.util.ArrayList;
26 import java.util.HashMap;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Objects;
30 import java.util.concurrent.CompletableFuture;
31 import java.util.function.Function;
32 import java.util.stream.Collectors;
33 import java.util.stream.Stream;
34
35 public class KeepClient {
36
37     private final KeepServicesApiClient keepServicesApiClient;
38     private final Logger log = org.slf4j.LoggerFactory.getLogger(KeepClient.class);
39     private List<KeepService> keepServices;
40     private List<KeepService> writableServices;
41     private Map<String, KeepService> gatewayServices;
42     private Integer maxReplicasPerService;
43     private final ConfigProvider config;
44
45     public KeepClient(ConfigProvider config) {
46         this.config = config;
47         keepServicesApiClient = new KeepServicesApiClient(config);
48     }
49
50     public byte[] getDataChunk(KeepLocator keepLocator) {
51
52         Map<String, String> headers = new HashMap<>();
53         Map<String, FileTransferHandler> rootsMap = new HashMap<>();
54
55         List<String> sortedRoots = mapNewServices(rootsMap, keepLocator, false, false, headers);
56
57         byte[] dataChunk = sortedRoots
58                 .stream()
59                 .map(rootsMap::get)
60                 .map(r -> r.get(keepLocator))
61                 .filter(Objects::nonNull)
62                 .findFirst()
63                 .orElse(null);
64
65         if (dataChunk == null) {
66             throw new ArvadosClientException("No server responding. Unable to download data chunk.");
67         }
68
69         return dataChunk;
70     }
71
72     public String put(File data, int copies, int numRetries) {
73
74         byte[] fileBytes;
75         try {
76             fileBytes = FileUtils.readFileToByteArray(data);
77         } catch (IOException e) {
78             throw new ArvadosClientException("An error occurred while reading data chunk", e);
79         }
80
81         String dataHash = DigestUtils.md5Hex(fileBytes);
82         String locatorString = String.format("%s+%d", dataHash, data.length());
83
84         if (copies < 1) {
85             return locatorString;
86         }
87         KeepLocator locator = new KeepLocator(locatorString);
88
89         // Tell the proxy how many copies we want it to store
90         Map<String, String> headers = new HashMap<>();
91         headers.put(Headers.X_KEEP_DESIRED_REPLICAS, String.valueOf(copies));
92
93         Map<String, FileTransferHandler> rootsMap = new HashMap<>();
94         List<String> sortedRoots = mapNewServices(rootsMap, locator, false, true, headers);
95
96         int numThreads = 0;
97         if (maxReplicasPerService == null || maxReplicasPerService >= copies) {
98             numThreads = 1;
99         } else {
100             numThreads = ((Double) Math.ceil(1.0 * copies / maxReplicasPerService)).intValue();
101         }
102         log.debug("Pool max threads is {}", numThreads);
103
104         List<CompletableFuture<String>> futures = Lists.newArrayList();
105         for (int i = 0; i < numThreads; i++) {
106             String root = sortedRoots.get(i);
107             FileTransferHandler keepServiceLocal = rootsMap.get(root);
108             futures.add(CompletableFuture.supplyAsync(() -> keepServiceLocal.put(dataHash, data)));
109         }
110
111         @SuppressWarnings("unchecked")
112         CompletableFuture<String>[] array = futures.toArray(new CompletableFuture[0]);
113
114         return Stream.of(array)
115                 .map(CompletableFuture::join)
116                 .reduce((a, b) -> b)
117                 .orElse(null);
118     }
119
120     private List<String> mapNewServices(Map<String, FileTransferHandler> rootsMap, KeepLocator locator,
121                                         boolean forceRebuild, boolean needWritable, Map<String, String> headers) {
122
123         headers.putIfAbsent("Authorization", String.format("OAuth2 %s", config.getApiToken()));
124         List<String> localRoots = weightedServiceRoots(locator, forceRebuild, needWritable);
125         for (String root : localRoots) {
126             FileTransferHandler keepServiceLocal = new FileTransferHandler(root, headers, config);
127             rootsMap.putIfAbsent(root, keepServiceLocal);
128         }
129         return localRoots;
130     }
131
132     /**
133      * Return an array of Keep service endpoints, in the order in which they should be probed when reading or writing
134      * data with the given hash+hints.
135      */
136     private List<String> weightedServiceRoots(KeepLocator locator, boolean forceRebuild, boolean needWritable) {
137
138         buildServicesList(forceRebuild);
139
140         List<String> sortedRoots = new ArrayList<>();
141
142         // Use the services indicated by the given +K@... remote
143         // service hints, if any are present and can be resolved to a
144         // URI.
145         //
146         for (String hint : locator.getHints()) {
147             if (hint.startsWith("K@")) {
148                 if (hint.length() == 7) {
149                     sortedRoots.add(String.format("https://keep.%s.arvadosapi.com/", hint.substring(2)));
150                 } else if (hint.length() == 29) {
151                     KeepService svc = gatewayServices.get(hint.substring(2));
152                     if (svc != null) {
153                         sortedRoots.add(svc.getServiceRoot());
154                     }
155                 }
156             }
157         }
158
159         // Sort the available local services by weight (heaviest first)
160         // for this locator, and return their service_roots (base URIs)
161         // in that order.
162         List<KeepService> useServices = keepServices;
163         if (needWritable) {
164             useServices = writableServices;
165         }
166         anyNonDiskServices(useServices);
167
168         sortedRoots.addAll(useServices
169                 .stream()
170                 .sorted((ks1, ks2) -> serviceWeight(locator.getMd5sum(), ks2.getUuid())
171                         .compareTo(serviceWeight(locator.getMd5sum(), ks1.getUuid())))
172                 .map(KeepService::getServiceRoot)
173                 .collect(Collectors.toList()));
174
175         return sortedRoots;
176     }
177
178     private void buildServicesList(boolean forceRebuild) {
179         if (keepServices != null && !forceRebuild) {
180             return;
181         }
182         KeepServiceList keepServiceList;
183         try {
184             keepServiceList = keepServicesApiClient.accessible();
185         } catch (ArvadosApiException e) {
186             throw new ArvadosClientException("Cannot obtain list of accessible keep services");
187         }
188         // Gateway services are only used when specified by UUID,
189         // so there's nothing to gain by filtering them by
190         // service_type.
191         gatewayServices = keepServiceList.getItems().stream().collect(Collectors.toMap(KeepService::getUuid, Function.identity()));
192
193         if (gatewayServices.isEmpty()) {
194             throw new ArvadosClientException("No gateway services available!");
195         }
196
197         // Precompute the base URI for each service.
198         for (KeepService keepService : gatewayServices.values()) {
199             String serviceHost = keepService.getServiceHost();
200             if (!serviceHost.startsWith("[") && serviceHost.contains(Characters.COLON)) {
201                 // IPv6 URIs must be formatted like http://[::1]:80/...
202                 serviceHost = String.format("[%s]", serviceHost);
203             }
204
205             String protocol = keepService.getServiceSslFlag() ? "https" : "http";
206             String serviceRoot = String.format("%s://%s:%d/", protocol, serviceHost, keepService.getServicePort());
207             keepService.setServiceRoot(serviceRoot);
208         }
209
210         keepServices = gatewayServices.values().stream().filter(ks -> !ks.getServiceType().startsWith("gateway:")).collect(Collectors.toList());
211         writableServices = keepServices.stream().filter(ks -> !ks.getReadOnly()).collect(Collectors.toList());
212
213         // For disk type services, max_replicas_per_service is 1
214         // It is unknown (unlimited) for other service types.
215         if (anyNonDiskServices(writableServices)) {
216             maxReplicasPerService = null;
217         } else {
218             maxReplicasPerService = 1;
219         }
220     }
221
222     private Boolean anyNonDiskServices(List<KeepService> useServices) {
223         return useServices.stream().anyMatch(ks -> !ks.getServiceType().equals("disk"));
224     }
225
226     /**
227      * Compute the weight of a Keep service endpoint for a data block with a known hash.
228      * <p>
229      * The weight is md5(h + u) where u is the last 15 characters of the service endpoint's UUID.
230      */
231     private static String serviceWeight(String dataHash, String serviceUuid) {
232         String shortenedUuid;
233         if (serviceUuid != null && serviceUuid.length() >= 15) {
234             int substringIndex = serviceUuid.length() - 15;
235             shortenedUuid = serviceUuid.substring(substringIndex);
236         } else {
237             shortenedUuid = (serviceUuid == null) ? "" : serviceUuid;
238         }
239         return DigestUtils.md5Hex(dataHash + shortenedUuid);
240     }
241
242 }