diff --git a/ChartExtractor/extraction/checkboxes.py b/ChartExtractor/extraction/checkboxes.py index eceb515..da5479b 100644 --- a/ChartExtractor/extraction/checkboxes.py +++ b/ChartExtractor/extraction/checkboxes.py @@ -38,30 +38,22 @@ def extract_checkboxes( - image: Image.Image, - detection_model: ObjectDetectionModel, + detections: List[Detection], side: Literal["intraoperative", "preoperative"], - slice_width: int, - slice_height: int, - horizontal_overlap_ratio: float = 0.5, - vertical_overlap_ratio: float = 0.5, + image_width: int, + image_height: int, ) -> Dict[str, str]: """Extracts checkbox data from an image of a chart. Args: - `image` (Image.Image): - The image to extract checkboxes from. - `detection_model` (ObjectDetectionModel): - An object that implements the ObjectDetectionModel interface. - `slice_height` (int): - The height of each slice. - `slice_width` (int): - The width of each slice. - `horizontal_overlap_ratio` (float): - The amount of left-right overlap between slices. - `vertical_overlap_ratio` (float): - The amount of top-bottom overlap between slices. - + detections (List[Detection]): + The detected checkboxes. + side (Literal["intraoperative", "preoperative"]): + The side of the chart. + image_width (int): + The original image's width. + image_height (int): + The original image's height. Returns: A dictionary mapping the name of checkboxes to "checked" or "unchecked". """ @@ -74,76 +66,21 @@ def extract_checkboxes( f'Invalid selection for side. Must be one of ["intraoperative", "preoperative"], value supplied was {side}' ) - checkbox_bboxes: List[BoundingBox] = detect_checkboxes( - image, - detection_model, - slice_width, - slice_height, - horizontal_overlap_ratio, - vertical_overlap_ratio, + checkbox_bboxes: List[BoundingBox] = [det.annotation for det in detections] + names: Dict[str, str] = find_checkbox_names( + checkbox_bboxes, + centroids, + image_width, + image_height ) - names: Dict[str, str] = find_checkbox_names(checkbox_bboxes, centroids, image.size) return names -def detect_checkboxes( - image: Image.Image, - detection_model: ObjectDetectionModel, - slice_width: int, - slice_height: int, - horizontal_overlap_ratio: float, - vertical_overlap_ratio: float, -) -> List[BoundingBox]: - """Uses an object detector to detect checkboxes and their state on an image. - - Args: - `image` (Image.Image): - The image to extract checkboxes from. - `detection_model` (ObjectDetectionModel): - An object that implements the ObjectDetectionModel interface. - `slice_height` (int): - The height of each slice. - `slice_width` (int): - The width of each slice. - `horizontal_overlap_ratio` (float): - The amount of left-right overlap between slices. - `vertical_overlap_ratio` (float): - The amount of top-bottom overlap between slices. - - Returns: - A list of Detection objects encoding the location and state of checkboxes. - """ - image_tiles: List[List[Image.Image]] = tile_image( - image, - slice_width, - slice_height, - horizontal_overlap_ratio, - vertical_overlap_ratio, - ) - detections: List[List[List[Detection]]] = [ - [detection_model(pil_to_cv2(tile))[0] for tile in row] - for row in image_tiles - ] - detections: List[Detection] = untile_detections( - detections, - slice_width, - slice_height, - horizontal_overlap_ratio, - vertical_overlap_ratio, - ) - detections: List[Detection] = non_maximum_suppression( - detections=detections, - threshold=0.8, - overlap_comparator=intersection_over_minimum, - sorting_fn=lambda det: det.annotation.area * det.confidence, - ) - return [det.annotation for det in detections] - - def find_checkbox_names( checkboxes: List[BoundingBox], centroids: Dict[str, Tuple[float, float]], - imsize: Tuple[int, int], + image_width: int, + image_height: int, threshold: float = 0.025, ) -> Dict[str, str]: """Finds the names of checkboxes. @@ -175,7 +112,7 @@ def distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float: checkbox_values: Dict[str, str] = dict() for ckbx in checkboxes: - center = ckbx.center[0] / imsize[0], ckbx.center[1] / imsize[1] + center = ckbx.center[0] / image_width, ckbx.center[1] / image_height distance_to_all_centroids: Dict[str, float] = { name: distance(center, centroid) for (name, centroid) in centroids.items() } @@ -188,19 +125,3 @@ def distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float: checkbox_values[closest_checkbox_centroid] = ckbx.category return checkbox_values - - -def find_interaoperative_checkbox_names( - intraoperative_checkboxes: List[BoundingBox], threshold: float = 0.025 -) -> Dict[str, str]: - """Finds the names of intraoperative checkboxes.""" - return find_checkbox_names(intraoperative_checkboxes, INTRAOP_CENTROIDS, threshold) - - -def find_preoperative_checkbox_names( - preoperative_checkboxes: List[BoundingBox], threshold: float = 0.025 -) -> Dict[str, str]: - """Finds the names of preoperative checkboxes.""" - return find_checkbox_names( - preoperative_checkboxes, PREOP_POSTOP_CENTROIDS, threshold - ) diff --git a/ChartExtractor/extraction/extraction.py b/ChartExtractor/extraction/extraction.py index f85050c..96b0a2c 100644 --- a/ChartExtractor/extraction/extraction.py +++ b/ChartExtractor/extraction/extraction.py @@ -1,10 +1,12 @@ """Consolidates all the functionality for extracting data from charts into one function.""" # Built-in imports +from functools import partial, reduce +from operator import concat import os from pathlib import Path from PIL import Image -from typing import Dict, List, Literal, Tuple +from typing import Any, Dict, List, Literal, Tuple # Internal Imports from ..extraction.blood_pressure_and_heart_rate import ( @@ -13,7 +15,7 @@ from ..extraction.checkboxes import extract_checkboxes from ..extraction.extraction_utilities import ( combine_dictionaries, - detect_numbers, + detect_objects_using_tiling, label_studio_to_bboxes, ) from ..extraction.inhaled_volatile import extract_inhaled_volatile @@ -39,6 +41,13 @@ from ..object_detection_models.onnx_yolov11_detection import OnnxYolov11Detection from ..object_detection_models.onnx_yolov11_pose_single import OnnxYolov11PoseSingle from ..object_detection_models.object_detection_model import ObjectDetectionModel +from ..point_registration.homography import ( + find_homography, + transform_point, + transform_box, + transform_keypoint, +) +from ..utilities.annotations import BoundingBox, Keypoint from ..utilities.detections import Detection from ..utilities.detection_reassembly import ( untile_detections, @@ -49,13 +58,10 @@ from ..utilities.read_config import read_config from ..utilities.tiling import tile_image +# External Imports +import numpy as np + -CORNER_LANDMARK_NAMES: List[str] = [ - "anesthesia_start", - "safety_checklist", - "lateral", - "units", -] PATH_TO_DATA: Path = (Path(os.path.dirname(__file__)) / ".." / ".." / "data").resolve() PATH_TO_MODELS: Path = PATH_TO_DATA / "models" PATH_TO_MODEL_METADATA = PATH_TO_DATA / "model_metadata" @@ -116,6 +122,377 @@ def digitize_sheet(intraop_image: Image.Image, preop_postop_image: Image.Image) return data +def run_models( + intraop_image: Image.Image, + preop_postop_image: Image.Image +) -> Dict[str, List[Detection]]: + """Runs all the models and puts their output into a dictionary. + + Args: + `intraop_image` (Image.Image): + A smartphone photograph of the intraoperative side of the paper + anesthesia record. + `preop_postop_image` (Image.Image): + A smartphone photograph of the preoperative/postoperative side of the + paper anesthesia record. + + Returns: + A dictionary containing all the detections on both images. The structure of the dictionary + is set up as: + { + "intraoperative": { + "landmarks": [...], + "numbers": [...], + "checkboxes": [...], + "systolic": [...], + "diastoic": [...], + "heart_rate": [...], + }, + "preoperative_postoperative": { + "landmarks": [...], + "numbers": [...], + "checkboxes": [...], + } + } + """ + detections_dict: Dict[str, List[Detection]] = dict() + detections_dict["intraoperative"] = run_intraoperative_models(intraop_image) + detections_dict["preoperative_postoperative"] = ( + run_preoperative_postoperative_models(preop_postop_image) + ) + return detections_dict + + +def run_intraoperative_models(intraop_image: Image.Image) -> Dict[str, List[Detection]]: + """Runs all the models on the preoperative/postoperative image and outputs to a dictionary. + + Args: + `intraop_image` (Image.Image): + A smartphone photograph of the intraoperative side of the paper anesthesia record. + + Returns: + A dictionary containing all of the detections on the intraoperative image. + The structure of the dictionary is set up as: + { + "landmarks": [...], + "numbers": [...], + "checkboxes": [...], + "systolic": [...], + "diastoic": [...], + "heart_rate": [...], + } + """ + detections_dict: Dict[str, List[Detection]] = dict() + + # landmarks + landmark_tile_size: int = compute_tile_size( + MODEL_CONFIG["intraoperative_document_landmarks"], + intraop_image.size + ) + detections_dict["landmarks"] = detect_objects_using_tiling( + intraop_image, + INTRAOP_DOC_MODEL, + landmark_tile_size, + landmark_tile_size, + MODEL_CONFIG["intraoperative_document_landmarks"]["horz_overlap_proportion"], + MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"], + ) + + # numbers + digit_tile_size: int = compute_tile_size(MODEL_CONFIG["numbers"], intraop_image.size) + detections_dict["numbers"] = detect_objects_using_tiling( + intraop_image, + NUMBERS_MODEL, + digit_tile_size, + digit_tile_size, + MODEL_CONFIG["numbers"]["horz_overlap_proportion"], + MODEL_CONFIG["numbers"]["vert_overlap_proportion"], + ) + + # checkboxes + tile_size = compute_tile_size(MODEL_CONFIG["checkboxes"], intraop_image.size) + detections_dict["checkboxes"] = detect_objects_using_tiling( + intraop_image, + CHECKBOXES_MODEL, + tile_size, + tile_size, + MODEL_CONFIG["checkboxes"]["horz_overlap_proportion"], + MODEL_CONFIG["checkboxes"]["vert_overlap_proportion"], + nms_threshold=0.8 + ) + + # systolic + sys_tile_size: int = compute_tile_size(MODEL_CONFIG["systolic"], intraop_image.size) + detections_dict["systolic"] = detect_objects_using_tiling( + intraop_image.copy(), + SYSTOLIC_MODEL, + sys_tile_size, + sys_tile_size, + MODEL_CONFIG["systolic"]["horz_overlap_proportion"], + MODEL_CONFIG["systolic"]["vert_overlap_proportion"], + ) + + # diastolic + dia_tile_size: int = compute_tile_size(MODEL_CONFIG["diastolic"], intraop_image.size) + detections_dict["diastolic"] = detect_objects_using_tiling( + intraop_image.copy(), + DIASTOLIC_MODEL, + dia_tile_size, + dia_tile_size, + MODEL_CONFIG["diastolic"]["horz_overlap_proportion"], + MODEL_CONFIG["diastolic"]["vert_overlap_proportion"], + ) + + # heart rate + hr_tile_size: int = compute_tile_size(MODEL_CONFIG["heart_rate"], intraop_image.size) + detections_dict["heart_rate"] = detect_objects_using_tiling( + intraop_image.copy(), + HEART_RATE_MODEL, + hr_tile_size, + hr_tile_size, + MODEL_CONFIG["heart_rate"]["horz_overlap_proportion"], + MODEL_CONFIG["heart_rate"]["vert_overlap_proportion"], + ) + + return detections_dict + + +def run_preoperative_postoperative_models( + preop_postop_image: Image.Image +) -> Dict[str, List[Detection]]: + """Runs all the models on the preoperative/postoperative image and outputs to a dictionary. + + Args: + `preop_postop_image` (Image.Image): + A smartphone photograph of the preoperative/postoperative side of the + paper anesthesia record. + + Returns: + A dictionary containing all of the detections on the preoperative/postoperative image. + The structure of the dictionary is set up as: + { + "landmarks": [...], + "numbers": [...], + "checkboxes": [...], + } + """ + detections_dict: Dict[str, List[Detection]] = dict() + + # landmarks + landmark_tile_size: int = compute_tile_size( + MODEL_CONFIG["preop_postop_document_landmarks"], + preop_postop_image.size, + ) + detections_dict["landmarks"] = detect_objects_using_tiling( + preop_postop_image, + PREOP_POSTOP_DOC_MODEL, + landmark_tile_size, + landmark_tile_size, + MODEL_CONFIG["preop_postop_document_landmarks"]["horz_overlap_proportion"], + MODEL_CONFIG["preop_postop_document_landmarks"]["vert_overlap_proportion"], + ) + + # numbers + digit_tile_size: int = compute_tile_size(MODEL_CONFIG["numbers"], preop_postop_image.size) + detections_dict["numbers"] = detect_objects_using_tiling( + preop_postop_image, + NUMBERS_MODEL, + digit_tile_size, + digit_tile_size, + MODEL_CONFIG["numbers"]["horz_overlap_proportion"], + MODEL_CONFIG["numbers"]["vert_overlap_proportion"], + ) + + # checkboxes + tile_size = compute_tile_size(MODEL_CONFIG["checkboxes"], preop_postop_image.size) + detections_dict["checkboxes"] = detect_objects_using_tiling( + preop_postop_image, + CHECKBOXES_MODEL, + tile_size, + tile_size, + MODEL_CONFIG["checkboxes"]["horz_overlap_proportion"], + MODEL_CONFIG["checkboxes"]["vert_overlap_proportion"], + nms_threshold=0.8 + ) + + return detections_dict + + +def assign_meaning_to_detections(detections_dict: Dict[str, List[Detection]]) -> Dict[str, Any]: + """Imputes values to the detections to get the data encoded by the provider onto the chart. + + Examples of assigning meaning include getting mmHg and timestamp values for blood pressure + markers, assigning meaning to checked/unchecked checkbox detections, etc. + + Args: + detections_dict (Dict[str, List[Detection]]): + The detections from all models on both sides of the chart. Dictionary must match the + template that is output by run_models. + + Returns: + A dictionary with data that approximately matches the encoded meaning that the medical + provider wrote onto the chart. + """ + data: Dict[str, Any] = dict() + data["intraoperative"] = assign_meaning_to_intraoperative_detections( + detections_dict["intraoperative"] + ) + data["preoperative_postoperative"] = assign_meaning_to_preoperative_postoperative_detections( + detections_dict["preoperative_postoperative"] + ) + return data + + +def assign_meaning_to_intraoperative_detections( + intraop_detections_dict: Dict[str, List[Detection]], + image_size: Tuple[int, int] = (3300, 2550) +) -> Dict[str, Any]: + """Imputes values to the detections on the intraoperative side of the chart. + + Args: + intraop_detections_dict (Dict[str, List[Detection]]): + The detections from all models on the intraoperative side of the chart. + Must match the template that is output by run_intraoperative_models. + image_size (Tuple[int, int]): + The size of the image. + + Returns: + A dictionary with data that approximately matches the encoded meaning that the medical + provider wrote onto the intraoperative side of the chart. + """ + h = create_intraoperative_homography_matrix(intraop_detections_dict["landmarks"]) + corrected_detections_dict: Dict[str, List[Detection]] = dict() + for (key, detections) in intraop_detections_dict.items(): + if len(detections) == 0: + continue + remap_func = ( + transform_box + if isinstance(detections[0].annotation, BoundingBox) + else transform_keypoint + ) + corrected_detections_dict[key] = [ + Detection(remap_func(det.annotation, h), det.confidence) for det in detections + ] + + extracted_data: Dict[str, Any] = dict() + + # extract drug code and surgical timing + extracted_data["codes"] = extract_drug_codes( + corrected_detections_dict["numbers"], + *image_size + ) + extracted_data["timing"] = extract_surgical_timing( + corrected_detections_dict["numbers"], + *image_size + ) + extracted_data["ett_size"] = extract_ett_size( + corrected_detections_dict["numbers"], + *image_size + ) + + # extract inhaled volatile drugs + time_boxes, mmhg_boxes = isolate_blood_pressure_legend_bounding_boxes( + [det.annotation for det in corrected_detections_dict["landmarks"]], *image_size + ) + time_clusters: List[Cluster] = cluster_boxes( + time_boxes, cluster_kmeans, "mins", possible_nclusters=[40, 41, 42] + ) + mmhg_clusters: List[Cluster] = cluster_boxes( + mmhg_boxes, cluster_kmeans, "mmhg", possible_nclusters=[18, 19, 20] + ) + + legend_locations: Dict[str, Tuple[float, float]] = find_legend_locations( + time_clusters + mmhg_clusters + ) + extracted_data["inhaled_volatile"] = extract_inhaled_volatile( + corrected_detections_dict["numbers"], + legend_locations, + corrected_detections_dict["landmarks"] + ) + + # extract bp and hr + bp_and_hr_dets = reduce( + concat, + [ + corrected_detections_dict["systolic"], + corrected_detections_dict["diastolic"], + corrected_detections_dict["heart_rate"], + ], + list() + ) + + extracted_data["bp_and_hr"] = extract_heart_rate_and_blood_pressure( + bp_and_hr_dets, + time_clusters, + mmhg_clusters, + ) + + # extract physiological indicators + extracted_data["physiological_indicators"] = extract_physiological_indicators( + corrected_detections_dict["numbers"], + legend_locations, + corrected_detections_dict["landmarks"], + *image_size + ) + + # extract checkboxes + extracted_data["checkboxes"] = extract_checkboxes( + corrected_detections_dict["checkboxes"], + "intraoperative", + image_size[0], + image_size[1], + ) + + return extracted_data + + +def assign_meaning_to_preoperative_postoperative_detections( + preop_postop_detections_dict: Dict[str, List[Detection]], + image_size: Tuple[int, int] = (3300, 2550) +) -> Dict[str, Any]: + """Imputes values to the detections on the preoperative/postoperative side of the chart. + + Args: + intraop_detections_dict (Dict[str, List[Detection]]): + The detections from all models on the preoperative/postoperative side of the chart. + Must match the template that is output by run_intraoperative_models. + + Returns: + A dictionary with data that approximately matches the encoded meaning that the medical + provider wrote onto the preoperative/postoperative side of the chart. + """ + h = create_preoperative_postoperative_homography_matrix( + preop_postop_detections_dict["landmarks"] + ) + corrected_detections_dict: Dict[str, List[Detection]] = dict() + for (key, detections) in preop_postop_detections_dict.items(): + if len(detections) == 0: + continue + remap_func = ( + transform_box + if isinstance(detections[0].annotation, BoundingBox) + else transform_keypoint + ) + corrected_detections_dict[key] = [ + Detection(remap_func(det.annotation, h), det.confidence) for det in detections + ] + + extracted_data: Dict[str, Any] = dict() + + extracted_data.update( + extract_preop_postop_digit_data( + corrected_detections_dict["numbers"], + *image_size + ) + ) + extracted_data["checkboxes"] = extract_checkboxes( + corrected_detections_dict["checkboxes"], + "preoperative", + *image_size + ) + return extracted_data + + def digitize_intraop_record(image: Image.Image) -> Dict: """Digitizes the intraoperative side of a paper anesthesia record. @@ -128,13 +505,40 @@ def digitize_intraop_record(image: Image.Image) -> Dict: A dictionary containing all the data from the intraoperative side of the paper anesthesia record. """ + landmark_tile_size: int = compute_tile_size( + MODEL_CONFIG["intraoperative_document_landmarks"], + image.size + ) + uncorrected_document_landmark_detections: List[Detection] = detect_objects_using_tiling( + image, + INTRAOP_DOC_MODEL, + landmark_tile_size, + landmark_tile_size, + MODEL_CONFIG["intraoperative_document_landmarks"]["horz_overlap_proportion"], + MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"], + ) image: Image.Image = homography_intraoperative_chart( - image, make_document_landmark_detections(image, "intraop") + image, + uncorrected_document_landmark_detections, + ) + document_landmark_detections: List[Detection] = detect_objects_using_tiling( + image, + INTRAOP_DOC_MODEL, + landmark_tile_size, + landmark_tile_size, + MODEL_CONFIG["intraoperative_document_landmarks"]["horz_overlap_proportion"], + MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"], ) - document_landmark_detections: List[Detection] = make_document_landmark_detections( - image, "intraop" + + digit_tile_size: int = compute_tile_size(MODEL_CONFIG["numbers"], image.size) + digit_detections: List[Detection] = detect_objects_using_tiling( + image, + NUMBERS_MODEL, + digit_tile_size, + digit_tile_size, + MODEL_CONFIG["numbers"]["horz_overlap_proportion"], + MODEL_CONFIG["numbers"]["vert_overlap_proportion"], ) - digit_detections: List[Detection] = make_digit_detections(image) # extract drug code and surgical timing codes: Dict = {"codes": extract_drug_codes(digit_detections, *image.size)} @@ -155,6 +559,7 @@ def digitize_intraop_record(image: Image.Image) -> Dict: legend_locations: Dict[str, Tuple[float, float]] = find_legend_locations( time_clusters + mmhg_clusters ) + inhaled_volatile: Dict = { "inhaled_volatile": extract_inhaled_volatile( digit_detections, legend_locations, document_landmark_detections @@ -206,11 +611,28 @@ def digitize_preop_postop_record(image: Image.Image) -> Dict: A dictionary containing all the data from the preoperative/postoperative side of the paper anesthesia record. """ - image: Image.Image = homography_preoperative_chart( + landmark_tile_size: int = compute_tile_size( + MODEL_CONFIG["preop_postop_document_landmarks"], + image.size, + ) + document_landmark_detections: List[Detection] = detect_objects_using_tiling( image, - make_document_landmark_detections(image, "preop_postop"), + PREOP_POSTOP_DOC_MODEL, + landmark_tile_size, + landmark_tile_size, + MODEL_CONFIG["preop_postop_document_landmarks"]["horz_overlap_proportion"], + MODEL_CONFIG["preop_postop_document_landmarks"]["vert_overlap_proportion"], + ) + image: Image.Image = homography_preoperative_chart(image, document_landmark_detections) + digit_tile_size: int = compute_tile_size(MODEL_CONFIG["numbers"], image.size) + digit_detections: List[Detection] = detect_objects_using_tiling( + image, + NUMBERS_MODEL, + digit_tile_size, + digit_tile_size, + MODEL_CONFIG["numbers"]["horz_overlap_proportion"], + MODEL_CONFIG["numbers"]["vert_overlap_proportion"], ) - digit_detections: List[Detection] = make_digit_detections(image) digit_data = extract_preop_postop_digit_data(digit_detections, *image.size) checkbox_data = { "preoperative_checkboxes": make_preop_postop_checkbox_detections(image) @@ -218,8 +640,98 @@ def digitize_preop_postop_record(image: Image.Image) -> Dict: return combine_dictionaries([digit_data, checkbox_data]) +def create_homography_matrix( + landmark_detections: List[Detection], + corner_landmark_names: List[str], + destination_landmarks: List[BoundingBox] +) -> np.ndarray: + """Creates a homography matrix from the corner landmarks. + + Args: + landmark_detections (List[Detection]): + The list of detected landmarks. + corner_landmark_names (List[str]): + The list of names that match categories from the landmark detections. + destination_landmarks (List[BoundingBox]): + The landmark locations on the perfect, scanned image. + + Returns: + A homography matrix that linearly transforms points from the original image to the + scanned, perfect image. + """ + dest_points = [ + bb.center + for bb in sorted( + list(filter(lambda x: x.category in corner_landmark_names, destination_landmarks)), + key=lambda bb: bb.category, + ) + ] + src_points = [ + bb.annotation.center + for bb in sorted( + list( + filter( + lambda x: x.annotation.category in corner_landmark_names, + landmark_detections, + ) + ), + key=lambda bb: bb.annotation.category, + ) + ] + return find_homography(src_points, dest_points) + + +def create_intraoperative_homography_matrix(landmark_detections: List[Detection]) -> np.ndarray: + """Creates a homography matrix for the intraoperative side of the chart. + + Args: + landmark_detections (List[Detection]): + The list of detected landmarks. + + Returns: + A homography matrix that linearly transforms points from the original image to the + scanned, perfect image. + """ + corner_landmark_names: List[str] = [ + "anesthesia_start", + "safety_checklist", + "lateral", + "units", + ] + dst_landmarks: List[BoundingBox] = label_studio_to_bboxes( + str(PATH_TO_DATA / "intraop_document_landmarks.json") + )["unified_intraoperative_preoperative_flowsheet_v1_1_front.png"] + return create_homography_matrix(landmark_detections, corner_landmark_names, dst_landmarks) + + +def create_preoperative_postoperative_homography_matrix( + landmark_detections: List[Detection] +) -> np.ndarray: + """Creates a homography matrix for the intraoperative side of the chart. + + Args: + landmark_detections (List[Detection]): + The list of detected landmarks. + + Returns: + A homography matrix that linearly transforms points from the original image to the + scanned, perfect image. + """ + corner_landmark_names: List[str] = [ + "patient_profile", + "weight", + "signature", + "disposition", + ] + dst_landmarks: List[BoundingBox] = label_studio_to_bboxes( + str(PATH_TO_DATA / "preoperative_document_landmarks.json") + )["unified_intraoperative_preoperative_flowsheet_v1_1_back.png"] + return create_homography_matrix(landmark_detections, corner_landmark_names, dst_landmarks) + + def homography_intraoperative_chart( - image: Image.Image, intraop_document_detections: List[Detection] + image: Image.Image, + intraop_document_detections: List[Detection], ) -> Image.Image: """Performs a homography transformation on the intraoperative side of the chart. @@ -270,7 +782,8 @@ def homography_intraoperative_chart( def homography_preoperative_chart( - image: Image.Image, preop_document_detections: List[Detection] + image: Image.Image, + preop_document_detections: List[Detection], ) -> Image.Image: """Performs a homography transformation on the preop/postop side of the chart. @@ -320,89 +833,22 @@ def homography_preoperative_chart( ) -def make_document_landmark_detections( - image: Image.Image, - document_side: Literal["intraop", "preop_postop"], -) -> List[Detection]: - """Runs the document landmark detection model to find document landmarks. - - Args: - `image` (Image.Image): - The image to detect on. - `document_side` (Path): - The side of the document to find landmarks on. - - Returns: - A list of detections containing the locations of the document landmarks. - """ - if document_side not in ["intraop", "preop_postop"]: - err_msg = f"Value for \"document_side\" is not in [\"intraop\", " - err_msg += f"\"preop_postop\"] (passed: {document_side})." - raise ValueError(err_msg) +def compute_tile_size(model_config: Dict, image_size: Tuple[int, int]) -> int: + """Finds the tile size for a model based on how its training dataset was generated. - document_model: UltralyticsYOLOv8 = ( - INTRAOP_DOC_MODEL if document_side == "intraop" else PREOP_POSTOP_DOC_MODEL - ) - tile_size_proportion: float = MODEL_CONFIG["intraoperative_document_landmarks"][ - "tile_size_proportion" - ] - tile_size: int = int( - min( - image.size[0] * tile_size_proportion, - image.size[1] * tile_size_proportion, - ) - ) - tiles: List[List[Image.Image]] = tile_image( - image, - tile_size, - tile_size, - MODEL_CONFIG["intraoperative_document_landmarks"]["horz_overlap_proportion"], - MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"], - ) - detections = [[document_model(pil_to_cv2(tile))[0] for tile in row] for row in tiles] - detections = untile_detections( - detections, - tile_size, - tile_size, - MODEL_CONFIG["intraoperative_document_landmarks"]["horz_overlap_proportion"], - MODEL_CONFIG["intraoperative_document_landmarks"]["vert_overlap_proportion"], - ) - detections = non_maximum_suppression( - detections, - overlap_comparator=intersection_over_minimum, - sorting_fn=lambda det: det.annotation.area * det.confidence, - ) - return detections - - -def make_digit_detections( - image: Image.Image, -) -> List[Detection]: - """Runs the digit detection detection model to find handwritten digits. - Args: - `image` (Image.Image): - The image to detect on. - - Returns: - A list of detections containing the locations of handwritten digits. + model_config (Dict): + The model's config dictionary. + image_size (Tuple[int, int]) """ - tile_size_proportion: float = MODEL_CONFIG["numbers"]["tile_size_proportion"] - tile_size = int( + tile_size_proportion = model_config["tile_size_proportion"] + tile_size: int = int( min( - image.size[0] * tile_size_proportion, - image.size[1] * tile_size_proportion, + image_size[0] * tile_size_proportion, + image_size[1] * tile_size_proportion, ) ) - number_detections: List[Detection] = detect_numbers( - image, - NUMBERS_MODEL, - tile_size, - tile_size, - MODEL_CONFIG["numbers"]["horz_overlap_proportion"], - MODEL_CONFIG["numbers"]["vert_overlap_proportion"], - ) - return number_detections + return tile_size def make_bp_and_hr_detections( @@ -423,110 +869,43 @@ def make_bp_and_hr_detections( Returns: A dictionary mapping timestamps to values for systolic, diastolic, and heart rate. """ - - def tile_predict( - model: ObjectDetectionModel, - image: Image.Image, - tile_width: int, - tile_height: int, - horizontal_overlap_ratio: float, - vertical_overlap_ratio: float, - ): - """Performs tiled prediction.""" - tiles: List[List[Image.Image]] = tile_image( - image, - tile_width, - tile_height, - horizontal_overlap_ratio, - vertical_overlap_ratio, - ) - tiled_detections: List[List[List[Detection]]] = [ - [model(pil_to_cv2(tile), confidence=0.5)[0] for tile in row] - for row in tiles - ] - detections: List[Detection] = untile_detections( - tiled_detections, - tile_width, - tile_height, - horizontal_overlap_ratio, - vertical_overlap_ratio, - ) - return detections - - sys_tile_size = int( - min( - image.size[0] * MODEL_CONFIG["systolic"]["tile_size_proportion"], - image.size[1] * MODEL_CONFIG["systolic"]["tile_size_proportion"], - ) - ) - dia_tile_size = int( - min( - image.size[0] * MODEL_CONFIG["diastolic"]["tile_size_proportion"], - image.size[1] * MODEL_CONFIG["diastolic"]["tile_size_proportion"], - ) - ) - hr_tile_size = int( - min( - image.size[0] * MODEL_CONFIG["heart_rate"]["tile_size_proportion"], - image.size[1] * MODEL_CONFIG["heart_rate"]["tile_size_proportion"], - ) - ) - - sys_dets: List[Detection] = tile_predict( - SYSTOLIC_MODEL, + sys_tile_size: int = compute_tile_size(MODEL_CONFIG["systolic"], image.size) + dia_tile_size: int = compute_tile_size(MODEL_CONFIG["diastolic"], image.size) + hr_tile_size: int = compute_tile_size(MODEL_CONFIG["heart_rate"], image.size) + + sys_dets: List[Detection] = detect_objects_using_tiling( image.copy(), + SYSTOLIC_MODEL, sys_tile_size, sys_tile_size, MODEL_CONFIG["systolic"]["horz_overlap_proportion"], MODEL_CONFIG["systolic"]["vert_overlap_proportion"], ) - dia_dets: List[Detection] = tile_predict( - DIASTOLIC_MODEL, + dia_dets: List[Detection] = detect_objects_using_tiling( image.copy(), + DIASTOLIC_MODEL, dia_tile_size, dia_tile_size, MODEL_CONFIG["diastolic"]["horz_overlap_proportion"], MODEL_CONFIG["diastolic"]["vert_overlap_proportion"], ) - hr_dets: List[Detection] = tile_predict( - HEART_RATE_MODEL, + hr_dets: List[Detection] = detect_objects_using_tiling( image.copy(), + HEART_RATE_MODEL, hr_tile_size, hr_tile_size, MODEL_CONFIG["heart_rate"]["horz_overlap_proportion"], MODEL_CONFIG["heart_rate"]["vert_overlap_proportion"], ) - sys_dets: List[Detection] = non_maximum_suppression( - sys_dets, - 0.5, - intersection_over_minimum, - lambda det: det.annotation.area * det.confidence, - ) - dia_dets: List[Detection] = non_maximum_suppression( - dia_dets, - 0.5, - intersection_over_minimum, - lambda det: det.annotation.area * det.confidence, - ) - hr_dets: List[Detection] = non_maximum_suppression( - hr_dets, - 0.5, - intersection_over_minimum, - lambda det: det.annotation.area * det.confidence, - ) - dets: List[Detection] = sys_dets + dia_dets + hr_dets bp_and_hr = extract_heart_rate_and_blood_pressure( dets, time_clusters, mmhg_clusters ) - return bp_and_hr -def make_intraop_checkbox_detections( - image: Image.Image, -) -> Dict: +def make_intraop_checkbox_detections(image: Image.Image) -> Dict: """Finds checkboxes on the intraoperative form, then associates a meaning to them. Args: @@ -536,21 +915,26 @@ def make_intraop_checkbox_detections( Returns: A dictionary mapping names of checkboxes to a "checked" or "unchecked" state. """ - tile_size: int = int( - min( - image.size[0] * MODEL_CONFIG["checkboxes"]["tile_size_proportion"], - image.size[1] * MODEL_CONFIG["checkboxes"]["tile_size_proportion"], - ) + tile_size = compute_tile_size(MODEL_CONFIG["checkboxes"], image.size) + detections: List[Detection] = detect_objects_using_tiling( + image, + CHECKBOXES_MODEL, + tile_size, + tile_size, + MODEL_CONFIG["checkboxes"]["horz_overlap_proportion"], + MODEL_CONFIG["checkboxes"]["vert_overlap_proportion"], + nms_threshold=0.8 ) intraop_checkboxes = extract_checkboxes( - image, CHECKBOXES_MODEL, "intraoperative", tile_size, tile_size + detections, + "intraoperative", + image.size[0], + image.size[1], ) return intraop_checkboxes -def make_preop_postop_checkbox_detections( - image: Image.Image, -): +def make_preop_postop_checkbox_detections(image: Image.Image): """Finds checkboxes on the intraoperative form, then associates a meaning to them. Args: @@ -560,13 +944,20 @@ def make_preop_postop_checkbox_detections( Returns: A dictionary mapping names of checkboxes to a "checked" or "unchecked" state. """ - tile_size: int = int( - min( - image.size[0] * MODEL_CONFIG["checkboxes"]["tile_size_proportion"], - image.size[1] * MODEL_CONFIG["checkboxes"]["tile_size_proportion"], - ) + tile_size = compute_tile_size(MODEL_CONFIG["checkboxes"], image.size) + detections: List[Detection] = detect_objects_using_tiling( + image, + CHECKBOXES_MODEL, + tile_size, + tile_size, + MODEL_CONFIG["checkboxes"]["horz_overlap_proportion"], + MODEL_CONFIG["checkboxes"]["vert_overlap_proportion"], + nms_threshold=0.8 ) preop_postop_checkboxes = extract_checkboxes( - image, CHECKBOXES_MODEL, "preoperative", tile_size, tile_size + detections, + "preoperative", + image.size[0], + image.size[1], ) return preop_postop_checkboxes diff --git a/ChartExtractor/extraction/extraction_utilities.py b/ChartExtractor/extraction/extraction_utilities.py index 4c97eaa..7b95b30 100644 --- a/ChartExtractor/extraction/extraction_utilities.py +++ b/ChartExtractor/extraction/extraction_utilities.py @@ -5,14 +5,14 @@ import json from pathlib import Path from PIL import Image -from typing import Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union # External imports import numpy as np # Internal imports from ..object_detection_models.object_detection_model import ObjectDetectionModel -from ..utilities.annotations import BoundingBox +from ..utilities.annotations import BoundingBox, Keypoint from ..utilities.detections import Detection from ..utilities.detection_reassembly import ( intersection_over_minimum, @@ -98,33 +98,59 @@ def compute_digit_distances_to_centroids( return closest_boxes -def detect_numbers( +def detect_objects_using_tiling( image: Image.Image, detection_model: ObjectDetectionModel, slice_width: int, slice_height: int, horizontal_overlap_ratio: float, vertical_overlap_ratio: float, - conf: float = 0.5, + minimum_confidence: float = 0.5, + nms_threshold: float = 0.5, + overlap_comparator: Callable[[Detection, Detection], float] = intersection_over_minimum, + sorting_fn: Callable[[Detection], float] = lambda det: det.annotation.area * det.confidence, ) -> List[Detection]: - """Detects handwritten digits on an image. + """Detects objects, especially small ones, using image tiling. + + Splits an image up into smaller tiles, runs the model on each tile, then untiles the detections + and performs non-maximum suppression on the result. Args: `image` (Image.Image): The image to detect on. `detection_model` (ObjectDetectionModel): - The digit detection model. + The detection model to use. Can be any object that implements the ObjectDetectionModel + protocol. `slice_height` (int): The height of each slice. `slice_width` (int): The width of each slice. `horizontal_overlap_ratio` (float): The amount of left-right overlap between slices. + (Ex: 0.2 results in 20% of a tile overlapping with the tile on the left and 20% + overlapping on the right.) `vertical_overlap_ratio` (float): The amount of top-bottom overlap between slices. - + (Ex: 0.2 results in 20% of a tile overlapping with the tile on the top and 20% + overlapping on the bottom.) + `minimum_confidence` (float): + The minimum confidence level. Any detection with a confidence score below this will not + be added to the returned detections. Defaults to 0.5. + `nms_threshold` (float): + The threshold above which nms registers a 'match', and deletes all but the first + detection in a 'group'. A group is determined by the sorting_fn. Defaults to 0.5. + `overlap_comparator` (float): + The function that determines how much two detections overlap. Defaults to the + intersection of the detections divided by the minimum of the two detection's areas. + This default prevents partial detections from remaining inside the full detection. + `sorting_fn` (Callable[[Detection], float]): + The function that applies a 'score' to each detection to determine which has priority + when NMS deletes detections. Only the detection with the highest score in a group + remains. Defaults to the detection's confidence times its area. + Returns: - A list of handwritten digit detections on the image. + A list of detections showing objects on the image that the object detection model was + trained to identify. """ image_tiles: List[List[Image.Image]] = tile_image( image, @@ -134,7 +160,7 @@ def detect_numbers( vertical_overlap_ratio, ) detections: List[List[List[Detection]]] = [ - [detection_model(pil_to_cv2(tile), confidence=conf)[0] for tile in row] + [detection_model(pil_to_cv2(tile), confidence=minimum_confidence)[0] for tile in row] for row in image_tiles ] detections: List[Detection] = untile_detections( @@ -146,9 +172,9 @@ def detect_numbers( ) detections: List[Detection] = non_maximum_suppression( detections=detections, - threshold=0.5, - overlap_comparator=intersection_over_minimum, - sorting_fn=lambda det: det.annotation.area * det.confidence, + threshold=nms_threshold, + overlap_comparator=overlap_comparator, + sorting_fn=sorting_fn, ) return detections @@ -205,3 +231,57 @@ def label_studio_to_bboxes( ] for sheet_data in json_data } + + +def read_detections_from_json( + filepath: Path, + detection_type: Union[BoundingBox, Keypoint] +) -> List[Detection]: + """Deserializes detections from a json file. + + Args: + filepath (Path): + The filepath to the json detections. + detection_type (Union[BoundingBox, Keypoint]): + The type of detection that has been serialized. + Passed to Detection.from_dict directly. + + Returns: + A list of Detection objects from the encoded data. + """ + json_data: Dict[str, Any] = json.loads(open(str(filepath), 'r').read()) + if not isinstance(json_data, list): + raise ValueError(f"Data at {filepath} is not a list of detections.") + + detections: List[Detection] = [ + Detection.from_dict(det_dict, detection_type) + for det_dict in json_data + ] + return detections + + + +def write_detections_to_json( + filepath: Path, + detections: List[Detection] +) -> bool: + """Serializes detections to a json file. + + Args: + filepath (Path): + The filepath to store the json detections at. + detections (List[Detection]): + The detections to serialize and save. + + Returns: + True if the writing was a success, False otherwise. + """ + detections_as_dicts: List[Dict[str, Any]] = [detection.to_dict() for detection in detections] + json_data: str = json.dumps(detections_as_dicts) + try: + with open(str(filepath), 'w') as f: + f.write(json_data) + return True + except Exception as e: + print(f"Writing detections to json generated the following error:\n{e}") + return False diff --git a/ChartExtractor/point_registration/__init__.py b/ChartExtractor/point_registration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ChartExtractor/point_registration/homography.py b/ChartExtractor/point_registration/homography.py new file mode 100644 index 0000000..0f61df7 --- /dev/null +++ b/ChartExtractor/point_registration/homography.py @@ -0,0 +1,147 @@ +"""Module for remapping points using a homography transform. + +This module exposes two functions, (1) find_homography, which is a thin wrapper around opencv's +findHomography function that restricts the original function's usage to only 2d points, and +provides more robust error messages for this libraries usage, and (2) transform_point, which takes +a point and a homography matrix and transforms the point. + +Functions: + find_homography(source_points: List[Tuple[int, int]], destination_points: List[Tuple[int, int]]) + -> np.ndarray: + Computes the homography transformation that maps the source_points array to the + destination_points array. A thin wrapper around opencv's findHomography function. + transform_point(point: Tuple[int, int], homography_matrix: np.ndarray) -> Tuple[int, int]: + Remaps a single point using the homography matrix. +""" + +# Built-in imports +from typing import List, Tuple + +# Internal imports +from ..utilities.annotations import BoundingBox, Keypoint, Point + +# External imports +from cv2 import findHomography +import numpy as np + + +def find_homography( + source_points: List[Tuple[int, int]], + destination_points: List[Tuple[int, int]], +) -> np.ndarray: + """A thin wrapper around opencv's findHomography function. + + Provides some additional checks and more informative errors. + + Args: + source_points (List[Tuple[int, int]]): + The points to move to match to destination points. + destination_points (List[Tuple[int, int]]): + The points that the source points are moved to match. + + Returns: + A numpy ndarray containing the homography matrix which can be used with transform_point + to transform points according to the transformation that remaps the source points to the + destination points. + """ + too_few_source_points: bool = len(source_points) < 4 + too_few_destination_points: bool = len(destination_points) < 4 + unequal_point_sets: bool = len(source_points) != len(destination_points) + source_points_not_two_dimensional: bool = set([len(p) for p in source_points]) != {2} + destination_points_not_two_dimensional: bool = set([len(p) for p in destination_points]) != {2} + + if too_few_source_points: + raise ValueError( + f"Too few points in source set (need at least 4, had {len(source_points)})." + ) + if too_few_destination_points: + raise ValueError( + f"Too few points in destination set (need at least 4, had {len(destination_points)})." + ) + if unequal_point_sets: + err_msg: str = "Point sets were unequal in length. " + err_msg += f"(length of source: {len(source_points)}, " + err_msg += f"length of destination: {len(destination_points)})" + raise ValueError(err_msg) + if source_points_not_two_dimensional: + err_msg: str = "Source point set contains non two dimensional points. " + err_msg += f"(Included dimensions: {set([len(p) for p in source_points])})" + raise ValueError(err_msg) + if destination_points_not_two_dimensional: + err_msg: str = "Destination point set contains non two dimensional points. " + err_msg += f"(Included dimensions: {set([len(p) for p in destination_points])})" + raise ValueError(err_msg) + + return findHomography(np.array(source_points), np.array(destination_points))[0] + + +def transform_point(point: Tuple[int, int], homography_matrix: np.ndarray) -> Tuple[int, int]: + """Remaps a single point using the homography matrix. + + Args: + point (Tuple[int, int]): + The point to remap. + homography_matrix (np.ndarray): + A homography matrix. + + Returns: + A point which has been transformed by the homography. + """ + if len(point) != 2: + raise ValueError(f"Point is not two dimensional: {point}.") + + remapped_point = homography_matrix.dot(np.array([point[0], point[1], 1])) + remapped_point /= remapped_point[2] + return (remapped_point[0], remapped_point[1]) + + +def transform_box(box: BoundingBox, homography_matrix: np.ndarray) -> BoundingBox: + """Remaps a BoundingBox using the homography matrix. + + Args: + box (BoundingBox): + The bounding box to remap. + homography_matrix (np.ndarray): + A homography matrix + + Returns: + A BoundingBox that has been transformed by the homography. + """ + remapped_top_left: Tuple[float, float] = transform_point((box.left, box.top), homography_matrix) + remapped_top_right: Tuple[float, float] = transform_point( + (box.right, box.top), + homography_matrix, + ) + remapped_bottom_left: Tuple[float, float] = transform_point( + (box.left, box.bottom), + homography_matrix, + ) + remapped_bottom_right: Tuple[float, float] = transform_point( + (box.right, box.bottom), + homography_matrix, + ) + + left = min(remapped_top_left[0], remapped_bottom_left[0]) + top = min(remapped_top_left[1], remapped_top_right[1]) + right = max(remapped_top_right[0], remapped_bottom_right[0]) + bottom = max(remapped_bottom_left[1], remapped_bottom_right[1]) + + return BoundingBox(box.category, left, top, right, bottom) + + +def transform_keypoint(keypoint: Keypoint, homography_matrix: np.ndarray) -> Keypoint: + """Remaps a Keypoint using the homography matrix. + + Args: + keypoint (Keypoint): + The keypoint to remap. + homography_matrix (np.ndarray): + A homography matrix + + Returns: + A Keypoint that has been transformed by the homography. + """ + point = (keypoint.keypoint.x, keypoint.keypoint.y) + remapped_point = transform_point(point, homography_matrix) + remapped_box = transform_box(keypoint.bounding_box, homography_matrix) + return Keypoint(Point(*remapped_point), remapped_box, do_keypoint_validation=False)