YATO model - jiesutd/YATO GitHub Wiki
YATO model
Various functions of the YATO model can be called to perform training, decoding, and visualization of the sequence labeling and sequence classification tasks.
YATO Function
from yato import YATO
yato = YATO(configuration file)
yato.get_config()
Get the Data class in yato model, Data class is the core class of yato model, which includes configuration parameters, labels.
def get_config(self): return self.data
yato.set_config_from_dset(dset)
If you want to call a configuration from another dset file, you can use this function.
Input:
dset:string, dset file path.
def set_config_from_dset(self, dset): self.data.load(dset)
yato.set_config_from_data(custom_data)
You are able to replace the Data class in yato, which is mainly used to pass Data class in the yato model.
Input:
custom_data:Data class
def set_config_from_data(self, custom_data): self.data = custom_data
yato.set_config_from_custom_configuration(custom_configuration)
You can replace the parameters in the Data class in yato, mainly used in parameter search and large batch experiments.
Input:
custom_configuration:dict, Replace the values in the Data class according to the key-value of the dictionary
def set_config_from_custom_configuration(self, custom_configuration): self.data.read_config(self.config, custom_configuration)
yato.train(log='test.log', metric='F')
Input:
log: string, The path of log file
mertric: string, Save the best model based on the metrics of interest 'F'=F1 score 'A'=Accuracydef train(self, log='test.log', metric='F'): status = self.data.status.lower() if status == 'train': print("MODEL: train") data_initialization(self.data) self.data.generate_instance('train') self.data.generate_instance('dev') self.data.generate_instance('test') self.data.build_pretrain_emb() self.data.summary() train(self.data, log, metric)
yato.decode(write_decode_file=True)
Input:
write_decode_file: boolean, Whether to write the results to a file
Output:
speed: decoding speed
accuracy: If the decoded file contains annotation results, accuracy means verifying the accuracy
precision: If the decoded file contains annotation results, precision means verifying the precision
recall: If the decoded file contains annotation results, recall means verifying the recall
predict_result: predicted result
nbest_predict_score: nbest scores of decoded prediction
label: Mapping between labels and indexesdef decode(self, write_decode_file=True): print("MODEL: decode") predict_lines = self.convert_file_to_predict_style() speed, acc, p, r, f, pred_results, pred_scores = self.predict(input_text=predict_lines, write_decode_file=write_decode_file) return {"speed": speed, "accuracy": acc, "precision": p, "recall": r, "predict_result": pred_results, "nbest_predict_score": pred_scores, 'label': self.data.label_alphabet}
yato.predict(input_text=None, predict_file=None, write_decode_file=True)
Input:
input_text: list, Direct input prediction array
predict_file: string, Files to be predicted
write_decode_file: boolean, Whether to write the results to a fileOutput:
speed: decoding speed
accuracy: If the decoded file contains annotation results, accuracy means verifying the accuracy
precision: If the decoded file contains annotation results, precision means verifying the precision
recall: If the decoded file contains annotation results, recall means verifying the recall
predict_result: predicted result
nbest_predict_score: nbest scores of decoded prediction
label: Mapping between labels and indexesdef predict(self, input_text=None, predict_file=None, write_decode_file=True): self.data.read_config(self.config) dset = self.data.dset_dir self.set_config_from_dset(dset) self.data.read_config(self.config) if predict_file is not None and input_text is None: input_text = open(predict_file, 'r', encoding="utf8").readlines() elif predict_file is not None and input_text is not None: print("Choose Predict Source") self.data.generate_instance('predict', input_text) print("nbest: {}".format(self.data.nbest)) speed, acc, p, r, f, pred_results, pred_scores = load_model_decode(self.data, 'predict') if write_decode_file and self.data.nbest > 0 and not self.data.sentence_classification: self.data.write_nbest_decoded_results(pred_results, pred_scores, 'predict') elif write_decode_file: self.data.write_decoded_results(pred_results, 'predict') return speed, acc, p, r, f, pred_results, pred_scores
yato.attention(input_text=None)
Input:
input_text: list, Arrays that need to be visualized by text
Output:
probs_ls: Label Probability
weights_ls: Weight of each worddef attention(self, input_text=None): self.data.read_config(self.config) dset = self.data.dset_dir self.set_config_from_dset(dset) self.data.read_config(self.config) print("MODEL: Attention Weight") self.data.generate_instance('predict', input_text) probs_ls, weights_ls = extract_attention_weight(self.data) return probs_ls, weights_ls
yato.set_seed(seed=42, hard = False)
Ensure the reproducibility of the model, hard mode will lead to some performance loss��
Input:
seed: int, random seed
def set_seed(seed=42, hard = False): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. np.random.seed(seed) # Numpy module. random.seed(seed) # Python random module. torch.backends.cudnn.deterministic = True if hard: torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed)
yato.decode_raw(raw_text_path, task, out_text_path='raw.out')
Split the original text segment into standard format data for decoding��
Input:
raw_text_path: string, the path of standard format data task: string, the task of the original text out_text_path: string, decode file path
def decode_raw(self, raw_text_path, task, out_text_path='raw.out'): """ :param raw_text_path:The path of raw text file :param task:choose the task :param out_text_path:The path of decode result file :return: """ raw_text = open(raw_text_path, 'r', encoding='utf-8').read() out_text = open(out_text_path, 'w', encoding='utf-8') if task.lower() == 'ner': sentences = self.para2sent(raw_text) for sentence in sentences: words = self.sent2word(sentence) for word in words: out_text.write(word + ' O\n') out_text.write('\n') elif task.lower() == 'classifier': sentences = self.para2sent(raw_text) for sentence in sentences: out_text.write(sentence + ' ||| 0\n') self.data.raw_dir = out_text_path self.decode()
Source Code
class YATO:
def __init__(self, config):
self.set_seed()
self.config = config
self.data = Data()
self.data.read_config(self.config)
def set_config_from_dset(self, dset):
self.data.load(dset)
def set_config_from_data(self, custom_data):
self.data = custom_data
def set_config_from_custom_configuration(self, custom_configuration):
self.data.read_config(self.config, custom_configuration)
def get_config(self):
return self.data
def train(self, log='test.log', metric='F'):
status = self.data.status.lower()
if status == 'train':
print("MODEL: train")
data_initialization(self.data)
self.data.generate_instance('train')
self.data.generate_instance('dev')
self.data.generate_instance('test')
self.data.build_pretrain_emb()
self.data.summary()
train(self.data, log, metric)
def decode(self, write_decode_file=True):
print("MODEL: decode")
predict_lines = self.convert_file_to_predict_style()
speed, acc, p, r, f, pred_results, pred_scores = self.predict(input_text=predict_lines,
write_decode_file=write_decode_file)
return {"speed": speed, "accuracy": acc, "precision": p, "recall": r, "predict_result": pred_results,
"nbest_predict_score": pred_scores, 'label': self.data.label_alphabet}
def predict(self, input_text=None, predict_file=None, write_decode_file=True):
self.data.read_config(self.config)
dset = self.data.dset_dir
self.set_config_from_dset(dset)
self.data.read_config(self.config)
if predict_file is not None and input_text is None:
input_text = open(predict_file, 'r', encoding="utf8").readlines()
elif predict_file is not None and input_text is not None:
print("Choose Predict Source")
self.data.generate_instance('predict', input_text)
print("nbest: {}".format(self.data.nbest))
speed, acc, p, r, f, pred_results, pred_scores = load_model_decode(self.data, 'predict')
if write_decode_file and self.data.nbest > 0 and not self.data.sentence_classification:
self.data.write_nbest_decoded_results(pred_results, pred_scores, 'predict')
elif write_decode_file:
self.data.write_decoded_results(pred_results, 'predict')
return speed, acc, p, r, f, pred_results, pred_scores
def attention(self, input_text=None):
self.data.read_config(self.config)
dset = self.data.dset_dir
self.set_config_from_dset(dset)
self.data.read_config(self.config)
print("MODEL: Attention Weight")
self.data.generate_instance('predict', input_text)
probs_ls, weights_ls = extract_attention_weight(self.data)
return probs_ls, weights_ls
def set_seed(self, seed=42, hard = False):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.deterministic = True
if hard:
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
def convert_file_to_predict_style(self):
predict_lines = open(self.data.raw_dir, 'r', encoding="utf8").readlines()
return predict_lines
def para2sent(self, para):
"""
:param para:Dividing paragraphs into sentences
:return:
"""
para = re.sub('([.��!��?��\?])([^����])', r"\1\n\2", para)
para = re.sub('(\.{6})([^����])', r"\1\n\2", para)
para = re.sub('(\��{2})([^����])', r"\1\n\2", para)
para = re.sub('([������\?][����])([^��������\?])', r'\1\n\2', para)
para = para.rstrip()
return para.split("\n")
def sent2word(self, sentence):
"""
:param sentence:Dividing sentences into words or chars
:return:
"""
char_ls = list(sentence)
word_ls = [char_ls[0]]
for i in range(1, len(char_ls)):
if 65 <= ord(char_ls[i]) <= 122 and 65 <= ord(char_ls[i - 1]) <= 122:
word_ls[-1] = word_ls[-1] + char_ls[i]
else:
word_ls.append(char_ls[i])
return word_ls
def decode_raw(self, raw_text_path, task, out_text_path='raw.out'):
"""
:param raw_text_path:The path of raw text file
:param task:choose the task
:param out_text_path:The path of decode result file
:return:
"""
raw_text = open(raw_text_path, 'r', encoding='utf-8').read()
out_text = open(out_text_path, 'w', encoding='utf-8')
if task.lower() == 'ner':
sentences = self.para2sent(raw_text)
for sentence in sentences:
words = self.sent2word(sentence)
for word in words:
out_text.write(word + ' O\n')
out_text.write('\n')
elif task.lower() == 'classifier':
sentences = self.para2sent(raw_text)
for sentence in sentences:
out_text.write(sentence + ' ||| 0\n')
self.data.raw_dir = out_text_path
self.decode()
def get_gold_predict(self, golden_standard, predict_result, stoken):
"""
:param golden_standard:golden standard file path
:param predict_result:predict result file path
:param stoken:split token
:return:
"""
golden_data = open(golden_standard, 'r', encoding='utf-8').readlines()
predict_data = open(predict_result, 'r', encoding='utf-8').readlines()
golden_list = []
predict_list = []
tmp_gold = []
tmp_predict = []
for gold_idx, pre_idx in zip(golden_data, predict_data):
if gold_idx != '\n':
gentity_with_label = gold_idx.split(stoken)
glabel = gentity_with_label[1].replace('\n', '')
tmp_gold.append(glabel)
pentity_with_label = pre_idx.split(stoken)
plabel = pentity_with_label[1].replace('\n', '')
tmp_predict.append(plabel)
else:
golden_list.append(tmp_gold)
predict_list.append(tmp_predict)
tmp_gold = []
tmp_predict = []
return golden_list, predict_list
##The metric values that are output to the log are calculated using the yato internal function. Here, the metric calculated by seqeval is slightly different from the log metric
def report_f1(self, golden_standard, predict_result, split=" "):
"""
:param golden_standard:golden standard file path
:param predict_result:predict result file path
:param split:split token
:return:
"""
golden_list, predict_list = self.get_gold_predict(golden_standard, predict_result, split)
print(classification_report(golden_list, predict_list))
##The metric values that are output to the log are calculated using the yato internal function. Here, the metric calculated by seqeval is slightly different from the log metric
def report_acc(self, golden_standard, predict_result, split=' ||| '):
"""
:param golden_standard:golden standard file path
:param predict_result:predict result file path
:param split:split token
:return:
"""
golden_list, predict_list = self.get_gold_predict(golden_standard, predict_result, split)
print("Report accuracy: %0.2f" % accuracy_score(golden_list, predict_list))