#!/usr/bin/python
# Copyright 2010, 2011  Lars Wirzenius
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Excercise my B-tree implementation, for simple benchmarking purposes.
# The benchmark gets a location and nb of keys to use as command line
# arguments --location=LOCATION and --keys=KEYS.

# To debug, one can create a tracing logfile by adding arguments like:
# --trace=refcount --log=refcount.logfile
#
# If the location is the empty string, an in-memory node store is used.
# Otherwise it must be a non-existent directory name.
#
# The benchmark will do the given number of insertions into the tree, and
# measure the speed of that. Then it will look up each of those, and measure
# the lookups.


import cliapp
import cProfile
import csv
import gc
import logging
import os
import random
import shutil
import subprocess
import sys
import time
import tracing

import larch


class SpeedTest(cliapp.Application):

    def add_settings(self):
        self.settings.boolean(['profile'], 'profile with cProfile?')
        self.settings.boolean(['log-memory-use'], 'log VmRSS?')
        self.settings.string(['trace'], 
                'code module in which to do trace logging')
        self.settings.integer(['keys'], 
                'how many keys to test with (default is %default)',
                default=1000)
        self.settings.string(['location'], 
                'where to store B-tree on disk (in-memory test if not set)')
        self.settings.string(['csv'],
                'append a CSV row to FILE',
                metavar='FILE')

    def process_args(self, args):
        if self.settings['trace']:
            tracing.trace_add_pattern(self.settings['trace'])
    
        key_size = 19
        value_size = 128
        node_size = 64*1024

        n = self.settings['keys']
        location = self.settings['location']
        
        if n is None:
            raise Exception('You must set number of keys with --keys')
        
        if not location:
            forest = larch.open_forest(
                allow_writes=True, key_size=key_size, node_size=node_size,
                node_store=larch.NodeStoreMemory)
        else:
            if os.path.exists(location):
                raise Exception('%s exists already' % location)
            os.mkdir(location)
            forest = larch.open_forest(
                allow_writes=True, key_size=key_size, node_size=node_size,
                dirname=location)

        tree = forest.new_tree()
        
        # Create list of keys.
        keys = ['%0*d' % (key_size, i) for i in xrange(n)]
        ranges = []
        range_len = 10
        for i in range(0, len(keys) - range_len):
            ranges.append((keys[i], keys[i+range_len-1]))
        
        # Helper functions.
        nop = lambda *args: None
        
        # Calibrate.
        looptime = self.measure(keys, nop, nop, 'calibrate')

        # Measure inserts.
        random.shuffle(keys)
        value = 'x' * value_size
        insert = self.measure(keys, lambda key: tree.insert(key, value), 
                              forest.commit, 'insert')
            
        # Measure lookups.
        random.shuffle(keys)
        lookup = self.measure(keys, tree.lookup, nop, 'lookup')
            
        # Measure range lookups.
        random.shuffle(ranges)
        lookup_range = self.measure(ranges, 
                                    lambda x: 
                                        list(tree.lookup_range(x[0], x[1])),
                                    nop, 'lookup_range')

        # Measure count of range lookup results.
        len_lookup_range = self.measure(ranges,
                         lambda x: len(list(tree.lookup_range(x[0], x[1]))),
                         nop, 'len_lookup_range')

        # Measure count range.
        count_range = self.measure(ranges,
                                   lambda x: tree.count_range(x[0], x[1]),
                                   nop, 'count_range')

        # Measure inserts into existing tree.
        random.shuffle(keys)
        insert2 = self.measure(keys, lambda key: tree.insert(key, value),
                               forest.commit, 'insert2')

        # Measure removes from tree.
        random.shuffle(keys)
        remove = self.measure(keys, tree.remove, forest.commit, 'remove')

        # Measure remove_range. This requires building a new tree.
        keys.sort()
        for key in keys:
            tree.insert(key, value)
        random.shuffle(ranges)
        remove_range = self.measure(ranges, 
                                    lambda x: tree.remove_range(x[0], x[1]),
                                    forest.commit, 'remove_range')

        # Report
        def speed(result, i):
            if result[i] == looptime[i]:
                # computer too fast for the number of "keys" used...
                return float("infinity")
            else:
                return n / (result[i] - looptime[i])
        def report(label, result):
            cpu, wall = result
            print '%-16s: %5.3f s (%8.1f/s) CPU; %5.3f s (%8.1f/s) wall' % \
                (label, cpu, speed(result, 0), wall, speed(result, 1))

        print 'location:', location if location else 'memory'
        print 'num_operations: %d' % n
        report('insert', insert)
        report('lookup', lookup)
        report('lookup_range', lookup_range)
        report('len_lookup_range', len_lookup_range)
        report('count_range', count_range)
        report('insert2', insert2)
        report('remove', remove)
        report('remove_range', remove_range)
        if self.settings['profile']:
            print 'View *.prof with ./viewprof for profiling results.'
            
        if self.settings['csv']:
            self.append_csv(n, 
                            speed(insert, 0), 
                            speed(insert2, 0), 
                            speed(lookup, 0),
                            speed(lookup_range, 0), 
                            speed(remove, 0), 
                            speed(remove_range, 0))

        # Clean up
        if location:
            shutil.rmtree(location)

    def measure(self, items, func, finalize, profname):

        def log_memory_use(stage):
            if self.settings['log-memory-use']:
                logging.info('%s memory use: %s' % (profname, stage))
                logging.info('  VmRSS: %s KiB' %  self.vmrss())
                logging.info('  # objects: %d' % len(gc.get_objects()))
                logging.info('  # garbage: %d' % len(gc.garbage))

        def helper():
            log_memory_use('at start')
            for item in items:
                func(item)
            log_memory_use('after calls')
            finalize()
            log_memory_use('after finalize')

        print 'measuring', profname
        start_time = time.time()
        start = time.clock()
        if self.settings['profile']:
            globaldict = globals().copy()
            localdict = locals().copy()
            cProfile.runctx('helper()', globaldict, localdict, 
                            '%s.prof' % profname)
        else:
            helper()
        end = time.clock()
        end_time = time.time()
        return end - start, end_time - start_time

    def vmrss(self):
        f = open('/proc/self/status')
        rss = 0
        for line in f:
            if line.startswith('VmRSS'):
                rss = line.split()[1]
        f.close()
        return rss

    def append_csv(self, keys, insert, insert2, lookup, lookup_range,
                    remove, remove_range):
        write_title = not os.path.exists(self.settings['csv'])
        f = open(self.settings['csv'], 'a')
        self.writer = csv.writer(f, lineterminator='\n')
        if write_title:
            self.writer.writerow(('revno',
                                  'keys',
                                  'insert (random)',
                                  'insert (seq)',
                                  'lookup',
                                  'lookup_range',
                                  'remove',
                                  'remove_range'))

        if os.path.exists('.bzr'):
            p = subprocess.Popen(['bzr', 'revno'], stdout=subprocess.PIPE)
            out, err = p.communicate()
            if p.returncode != 0:
                raise cliapp.AppException('bzr failed')
            revno = out.strip()
        else:
            revno = '?'

        self.writer.writerow((revno,
                              keys,
                              self.format(insert),
                              self.format(insert2),
                              self.format(lookup),
                              self.format(lookup_range),
                              self.format(remove),
                              self.format(remove_range)))
        f.close()

    def format(self, value):
        return '%.0f' % value


if __name__ == '__main__':
    SpeedTest().run()
