2411: Add copyright notices to everything.
[arvados.git] / sdk / python / tests / performance / performance_profiler.py
index 3810f92f18872b78f204fc5308178dc08efb68f7..3be00c4546264a20c3fd55ee4cc5ece6f9000626 100644 (file)
@@ -1,35 +1,49 @@
-# Use the PerformanceProfiler class to write your performance tests.
+# Copyright (C) The Arvados Authors. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# Use the "profiled" decorator on a test to get profiling data.
 #
 # Usage:
-#   from performance_profiler import PerformanceProfiler
-#   self.run_profiler(...
+#   from performance_profiler import profiled
+#
+#   # See report in tmp/profile/foobar
+#   @profiled
+#   def foobar():
+#       baz = 1
 #
 #   See "test_a_sample.py" for a working example.
 #
-# To run performance tests:
-#     cd arvados/sdk/python
+# Performance tests run as part of regular test suite.
+# You can also run only the performance tests using one of the following:
 #     python -m unittest discover tests.performance
-#
-#     Alternatively, using run-tests.sh
-#         ./run-tests.sh WORKSPACE=~/arvados --only sdk/python sdk/python_test="--test-suite=tests.performance"
-#
+#     ./run-tests.sh WORKSPACE=~/arvados --only sdk/python sdk/python_test="--test-suite=tests.performance"
 
+import functools
 import os
-import unittest
+import pstats
 import sys
-from datetime import datetime
+import unittest
 try:
     import cProfile as profile
 except ImportError:
     import profile
 
-class PerformanceProfiler(unittest.TestCase):
-    def run_profiler(self, function, test_name):
-        filename = os.getcwd()+'/tmp/performance/'+ datetime.now().strftime('%Y-%m-%d-%H-%M-%S') +'-' +test_name
-
-        directory = os.path.dirname(filename)
-        if not os.path.exists(directory):
-            os.makedirs(directory)
+output_dir = os.path.abspath(os.path.join('tmp', 'profile'))
+if not os.path.exists(output_dir):
+    os.makedirs(output_dir)
 
-        sys.stdout = open(filename, 'w')
-        profile.runctx(function, globals(), locals())
+def profiled(function):
+    @functools.wraps(function)
+    def profiled_function(*args, **kwargs):
+        outfile = open(os.path.join(output_dir, function.__name__), "w")
+        caught = None
+        pr = profile.Profile()
+        pr.enable()
+        try:
+            return function(*args, **kwargs)
+        finally:
+            pr.disable()
+            ps = pstats.Stats(pr, stream=outfile)
+            ps.sort_stats('time').print_stats()
+    return profiled_function