#!/usr/bin/env python
# encoding: utf-8
"""
Evaluation script.

"""

from __future__ import absolute_import, division, print_function

import os
import sys
import argparse
import warnings

from madmom.utils import search_files, match_file
from madmom.evaluation import alignment, beats, notes, onsets, tempo


def main():
    """Evaluation script"""

    # define parser
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    This program evaluates pairs of files containing annotations and
    detections.

    Please note that all available evaluation options have their individual
    help messages, e.g.

      $ evaluate onsets -h

    ''')
    # add version
    p.add_argument('--version', action='version', version='evaluate')

    # add subparsers
    sub_parsers = p.add_subparsers(title='available evaluation options')
    # onset evaluation
    onsets.add_parser(sub_parsers)
    # beat evaluation
    beats.add_parser(sub_parsers)
    # tempo evaluation
    tempo.add_parser(sub_parsers)
    # note evaluation
    notes.add_parser(sub_parsers)
    # alignment evaluation
    alignment.add_parser(sub_parsers)
    # chord evaluation
    # chords.add_parser(sub_parsers)

    # parse the args
    args = p.parse_args()

    # print the arguments
    if args.verbose >= 2:
        print(args)
    if args.quiet:
        warnings.filterwarnings("ignore")

    # get detection and annotation files
    if args.det_dir is None:
        args.det_dir = args.files
    if args.ann_dir is None:
        args.ann_dir = args.files
    det_files = search_files(args.det_dir, args.det_suffix)
    ann_files = search_files(args.ann_dir, args.ann_suffix)
    # quit if no files are found
    if len(ann_files) == 0:
        print("no files to evaluate. exiting.")
        exit()

    # list to collect the individual evaluation objects
    eval_objects = []

    # progress
    progress = ''

    # evaluate all files
    num_files = len(ann_files)
    for num_file, ann_file in enumerate(ann_files):
        # print progress
        progress_len = len(progress)
        if args.verbose >= 2:
            progress = 'evaluating %s' % os.path.basename(ann_file)
        else:
            progress = 'evaluating file %d of %d' % (num_file + 1, num_files)
        sys.stderr.write('\r%s' % progress.ljust(progress_len))
        sys.stderr.flush()

        #  get the matching detection files
        matches = match_file(ann_file, det_files,
                             args.ann_suffix, args.det_suffix)
        if len(matches) > 1:
            # exit if multiple detections were found
            raise SystemExit("multiple detections for %s found" % ann_file)
        elif len(matches) == 0:
            # ignore non-existing detections
            if args.ignore_non_existing:
                continue
            # output a warning if no detections were found
            warnings.warn(" can't find detections for %s" % ann_file)
            # but continue and assume no detections
            det_file = None
        else:
            # use the first (and only) matched detection file
            det_file = matches[0]

        # evaluate
        print(ann_file)
        e = args.eval(det_file, ann_file, name=os.path.basename(ann_file),
                      **vars(args))

        # add this file's evaluation to the global evaluation list
        eval_objects.append(e)

    # clear progress
    sys.stderr.write('\r%s\r' % ' '.ljust(len(progress)))
    sys.stderr.flush()

    # output every evaluation object individually
    out_list = []
    if args.verbose:
        out_list.extend(eval_objects)
    # add sum/mean evaluation to output
    if args.sum_eval is not None:
        out_list.append(args.sum_eval(eval_objects))
    out_list.append(args.mean_eval(eval_objects))

    # output everything
    args.outfile.write(args.output_formatter(out_list, **vars(args)) + '\n')


if __name__ == '__main__':
    main()
