123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
from typing import Tuple
import random
import math
import os
import sys
import operator
def lower_bound(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e < x, and all e in
a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
insert just before the leftmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
L=len(x)
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid][:L] < x: lo = mid+1
else: hi = mid
return lo
def upper_bound(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e <= x, and all e in
a[i:] have e > x. So if x already appears in the list, a.insert(x) will
insert just after the rightmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
L=len(x)
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid][:L]: hi = mid
else: lo = mid+1
return lo
class Tree(object):
"""
Our tree implementation.
"""
def __init__(self, scanWindowSz):
self.scanWindowSz = scanWindowSz
self.lenRecords = 0 # number of records
self.records = [] # all hypo records made from entailData for fast search of scanWindow or verification
self.candidates = []
def build(self, text:list):
print("Building Tree...")
def make_record(word_list):
res = ""
for word in word_list:
res += word.lower() + " "
return res.strip()
length = len(text)
self.records = []
for i in range(0, length):
windowSize= min(self.scanWindowSz, length - i)
shouldInsert = True
for word in text[i:i+windowSize]:
if not(word.isalpha() or (word[:-1].isalpha() and (word[-1]==',' or word[-1]=='.'))):
shouldInsert = False
break
if shouldInsert == True:
self.records.append(make_record(text[i:i+windowSize]))
self.records.sort()
self.lenRecords = len(self.records)
# print(self.records)
# print("Building Tree done")
def find_interval(self, searchStr, lo=0, hi=None):
if (hi == None):
hi = self.lenRecords
fro = lower_bound(self.records, searchStr, lo, hi)
to = upper_bound(self.records, searchStr, lo, hi)
return fro, to
def mergeCandidates(self, candidates, idx):
# print("before", self.candidates)
self.candidates = []
if (len(candidates) == 0):
return
def isSameCluster(a, b, verifyLen):
a = a.split()
b = b.split()
if len(a) < verifyLen + 1 or len(b) < verifyLen + 1:
return False
for i in range (0, verifyLen+1):
if a[i] != b[i]:
return False
return True
candidates.sort(key = operator.itemgetter(0, 1))
self.candidates.append(candidates[0])
for candy in candidates:
if candy[0] > self.candidates[-1][1]:
self.candidates.append(candy)
else:
if isSameCluster(self.records[self.candidates[-1][0]], self.records[candy[0]], idx):
self.candidates[-1] = (self.candidates[-1][0], max(self.candidates[-1][1], candy[1]))
else:
self.candidates.append(candy)
def update_words(self, idx:int, word_list:list):
# print(idx, word_list)
if idx == 0:
candidates = []
for searchStr in word_list:
fro, to = self.find_interval(searchStr + " ")
if fro < to:
candidates.append((fro, to))
else:
candidates = []
for (fro, to) in self.candidates:
for word in word_list:
searchStr = ' '.join(self.records[fro].split()[:idx]) + " " + word + " "
froNew, toNew = self.find_interval(searchStr, fro, to)
if froNew < toNew:
candidates.append((froNew, toNew))
# print(candidates)
self.mergeCandidates(candidates, idx)
# print(self.candidates)
def dist(self, a, b):
L = min(len(a), len(b))
nSame = 0
pattern = 0
for i in range(0, L):
if a[i] == b[i]:
nSame = nSame + 1
pattern = pattern + (1<<i)
return (nSame, pattern)
def get_words(self, idx:int):
"""
return word list of vanila serarch ex. [[down the street], [up the street], [down the stair], ...]
here predicted next word is [street, stair, ...]
"""
resultSet = []
for (fro, to) in self.candidates:
for sequence in self.records[fro:to]:
word_list = sequence.split()
if (len(word_list) > idx):
resultSet.append(' '.join(word_list[:idx+1]))
result = []
for hp in resultSet:
words = hp.split()
if words[-1][-1]==',' or words[-1][-1]=='.':
words[-1] = words[-1][:-1]
if words[-1] == "":
continue
result.append(words)
return result
def is_contain(self, searchStr:str) -> bool:
"""
check if a tree contain searchStr
"""
fro, to = self.find_interval(searchStr)
if fro < to:
return True
else:
return False
def calc_frequency(self, searchStr:str) -> int:
"""
calculate the number of occurances of search str in entailData
"""
fro, to = self.find_interval(searchStr)
if fro < to:
return to-fro
else:
return 0
def calc_most_frequent_next_word(self, searchStr:str):
"""
calcuate the most frequent next word of a SearchStr ex. "down the " -> "path, 15"
"""
L = len(searchStr.strip().split())
fro, to = self.find_interval(searchStr)
cur_word = ""
cur_cnt = 0
max_cnt = 0
max_word = ""
for sequence in self.records[fro:to]:
word_list = sequence.split()
if (len(word_list) <= L):
continue
word = word_list[L]
if word == cur_word:
cur_cnt += 1
else:
if cur_cnt > max_cnt:
max_cnt = cur_cnt
max_word = cur_word
cur_word = word
cur_cnt = 1
if cur_cnt > max_cnt:
max_cnt = cur_cnt
max_word = cur_word
return (max_word, max_cnt)
if __name__ == "__main__":
# tree = Tree(3)
# tree.build('hello wor*ld how are you how are How do yo-u do ')
# print(tree.is_contain("hello wOr*ld how ")) # True
# print(tree.is_contain("how do yo-u do ")) # True
# print(tree.is_contain("wor*LD how a")) # True
# print(tree.find_nextword("how do yo-u ")) # do
# print(tree.find_nextword("how do your ")) # None
# print(tree.find_nextword("how are ")) # you or how
# print(tree.find_nextword("how are ")) # you or how
import smart_open
import pickle
entailData = sys.argv[1] # entailData to make tree
tree = Tree(4, 3, "SplitedWindows")
words = []
with smart_open.smart_open(entailData, 'r') as entail:
for line in entail:
words.extend(line.strip().split())
tree.build(words)
print(tree.records)
print("Saving Tree to \"Tree\"")
with open("Tree", 'wb') as outfile:
pickle.dump(tree, outfile)
print("Saving Tree done")
tree.update_words(0, ["today", "Are", "Earth", "hit", "a"])
tree.update_words(1, ["meteor,", "Earth,"])
print(tree.get_words(2))
print(tree.is_contain("meteor, "))
print(tree.is_contain("a "))
print(tree.is_contain("a, "))
print(tree.calc_frequency("meteor, "))
print(tree.calc_frequency("a "))
print(tree.calc_frequency("a, "))
print("yes yes yes yes", tree.calc_frequency("yes yes yes yes"))