33aa098845f4f45561f2768b14e3cd0c17ae9131
[arvados.git] / sdk / cwl / arvados_cwl / http.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import division
6 from future import standard_library
7 standard_library.install_aliases()
8
9 import requests
10 import email.utils
11 import time
12 import datetime
13 import re
14 import arvados
15 import arvados.collection
16 import urllib.parse
17 import logging
18 import calendar
19 import urllib.parse
20
21 logger = logging.getLogger('arvados.cwl-runner')
22
23 def my_formatdate(dt):
24     return email.utils.formatdate(timeval=calendar.timegm(dt.timetuple()),
25                                   localtime=False, usegmt=True)
26
27 def my_parsedate(text):
28     parsed = email.utils.parsedate_tz(text)
29     if parsed:
30         if parsed[9]:
31             # Adjust to UTC
32             return datetime.datetime(*parsed[:6]) + datetime.timedelta(seconds=parsed[9])
33         else:
34             # TZ is zero or missing, assume UTC.
35             return datetime.datetime(*parsed[:6])
36     else:
37         return datetime.datetime(1970, 1, 1)
38
39 def fresh_cache(url, properties, now):
40     pr = properties[url]
41     expires = None
42
43     logger.debug("Checking cache freshness for %s using %s", url, pr)
44
45     if "Cache-Control" in pr:
46         if re.match(r"immutable", pr["Cache-Control"]):
47             return True
48
49         g = re.match(r"(s-maxage|max-age)=(\d+)", pr["Cache-Control"])
50         if g:
51             expires = my_parsedate(pr["Date"]) + datetime.timedelta(seconds=int(g.group(2)))
52
53     if expires is None and "Expires" in pr:
54         expires = my_parsedate(pr["Expires"])
55
56     if expires is None:
57         # Use a default cache time of 24 hours if upstream didn't set
58         # any cache headers, to reduce redundant downloads.
59         expires = my_parsedate(pr["Date"]) + datetime.timedelta(hours=24)
60
61     if not expires:
62         return False
63
64     return (now < expires)
65
66 def remember_headers(url, properties, headers, now):
67     properties.setdefault(url, {})
68     for h in ("Cache-Control", "ETag", "Expires", "Date", "Content-Length"):
69         if h in headers:
70             properties[url][h] = headers[h]
71     if "Date" not in headers:
72         properties[url]["Date"] = my_formatdate(now)
73
74
75 def changed(url, properties, now):
76     req = requests.head(url, allow_redirects=True)
77     remember_headers(url, properties, req.headers, now)
78
79     if req.status_code != 200:
80         # Sometimes endpoints are misconfigured and will deny HEAD but
81         # allow GET so instead of failing here, we'll try GET If-None-Match
82         return True
83
84     pr = properties[url]
85     if "ETag" in pr and "ETag" in req.headers:
86         if pr["ETag"] == req.headers["ETag"]:
87             return False
88
89     return True
90
91 def etag_quote(etag):
92     # if it already has leading and trailing quotes, do nothing
93     if etag[0] == '"' and etag[-1] == '"':
94         return etag
95     else:
96         # Add quotes.
97         return '"' + etag + '"'
98
99 def http_to_keep(api, project_uuid, url, utcnow=datetime.datetime.utcnow):
100     r = api.collections().list(filters=[["properties", "exists", url]]).execute()
101
102     now = utcnow()
103
104     etags = {}
105
106     for item in r["items"]:
107         properties = item["properties"]
108         if fresh_cache(url, properties, now):
109             # Do nothing
110             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
111             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
112
113         if not changed(url, properties, now):
114             # ETag didn't change, same content, just update headers
115             api.collections().update(uuid=item["uuid"], body={"collection":{"properties": properties}}).execute()
116             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
117             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
118
119         if "ETag" in properties and len(properties["ETag"]) > 2:
120             etags[properties["ETag"]] = item
121
122     properties = {}
123     headers = {}
124     if etags:
125         headers['If-None-Match'] = ', '.join([etag_quote(k) for k,v in etags.items()])
126     req = requests.get(url, stream=True, allow_redirects=True, headers=headers)
127
128     if req.status_code not in (200, 304):
129         raise Exception("Failed to download '%s' got status %s " % (url, req.status_code))
130
131     remember_headers(url, properties, req.headers, now)
132
133     if req.status_code == 304 and "ETag" in req.headers and req.headers["ETag"] in etags:
134         item = etags[req.headers["ETag"]]
135         item["properties"].update(properties)
136         api.collections().update(uuid=item["uuid"], body={"collection":{"properties": item["properties"]}}).execute()
137         cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
138         return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
139
140     if "Content-Length" in properties[url]:
141         cl = int(properties[url]["Content-Length"])
142         logger.info("Downloading %s (%s bytes)", url, cl)
143     else:
144         cl = None
145         logger.info("Downloading %s (unknown size)", url)
146
147     c = arvados.collection.Collection()
148
149     if req.headers.get("Content-Disposition"):
150         grp = re.search(r'filename=("((\"|[^"])+)"|([^][()<>@,;:\"/?={} ]+))', req.headers["Content-Disposition"])
151         if grp.group(2):
152             name = grp.group(2)
153         else:
154             name = grp.group(4)
155     else:
156         name = urllib.parse.urlparse(url).path.split("/")[-1]
157
158     count = 0
159     start = time.time()
160     checkpoint = start
161     with c.open(name, "wb") as f:
162         for chunk in req.iter_content(chunk_size=1024):
163             count += len(chunk)
164             f.write(chunk)
165             loopnow = time.time()
166             if (loopnow - checkpoint) > 20:
167                 bps = count / (loopnow - start)
168                 if cl is not None:
169                     logger.info("%2.1f%% complete, %3.2f MiB/s, %1.0f seconds left",
170                                 ((count * 100) / cl),
171                                 (bps // (1024*1024)),
172                                 ((cl-count) // bps))
173                 else:
174                     logger.info("%d downloaded, %3.2f MiB/s", count, (bps / (1024*1024)))
175                 checkpoint = loopnow
176
177     logger.info("Download complete")
178
179     collectionname = "Downloaded from %s" % urllib.parse.quote(url, safe='')
180
181     # max length - space to add a timestamp used by ensure_unique_name
182     max_name_len = 254 - 28
183
184     if len(collectionname) > max_name_len:
185         over = len(collectionname) - max_name_len
186         split = int(max_name_len/2)
187         collectionname = collectionname[0:split] + "…" + collectionname[split+over:]
188
189     c.save_new(name=collectionname, owner_uuid=project_uuid, ensure_unique_name=True)
190
191     api.collections().update(uuid=c.manifest_locator(), body={"collection":{"properties": properties}}).execute()
192
193     return "keep:%s/%s" % (c.portable_data_hash(), name)