#!/usr/bin/python
# -*- coding: utf-8 -*-

import sys
import os
from lxml import etree

class AnchorWordHit:
	def __init__(self, index, elementNumber, pos, word):
		self.index = index
		self.elementNumber = elementNumber
		self.pos = pos
		self.word = word

	def getIndex(self):
		return self.index

	def getWord(self):
		return self.word

	def getPos(self):
		return self.pos

	def getElementNumber(self):
		return self.elementNumber

	def setPos(self, pos):
		self.pos = pos

class AnchorWordListEntry:
	def __init__(self, anchorWordListEntryText):
		self.entry = []
		if len(anchorWordListEntryText) > 0:
			data = anchorWordListEntryText.strip().split('/')
			
			if len(data) == 2:
				for entry in data:
					tmp = []
					for word in entry.split(','):
						tmp.append(word.strip())
					self.entry.append(tmp)
			else:
				print "Faulty anchor entry", anchorWordListEntryText
				sys.exit(1)

	def getEntry(self):
		return self.entry

class AnchorWordList:
	def __init__(self, filename = ''):
		self.anchorWordEntries = []
		if filename != '':
			self.loadFromFile(filename)

	def loadFromFile(self, anchorWordListFile):
		if os.path.exists(anchorWordListFile):
			anchorfile = open(anchorWordListFile)
			buffer = anchorfile.readlines()
			for line in buffer:
				self.anchorWordEntries.append(AnchorWordListEntry(line.strip()))

	def getEntry(self, id):
		return self.anchorWordEntries[id]
		
	def getSize(self):
		return len(self.anchorWordEntries)

	def getAnchorWordHits(self, words, index, elementNumber):
		hits = []
		for i in range(0, len(self.anchorWordEntries)):
			synonyms = self.anchorWordEntries[i].getEntry()[index]
			for w in range(0, len(words)):
				for synonym in synonyms:
					if anchorMatch(synonym, words[w]):
						hits.append(AnchorWordHit(i, elementNumber, w, words[w]))
						break

	def getProperNames(self, words):
		ret = []
		for word in words:
			if word[0].isupper():
				ret.append(word)
		
		return ret

	def getScoringCharacters(self, text):
		scoringCharacters = '?%!'
		
		ret = ''
		for c in text:
			if c in scoringCharacters:
				ret = ret + c
		
		return ret

class AlignmentElement:
	def __init__(self, text, id, elementNumber):
		self.text = text
		self.id = id
		self.elementNumber = elementNumber
		self.alignmentNumber = -1

	def getText(self):
		return self.text

	def getId(self):
		return self.id

	def getElementNumber(self):
		return self.elementNumber

	def getAlignmentNumber(self):
		return self.alignmentNumber

	def setAlignmentNumber(self, number):
		self.alignmentNumber = number

	def getLength(self):
		return len(self.text)

class SentenceHandler:
	def __init__(self, filename):
		self.filename = filename
		self.elements = []
		self.setElements()

	def setElements(self):
		tree = etree.parse(self.filename)
		root = tree.getroot()
		sentences = root.findall(".//s")
		for i in range(0, len(sentences)):
			self.elements.append(AlignmentElement(sentences[i].text, sentences[i].attrib['id'], i))

	def getLength(self):
		return len(self.elements)

	def getSentence(self, index):
		return self.elements[index].getText()

	def getElementNumber(self, index):
		return self.elements[index].getElementNumber()

	def getSentences(self):
		return self.elements

	def getFilename(self):
		return self.filename

