19699: Report download done, don't try to stage deferred downloads
[arvados.git] / sdk / cwl / arvados_cwl / http.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import division
6 from future import standard_library
7 standard_library.install_aliases()
8
9 import requests
10 import email.utils
11 import time
12 import datetime
13 import re
14 import arvados
15 import arvados.collection
16 import urllib.parse
17 import logging
18 import calendar
19 import urllib.parse
20
21 logger = logging.getLogger('arvados.cwl-runner')
22
23 def my_formatdate(dt):
24     return email.utils.formatdate(timeval=calendar.timegm(dt.timetuple()),
25                                   localtime=False, usegmt=True)
26
27 def my_parsedate(text):
28     parsed = email.utils.parsedate_tz(text)
29     if parsed:
30         if parsed[9]:
31             # Adjust to UTC
32             return datetime.datetime(*parsed[:6]) + datetime.timedelta(seconds=parsed[9])
33         else:
34             # TZ is zero or missing, assume UTC.
35             return datetime.datetime(*parsed[:6])
36     else:
37         return datetime.datetime(1970, 1, 1)
38
39 def fresh_cache(url, properties, now):
40     pr = properties[url]
41     expires = None
42
43     logger.debug("Checking cache freshness for %s using %s", url, pr)
44
45     if "Cache-Control" in pr:
46         if re.match(r"immutable", pr["Cache-Control"]):
47             return True
48
49         g = re.match(r"(s-maxage|max-age)=(\d+)", pr["Cache-Control"])
50         if g:
51             expires = my_parsedate(pr["Date"]) + datetime.timedelta(seconds=int(g.group(2)))
52
53     if expires is None and "Expires" in pr:
54         expires = my_parsedate(pr["Expires"])
55
56     if expires is None:
57         # Use a default cache time of 24 hours if upstream didn't set
58         # any cache headers, to reduce redundant downloads.
59         expires = my_parsedate(pr["Date"]) + datetime.timedelta(hours=24)
60
61     if not expires:
62         return False
63
64     return (now < expires)
65
66 def remember_headers(url, properties, headers, now):
67     properties.setdefault(url, {})
68     for h in ("Cache-Control", "ETag", "Expires", "Date", "Content-Length"):
69         if h in headers:
70             properties[url][h] = headers[h]
71     if "Date" not in headers:
72         properties[url]["Date"] = my_formatdate(now)
73
74
75 def changed(url, properties, now):
76     req = requests.head(url, allow_redirects=True)
77     remember_headers(url, properties, req.headers, now)
78
79     if req.status_code != 200:
80         # Sometimes endpoints are misconfigured and will deny HEAD but
81         # allow GET so instead of failing here, we'll try GET If-None-Match
82         return True
83
84     pr = properties[url]
85     if "ETag" in pr and "ETag" in req.headers:
86         if pr["ETag"] == req.headers["ETag"]:
87             return False
88
89     return True
90
91 def http_to_keep(api, project_uuid, url, utcnow=datetime.datetime.utcnow):
92     r = api.collections().list(filters=[["properties", "exists", url]]).execute()
93
94     now = utcnow()
95
96     etags = {}
97
98     for item in r["items"]:
99         properties = item["properties"]
100         if fresh_cache(url, properties, now):
101             # Do nothing
102             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
103             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
104
105         if not changed(url, properties, now):
106             # ETag didn't change, same content, just update headers
107             api.collections().update(uuid=item["uuid"], body={"collection":{"properties": properties}}).execute()
108             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
109             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
110
111         if "ETag" in properties:
112             etags[properties["ETag"]] = item
113
114     properties = {}
115     headers = {}
116     if etags:
117         headers['If-None-Match'] = ', '.join(['"%s"' % k for k,v in etags.items()])
118     req = requests.get(url, stream=True, allow_redirects=True, headers=headers)
119
120     if req.status_code not in (200, 304):
121         raise Exception("Failed to download '%s' got status %s " % (url, req.status_code))
122
123     remember_headers(url, properties, req.headers, now)
124
125     if req.status_code == 304 and "ETag" in req.headers and req.headers["ETag"] in etags:
126         item = etags[req.headers["ETag"]]
127         item["properties"].update(properties)
128         api.collections().update(uuid=item["uuid"], body={"collection":{"properties": item["properties"]}}).execute()
129         cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
130         return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
131
132     if "Content-Length" in properties[url]:
133         cl = int(properties[url]["Content-Length"])
134         logger.info("Downloading %s (%s bytes)", url, cl)
135     else:
136         cl = None
137         logger.info("Downloading %s (unknown size)", url)
138
139     c = arvados.collection.Collection()
140
141     if req.headers.get("Content-Disposition"):
142         grp = re.search(r'filename=("((\"|[^"])+)"|([^][()<>@,;:\"/?={} ]+))', req.headers["Content-Disposition"])
143         if grp.group(2):
144             name = grp.group(2)
145         else:
146             name = grp.group(4)
147     else:
148         name = urllib.parse.urlparse(url).path.split("/")[-1]
149
150     count = 0
151     start = time.time()
152     checkpoint = start
153     with c.open(name, "wb") as f:
154         for chunk in req.iter_content(chunk_size=1024):
155             count += len(chunk)
156             f.write(chunk)
157             loopnow = time.time()
158             if (loopnow - checkpoint) > 20:
159                 bps = count / (loopnow - start)
160                 if cl is not None:
161                     logger.info("%2.1f%% complete, %3.2f MiB/s, %1.0f seconds left",
162                                 ((count * 100) / cl),
163                                 (bps // (1024*1024)),
164                                 ((cl-count) // bps))
165                 else:
166                     logger.info("%d downloaded, %3.2f MiB/s", count, (bps / (1024*1024)))
167                 checkpoint = loopnow
168
169     logger.info("Download complete")
170
171     collectionname = "Downloaded from %s" % urllib.parse.quote(url, safe='')
172
173     # max length - space to add a timestamp used by ensure_unique_name
174     max_name_len = 254 - 28
175
176     if len(collectionname) > max_name_len:
177         over = len(collectionname) - max_name_len
178         split = int(max_name_len/2)
179         collectionname = collectionname[0:split] + "…" + collectionname[split+over:]
180
181     c.save_new(name=collectionname, owner_uuid=project_uuid, ensure_unique_name=True)
182
183     api.collections().update(uuid=c.manifest_locator(), body={"collection":{"properties": properties}}).execute()
184
185     return "keep:%s/%s" % (c.portable_data_hash(), name)