#!/usr/bin/env python
# encoding: utf-8
"""
Drum transcription with recurrent (RNN), convolutional (CNN)pip lis, or convolutional recurrent neural networks (CRNN).

"""

from __future__ import absolute_import, division, print_function

import argparse

from madmom.features import ActivationsProcessor
from madmom.features.drums import CRNNDrumProcessor, DrumPeakPickingProcessor
from madmom.features.notes import write_midi, write_notes
from madmom.processors import IOProcessor, io_arguments


def main():
    """DrumTranscriptor"""

    # define parser
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    Drum transcription with a convolutional recurrent neural network (CRNN).
    ''')
    # version
    p.add_argument('--version', action='version',
                   version='DrumTranscriptor.2017')
    # input/output arguments
    io_arguments(p, output_suffix='.drum_transcriptor.txt')
    ActivationsProcessor.add_arguments(p)
    # peak picking arguments
    DrumPeakPickingProcessor.add_arguments(
        p, threshold=0.15, smooth=0, pre_avg=0.1, post_avg=0.01, pre_max=0.02,
        post_max=0.01, combine=0.02)
    # midi arguments
    p.add_argument('--midi', dest='output_format', action='store_const',
                   const='midi', help='save as MIDI')

    CRNNDrumProcessor.add_arguments(p)
    # parse arguments
    args = p.parse_args()

    # set immutable defaults
    args.fps = 100

    # set the suffix for midi files
    if args.output_format == 'midi':
        args.output_suffix = '.mid'

    # print arguments
    if args.verbose:
        print(args)

    # input processor
    if args.load:
        # load the activations from file
        in_processor = ActivationsProcessor(mode='r', **vars(args))
    else:
        # use a RNN to predict the notes
        in_processor = CRNNDrumProcessor(**vars(args))

    # output processor
    if args.save:
        # save the RNN note activations to file
        out_processor = ActivationsProcessor(mode='w', **vars(args))
    else:
        # perform peak picking on the activation function
        peak_picking = DrumPeakPickingProcessor(**vars(args))
        # output everything in the right format
        if args.output_format is None:
            output = write_notes
        elif args.output_format == 'midi':
            output = write_midi
        else:
            raise ValueError('unknown output format: %s' % args.output_format)
        out_processor = [peak_picking, output]

    # create an IOProcessor
    processor = IOProcessor(in_processor, out_processor)

    # and call the processing function
    args.func(processor, **vars(args))


if __name__ == '__main__':
    main()
