#!/usr/bin/python
import sys
import os
import subprocess
from lxml import etree as ET
tree = ET.parse("renouveau.xml")
root = tree.getroot();

try:
	os.mkdir("renouveau")
except:
	pass

children = {}

for obj in root:
	if obj.tag != "object":
		continue
	if "name" in obj.attrib:
		if obj.attrib["name"] == "NV34TCL":
			obj.attrib["name"] = "NV30TCL"
			obj.attrib["id"] = "0x0397"
		elif obj.attrib["name"] == "NV30TCL":
			obj.attrib["name"] = "NV34TCL"
			obj.attrib["id"] = "0x0697"
			obj.attrib["parent"] = "0x0497"
		elif obj.attrib["name"] == "NV35TCL":
			obj.attrib["parent"] = "0x0397"

	if "id" in obj.attrib and "parent" in obj.attrib:
		children.setdefault(int(obj.attrib["parent"], 0), []).append(obj)

def get_all_children(obj):
	objid = int(obj.attrib.get("id"), 0)
	if objid not in children:
		return []
	cs = children[objid]
	cs += sum([get_all_children(c) for c in cs if "id" in c.attrib], [])
	cd = {}
	for c in cs:
		cd[c.attrib["name"]] = c
	cn = cd.keys()
	cn.sort()
	return [cd[n] for n in cn]

for filename, tclnames in [("nv10-20tcl", ("NV10TCL", "NV20TCL")), ("nv30-40tcl", ("NV30TCL", "NV40TCL")), ("nv50tcl", ("NV50TCL",))]:
	tcls = [None for i in tclnames]
	for obj in root:
		if obj.tag != "object":
			continue
		for i in xrange(len(tclnames)):
			if obj.get("name") == tclnames[i]:
				if obj.get("parent"):
					print "ERROR:", tclnames[i], "has a parent"
					sys.exit(1)
				tcls[i] = obj

	subtcls = [[tcl] + get_all_children(tcl) for tcl in tcls]
	alltcls = sum(subtcls, [])
	inftcls = sum([get_all_children(tcl) for tcl in tcls], [])

	if len(alltcls) > 1:
		allvariants = alltcls[0].attrib["name"] + "-" + alltcls[-1].attrib["name"]
	else:
		allvariants = alltcls[0].attrib["name"]

	variants = [st[0].attrib["name"] + "-" + st[-1].attrib["name"] for st in subtcls]
	infvariants = []
	for tcl in inftcls:
		sit = [tcl] + get_all_children(tcl)
		if len(sit) > 1:
			infvariants.append(sit[0].attrib["name"] + "-" + sit[-1].attrib["name"])
		else:
			infvariants.append(sit[0].attrib["name"])

	def get_regs(tcl):
		regs = {}
		for i in tcl.iterchildren(tag=ET.Element):
			assert i.tag == "reg32"
			if "offset" in i.attrib:
				regs[int(i.attrib["offset"], 0)] = i
		return regs

	def fix_field_name(fname):
		return fname.replace("MAGNIFY", "MAG").replace("MINIFY", "MIN")

	def get_fields(reg32):
		fields = {}
		for i in reg32.iterchildren(tag=ET.Element):
			assert i.tag == "bitfield"
			fields[int(i.get("low"))] = i
		return fields

	dicts = [get_regs(tcl) for tcl in tcls]
	infdicts = [get_regs(tcl) for tcl in inftcls]

	common = ET.Element("object")
	common.text = ""

	for tcl in tcls:
		for reg32 in tcl:
			if reg32.tail:
				reg32.tail = "\n    "

	all_offsets = [[int(r.attrib["offset"], 0) for r in tcl.iterchildren(tag=ET.Element) if "offset" in r.attrib] for tcl in alltcls]
	all_offsets = list(set(sum(all_offsets, [])))
	all_offsets.sort()

	def rnn_fix(r):
		if "enum_name" in r.attrib:
			r.attrib["type"] = r.attrib["enum_name"]
			del r.attrib["enum_name"]
		if "low" in r.attrib and "high" in r.attrib and r.attrib["low"] == r.attrib["high"] and r.attrib.get("type") == "boolean":
			del r.attrib["type"]
		if "size" in r.attrib:
			r.attrib["length"] = r.attrib["size"]
			del r.attrib["size"]
		if r.attrib.get("type") in ("enum", "bitfield", "hexa"):
			del r.attrib["type"]

	for offset in all_offsets:
		if offset == 0x50: # REF_CNT
			continue

		regs = [d.get(offset) for d in dicts]
		if any([r is not None for r in regs]):
			candnames = [r.get("name") for r in regs if r is not None and "name" in r.attrib]

			if not candnames:
				continue

			# if name == "VIEWPORT_CLIP_HORIZ" or name == "VIEWPORT_CLIP_VERT":
			# these have bitfields defined on nv30 but not nv40
			if tclnames[0] == "NV34TCL" and offset in (0x2c0, 0x2c4):
				regs[1] = ET.fromstring(ET.tostring(regs[0]))

			for i in xrange(len(regs)):
				if regs[i] is not None:
					if regs[i] in tcls[i]:
						tcls[i].remove(regs[i])
					common.append(regs[i])
					rnn_fix(regs[i])

			unify = False
			while len(regs) > 1:
				if not all([r is not None for r in regs]):
					break

				# these must not be unified (same offset, different meanings)
				if tclnames[0] == "NV34TCL" and offset in (0x370, 0x1840, 0x22c, 0x400):
					break

				diff = 0

				deep = False
				if regs[0].attrib.get("type") == regs[1].attrib.get("type") and regs[0].attrib.get("stride") == regs[1].attrib.get("stride"):
					deep = True

				if not deep:
					s = None
					for i in xrange(len(regs)):
						r = ET.fromstring(ET.tostring(regs[i]))
						r.attrib["name"] = ""
						if s == None:
							s = ET.tostring(r).strip()
						elif ET.tostring(r).strip() != s:
							diff += 1
					if diff:
						break
				else:
					field_dicts = [get_fields(reg) for reg in regs]
					field_list = list(set(sum([[int(f.attrib.get("low", -1)) for f in r] for r in regs], [])))
					field_list.sort()
					creg = regs[0]
					creg.text = "\n\t\t"
					creg.attrib["offset"] = str(offset)
					if "size" in regs[0].attrib or "size" in regs[1].attrib:
						size0 = int(regs[0].attrib["size"]) if "size" in regs[0].attrib else 0
						size1 = int(regs[1].attrib["size"]) if "size" in regs[1].attrib else 0
						creg.attrib["size"] = str(size1 if size1 > size0 else size0)
					if "stride" in regs[0].attrib or "stride" in regs[1].attrib:
						assert regs[0].attrib["stride"] == regs[1].attrib["stride"]
						creg.attrib["stride"] = regs[0].attrib["stride"]
					creg.tail = "\n    "
					prev = None
					cfields = []
					#print >> sys.stderr, "NAME " + name
					for foffset in field_list:
						fields = [f.get(foffset) for f in field_dicts]

						for i in xrange(len(fields)):
							if fields[i] is not None:
								regs[i].remove(fields[i])
								creg.append(fields[i])
								fields[i].tail ="\n\t\t"

						funify = False
						if all([f is not None for f in fields]):
							s0 = fix_field_name(ET.tostring(fields[0]).strip())
							s1 = fix_field_name(ET.tostring(fields[1]).strip())
