Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def __init__(self, base_tokenizer: BaseTokenizer,
def _init_data_paths(self):
doc_cls = self.base_tokenizer.get_doc_class()
doc_cls.register_addon_path('relations', def_val=[], force=True)
entity_cls = self.base_tokenizer.get_entity_class()
entity_cls.register_addon_path('start', def_val=None, force=True)
entity_cls.register_addon_path('end', def_val=None, force=True)

def save(self, save_path: str = "./") -> None:
self.component.save(save_path=save_path)
Expand Down Expand Up @@ -833,27 +836,23 @@ def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs

relations: list = doc.get_addon_data( # type: ignore
"relations")
out_rels = predict_rel_dataset.dataset[
"output_relations"][rel_idx]
relations.append(
{
"relation": rc_cnf.general.idx2labels[
predicted_label_id],
"label_id": predicted_label_id,
"ent1_text": predict_rel_dataset.dataset[
"output_relations"][rel_idx][
2],
"ent2_text": predict_rel_dataset.dataset[
"output_relations"][rel_idx][
3],
"ent1_text": out_rels[2],
"ent2_text": out_rels[3],
"confidence": float("{:.3f}".format(
confidence[0])),
"start_ent_pos": "",
"end_ent_pos": "",
"start_entity_id":
predict_rel_dataset.dataset[
"output_relations"][rel_idx][8],
"end_entity_id":
predict_rel_dataset.dataset[
"output_relations"][rel_idx][9]
"start_ent1_char_pos": out_rels[18],
"end_ent1_char_pos": out_rels[19],
"start_ent2_char_pos": out_rels[20],
"end_ent2_char_pos": out_rels[21],
"start_entity_id": out_rels[8],
"end_entity_id": out_rels[9],
})
pbar.update(len(token_ids))
pbar.close()
Expand Down Expand Up @@ -901,6 +900,8 @@ def predict_text_with_anns(self, text: str, annotations: list[dict]
entity = base_tokenizer.create_entity(
doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"])
entity.cui = ann["cui"]
entity.set_addon_data('start', ann['strat'])
entity.set_addon_data('end', ann['end'])
doc.ner_ents.append(entity)

doc = self(doc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def _create_relation_validation(self,
ent2_token: Union[str, MutableEntity] = tmp_doc_text[
ent2_start_char_pos: ent2_end_char_pos]

annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens(
self.config.general.annotation_schema_tag_ids)

if (abs(ent2_start_char_pos - ent1_start_char_pos
) <= self.config.general.window_size and
ent1_token != ent2_token):
Expand Down Expand Up @@ -281,24 +284,20 @@ def _create_relation_validation(self,

if is_spacy_doc or is_mct_export:
tmp_doc_text = text
_pre_e1 = tmp_doc_text[0: (ent1_start_char_pos)]
_e1_s2 = (
tmp_doc_text[ent1_end_char_pos: ent2_start_char_pos - 1])
_e2_end = tmp_doc_text[ent2_end_char_pos + 1: text_length]
ent2_token_end_pos = (ent2_token_end_pos + 2)

annotation_token_text = (
self.tokenizer.hf_tokenizers.convert_ids_to_tokens(
self.config.general.annotation_schema_tag_ids))
s1, e1, s2, e2 = annotation_token_text

tmp_doc_text = (
str(_pre_e1) + " " +
annotation_token_text[0] + " " +
str(ent1_token) + " " +
annotation_token_text[1] + " " + str(_e1_s2) + " " +
annotation_token_text[2] + " " + str(ent2_token) + " " +
annotation_token_text[3] + " " + str(_e2_end)
)
tmp_doc_text[:ent2_end_char_pos] +
e2 + tmp_doc_text[ent2_end_char_pos:])
tmp_doc_text = (
tmp_doc_text[:ent2_start_char_pos] +
s2 + tmp_doc_text[ent2_start_char_pos:])
tmp_doc_text = (
tmp_doc_text[:ent1_end_char_pos] +
e1 + tmp_doc_text[ent1_end_char_pos:])
tmp_doc_text = (
tmp_doc_text[:ent1_start_char_pos] +
s1 + tmp_doc_text[ent1_start_char_pos:])

ann_tag_token_len = len(annotation_token_text[0])

Expand All @@ -309,11 +308,10 @@ def _create_relation_validation(self,

_right_context_start_end_pos = ( # 8 for spces
right_context_end_char_pos + (ann_tag_token_len * 4) + 8)
right_context_end_char_pos = (
len(tmp_doc_text) + 1 if
right_context_end_char_pos >= len(tmp_doc_text) or
_right_context_start_end_pos >= len(tmp_doc_text)
else _right_context_start_end_pos)
right_context_end_char_pos = len(tmp_doc_text) if (
right_context_end_char_pos >= len(tmp_doc_text)
or _right_context_start_end_pos >= len(tmp_doc_text)
) else _right_context_start_end_pos

# reassign the new text with added tags
text_length = len(tmp_doc_text)
Expand Down Expand Up @@ -363,16 +361,20 @@ def _create_relation_validation(self,
ent2_token_start_pos += ent1_token_start_pos

ent1_ent2_new_start = (ent1_token_start_pos, ent2_token_start_pos)
en1_start, en1_end = window_tokenizer_data[
"offset_mapping"][ent1_token_start_pos]
en2_start, en2_end = window_tokenizer_data[
"offset_mapping"][ent2_token_start_pos]
os_map = window_tokenizer_data["offset_mapping"]
s1_start, s1_end = os_map[ent1_token_start_pos]
e1_start, e1_end = os_map[_ent1_token_end_pos]

s2_start, s2_end = os_map[ent2_token_start_pos]
e2_start, e2_end = os_map[_ent2_token_end_pos]

return [window_tokenizer_data["input_ids"], ent1_ent2_new_start,
ent1_token, ent2_token, "UNK",
self.config.model.padding_idx,
None, None, None, None, None, None, doc_id, "",
en1_start, en1_end, en2_start, en2_end]
s1_start, e1_end, s2_start, e2_end,
ent1_start_char_pos, ent1_end_char_pos,
ent2_start_char_pos, ent2_end_char_pos]
return []

def _get_token_type_and_start_end(
Expand Down
Loading