Newer
Older
# This file is part of the mantidqt package
#
# Copyright (C) 2017 mantidproject
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import (absolute_import, unicode_literals)
# std imports
import unittest
# 3rdparty imports
from qtpy.QtCore import QCoreApplication, QObject
from mantidqt.utils.qt.testing import requires_qapp
from mantidqt.widgets.codeeditor.execution import PythonCodeExecution
class Receiver(QObject):
success_cb_called, error_cb_called = False, False
task_exc, error_stack = None, None
def on_success(self):
self.success_cb_called = True
def on_error(self, task_result):
self.error_cb_called = True
self.task_exc = task_result.exc_value
self.error_stack = task_result.stack
class ReceiverWithProgress(Receiver):
def __init__(self):
super(ReceiverWithProgress, self).__init__()
self.lines_received = []
def on_progess_update(self, lineno):
self.lines_received.append(lineno)
@requires_qapp
class PythonCodeExecutionTest(unittest.TestCase):
def test_default_construction_yields_empty_context(self):
executor = PythonCodeExecution()
self.assertEqual(0, len(executor.globals_ns))
def test_reset_context_clears_context(self):
executor = PythonCodeExecution()
globals_len = len(executor.globals_ns)
executor.execute("x = 1")
self.assertTrue(globals_len + 1, len(executor.globals_ns))
executor.reset_context()
self.assertEqual(0, len(executor.globals_ns))
# ---------------------------------------------------------------------------
# Successful execution tests
# ---------------------------------------------------------------------------
def test_execute_places_output_in_globals(self):
user_globals = self._verify_serial_execution_successful(code)
self.assertEquals(100, user_globals['_local'])
user_globals = self._verify_async_execution_successful(code)
self.assertEquals(100, user_globals['_local'])
def test_execute_async_calls_success_signal_on_completion(self):
code = "x=1+2"
executor, recv = self._run_async_code(code)
self.assertTrue(recv.success_cb_called)
self.assertFalse(recv.error_cb_called)
# ---------------------------------------------------------------------------
# Error execution tests
# ---------------------------------------------------------------------------
def test_execute_raises_syntax_error_on_bad_code(self):
code = "if:"
self._verify_failed_serial_execute(SyntaxError, code)
def test_execute_async_calls_error_cb_on_syntax_error(self):
code = "if:"
executor, recv = self._run_async_code(code)
self.assertFalse(recv.success_cb_called)
self.assertTrue(isinstance(recv.task_exc, SyntaxError),
msg="Unexpected exception found. "
"SyntaxError expected, found {}".format(recv.task_exc.__class__.__name__))
self.assertEqual(1, recv.task_exc.lineno)
def test_execute_returns_failure_on_runtime_error_and_captures_exception(self):
code = "x = _local + 1"
self._verify_failed_serial_execute(NameError, code)
def test_execute_async_returns_failure_on_runtime_error_and_captures_expected_stack(self):
def foo():
def bar():
# raises a NameError
y = _local + 1
# call inner
bar()
foo()
executor, recv = self._run_async_code(code, with_progress=True)
self.assertFalse(recv.success_cb_called)
self.assertTrue(recv.error_cb_called)
self.assertTrue(isinstance(recv.task_exc, NameError),
msg="Unexpected exception found. "
"NameError expected, found {}".format(recv.task_exc.__class__.__name__))
# Test the stack has been chopped as expected
self.assertEqual(3, len(recv.error_stack))
# check line numbers
self.assertEqual(8, recv.error_stack[0][1])
self.assertEqual(7, recv.error_stack[1][1])
self.assertEqual(5, recv.error_stack[2][1])
# ---------------------------------------------------------------------------
# Progress tests
# ---------------------------------------------------------------------------
def test_progress_cb_is_not_called_for_empty_string(self):
code = ""
executor, recv = self._run_async_code(code, with_progress=True)
self.assertEqual(0, len(recv.lines_received))
def test_progress_cb_is_not_called_for_code_with_syntax_errors(self):
code = """x = 1
executor, recv = self._run_async_code(code, with_progress=True)
self.assertEqual(0, len(recv.lines_received))
self.assertFalse(recv.success_cb_called)
self.assertTrue(recv.error_cb_called)
self.assertEqual(0, len(recv.lines_received))
def test_progress_cb_is_called_for_single_line(self):
code = "x = 1"
executor, recv = self._run_async_code(code, with_progress=True)
if not recv.success_cb_called:
self.assertTrue(recv.error_cb_called)
self.fail("Execution failed with error:\n" + str(recv.task_exc))
self.assertEqual([1], recv.lines_received)
def test_progress_cb_is_called_for_multiple_single_lines(self):
code = """x = 1
y = 2
"""
executor, recv = self._run_async_code(code, with_progress=True)
if not recv.success_cb_called:
self.assertTrue(recv.error_cb_called)
self.fail("Execution failed with error:\n" + str(recv.task_exc))
self.assertEqual([1, 2], recv.lines_received)
def test_progress_cb_is_called_for_mix_single_lines_and_blocks(self):
code = """x = 1
# comment line
sum = 0
for i in range(10):
if i %2 == 0:
sum += i
squared = sum*sum
"""
executor, recv = self._run_async_code(code, with_progress=True)
if not recv.success_cb_called:
if recv.error_cb_called:
self.fail("Unexpected error found: " + str(recv.task_exc))
else:
self.fail("No callback was called!")
context = executor.globals_ns
self.assertEqual(20, context['sum'])
self.assertEqual(20*20, context['squared'])
self.assertEqual(1, context['x'])
self.assertEqual([1, 2, 3, 4, 9], recv.lines_received)
# -------------------------------------------------------------------------
# Filename checks
# -------------------------------------------------------------------------
def test_filename_included_in_traceback_if_supplied(self):
code = """raise RuntimeError"""
executor, recv = self._run_async_code(code, filename=filename)
self.assertTrue(recv.error_cb_called)
self.assertEqual(filename, recv.error_stack[0][0])
# -------------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------------
def _verify_serial_execution_successful(self, code):
return executor.globals_ns
def _verify_async_execution_successful(self, code):
task = executor.execute_async(code)
return executor.globals_ns
def _verify_failed_serial_execute(self, expected_exc_type, code):
self.assertRaises(expected_exc_type, executor.execute, code)
def _run_async_code(self, code, with_progress=False, filename=''):
executor = PythonCodeExecution()
if with_progress:
recv = ReceiverWithProgress()
executor.sig_exec_progress.connect(recv.on_progess_update)
else:
recv = Receiver()
executor.sig_exec_success.connect(recv.on_success)
executor.sig_exec_error.connect(recv.on_error)
task = executor.execute_async(code, filename)
task.join()
QCoreApplication.processEvents()
return executor, recv
if __name__ == "__main__":
unittest.main()