#!/usr/bin/env python
#
# This utility is used to load Netlogger .bp files containing
# Pegasus events into SQLite databases containing the Stampede
# schema. The databases can be mined for information about the
# workflow events loaded.
#
import sys
import os
import re
import time
import calendar
from datetime import datetime

BATCHSIZE = 1000	# The maximum number of SQL statements in a batch

# The format of a netlogger event is a series of key=value or key="value"
re_parse_netlogger_event = re.compile(r'([^ =]+)[ ]*=[ ]*(("([^"]*)")|([^ ]+))')	

def parse_netlogger_event(line):
	"Convert a string containing a netlogger event into a dict"
	rec = {}
	for m in re_parse_netlogger_event.finditer(line):
		key = m.group(1)
		value = m.group(3)
		if value is None:
			value = m.group(5)
		rec[key] = value
	return rec
	
def parse_ts(ts):
	"Parse a netlogger timestamp"
	ts, subs = ts.split('.')
	subs = float(subs[:-1])
	return calendar.timegm(time.strptime(ts, r'%Y-%m-%dT%H:%M:%S')) + subs
	
class Batch:
	"Stores and executes a batch of sql statements"
	
	def __init__(self, sql, batchsize=BATCHSIZE):
		self.items = []
		self.batchsize = batchsize
		
		# Convert parameters of type :key into a list of keys
		re_param = re.compile(":([a-zA-Z0-9_]+)")
		self.keys = []
		for m in re_param.finditer(sql):
			key = m.group(1)
			self.keys.append(key)
		
		# Convert parameters of type :key into question marks
		self.sql = re_param.sub("?", sql)
		
	def add_batch(self, rec):
		"Add a record to the batch"
		params = []
		for k in self.keys:
			params.append(rec[k])
		self.items.append(params)
		
	def batch_ready(self):
		"Return True if the batch is ready to execute"
		return len(self.items) >= self.batchsize
		
	def execute_batch(self, cursor):
		"Execute the batch if there are any parameters"
		if len(self.items) > 0:
			try:
				cursor.executemany(self.sql, self.items)
				self.items = []
			except Exception, e:
				e.args = list(e.args) + [self.sql]
				raise
		
