123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
from __future__ import division
import sys
import gensim
import smart_open
import numpy
import random
from Tree import Tree
import pickle
import heapq
import math
import keyboard
DEBUG = True
class TextGenerator(object):
def __init__(self, window=4, verify=2, ofTranslates=2, ofMatches=1, ofOriginalsNearNextWord=0, CommasPerSentence=2, WordsPerCommaOrPeriod=(4,7),
commonWordCoefficient=(1, 100, 0.0011), contextualCoefficient=(1, 0.000001, 0.0002)):
self.Tree = None
self.w2vModel = None
self.window = window
self.verify = verify
self.ofMatches = ofMatches
self.ofTranslates = ofTranslates
self.ofOriginalsNearNextWord = ofOriginalsNearNextWord
self.CommasPerSentence = CommasPerSentence
self.WordsPerCommaOrPeriod = WordsPerCommaOrPeriod
self.commonWordCoefficient = commonWordCoefficient
self.contextualCoefficient = contextualCoefficient
self.TreeSaveName = "Tree"
self.treeScanwindow = max(window + 1, verify)
self.translatesBucket = {} # translatesbucket to contain all translates of words up to now.
self.pickfromTranslatesBucket = 1 # how many words we should pick from a bucket for softmax.
self.addToTranslatesBucket = 30 # how many gloves of a word we should insert to translatesBucket.
self.nofFrequencyCheckCandidates = 8 # how many last candidates we should check for frequency.
self.lenFrequencyCheckWindow = 2 # length of frequency check window.
self.fakeCommaChoiceBucket = []
self.seed_word_list = []
self.wordTransScore = {}
def load_Model(self, model_file):
print("Loading w2v model from \"{}\"...".format(model_file))
self.w2vModel = gensim.models.KeyedVectors.load_word2vec_format(model_file, binary=False) # GloVe Model
print("Loading w2v model done")
def save_Tree(self):
print("Saving Tree to \"{}\"".format(self.TreeSaveName))
with open(self.TreeSaveName, 'wb') as outfile:
pickle.dump(self.Tree, outfile)
print("Saving Tree done")
def load_Tree(self):
print("Loading Tree from \"{}\"...".format(self.TreeSaveName))
with open(self.TreeSaveName, 'rb') as infile:
self.Tree = pickle.load(infile)
print("Loading Tree done")
def init_Tree(self, entailData: str):
"""
Function to construct tree from entailData
"""
self.Tree = Tree(self.treeScanwindow)
words = []
with smart_open.smart_open(entailData, 'r') as entail:
for line in entail:
words.extend(line.strip().split())
self.Tree.build(words)
self.save_Tree()
# if DEBUG:
# print(self.Tree.records)
def load_CommonWord(self, wordFile: str):
self.wordTransScore = {}
cnt = 1
with smart_open.smart_open(wordFile, encoding='ansi') as f:
for line in f:
splited_line = line.strip().split()
word = splited_line[0].lower()
if len(splited_line) == 1 and word[0] != '#':
if not word in self.wordTransScore:
self.wordTransScore[word] = cnt
cnt = cnt + 1
"""
we are using following function for exponential curve according to word list.
y = k * exp(-a * x) + b
this curve will pass 2 points (1, M1) & (L, M2)
a will be the incline prameter.
"""
M1 = self.commonWordCoefficient[1]
M2 = self.commonWordCoefficient[0]
a = self.commonWordCoefficient[2]
L = len(self.wordTransScore)
k = (M1 - M2) / (math.exp(-a) - math.exp(-a*L))
b = (M2 * math.exp(-a) - M1 * math.exp(-a*L)) / (math.exp(-a) - math.exp(-a*L))
for word in self.wordTransScore:
self.wordTransScore[word] = M1 + M2 -(k * math.exp(-a*(self.wordTransScore[word])) + b)
# if DEBUG:
print(word, self.wordTransScore[word])
def calc_Transscore(self, word:str):
if word in self.wordTransScore:
return self.wordTransScore[word]
else:
return self.commonWordCoefficient[1]
def makeGlove(self, word, ofTranslates):
if word[-1] == "," or word[-1] == ".":
suffix = word[-1]
word = word[:-1]
else:
suffix = ''
word_list = [word + suffix]
try:
for tup in self.w2vModel.similar_by_word(word, ofTranslates):
word_list.append(tup[0] + suffix)
except:
pass
return word_list
def insertToTranslatesBucket(self, insert_word, seed_word):
insert_score = self.calc_Transscore(seed_word)
if insert_word in self.translatesBucket:
self.translatesBucket[insert_word] += insert_score
else:
self.translatesBucket[insert_word] = insert_score
def pickNextword(self, hypolist): #
predictedHypoList = []
IndexesofPredictedHypo = {}
for idx, hp in enumerate(hypolist):
word, cnt = hp[0][-1], hp[1]
if not (word in predictedHypoList):
predictedHypoList.append(word)
IndexesofPredictedHypo[word] = [idx]
else:
IndexesofPredictedHypo[word].append(idx)
def dist(hp, orig):
L = len(hp) - 1
ret = 0
for i in range(0, L):
if hp[-2-i] == orig[-1-i]:
ret += 1
return ret
if DEBUG:
print("Scores of each word")
print("{:<20}: {:<10}: {:<10}, {:<10}, {:<10}, {:<10}".format("word", "score", "TransScore", "LenScore", "Contextual", "Originals"))
relatedTranslatesBucket = []
M1 = self.contextualCoefficient[1]
M2 = self.contextualCoefficient[0]
a = self.contextualCoefficient[2]
L = self.window + 1
k = (M2 - M1) / (math.exp(L*a) - math.exp(a))
b = (M1 * math.exp(L*a) - M2 * math.exp(a)) / (math.exp(L*a) - math.exp(a))
for word in predictedHypoList:
maxLen = 0
contextualFrequency = 0
originals = 0
for idx in IndexesofPredictedHypo[word]:
curLen = len(hypolist[idx][0])
maxLen = max(maxLen, curLen)
exponentialCoeff = (k * math.exp(a*curLen) + b) * hypolist[idx][1]
contextualFrequency += exponentialCoeff
originals += exponentialCoeff * dist(hypolist[idx][0], self.seed_word_list)
if word in self.translatesBucket:
transScore = self.translatesBucket[word]
else:
transScore = 0
score = transScore + maxLen + contextualFrequency + originals
if DEBUG:
print("{:<20}: {:10.4f}: {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(word, score, transScore, maxLen, contextualFrequency, originals))
relatedTranslatesBucket.append((score, word))
ntopCandidates = min(self.pickfromTranslatesBucket, len(predictedHypoList))
topCandidates = heapq.nlargest(ntopCandidates, relatedTranslatesBucket)
predictedHypoList = []
softMaxVals = []
for (score, word) in topCandidates:
softMaxVals.append(score)
predictedHypoList.append(word)
return random.choices(predictedHypoList, weights=softMaxVals)[0]
def predict(self, input_seed: list) -> str:
"""
Function to predict next word of a sentence.
"""
self.seed_word_list = []
for word in input_seed:
self.seed_word_list.append(word.lower())
verified_hypo_list = []
total_verified = 0
windowLen = min(len(input_seed), self.window)
while windowLen > 0:
ofOriginalsNearNextWord = min(windowLen, self.ofOriginalsNearNextWord)
verifyLen = min(self.verify, windowLen + 1)
## get window list from seed word list
windowlist = self.seed_word_list[-windowLen:]
## make listoflists
listoflists = []
idx = -1
for i, word in enumerate(windowlist):
if word[-1] == ',' or word[-1] == '.':
idx = i
for i, word in enumerate(windowlist):
if i == idx:
listoflists.append(self.makeGlove(word, self.ofTranslates))
else:
listoflists.append(self.makeGlove(self.stem_Word(word), self.ofTranslates))
print("Lists of %d word window =" % windowLen, listoflists)
## do vanilla search
idx = 0
for word_list in listoflists:
self.Tree.update_words(idx, word_list)
idx = idx + 1
## get all hypothesis of next word
hypo_list = self.Tree.get_words(idx)
if DEBUG:
print("hypo_list = ", hypo_list)
def is_ofOriginalsNearNextWord(a, b, L):
"""
check if last L words of a[:-1] are same as b[-L:]
"""
for i in range(L):
if a[-i-2] != b[-i-1]:
return False
return True
## get ofOriginalNearNextWord list
ofOs_hypo_list = []
for hp in hypo_list:
if is_ofOriginalsNearNextWord(hp, windowlist, ofOriginalsNearNextWord) == True:
ofOs_hypo_list.append(hp)
if DEBUG:
print("ofOs_hypo_list = ", ofOs_hypo_list)
def add_to_verified_HypoList(hp):
kLen = 0 # sum of length that contain this hp in verified hypo list.
jdx = -1 # index fo current hp in verified hypo list.
lenVerified = len(verified_hypo_list)
for idx in range(0, lenVerified):
prev = verified_hypo_list[idx][0]
L = len(hp)
isContained = True
for i in range(L):
if hp[-i-1] != prev[-i-1]:
isContained = False
break
if isContained == True:
if len(prev) == L:
jdx = idx
else:
kLen += verified_hypo_list[idx][1]
ret = 1
if jdx == -1:
if kLen == 0:
verified_hypo_list.append((hp, 1))
else:
verified_hypo_list.append((hp, 1-kLen))
ret -= kLen
else:
verified_hypo_list[jdx] = (verified_hypo_list[jdx][0], verified_hypo_list[jdx][1]+1)
return ret
## do verification
prefix = ' '.join(windowlist[-verifyLen+1:])
for hp in ofOs_hypo_list:
search_str = prefix + " " + hp[-1] + " "
search_str_comma = prefix + " " + hp[-1] + ", "
search_str_stop = prefix + " " + hp[-1] + ". "
if self.Tree.is_contain(search_str) == True or self.Tree.is_contain(search_str_comma) == True or self.Tree.is_contain(search_str_stop) == True: # verified
total_verified += add_to_verified_HypoList(hp)
if DEBUG:
print("verified_hypo_list = ", verified_hypo_list)
if total_verified >= self.ofMatches:
break
else:
windowLen = windowLen - 1
if total_verified == 0:
print("Matchs =", [])
return None
def dist(hp, orig):
L = len(hp) - 1
LO = len(orig)
nSame = 0
pattern = 0
for i in range(0, L):
if hp[-2-i] == orig[-1-i]:
nSame = nSame + 1
pattern = pattern + (1<<(LO-i))
return (nSame, pattern)
windowlist = self.seed_word_list[-self.window:]
L = len(verified_hypo_list)
ofMatches = min(total_verified, self.ofMatches)
for i in range(0, L):
for j in range(i+1, L):
if dist(verified_hypo_list[i][0], windowlist) < dist(verified_hypo_list[j][0], windowlist):
verified_hypo_list[i], verified_hypo_list[j] = verified_hypo_list[j], verified_hypo_list[i]
curMatches = 0
for i in range(0, L):
curMatches += verified_hypo_list[i][1]
if curMatches >= ofMatches:
verified_hypo_list[i] = (verified_hypo_list[i][0], verified_hypo_list[i][1]-(curMatches-ofMatches))
verified_hypo_list = verified_hypo_list[:i+1]
break
matcheStr = []
for hp in verified_hypo_list:
matcheStr.append((' '.join(hp[0]), hp[1]))
print("Matchs =", matcheStr)
return self.pickNextword(verified_hypo_list)
def calcCommaLen(self, words):
L = len(words)
for i in range(L):
if words[-i-1][-1] == '.' or words[-i-1][-1] == ',':
return i
return L
def stem_Word(self, word):
res = word.lower()
if res[-1] == '.' or res[-1] == ',':
res = res[:-1]
return res
def make_FirstWord(self, word):
res = word[0].upper() + word[1:]
return res
def isLastCandidateMostFrequent(self, words):
"""
check if the last candidates is most frequent in entailData if we add comma/stop.
"""
if DEBUG:
print("Checking if last word is most frequent for hypo -> ", ' '.join(words))
L = len(words)
last_frequency = 0
most_frequency = 0
for i in range(self.nofFrequencyCheckCandidates):
if L-self.lenFrequencyCheckWindow-i<0:
break
candy = []
for j in range(self.lenFrequencyCheckWindow):
candy.append(self.stem_Word(words[L-self.lenFrequencyCheckWindow-i+j]))
comma_cnt = self.Tree.calc_frequency(' '.join(candy)+', ')
stop_cnt = self.Tree.calc_frequency(' '.join(candy)+'. ')
total_cnt = self.Tree.calc_frequency(' '.join(candy) + ' ') + comma_cnt + stop_cnt
# next_word, next_cnt = self.Tree.calc_most_frequent_next_word(' '.join(candy)+' ')
# if next_cnt == 0:
if total_cnt == 0:
freq=0
else:
# freq = (comma_cnt + stop_cnt)/next_cnt
freq = (comma_cnt + stop_cnt)/total_cnt
if DEBUG:
# print("\"{}\": {},\t\"{}\": {},\tMost freq next word: ({}, {}),\tFreq: {}".format(' '.join(candy)+",", comma_cnt, ' '.join(candy)+".", stop_cnt, next_word, next_cnt, freq))
print("\"{}\": {},\t\"{}\": {},\tTotal : {},\tFreq: {}".format(' '.join(candy)+",", comma_cnt, ' '.join(candy)+".", stop_cnt, total_cnt, freq))
if i == 0:
last_frequency = freq
else:
most_frequency = max(most_frequency, freq)
if last_frequency > most_frequency:
return True
else:
return False
def addNextWord(self, words, next_word):
if words[-1][-1] == '.':
words.append(self.make_FirstWord(next_word))
else:
words.append(next_word)
for glove_next_word in self.makeGlove(next_word, self.addToTranslatesBucket):
self.insertToTranslatesBucket(glove_next_word, next_word)
if self.isLastCandidateMostFrequent(words) == False: # we couldn't add any comma/stop
if DEBUG:
print("we can't add comma due to last word frequency rule.")
return words
commaLen = self.calcCommaLen(words)
bshouldAddComma = False
if commaLen >= self.WordsPerCommaOrPeriod[0]:
if commaLen >= self.WordsPerCommaOrPeriod[1]:
bshouldAddComma = True
else:
if random.choice([True, False]) == True:
bshouldAddComma = True
L = len(words)
if bshouldAddComma == True: # we should add real comma/stop
numCommas = 0 # calc number of commas in current sentence.
for i in range(L):
if words[-i-1][-1] == ',':
numCommas = numCommas + 1
if words[-i-1][-1] == '.':
break
"""
we have to add addCharacter(./,) after last word now.
"""
if DEBUG:
print("add real comma/stop")
if numCommas >= self.CommasPerSentence: # we have to add stop
words[-1] = words[-1] + '.'
else: # we have to add comma
words[-1] = words[-1] + ','
else:
if DEBUG:
print("we can't add comma due to WordsPerCommaOrPeriod rule.")
return words
def morph_seed(self, seed: str, length):
self.translatesBucket = {}
words = seed.strip().split()
for word in words:
for glove_word in self.makeGlove(self.stem_Word(word), self.addToTranslatesBucket):
self.insertToTranslatesBucket(glove_word, self.stem_Word(word))
for i in range(length):
print("Seed=" + " ".join(words))
word = self.predict(words)
if word == None:
print("can't predict next word and end morphing.")
return
if keyboard.is_pressed('q'):
return
words = self.addNextWord(words, word)
print("Seed=" + " ".join(words))
if __name__ == "__main__":
entailData = "entailData.txt" # entailData to make tree
gloveModel = "glove_model.txt" # glove model which is converted with glove2word2vector.py
commonWordFile = "wiki-100k.txt" # wiki common word file.
textgenerator = TextGenerator(window=7, verify=2, ofTranslates=140, ofMatches=150, ofOriginalsNearNextWord=2, CommasPerSentence=1, WordsPerCommaOrPeriod=(3,3),
commonWordCoefficient=(0.0001, 60, 0.001), contextualCoefficient=(1, 1, 2.3))
# window, verify, ofTranslates, ofMatches, ofOriginalsNearNextWord, CommasPerSentence, WordsPerCommaOrPeriod, commonWordCoefficient, contextualCoefficient
textgenerator.load_CommonWord(commonWordFile)
# textgenerator.init_Tree(entailData)
textgenerator.load_Tree()
textgenerator.load_Model(gloveModel)
while True: # infinite loop
seed = input("\nPlease input seed text here:\n")
textgenerator.morph_seed(seed, 15)