Merge branch '18004-cached-token-race-condition' into main. Closes #18004
[arvados.git] / sdk / cwl / arvados_cwl / http.py
1 # Copyright (C) The Arvados Authors. All rights reserved.
2 #
3 # SPDX-License-Identifier: Apache-2.0
4
5 from __future__ import division
6 from future import standard_library
7 standard_library.install_aliases()
8
9 import requests
10 import email.utils
11 import time
12 import datetime
13 import re
14 import arvados
15 import arvados.collection
16 import urllib.parse
17 import logging
18 import calendar
19 import urllib.parse
20
21 logger = logging.getLogger('arvados.cwl-runner')
22
23 def my_formatdate(dt):
24     return email.utils.formatdate(timeval=calendar.timegm(dt.timetuple()),
25                                   localtime=False, usegmt=True)
26
27 def my_parsedate(text):
28     parsed = email.utils.parsedate_tz(text)
29     if parsed:
30         if parsed[9]:
31             # Adjust to UTC
32             return datetime.datetime(*parsed[:6]) + datetime.timedelta(seconds=parsed[9])
33         else:
34             # TZ is zero or missing, assume UTC.
35             return datetime.datetime(*parsed[:6])
36     else:
37         return datetime.datetime(1970, 1, 1)
38
39 def fresh_cache(url, properties, now):
40     pr = properties[url]
41     expires = None
42
43     logger.debug("Checking cache freshness for %s using %s", url, pr)
44
45     if "Cache-Control" in pr:
46         if re.match(r"immutable", pr["Cache-Control"]):
47             return True
48
49         g = re.match(r"(s-maxage|max-age)=(\d+)", pr["Cache-Control"])
50         if g:
51             expires = my_parsedate(pr["Date"]) + datetime.timedelta(seconds=int(g.group(2)))
52
53     if expires is None and "Expires" in pr:
54         expires = my_parsedate(pr["Expires"])
55
56     if expires is None:
57         # Use a default cache time of 24 hours if upstream didn't set
58         # any cache headers, to reduce redundant downloads.
59         expires = my_parsedate(pr["Date"]) + datetime.timedelta(hours=24)
60
61     if not expires:
62         return False
63
64     return (now < expires)
65
66 def remember_headers(url, properties, headers, now):
67     properties.setdefault(url, {})
68     for h in ("Cache-Control", "ETag", "Expires", "Date", "Content-Length"):
69         if h in headers:
70             properties[url][h] = headers[h]
71     if "Date" not in headers:
72         properties[url]["Date"] = my_formatdate(now)
73
74
75 def changed(url, properties, now):
76     req = requests.head(url, allow_redirects=True)
77     remember_headers(url, properties, req.headers, now)
78
79     if req.status_code != 200:
80         raise Exception("Got status %s" % req.status_code)
81
82     pr = properties[url]
83     if "ETag" in pr and "ETag" in req.headers:
84         if pr["ETag"] == req.headers["ETag"]:
85             return False
86
87     return True
88
89 def http_to_keep(api, project_uuid, url, utcnow=datetime.datetime.utcnow):
90     r = api.collections().list(filters=[["properties", "exists", url]]).execute()
91
92     now = utcnow()
93
94     for item in r["items"]:
95         properties = item["properties"]
96         if fresh_cache(url, properties, now):
97             # Do nothing
98             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
99             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
100
101         if not changed(url, properties, now):
102             # ETag didn't change, same content, just update headers
103             api.collections().update(uuid=item["uuid"], body={"collection":{"properties": properties}}).execute()
104             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
105             return "keep:%s/%s" % (item["portable_data_hash"], list(cr.keys())[0])
106
107     properties = {}
108     req = requests.get(url, stream=True, allow_redirects=True)
109
110     if req.status_code != 200:
111         raise Exception("Failed to download '%s' got status %s " % (url, req.status_code))
112
113     remember_headers(url, properties, req.headers, now)
114
115     if "Content-Length" in properties[url]:
116         cl = int(properties[url]["Content-Length"])
117         logger.info("Downloading %s (%s bytes)", url, cl)
118     else:
119         cl = None
120         logger.info("Downloading %s (unknown size)", url)
121
122     c = arvados.collection.Collection()
123
124     if req.headers.get("Content-Disposition"):
125         grp = re.search(r'filename=("((\"|[^"])+)"|([^][()<>@,;:\"/?={} ]+))', req.headers["Content-Disposition"])
126         if grp.group(2):
127             name = grp.group(2)
128         else:
129             name = grp.group(4)
130     else:
131         name = urllib.parse.urlparse(url).path.split("/")[-1]
132
133     count = 0
134     start = time.time()
135     checkpoint = start
136     with c.open(name, "wb") as f:
137         for chunk in req.iter_content(chunk_size=1024):
138             count += len(chunk)
139             f.write(chunk)
140             loopnow = time.time()
141             if (loopnow - checkpoint) > 20:
142                 bps = count / (loopnow - start)
143                 if cl is not None:
144                     logger.info("%2.1f%% complete, %3.2f MiB/s, %1.0f seconds left",
145                                 ((count * 100) / cl),
146                                 (bps // (1024*1024)),
147                                 ((cl-count) // bps))
148                 else:
149                     logger.info("%d downloaded, %3.2f MiB/s", count, (bps / (1024*1024)))
150                 checkpoint = loopnow
151
152
153     collectionname = "Downloaded from %s" % urllib.parse.quote(url, safe='')
154     c.save_new(name=collectionname, owner_uuid=project_uuid, ensure_unique_name=True)
155
156     api.collections().update(uuid=c.manifest_locator(), body={"collection":{"properties": properties}}).execute()
157
158     return "keep:%s/%s" % (c.portable_data_hash(), name)