class AlignerHandler:
	def __init__(self, anchorfile, files):
		self.anchorList = AnchorWordList(anchorfile)
		self.compare = Compare()
		self.unaligned = []
		self.lengthRatio = 1.1
		for filename in files:
			self.unaligned.append(SentenceHandler(filename))

	def getLengthRatio(self):
		return self.lengthRatio

	def align(self):
		doneAligning = False
		while not doneAligning:
			self.compare.resetBestPathScores()
			position = []
			for sentenceHandler in self.unaligned:
				if sentenceHandler.getLength() > 0:
					position.append(sentenceHandler.getElementNumber(0))
				else:
					position.append(sentenceHandler.getLength() - 1)
			
			queueList = QueueList(position)
			nextQueueList = QueueList()
			stepCount = 0
			doneLengthening = False
			
			while not doneLengthening:
				for queueEntry in queueList.getEntries():
					if not queueEntry.getRemoved():
						if not queueEntry.getEnd():
							stepList = self.compare.getStepList()
							for step in stepList:
								newQueueEntry = queueEntry.makeLongerPath(step)
								if newQueueEntry.getPath() != None:
									pos = newQueueEntry.getPath.getPosition()
									nextQueueList.remove(pos)
									queueList.remove(pos)
									nextQueueList.add(newQueueEntry)
			
				nextQueueList.removeForReal()
				if nextQueueList.empty():
					doneLengthening = True
				else:
					queueList = nextQueueList
					stepCount = stepCount + 1
					doneLengthening = (stepCount >= getMaxPathLength())
			
			if queueList.getEntry().getSize() == 0 or queueList.getEntry().getSize() == 1 and queueList.getEntry().get(0).getPath().getSteps().getSize() == 0:
				doneAligning = True
			else:
				print "kvukk"

		return 1

class BestPathScore:
	def __init__(self, score = -1.0):
		self.score = score
	
	def getScore(self):
		return self.score

class AnchorWordHit:
	def __init__(self, index, elementNumber, pos, word):
		self.index = index
		self.elementNumber = elementNumber
		self.pos = pos
		self.word = word

	def getIndex(self):
		return self.index

	def getWord(self):
		return self.word

	def getPos(self):
		return self.pos

	def getElementNumber(self):
		return self.elementNumber

	def setPos(self, pos):
		self.pos = pos

class ElementInfo:
	def __init__(self, anchorWordList, text, index , elementNumber ):
		self.words = text # TODO: denne må gjøres ordentlig
		self.anchorWordHits = anchorWordList.getAnchorWordHits(self.words, index, elementNumber)
		properName = anchorWordList.getProperNames(self.words)
		self.scoringCharacters = anchorWordList.getScoringCharacters(text)

	def getLength(self):
		return len(self.words)

	def getWords(self):
		return self.words

	def getScoringCharacters(self):
		return self.scoringCharacters