class Loader:
	"Database loader for stampede events"
	
	def __init__(self):
		# Caches for IDs
		self.workflows = {}
		self.jobs = {}
		self.hosts = {}
		
		# Stores the number of times we skipped each event
		self.skips = {}
		
		# Next ID number to use
		self.next_wf_id = 1
		self.next_job_id = 1
		self.next_task_id = 1
		self.next_host_id = 1
		
		# Batched SQL statements
		self.workflow_batch = Batch("INSERT INTO workflow (wf_id, wf_uuid, dax_label, timestamp, submit_hostname, submit_dir, planner_arguments, username, grid_dn, planner_version, parent_workflow_id) VALUES (:wf_id, :wf_uuid, :dax_label, :timestamp, :submit_hostname, :submit_dir, :planner_arguments, :user, :grid_dn, :planner_version, :parent_workflow_id)")
		self.workflowstate_batch = Batch("INSERT INTO workflowstate (wf_id, state, timestamp) VALUES (:wf_id, :state, :timestamp)")
		self.host_batch = Batch("INSERT INTO host (host_id, site_name, hostname, ip_address, uname, total_ram) VALUES (:host_id, :site_name, :hostname, :ip_address, :uname, :total_ram)")
		self.edge_static_batch = Batch("INSERT INTO edge_static (wf_uuid, parent, child) VALUES (:wf_uuid, :parent, :child)")
		self.pre_batch = Batch("INSERT INTO job (job_id, wf_id, job_submit_seq, name, jobtype) VALUES (:job_id, :wf_id, :job_submit_seq, :name, :jobtype)")
		self.job_batch = Batch("REPLACE INTO job (job_id, wf_id, job_submit_seq, name, condor_id, jobtype) VALUES (:job_id, :wf_id, :job_submit_seq, :name, :condor_id, :jobtype)")
		self.job_host_batch = Batch("UPDATE job SET host_id=:host_id WHERE job_id=:job_id")
		self.job_update_batch = Batch("UPDATE job SET cluster_duration=:cluster_duration, remote_user=:remote_user, site_name=:site_name, remote_working_dir=:remote_working_dir, clustered=:clustered, cluster_start_time=:cluster_start_time WHERE job_id=:job_id")
		self.jobstate_batch = Batch("INSERT INTO jobstate (job_id, state, timestamp, jobstate_submit_seq) VALUES (:job_id, :state, :timestamp, :jobstate_submit_seq)")
		self.task_batch = Batch("INSERT INTO task (task_id, job_id, task_submit_seq, start_time, duration, exitcode, transformation, executable, arguments) VALUES (:task_id, :job_id, :task_submit_seq, :start_time, :duration, :exitcode, :transformation, :executable, :arguments)")
		# This query assumes that the parent job with the largest job_id is the one we want
		self.edge_batch = Batch("INSERT INTO edge SELECT MAX(p.job_id) parent_id, c.job_id child_id FROM edge_static e, job c, job p WHERE e.wf_uuid=:wf_uuid AND e.child=:name AND c.job_id=:job_id AND p.wf_id=:wf_id AND p.name=e.parent GROUP BY p.name")
		
		# Order in which batches should be executed
		self.batches = [
			self.workflow_batch,
			self.workflowstate_batch,
			self.host_batch,
			self.edge_static_batch,
			self.pre_batch,
			self.job_batch,
			self.job_host_batch,
			self.job_update_batch,
			self.jobstate_batch,
			self.task_batch,
			self.edge_batch
		]
	
	def load_bpfile(self, bpfile):
		"Load a netlogger bpfile"
		try:
			tty = sys.stdout.isatty()
			cursor = self.conn.cursor()
			i = 0
			log = open(bpfile)
			for l in log:
				rec = parse_netlogger_event(l.strip())
				if self.load_event(rec):
					i += 1
					# Only check batches once every 100 records
					if (i%100) == 0 and self.check_batches():
						self.execute_batches(cursor)
					if tty and (i%73) == 0:
						sys.stdout.write("Processed %d events\r" % i)
						sys.stdout.flush()
				else:
					self.skip_event(rec)
			self.execute_batches(cursor)
			print "Processed %d events" % i
			cursor.close()
			self.conn.commit()
		except:
			self.conn.rollback()
			raise
			
	def skip_event(self, rec):
		"Skip a parsed event"
		
		event = rec['event']
		if event in self.skips:
			self.skips[event] = self.skips[event] + 1
		else:
			self.skips[event] = 1
			
	def load_event(self, rec):
		"Load a parsed event"
		
		event = rec['event']
		
		if not event.startswith('stampede.'):
			raise Exception("Invalid event: %s" % e)
		
		e = event.replace('stampede.','').replace('.','_')
		
		handler = getattr(self, e, None)
		if handler:
			try:
				handler(rec)
			except Exception, e:
				e.args = list(e.args) + [rec]
				raise
			return True
		else:
			return False
		
	def check_batches(self):
		"Check to see if any batch SQL statements are ready to run"
		execute = False
		for b in self.batches:
			if b.batch_ready():
				return True
		return False
			
	def execute_batches(self, cursor):
		"Execute all batch SQL statements"
		for b in self.batches:
			b.execute_batch(cursor)
		
	def workflow_plan(self, rec):
		"Handle a stampede.workflow.plan event"
		
		wf_id = self.new_workflow(rec)
		
		rec.setdefault('submit_hostname', None)
		rec.setdefault('submit_dir', None)
		rec.setdefault('planner_arguments', None)
		rec.setdefault('user', None)
		rec.setdefault('grid_dn', None)
		rec.setdefault('planner_version', None)
		rec.setdefault('parent_workflow_id', None)
		
		parent = rec['parent.wf.id']
		if parent is not None and parent != 'None':
			if parent in self.workflows:
				rec['parent_workflow_id'] = self.workflows[parent]
			else:
				raise Exception("Unknown parent workflow: %s" % parent)
		
		rec['wf_id'] = wf_id
		rec['wf_uuid'] = rec['wf.id']
		rec['timestamp'] = parse_ts(rec['ts'])
		
		self.workflow_batch.add_batch(rec)
			
		# Add to workflowstate table
		state = {}
		state['wf_id'] = wf_id
		state['timestamp'] = rec['timestamp']
		state['state'] = 'parse'
		self.workflowstate_batch.add_batch(state)
		
	def workflow_start(self, rec):
		"Handle a stampede.workflow.start event"
		
		wf_id = self.lookup_workflow(rec)
		
		if wf_id is None:
			wf_id = self.old_workflow(rec)
		
		rec['wf_id'] = wf_id
		rec['timestamp'] = parse_ts(rec['ts'])
		rec['state'] = 'start'
		self.workflowstate_batch.add_batch(rec)
		
	def workflow_end(self, rec):
		"Handle a stampede.workflow.end event"
		
		wf_id = self.lookup_workflow(rec)
		
		if wf_id is None:
			raise Exception("Unknown workflow: %s" % rec['wf.id'])
		
		rec['wf_id'] = wf_id
		rec['timestamp'] = parse_ts(rec['ts'])
		rec['state'] = 'end'
		self.workflowstate_batch.add_batch(rec)
		
		# Clean up workflow map
		wf_uuid = rec['wf.id']
		del self.workflows[wf_uuid]
		del self.jobs[wf_uuid]
		
	def host(self, rec):
		"Handle a stampede.host event"
		
		key = (rec['site_name'], rec['hostname'], rec['ip_address'])
		if key not in self.hosts:
			# If it is not in the host cache, add it
			host_id = self.next_host_id
			self.next_host_id += 1
			self.hosts[key] = host_id
			
			rec['host_id'] = host_id
			rec.setdefault('uname',None)
			rec.setdefault('total_ram',None)
			
			# Update host table
			self.host_batch.add_batch(rec)
		
		# Update job table
		host_id = self.hosts[key]
		job_id = self.lookup_job(rec)
		self.job_host_batch.add_batch(
			{'host_id': host_id, 'job_id': job_id})
		
	def edge(self, rec):
		"Handle a stampede.edge event"
		
		# Update edge_static table
		self.edge_static_batch.add_batch({
			'wf_uuid': rec['wf.id'],
			'parent': rec['parent'],
			'child': rec['child']
		})
		
	def job_prescript_start(self, rec):
		"Handle a stampede.job.prescript.start event"
		
		# Get wf_id
		wf_uuid = rec['wf.id']
		wf_id = self.workflows[wf_uuid]
		rec['wf_id'] = wf_id
		
		# Get job_id
		job_id = self.new_job(rec)
		rec['job_id'] = job_id
		
		rec['job_submit_seq'] = rec['job.id']
		
		# Update the job table
		self.pre_batch.add_batch(rec)
		
	def job_mainjob_start(self, rec):
		"Handle a stampede.job.mainjob.start event"
		
		# Get wf_id
		wf_uuid = rec['wf.id']
		wf_id = self.workflows[wf_uuid]
		rec['wf_id'] = wf_id
		
		# Get job_id
		job_seq = rec['job.id']
		if job_seq in self.jobs[wf_uuid]:
			job_id = self.jobs[wf_uuid][job_seq]
		else:
			job_id = self.new_job(rec)
		rec['job_id'] = job_id
		
		rec['job_submit_seq'] = rec['job.id']
		
		# Sometimes we don't get a condor ID
		rec['condor_id'] = None
		if 'condor.id' in rec:
			rec['condor_id'] = rec['condor.id']
		
		# Update job table
		self.job_batch.add_batch(rec)
		
		# Update edge table
		self.edge_batch.add_batch({
			'name': rec['name'],
			'job_id': job_id, 
			'wf_uuid': wf_uuid,
			'wf_id': wf_id
		})

	def job_mainjob_end(self, rec):
		"Handle a stampede.job.mainjob.end event"
		
		job_id = self.lookup_job(rec)
		rec['job_id'] = job_id
		
		rec.setdefault('remote_user',None)
		rec.setdefault('remote_working_dir',None)
		rec.setdefault('clustered',0)
		rec.setdefault('cluster_start_time',None)
		rec.setdefault('cluster_duration',None)
		
		# Update job table
		self.job_update_batch.add_batch(rec)
		
	def job_state(self, rec):
		"Handle a stampede.job.state event"
		
		job_id = self.lookup_job(rec)
		
		if job_id is None:
			# This usually occurs when we have a job.state before the mainjob.start
			# The only thing we can do is skip the event
			self.skip_event(rec)
			return
		
		rec['job_id'] = job_id
		rec['timestamp'] = parse_ts(rec['ts'])
		rec['jobstate_submit_seq'] = rec['js.id']
		
		# Update jobstate table
		self.jobstate_batch.add_batch(rec)
		
	def task_mainjob(self, rec):
		"Handle a stampede.task.mainjob event"
		
		job_id = self.lookup_job(rec)
		
		# If the job hasn't been loaded yet, we just have to skip it
		if job_id is None:
			self.skip_event(rec)
			return
		
		task_id = self.new_task()
		
		rec['job_id'] = job_id
		rec['task_id'] = task_id
		rec['task_submit_seq'] = rec['task.id']
		rec.setdefault('executable',None)
		rec.setdefault('arguments',None)
		
		self.task_batch.add_batch(rec)
		
	task_prescript = task_mainjob
	task_postscript = task_mainjob
		
	def new_workflow(self, rec):
		"Create a new workflow and add it to the cache"
		wf_uuid = rec['wf.id']
		
		if wf_uuid in self.workflows:
			raise Exception("Duplicate workflow: %s" % wf_uuid)
		
		wf_id = self.next_wf_id
		self.next_wf_id += 1
		
		self.workflows[wf_uuid] = wf_id
		self.jobs[wf_uuid] = {}
		
		return wf_id
		
	def old_workflow(self, rec):
		"Re-add a workflow to the workflow cache"
		wf_uuid = rec['wf.id']
		wf_id = self.get_wf_id(wf_uuid)
		
		self.workflows[wf_uuid] = wf_id
		self.jobs[wf_uuid] = {}
		
		return wf_id
		
	def get_wf_id(self, wf_uuid):
		"Get a wf_id from the database using wf_uuid"
		cur = self.conn.cursor()
		cur.execute("SELECT wf_id FROM workflow WHERE wf_uuid=?",(wf_uuid,))
		wf_id = str(cur.fetchone()[0])
		cur.close()
		return wf_id
	
	def new_job(self, rec):
		"Create a new job and add it to the cache"
		wf_uuid = rec['wf.id']
		jobs = self.jobs[wf_uuid]
		
		seq_no = rec['job.id']
		if seq_no in jobs:
			raise Exception("Duplicate job: %d" % seq_no)
		
		job_id = self.next_job_id
		self.next_job_id += 1
		
		jobs[seq_no] = job_id
		
		return job_id
		
	def new_task(self):
		"Create a new task"
		task_id = self.next_task_id
		self.next_task_id += 1
		return task_id
		
	def lookup_workflow(self, rec):
		"Find a workflow in the cache and return its wf_id"
		wf_uuid = rec['wf.id']
		
		if wf_uuid in self.workflows:
			return self.workflows[wf_uuid]
		
		return None
		
	def lookup_job(self, rec):
		"Find a job in the cache and return its job_id"
		wf_uuid = rec['wf.id']
		seq_no = rec['job.id']
		
		jobs = self.jobs[wf_uuid]
		if seq_no in jobs:
			return jobs[seq_no]
		
		return None
	
