#!/usr/bin/env python
# Copyright 2011  Lars Wirzenius
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import cliapp
import glob
import logging
import os
import shutil
import sys
import tempfile
import time
import ttystatus
import unittest

import cmdtestlib
import yarnlib


ALL_TESTS = ''
COMPLETE_TEST = ''


class TestFailure(Exception):

    def __init__(self, test, msg):
        self.msg = 'FAIL: %s: %s' % (test.name, msg)
        
    def __str__(self):
        return self.msg


class CommandTester(cliapp.Application):

    def add_settings(self):
        self.settings.string(['command', 'c'], 
                             'ignored for backwards compatibility')
        self.settings.string_list(['test', 't'],
                                  'run only TEST (can be given many times)',
                                  metavar='TEST')
        self.settings.boolean(['timings'], 'report how long each test takes')
        self.settings.boolean(['keep', 'k'], 'keep temporary data on failure')

    def process_args(self, dirnames):

        self.timings = {}
        self.timings[ALL_TESTS] = {}
        suite_started = time.time()

        td = self.load_tests(dirnames)
        self.setup_ttystatus(td)
        
        errors = 0
        self.setup_tempdir()
        self.run_script(ALL_TESTS, td.setup_once)
        for test in td.tests:
            self.timings[test.name] = {}
            started = time.time()
            self.ts['test'] = test
            self.ts['test-name'] = test.name
            self.run_script(test.name, td.setup)
            for e in self.run_test(test):
                logging.error(str(e))
                self.ts.clear()
                self.output.write('%s\n' % str(e))
                errors += 1
            self.run_script(test.name, td.teardown)
            self.timings[test.name][COMPLETE_TEST] = time.time() - started
        self.run_script(ALL_TESTS, td.teardown_once)
        
        keep_tempdir = errors and self.settings['keep']
        if not keep_tempdir:
            self.cleanup_tempdir()

        self.timings[ALL_TESTS][COMPLETE_TEST] = time.time() - suite_started
                
        ok = len(td.tests) - errors
        self.ts.finish()
        self.output.write('%d/%d tests OK, %d failures\n' % 
                            (ok, len(td.tests), errors))

        if keep_tempdir:
            self.output.write('Test work area: %s\n' % self.tempdir)

        if self.settings['timings']:
            self.report_timings()

        if errors:
            sys.exit(1)

    def setup_ttystatus(self, td):
        self.ts = ttystatus.TerminalStatus(period=0.001)
        self.ts['tests'] = td.tests
        self.ts.format('test %Index(test,tests): %String(test-name)')
        
    def load_tests(self, dirnames):
        td = cmdtestlib.TestDir()
        for dirname in dirnames:
            if self.settings['test']:
                filenames = self.find_requested_tests(dirname)
                td.scan(dirname, filenames)
            else:
                td.scan(dirname)
        return td

    def find_requested_tests(self, dirname):
        filenames = []
        for test in self.settings['test']:
            matches = glob.glob(os.path.join(dirname, test + '.*'))
            filenames += [os.path.basename(x) for x in matches]

        basenames = ['setup_once', 'setup', 'teardown', 'teardown_once']
        for basename in basenames:
            filename = os.path.join(dirname, basename)
            if os.path.exists(filename):
                filenames.append(basename)

        return filenames

    def setup_tempdir(self):
        self.tempdir = tempfile.mkdtemp()
        logging.info('Temporary directory %s' % self.tempdir)
        self.datadir = os.path.join(self.tempdir, 'data')
        os.mkdir(self.datadir)
        
    def cleanup_tempdir(self):
        logging.info('Removing temporary directory %s' % self.tempdir)
        shutil.rmtree(self.tempdir)

    def run_script(self, test_name, script_name):
        logging.debug('run_script: test_name=%s script_name=%s' %
                      (test_name, script_name))
        if script_name:
            started = time.time()
            self.runcmd([script_name], env=self.add_to_env(test_name))
            self.timings[test_name][script_name] = time.time() - started

    def add_to_env(self, test_name):
        env = dict(os.environ)
        env['SRCDIR'] = os.getcwd()
        env['DATADIR'] = self.datadir
        env['TESTNAME'] = test_name
        return env

    def run_test(self, test):
        logging.info('Test case: %s' % test.name)

        self.run_script(test.name, test.setup)

        if test.script:
            argv = [test.script]
        else:
            raise cliapp.AppException('Must have a .script file for test')

        stdout_name = test.path_prefix + '.stdout-actual'
        stderr_name = test.path_prefix + '.stderr-actual'
        with open(stdout_name, 'wb') as stdout:
            with open(stderr_name, 'wb') as stderr:
                if test.stdin:
                    stdin = open(test.stdin, 'rb')
                else:
                    stdin = None
                env = self.add_to_env(test.name)
                exit, out, err = self.runcmd_unchecked(argv, 
                                                       env=env,
                                                       stdin=stdin,
                                                       stdout=stdout,
                                                       stderr=stderr)
                if stdin is not None:
                    stdin.close()

        self.run_script(test.name, test.teardown)

        errors = []

        stdout_diff_name = test.path_prefix + '.stdout-diff'
        stdout_diff = self.diff(test.stdout or '/dev/null', stdout_name,
                                stdout_diff_name)
        if stdout_diff:
            errors.append(TestFailure(test, 'stdout diff:\n%s' % stdout_diff))

        stderr_diff_name = test.path_prefix + '.stderr-diff'
        stderr_diff = self.diff(test.stderr or '/dev/null', stderr_name,
                                stderr_diff_name)
        if stderr_diff:
            errors.append(TestFailure(test, 'stderr diff:\n%s' % stderr_diff))

        contents = cmdtestlib.cat(test.exit or '/dev/null')
        expected_exit = int(contents.strip() or '0')
        if exit != expected_exit:
            errors.append(TestFailure(test, 
                                      'got exit code %s, expected %s' %
                                        (exit, expected_exit)))

        if not errors:
            os.remove(stdout_name)
            os.remove(stderr_name)
            os.remove(stdout_diff_name)
            os.remove(stderr_diff_name)
        
        return errors

    def diff(self, expected_name, actual_name, diff_name):
        argv = ['diff', '-u', expected_name, actual_name]
        with open(diff_name, 'w') as f:
            exit, out, err = self.runcmd_unchecked(argv, stdout=f)
        with open(diff_name, 'r') as f:
            return f.read()

    def report_timings(self):

        def report(key, title):
            t = self.timings[key]
            self.output.write('  %s: %.1f s\n' % (title, t[COMPLETE_TEST]))
            
            items = [(t[x], x) for x in t.keys() if x != COMPLETE_TEST]
            for secs, script in sorted(items):
                self.output.write('    %4.1f %s\n' % (secs, script))
                
        self.output.write('Timings (in seconds):\n')
        report(ALL_TESTS, 'whole test suite')
        for name in sorted(x for x in self.timings.keys() if x != ALL_TESTS):
            report(name, name)


if __name__ == '__main__':
    CommandTester(version=yarnlib.__version__).run()
