hashformers.beamsearch package
Submodules
hashformers.beamsearch.algorithm module
- class hashformers.beamsearch.algorithm.Beamsearch(model_name_or_path=None, model_type=None, device='cuda', gpu_batch_size=1)
Bases:
hashformers.beamsearch.model_lm.ModelLM
- flatten_list(list_)
- next_step(list_of_candidates)
- reshape_tree(tree, measure)
- run(dataset, topk=20, steps=13)
- trim_tree(tree, prob_dict, topk)
- update_probabilities(tree, prob_dict)
hashformers.beamsearch.bert_lm module
hashformers.beamsearch.data_structures module
- class hashformers.beamsearch.data_structures.Node(hypothesis: str, characters: str, score: float)
Bases:
object
- characters: str
- hypothesis: str
- score: float
- class hashformers.beamsearch.data_structures.ProbabilityDictionary(dictionary: dict)
Bases:
object
- dictionary: dict
- get_segmentations(astype='dict', gold_array=None)
- get_top_k(k=2, characters_field='characters', segmentation_field='segmentation', score_field='score', return_dataframe=False, fill=False)
- to_csv(filename, characters_field='characters', segmentation_field='segmentation', score_field='score')
- to_dataframe(characters_field='characters', segmentation_field='segmentation', score_field='score')
- to_json(filepath)
- hashformers.beamsearch.data_structures.enforce_prob_dict(dictionary, score_field='score', segmentation_field='segmentation')
hashformers.beamsearch.gpt2_lm module
- class hashformers.beamsearch.gpt2_lm.GPT2LM(model_name_or_path, device='cuda', gpu_batch_size=20)
Bases:
object
- get_probs(list_of_candidates)
- class hashformers.beamsearch.gpt2_lm.PaddedGPT2LMScorer(*args, **kwargs)
Bases:
lm_scorer.models.gpt2.GPT2LMScorer
hashformers.beamsearch.model_lm module
- class hashformers.beamsearch.model_lm.ModelLM(model_name_or_path=None, model_type=None, device=None, gpu_batch_size=None, gpu_id=0)
Bases:
object
hashformers.beamsearch.reranker module
- class hashformers.beamsearch.reranker.Reranker(model_name_or_path='bert-base-uncased', model_type='bert', gpu_batch_size=1, gpu_id=0)
Bases:
hashformers.beamsearch.model_lm.ModelLM
- rerank(data)