class SQLiteLoader(Loader):
	def __init__(self, dbfile):
		Loader.__init__(self)
		self.dbfile = dbfile
		
		import sqlite3
		
		self.conn = sqlite3.connect(dbfile,isolation_level="EXCLUSIVE")
		
		self.conn.execute("PRAGMA locking_mode = EXCLUSIVE")
		
		# Count the number of tables in the schema
		cur = self.conn.cursor()
		cur.execute("SELECT count(*) tables FROM sqlite_master WHERE type='table'")
		tables = cur.fetchone()[0]
		cur.close()
		
		if tables == 0:
			# If there are no tables, then we need to create the database
			self.conn.execute("PRAGMA page_size = 4096")
			self.create_schema()
		else:
			# If the tables already exist, we need to figure out what the next
			# ID numbers need to be, and repopulate the host cache
			self.set_id_numbers()
			self.populate_host_cache()
		
	def create_schema(self):
		"Create tables and indexes in database"
		print "Creating Stampede schema in %s" % self.dbfile
		
		basedir = os.path.abspath(os.path.dirname(sys.argv[0]))
		schemafile = os.path.join(basedir, "schema.sql")
		script = open(schemafile).read()
		
		self.conn.executescript(script)
		self.conn.commit()
		
	def set_id_numbers(self):
		"Set the next ID numbers for the wf, job, task, and host entities"
		self.next_wf_id = self.next_id("wf_id", "workflow")
		self.next_job_id = self.next_id("job_id", "job")
		self.next_task_id = self.next_id("task_id", "task")
		self.next_host_id = self.next_id("host_id", "host")
		
	def next_id(self, idfield, table):
		"Retrieve max(table.idfield)+1 from the database"
		cur = self.conn.cursor()
		cur.execute("SELECT coalesce(max(%s),0)+1 FROM %s" % (idfield, table))
		next_id = cur.fetchone()[0]
		cur.close()
		return next_id
		
	def populate_host_cache(self):
		"Retrieve all the hosts from the database and place them in the host cache"
		cur = self.conn.cursor()
		cur.execute("SELECT host_id, site_name, hostname, ip_address FROM host")
		for row in cur:
			key = (str(row[1]),str(row[2]),str(row[3]))
			self.hosts[key] = str(row[0])
		cur.close()