#							if "nv04_tx_wrap" in s0:
#								s0 = s0.replace("nv04_tx_wrap", "nv40_texture_wrap")
#								s0 = s0.replace("high=\"7\"", "high=\"3\"")
							if s0 == s1:
								funify = True
						for i in xrange(len(fields)):
							if fields[i] is not None:
								fields[i].attrib["variants"] = variants[i]

						if funify:
							fname = None
							fcandnames = [f.attrib["name"] for f in fields if f is not None and "name" in f.attrib]
							fname = None
							if fcandnames:
								fname = fcandnames[0]
							#print >> sys.stderr, "FNAME " + fname
							fields[0].attrib["name"] = fix_field_name(fname)
							del fields[0].attrib["variants"]
							creg.remove(fields[1])

					unify = True
					break

				unify = True
				break

			for i in xrange(len(regs)):
				if regs[i] is not None:
					if variants[i] != allvariants:
						regs[i].attrib["variants"] = variants[i]
					regs[i].tail = "\n    "
					rnn_fix(regs[i])
					for f in regs[i].iterchildren(tag=ET.Element):
						rnn_fix(f)

			if unify:
				name = None
				if candnames:
					name = candnames[0]
				if name:
					regs[0].attrib["name"] = name
				del regs[0].attrib["variants"]
				if regs[1] is not None:
					common.remove(regs[1])
		# no unification for inftcls, it's not necessary and N-way unification is much more complicated
		infregs = [d.get(offset) for d in infdicts]

		for i in xrange(len(regs)):
			if infregs[i] is not None:
				if infregs[i] in inftcls[i]:
					inftcls[i].remove(infregs[i])
				infregs[i].attrib["variants"] = infvariants[i]
				common.append(infregs[i])
				rnn_fix(infregs[i])

	common.tag = "stripe"
	common.attrib["prefix"] = "obj-class"
	common.attrib["variants"] = allvariants

	colon = "39584ac56e6d0976692778b80915f94b61767814a8704982"
	database = ET.Element("database")
	database.attrib["xmlns"] = "http://nouveau.freedesktop.org/"
	database.attrib["xmlns" + colon + "xsi"] = "http://www.w3.org/2001/XMLSchema-instance"
	database.attrib["xsi" + colon + "schemaLocation"] = "http://nouveau.freedesktop.org/ rules-ng.xsd"

	objenum = ET.Element("enum")
	objenum.attrib["bare"] = "yes"
	objenum.attrib["name"] = "obj-class"

	for tcl in alltcls:
		objvalue = ET.Element("value")
		objvalue.attrib["value"] = tcl.attrib["id"]
		objvalue.attrib["name"] = tcl.attrib["name"]
		objenum.append(objvalue)

	database.append(objenum)

	domain = ET.Element("domain")
	domain.attrib["name"] = "NV01_SUBCHAN"
	domain.attrib["bare"] = "yes"
	domain.attrib["size"] = "0x2000"
	domain.append(common)
	database.append(domain)

	outfile = open("renouveau/" + filename, "w")
	xml2text = subprocess.Popen("./xml2text", stdin = subprocess.PIPE, stdout = outfile)
	xml2text.communicate(ET.tostring(database).replace(colon, ":"))
	xml2text.wait()
	outfile.close()