class ElementInfoToBeCompared:
	def __init(self):
		self.score = -1.0
		self.anchorWordHits = [[], []]
		self.info = [[], []]

	def add(self, elementInfo, index):
		self.info[index].append(elementInfo)

	def empty(self):
		for i in self.info:
			if len(i) == 0:
				return True
		return False

	def getScore(self):
		if self.score == -1.0:
			self.toList()
			
		return self.score

	def toList(self):
		self.score = 0.0
		
		if not self.empty():
			length = [0, 0]
			elementCount = [0, 0]
			for t in range(0, 2):
				for i in info[t]:
					length[t] = i.getLength()
				elementCount[t] = len(self.info[t])
			if badLengthCorrelation( length[0], length[1], elementCount[0], elementCount[1], 1.1):
				self.score = -99999.0
			else:
				anchorWordClusters = Clusters()
				for t in range(0, 2):
					for i in info:
						for hit in i.anchorWordHits:
							hits[t].add(hit)
				current = [0, 0]
				done = False
				while not done:
					smallest = 1000
					smallestCount = 0
					for t in range(0, 2):
						if current[t] < len(hits[t]):
							hit = hits[t].get(current[t])
							if hit.getIndex() < smallest:
								smallest = hit.getIndex()
								smallestCount = 1
							elif hit.getIndex == smallest:
								smallestCount = smallestCount + 1
					presentInAllTexts = (smallestCount == 2)
					if smallest == 1000:
						done = True
					else:
						for t in range(0, 2):
							count = 0
							if current[t] < len(hits[t]):
								done2 = False
								while not done2:
									c = current[t]
									hit = hits[t].get(c)
									index = hit.getIndex()
									if index == smallest:
										elementNumber = hit.getElementNumber()
										pos = hit.getPos()
										word = hit.getWord()
										length = countWords(word)
										matchType = index
										if len > 1:
											weight = getAnchorWordPhraseMatchWeight()
										else:
											weight = getAnchorWordMatchWeight()
										if presentInAllTexts:
											anchorWordClusters.addRef(Ref(matchType, weight, t, elementNumber, pos, length, word))
										count = count + 1
									else:
										done2 = True
									if c + 1 >= len(hist[t]):
										done2 = True
								current[t] = current[t] + count
								
				anchorWordScore = self.getScore(getLargeClusterScorePercentage())
				properNameClusters = Clusters()
				diceClusters = Clusters()
				numberClusters = Clusters()
				
				for t in range(0, 2):
					for tt in range(0, 2):
						for info1 in self.info[t]:
							words1 = info1.getWords()
							for x in range(0, len(words1)):
								word1 = words1[x]
								if x < len(words) - 1:
									nextWord1 = words1[x + 1]
								else:
									nextWord1 = ''
								x = x + 1
								for info2 in self.info[tt]:
									words2 = info2.getWords()
									for y in range(0, len(words2)):
										word2 = words2[y]
										if y < len(words2):
											nextWord2 = words2[y + 1]
										else:
											nextWord2 = ''
										if word1[0].isupper() and word2[0].isupper():
											matchType = -1
											weight = getProperNameMatchWeight()
											properNameClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y, 1, 1, word1, word2)
										if len(word1) >= getDiceMinWordLength() and len(word2) >= getDiceMinWordLength():
											if diceMatch(word1, word1, getDiceMinCountingScore()):
												matchType = -2
												weight = getDiceMatchWeight()
												diceClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y, 1, 1, word1, word2)
										if nextWord1 != '':
											showPhrase1 = word1 + ' ' + nextWord1
											if len(word1) >= getDiceMinWordLength() and len(nextWord1) >= getDiceMinWordLength() and len(word2) >= getDiceMinWordLength():
												if diceMatch(word1, nextWord1, word2, '2-1', getDiceMinWordLength()):
													matchType = -2
													weight = getDicePhraseMatchWeight()
													diceClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y, 2, 1, showPhrase1, word2)
										if nextWord2 != '':
											showPhrase2 = word2 + ' ' + nextWord2
											if len(word1) >= getDiceMinWordLength() and len(word2) >= getDiceMinWordLength() and len(nextWord2) >= getDiceMinWordLength():
												if diceMatch(word1, word2, nextWord2, '1-2', getDiceMinCountingScore()):
													matchType = -2
													weight = getDicePhraseMatchWeight()
													diceClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y, 1, 2, word1, showPhrase2)
										try:
											num1 = float(word1)
											num2 = float(word2)
											if num1 == num2:
												matchType = -3
												getNumberMatchWeight()
												numberClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y , 1, 1, word1, word2)
										except ValueError:
											pass
				
				properNameScore = properNameClusters.getScore(getLargeClusterScorePercentage())
				diceScore = diceClusters.getScore(getLargeClusterScorePercentage())
				numberScore = numberClusters.getScore(getLargeClusterScorePercentage())
				
				commonClusters = Clusters()
				commonClusters.add(anchorWordClusters)
				commonClusters.add(properNameClusters)
				commonClusters.add(diceClusters)
				commonClusters.add(numberClusters)
				commonScore = commonClusters.getScore(getLargeClusterScorePercentage())
				self.score = self.score + commonScore
				
				scoringCharactersClusters = Clusters()
				for t in range(0, 2):
					for tt in range(0, 2):
						for info1 in self.info[t]:
							scoringChars1 = info1.getScoringCharacters()
							for x in range(0, len(scoringChars1)):
								char1 = scoringChars1[x:x+1]
								for info2 in self.info[tt]:
									scoringChars2 = info2.getScoringCharacters()
									for y in range(0, len(scoringChars2)):
										char2 = scoringChars1[y:y+1]
										if char1 == char2:
											matchType = -11
											weight = getScoringCharacterMatchWeight()
											scoringCharacterClusters.add(matchType, weight, t, tt, info1.getElementNumber(), info2.getElementNumber(), x, y, 1, 1, char1, char2)
											
				scoringCharacterScore = scoringCharacterClusters.getScore(getLargeClusterScorePercentage())
				self.score = self.score + scoringCharacterScore
				
				scoreBefore = self.score
				self.score = adjustForLengthCorrelation(self.score, length[0], length[1], elementCount[0], elementCount[1], getLengthRatio)