def main():
	if len(sys.argv) < 2:
		print "Usage: %s DBFILE [BPFILE]..." % os.path.basename(sys.argv[0])
		sys.exit(1)
		
	print "Loader starting with PID %d" % os.getpid()
		
	# Check dbfile
	dbfile = sys.argv[1]
	if os.path.exists(dbfile):
		db = open(dbfile)
		magic = db.read(13)
		db.close()
		if magic != "SQLite format":
			print "Invalid SQLite database: %s" % dbfile
			sys.exit(1)
		
	# Create database directory if it doesn't exist
	dbdir = os.path.abspath(os.path.dirname(dbfile))
	if not os.path.isdir(dbdir):
		os.makedirs(dbdir)
		
	# Check bpfiles
	bpfiles = sys.argv[2:]
	for bpfile in bpfiles:
		if not os.path.isfile(bpfile):
			print "BPFILE does not exist: %s" % bpfile
			sys.exit(1)
	
	# Read from stdin if no input file specified
	if len(bpfiles) == 0:
		bpfiles.append("/dev/stdin")
	
	# Load data
	loader = SQLiteLoader(dbfile)
	start = datetime.now()	
	for bpfile in bpfiles:
		print "Loading file %s" % bpfile
		fs = datetime.now()
		loader.load_bpfile(bpfile)
		fe = datetime.now()
		print "Loaded %s in %s" % (bpfile,fe-fs)
	end = datetime.now()
	print "Loaded all files in %s" % (end-start)
	
	for event in loader.skips:
		print "Skipped event %s %d times" % (event, loader.skips[event])
		
if __name__ == '__main__':
	try:
		main()
	except KeyboardInterrupt, k:
		pass