11162: Smarter http downloads.
[arvados.git] / sdk / cwl / arvados_cwl / http.py
1 import requests
2 import email.utils
3 import time
4 import datetime
5 import re
6 import arvados
7 import arvados.collection
8 import urlparse
9 import logging
10
11 logger = logging.getLogger('arvados.cwl-runner')
12
13 def my_formatdate(dt):
14     return email.utils.formatdate(timeval=time.mktime(now.timetuple()), localtime=False, usegmt=True)
15
16 def my_parsedate(text):
17     parsed = email.utils.parsedate(text)
18     if parsed:
19         return datetime.datetime(*parsed[:6])
20     else:
21         datetime.datetime(1970, 1, 1)
22
23 def fresh_cache(url, properties):
24     pr = properties[url]
25     expires = None
26
27     logger.debug("Checking cache freshness for %s using %s", url, pr)
28
29     if "Cache-Control" in pr:
30         if re.match(r"immutable", pr["Cache-Control"]):
31             return True
32
33         g = re.match(r"(s-maxage|max-age)=(\d+)", pr["Cache-Control"])
34         if g:
35             expires = my_parsedate(pr["Date"]) + datetime.timedelta(seconds=int(g.group(2)))
36
37     if expires is None and "Expires" in pr:
38         expires = my_parsedate(pr["Expires"])
39
40     if expires is None:
41         # Use a default cache time of 24 hours if upstream didn't set
42         # any cache headers, to reduce redundant downloads.
43         expires = my_parsedate(pr["Date"]) + datetime.timedelta(hours=24)
44
45     if not expires:
46         return False
47
48     return (datetime.datetime.utcnow() < expires)
49
50 def remember_headers(url, properties, headers):
51     properties.setdefault(url, {})
52     for h in ("Cache-Control", "ETag", "Expires", "Date", "Content-Length"):
53         if h in headers:
54             properties[url][h] = headers[h]
55     if "Date" not in headers:
56         properties[url]["Date"] = my_formatdate(datetime.datetime.utcnow())
57
58
59 def changed(url, properties):
60     req = requests.head(url, allow_redirects=True)
61     remember_headers(url, properties, req.headers)
62
63     if req.status_code != 200:
64         raise Exception("Got status %s" % req.status_code)
65
66     pr = properties[url]
67     if "ETag" in pr and "ETag" in req.headers:
68         if pr["ETag"] == req.headers["ETag"]:
69             return False
70     return True
71
72 def http_to_keep(api, project_uuid, url):
73     r = api.collections().list(filters=[["properties", "exists", url]]).execute()
74
75     for item in r["items"]:
76         properties = item["properties"]
77         if fresh_cache(url, properties):
78             # Do nothing
79             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
80             return "keep:%s/%s" % (item["portable_data_hash"], cr.keys()[0])
81
82         if not changed(url, properties):
83             # ETag didn't change, same content, just update headers
84             api.collections().update(uuid=item["uuid"], body={"collection":{"properties": properties}}).execute()
85             cr = arvados.collection.CollectionReader(item["portable_data_hash"], api_client=api)
86             return "keep:%s/%s" % (item["portable_data_hash"], cr.keys()[0])
87
88     properties = {}
89     req = requests.get(url, stream=True, allow_redirects=True)
90
91     if req.status_code != 200:
92         raise Exception("Failed to download '%s' got status %s " % (req.status_code, url))
93
94     remember_headers(url, properties, req.headers)
95
96     logger.info("Downloading %s (%s bytes)", url, properties[url]["Content-Length"])
97
98     c = arvados.collection.Collection()
99
100     if req.headers.get("Content-Disposition"):
101         grp = re.search(r'filename=("((\"|[^"])+)"|([^][()<>@,;:\"/?={} ]+))', req.headers["Content-Disposition"])
102         if grp.groups(2):
103             name = grp.groups(2)
104         else:
105             name = grp.groups(3)
106     else:
107         name = urlparse.urlparse(url).path.split("/")[-1]
108
109     count = 0
110     start = time.time()
111     checkpoint = start
112     with c.open(name, "w") as f:
113         for chunk in req.iter_content(chunk_size=1024):
114             count += len(chunk)
115             f.write(chunk)
116             now = time.time()
117             if (now - checkpoint) > 20:
118                 bps = (float(count)/float(now - start))
119                 logger.info("%2.1f%% complete, %3.2f MiB/s, %1.0f seconds left",
120                             float(count * 100) / float(properties[url]["Content-Length"]),
121                             bps/(1024*1024),
122                             (int(properties[url]["Content-Length"])-count)/bps)
123                 checkpoint = now
124
125     c.save_new(name="Downloaded from %s" % url, owner_uuid=project_uuid, ensure_unique_name=True)
126
127     api.collections().update(uuid=c.manifest_locator(), body={"collection":{"properties": properties}}).execute()
128
129     return "keep:%s/%s" % (c.portable_data_hash(), name)