class Clusters:
	def __init__(self):
		self.clusters = []

	
class Cluster:
	def __init__(self):
		self.refs = []

	def getRefs(self):
		return self.refs

	def addRef(self, otherRef):
		for ref in self.refs:
			if ref.exactlyMatches(otherRef):
				return
		refs.append(otherRef)

	def matchesRef(self, otherRef):
		for ref in self.refs:
			if ref.matches(otherRef):
				return True
		return False

	def matchesCluster(self, otherCluster):
		otherRefs = otherCluster.getRefs()
		for otherRef in otherRefs:
			if self.matchesRef(otherRef):
				return True
		return False

	def addCluster(self, otherCluster):
		otherRefs = otherCluster.getRefs()
		for otherRef in otherRefs:
			self.addRef(otherRef)

	def getScore(self, largeClusterScorePercentage):
		high = 0
		low = 1000000000
		clusterWeight = 0.0
		for t in range(0,2):
			count = 0
			positions = set()
			for ref in self.refs:
				if ref.isInText(t):
					positions.add(ref.getPosition())
					clusterWeight = max(clusterWeight, ref.getWeight)
				count = len(positions)
				low = min(low, count)
				high = max(high, count)
		
		return clusterWeight * ((low - 1) * largeClusterScorePercentage / 100.0)

class Ref:
	def __init__(self, matchType, weight, index, elementNumber, position, length, word):
		self.matchType = matchType
		self.weight = weight
		self.index = index
		self.elementNumber = elementNumber
		self.position = position
		self.length = length
		self.word = word

	def matches(self, otherRef):
		if self.index == otherRef.index and self.elementNumber == otherRef.elementNumber and overlaps(self.position, self.length, otherRef.position, otherRef.length):
			return True
		else:
			if otherRef.matchType >= 0:
				if self.matchType == otherRef.matchType:
					return True
		
		return False

	def exactlyMatches(self, otherRef):
		return self.matchType == otherRef.matchType 
		#and self.index == otherRef.index and self.elementNumber == otherRef.elementNumber and self.position == otherRef.position and self.length = otherRef.length

	def isInText(index):
		return self.index == index

	def getPosition(self):
		return self.position

	def getWeight(self):
		return self.weight

class CompareMatrix:
	def __init__(self):
		self.purge()

	def purge(self):
		self.cells = {}
		self.bestPathScores = {}

	def getScore(self, position):
		result = 0.0
		outside = False
		for p in position:
			if p < 0:
				outside = True
				break
				
		if outside:
			result = -1.0
		else:
			bestPathScoreKey = ""
			t = 0
			while t < len(position):
				if t > 0:
					bestPathScoreKey = bestPathScoreKey + ','
				bestPathScoreKey = bestPathScoreKey + str(position[t])
				t = t + 1
			if self.bestPathScores[bestPathScoreKey] == None:
				print "Program error? Cell doesn't exist. Position = " + position[0] + "," + position[1]
				result = -1.0
			else:
				result = self.bestPathScores[bestPathScoreKey]
		
		return result.getScore()

	def setScore(self, position, score):
		bestPathScoreKey = ''

		t = 0
		while t < len(position):
			if t > 0:
				bestPathScoreKey = bestPathScoreKey + ','
			bestPathScoreKey = bestPathScoreKey + str(position[t])
			t = t + 1
		
		print bestPathScoreKey
		self.bestPathScores[bestPathScoreKey] = BestPathScore(score)

	def resetBestPathScores(self):
		for key in self.bestPathScores.keys():
			self.bestPathScores[key] = self.BestPathScore(-1.0)

class PathStep:
	def __init__(self, increment):
		self.increment = increment
	
	def is11(self):
		for i in self.increment:
			if i != 1:
				return False
		return True

class Path:
	def __init__(self, initialPosition):
		self.steps = []
		self.position = initialPosition

	def getSteps(self):
		return self.steps

	def getPosition(self):
		return self.position

	def setSteps(self, steps):
		self.steps = steps

	def extend(self, step):
		self.steps.append(step)
		for t in range(0, 2):
			self.position[t] = self.position[t] + step.increment[t]

	def getLengthInSentences(self):
		count = 0
		
		for step in self.steps:
			for t in range(0, 2):
				count = count + step.increment[t]
		
		return count

class QueueEntry:
	def __init__(self, position, score):
		self.path = Path(position)
		self.score = score
		self.removed = False
		self.ended = False

	def getPath(self):
		return self.path

	def setPath(self, path):
		self.path = path

	def getRemoved(self):
		return self.removed

	def setRemoved(self):
		self.removed = True

	def getEnd(self):
		return self.ended

	def setEnd(self):
		self.ended = True

	def makeLongerPath(self, newStep):
		retQueueEntry = QueueEntry(self.getPath().getPosition(), self.score)
		newScore = self.tryStep(newStep)
		retQueueEntry.score = newScore
		retQueueEntry.path.extend(newStep)
		
		if retQueueEntry.score > compare.getScore(retQueueEntry.getPath().getPosition(), retQueueEntry.getScore()):
			return retQueueEntry
		else:
			retQueueEntry.path = None
			return retQueueEntry

	def tryStep(self, newStep):
		stepScore = 0.0
		stepScore = self.getStepScore(self.path.position, newStep)
		
		return self.score + stepScore

	def getStepScore(self, compare, position, newStep):
		compareCells = compare.getCellValues(position, newStep)
		return compareCells.elementInfoToBeCompared.getScore()

class QueueList:
	def __init__(self, position = []):
		self.entries = []
		if len(position) > 0:
			self.entries.append(QueueEntry(position, 0))

	def getEntries(self):
		return self.entries

	def empty(self):
		return len(self.entries) == 0
	
	def append(self, queueEntry):
		self.entries.append(queueEntry)

	def contains(self, queueEntry):
		for entry in self.entries:
			if entry.getPath().equals(queueEntry.getPath()):
				return True
		return False

	#def remove(self, position):
		#for entry in self.entries:
			#hit = False
			
class CompareCells:
	def __init__(self, position, step):
		self.elementInfoToBeCompared()
		
		textEndCount = 0
		for t in range(0,2):
			x = position[t] + 1
			while x <= position[t] + step.increment[t]:
				info = compare.elementsInfo[t].getElementInfo(x, t)
				elementInfoToBeCompared.add(info, t)
	
	def getScore(self):
		return self.elementInfoToBeCompared.getScore()

class Compare:
	def __init__(self):
		self.elemementsInfo = []
		self.matrix = CompareMatrix()
		self.stepList = []
		self.createStepList()
	
	def createStepList(self):
		myRange = 2 - 0 + 1
		limit = 1
		
		for j in range(0,2):
			limit = limit * myRange
			
		increment = [0, 0]
		for i in range(0,limit):
			increment = [0, 0]
			radixy = self.str_base(limit + i, myRange)
			combString = radixy[1:3]
			minimum = 2 + 1
			maximum = 0 - 1
			total = 0
			for t in range(0,2):
				increment[t] = 0 + int(combString[t:t+1], myRange)
				total = total + increment[t]
				minimum = min(minimum, increment[t])
				maximum = max(maximum, increment[t])
				
			if maximum > 0 and maximum - minimum <= 1 and total <= 3:
				self.stepList.append(PathStep(increment))

	def getStepList(self):
		return self.stepList

	def str_base(self, num, base, numerals = '0123456789abcdefghijklmnopqrstuvwxyz'):
		if base < 2 or base > len(numerals):
			raise ValueError("str_base: base must be between 2 and %i" % len(numerals))

		if num == 0:
			return '0'

		if num < 0:
			sign = '-'
			num = -num
		else:
			sign = ''

		result = ''
		while num:
			result = numerals[num % (base)] + result
			num //= base

		return sign + result

	def resetBestPathScores(self):
		self.matrix.resetBestPathScores()

def diceMatch(word1, word2):
	"""
	dice coefficient = bigram overlap * 2 / bigrams in a + bigrams in b
	"""
	a_bigrams = set(word1[i:i+2] for i in range(len(word1) - 1))
	b_bigrams = set(word2[i:i+2] for i in range(len(word2) - 1))

	overlap = len(a_bigrams & b_bigrams)

	total = len(a_bigrams) + len(b_bigrams)
	diceScore  = overlap * 2.0 / total

	return diceScore

def diceMatches(word1, word2, word3 = "", wordType = ""):
	diceMatches = []

	if wordType == "" and word3 == "":
		diceMatches.append(diceMatch(word1, word2))
	else:
		if wordType == '2-1':
			phraseWord1 = word1
			phraseWord2 = word2
			word = word3
		if wordType == '1-2':
			word = word1
			phraseWord1 = word2
			phraseWord2 = word3
		
		diceMatches.append(diceMatch(word, phraseWord1))
		diceMatches.append(diceMatch(word, phraseWord2))
		
	return diceMatches

def anchorMatch(anchorWord, word):
	word = word.lower()
	anchorWord = anchorWord.lower()

	if anchorWord[-1:] == '*':
		anchorWord = anchorWord[:-1]

	return word.find(anchorWord) == 0

def badLengthCorrelation(length1, length2, elementCount1, elementCount2, ratio):
	killLimit = 0.5
	c = (2.0 * abs(ratio*length1 - length2)) / (ratio*length1 + length2)
	return (((elementCount1 > 0) and (elementCount2 > 0) and (elementCount1 != elementCount2)) and (c > killLimit))

def overlaps(position, length, otherPosition, otherLength):
	return position <= otherPosition + otherLength - 1 and otherPosition <= position + length -1

def adjustForLengthCorrelation(score, length1, length2, elementCount1, elementCount2, ratio):
	newScore = 0.0
	lowerLimit = 0.4
	upperLimit = 1.0
	killLimit = 0.5

	c = (2.0 * abs(ratio*length1 - length2)) / (ratio*length1 + length2)
	print "this is c in aflc", c
	if ((elementCount1 > 0) and (elementCount2 > 0) and (elementCount1 != elementCount2)):
		if c < lowerLimit*0.5:
			newScore = score + 2
		elif c < lowerLimit:
			newScore = score + 1
		elif c > killLimit:
			newScore = -99999.0
		else:
			newScore = score
	else:
		if c < lowerLimit + 2:
			newScore = score + 2
		elif c < lowerlimit:
			newScore = score + 1
		elif c > upperLimit:
			newScore = score / 3
		else:
			newScore = score
	
	return newScore

def main():
	s = SentenceHandler('1999_1.doc.sent.xml')
	a = AlignerHandler('anchor-nor-sme.txt', ['1999_1.doc.sent.xml', '1999_1s.doc.sent.xml'])

if __name__ == '__main__':
	main()