diff --git a/ChartExtractor/utilities/annotations.py b/ChartExtractor/utilities/annotations.py index cf4ea84..8a49c1b 100644 --- a/ChartExtractor/utilities/annotations.py +++ b/ChartExtractor/utilities/annotations.py @@ -5,6 +5,7 @@ # Built-in Imports from dataclasses import dataclass +import json from typing import Dict, List, Tuple import warnings @@ -29,8 +30,17 @@ def __eq__(self, other): return self.x == other.x and self.y == other.y def __repr__(self): - """Returns a string representation of this Point object.""" + """Returns a string representation of this `Point` object.""" return f"Point({self.x}, {self.y})" + + @classmethod + def from_dict(cls, point_dict: Dict[str, float]): + """Creates a `Point` from a dictionary.""" + return Point(**point_dict) + + def to_dict(self) -> str: + """Returns a json serialized version of the point.""" + return vars(self) @dataclass @@ -53,10 +63,13 @@ class BoundingBox: Constructors : `from_yolo(yolo_line: str, image_width: int, image_height: int, int_to_category: Dict[int, str])`: - Constructs a `BoundingBox` from a line in a YOLO formatted labels file. It requires the original image dimensions and a dictionary mapping category IDs to category names. + Constructs a `BoundingBox` from a line in a YOLO formatted labels file. + It requires the original image dimensions and a dictionary mapping category IDs to + category names. `from_coco(coco_annotation: Dict, categories: List[Dict])`: - Constructs a `BoundingBox` from an annotation in a COCO data JSON file. It requires the annotation dictionary and a list of category dictionaries. + Constructs a `BoundingBox` from an annotation in a COCO data JSON file. + It requires the annotation dictionary and a list of category dictionaries. Properties : @@ -91,7 +104,7 @@ def __init__( self.top = top self.right = right self.bottom = bottom - + @staticmethod def from_yolo( yolo_line: str, @@ -160,7 +173,20 @@ def from_coco(coco_annotation: Dict, categories: List[Dict]): f"Category {int(coco_annotation['category_id'])} not found in the categories list." ) return BoundingBox(category, left, top, right, bottom) + + @staticmethod + def from_dict(bbox_dict: Dict[str, float]): + """Constructs a `BoundingBox` from a dictionary of arguments. + + Args: + `bbox_dict` (Dict[str, float]): + A dictionary containing entries corresponding to the four bounding box sides. + Returns: + A `BoundingBox` object containing the data from the dictionary. + """ + return BoundingBox(**bbox_dict) + @classmethod def validate_box_values( cls, left: float, top: float, right: float, bottom: float @@ -282,6 +308,10 @@ def to_yolo( h = (self.bottom - self.top) / image_height return f"{c} {x:.{precision}f} {y:.{precision}f} {w:.{precision}f} {h:.{precision}f}" + def to_dict(self) -> dict: + """Returns a dictionary with all the attributes of this """ + return vars(self) + @dataclass class Keypoint: @@ -289,14 +319,17 @@ class Keypoint: Attributes : `keypoint` (Tuple[float]): - A tuple containing the (x, y) coordinates of the keypoint relative to the top-left corner of the image. + A tuple containing the (x, y) coordinates of the keypoint relative to the top-left + corner of the image. `bounding_box` (BoundingBox): A `BoundingBox` object that defines the bounding box around the object containing the keypoint. Constructors : `from_yolo(yolo_line: str, image_width: int, image_height: int, id_to_category: Dict[int, str])`: - Constructs a Keypoint from a line in a YOLO formatted labels file. It requires the original image dimensions and a dictionary mapping category IDs to category names. + Constructs a Keypoint from a line in a YOLO formatted labels file. + It requires the original image dimensions and a dictionary mapping category IDs to + category names. **Note:** This method ignores the "visibility" information (denoted by 'v') in the YOLO format. @@ -306,14 +339,18 @@ class Keypoint: `center` (Tuple[float]): The (x, y) coordinates of the bounding box's center (inherited from the `bounding_box`). `box` (Tuple[float]): - A list containing the bounding box coordinates as [left, top, right, bottom] (inherited from the `bounding_box`). + A list containing the bounding box coordinates as [left, top, right, bottom] + (inherited from the `bounding_box`). Methods : `to_yolo(self, image_width: int, image_height: int, category_to_id: Dict[str, int]) -> str`: - Generates a YOLO formatted string representation of this `Keypoint` object. It requires the image dimensions and a dictionary mapping category strings to integer labels. + Generates a YOLO formatted string representation of this `Keypoint` object. + It requires the image dimensions and a dictionary mapping category strings to integer + labels. `validate_keypoint(cls, bounding_box: BoundingBox, keypoint: Point) -> None`: - Validates that a keypoint lies within the specified bounding box. Raises a ValueError if the keypoint is outside the bounding box. + Validates that a keypoint lies within the specified bounding box. + Raises a ValueError if the keypoint is outside the bounding box. """ keypoint: Point @@ -370,6 +407,25 @@ def from_yolo( keypoint = Point(keypoint_x * image_width, keypoint_y * image_height) return Keypoint(keypoint, bounding_box, do_keypoint_validation) + @staticmethod + def from_dict(keypoint_dict: Dict[str, Dict[str, float]], do_validation: bool = True): + """Constructs a `Keypoint` from a dictionary of arguments. + + Args: + `keypoint_dict` (Dict[str, float]): + A dictionary containing entries that are passed in turn to the `BoundingBox` and + `Point` class' `from_dict` methods. + `do_validation` (bool): + A boolean encoding whether or not to do validation to make sure the keypoint is in + the bounding box. + + Returns: + A `Keypoint` object containing the data from the dictionary. + """ + keypoint: Point = Point.from_dict(keypoint_dict["keypoint"]) + bounding_box: BoundingBox = BoundingBox.from_dict(keypoint_dict["bounding_box"]) + return Keypoint(keypoint, bounding_box) + @classmethod def validate_keypoint(cls, bounding_box: BoundingBox, keypoint: Point) -> None: """Validates that a keypoint lies within the specified bounding box. @@ -392,9 +448,9 @@ def validate_keypoint(cls, bounding_box: BoundingBox, keypoint: Point) -> None: in_bounds_y: bool = bounding_box.top <= keypoint.y <= bounding_box.bottom in_bounds: bool = in_bounds_x and in_bounds_y if not in_bounds: - raise ValueError( - f"Keypoint is not in the bounding box intended to enclose it (Keypoint:{(keypoint.x, keypoint.y)}, BoundingBox:{str(bounding_box)})" - ) + err_msg: str = "Keypoint is not in the bounding box intended to enclose it " + err_msg += f"(Keypoint:{(keypoint.x, keypoint.y)}, BoundingBox:{str(bounding_box)})" + raise ValueError(err_msg) @property def category(self) -> str: @@ -601,3 +657,10 @@ def to_yolo( else: yolo_line += f" {keypoint_x:.{precision}f} {keypoint_y:.{precision}f}" return yolo_line + + def to_dict(self) -> dict: + """Converts this keypoint to a dictionary of its variables.""" + return { + "bounding_box": self.bounding_box.to_dict(), + "keypoint": self.keypoint.to_dict(), + } diff --git a/ChartExtractor/utilities/detections.py b/ChartExtractor/utilities/detections.py index ccb2783..63a807f 100644 --- a/ChartExtractor/utilities/detections.py +++ b/ChartExtractor/utilities/detections.py @@ -8,7 +8,7 @@ # Built-in Imports from dataclasses import dataclass -from typing import Union +from typing import Any, Dict, Union # Internal Imports from ..utilities.annotations import BoundingBox, Keypoint @@ -29,3 +29,29 @@ class Detection: annotation: Union[BoundingBox, Keypoint] confidence: float + + @staticmethod + def from_dict(detection_dict: Dict[str, Any], annotation_type: Union[BoundingBox, Keypoint]): + """Creates a `Detection` from a dictionary of data. + + Args: + detection_dict (Dict[str, Any]): + The dictionary with the `Detection` data. + annotation_type (Union[BoundingBox, Keypoint]): + The type of annotation in the + + Returns: + A `Detection` object with the data from the dictionary. + """ + annotation: Union[BoundingBox, Keypoint] = annotation_type.from_dict( + detection_dict["annotation"] + ) + confidence: float = detection_dict["confidence"] + return Detection(annotation, confidence) + + def to_dict(self) -> Dict[str, Any]: + """Converts this detection to a dictionary.""" + return { + "annotation": self.annotation.to_dict(), + "confidence": self.confidence + } diff --git a/tests/unit_tests/test_annotations.py b/tests/unit_tests/test_annotations.py index 7590d80..72b781b 100644 --- a/tests/unit_tests/test_annotations.py +++ b/tests/unit_tests/test_annotations.py @@ -14,6 +14,31 @@ class TestBoundingBox: def test_init(self): """Tests the init function with valid parameters.""" BoundingBox("Test", 0, 0, 1, 1) + + def test_from_dict(self): + """Tests the from_dict constructor.""" + bb_dict = { + "left": 1, + "right": 2, + "top": 3, + "bottom": 4, + "category": "Test" + } + true_bbox = BoundingBox("Test", 1, 3, 2, 4) + assert BoundingBox.from_dict(bb_dict) == true_bbox + + def test_from_dict_fails(self): + """Tests the from_dict constructor when the dictionary contains an erroneous entry.""" + bb_dict = { + "left": 1, + "right": 2, + "top": 3, + "bottom": 4, + "category": "Test", + "other": "thing" + } + with pytest.raises(TypeError): + BoundingBox.from_dict(bb_dict) # from_yolo def test_from_yolo(self): @@ -102,6 +127,18 @@ def test_box(self): """Tests the 'box' property.""" bbox = BoundingBox("Test", 0, 0, 1, 1) assert [0, 0, 1, 1] == bbox.box + + def test_to_dict(self): + """Tests the to_dict method.""" + bbox = BoundingBox("Test", 0, 2, 1, 3) + true_dict = { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test" + } + assert bbox.to_dict() == true_dict # to_yolo def test_to_yolo(self): @@ -123,7 +160,27 @@ def test_init(self): kp = Point(0.25, 0.25) bbox = BoundingBox("Test", 0, 0, 1, 1) Keypoint(kp, bbox) - + + def test_from_dict(self): + """Test the from_dict constructor.""" + keypoint_dict = { + "keypoint": { + "x": 0.5, + "y": 2.25, + }, + "bounding_box": { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test" + }, + } + true_point = Point(0.5, 2.25) + true_bounding_box = BoundingBox("Test", 0, 2, 1, 3) + true_keypoint = Keypoint(true_point, true_bounding_box) + assert Keypoint.from_dict(keypoint_dict) == true_keypoint + # from_yolo def test_from_yolo(self): """Tests the from_yolo constructor.""" @@ -155,6 +212,27 @@ def test_validate_keypoint_out_of_bounds_y(self): # Below box with pytest.raises(ValueError): Keypoint(Point(1, 4), BoundingBox("Test", 0, 2, 2, 3)) + + def test_to_dict(self): + """Tests the to_dict method.""" + point = Point(0.5, 2.25) + bbox = BoundingBox("Test", 0, 2, 1, 3) + kp = Keypoint(point, bbox) + kp_dict = kp.to_dict() + true_dict = { + "keypoint": { + "x": 0.5, + "y": 2.25, + }, + "bounding_box": { + "left": 0, + "top": 2, + "right": 1, + "bottom": 3, + "category": "Test", + }, + } + assert kp_dict == true_dict # to_yolo def test_to_yolo(self): diff --git a/tests/unit_tests/test_detections.py b/tests/unit_tests/test_detections.py new file mode 100644 index 0000000..a0bced1 --- /dev/null +++ b/tests/unit_tests/test_detections.py @@ -0,0 +1,98 @@ +"""Tests the detections module's Detection class.""" + +# External Imports +import pytest + +# Internal Imports +from ChartExtractor.utilities.annotations import BoundingBox, Keypoint, Point +from ChartExtractor.utilities.detections import Detection + + +class TestDetection: + """Tests the Detection class.""" + + def test_from_dict_bounding_box(self): + """Tests the from_dict constructor with a bounding box.""" + det_dict = { + "annotation": { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test", + }, + "confidence": 0.8, + } + true_det = Detection( + annotation=BoundingBox("Test", 0, 2, 1, 3), + confidence=0.8, + ) + assert Detection.from_dict(det_dict, BoundingBox) == true_det + + def test_from_dict_keypoint(self): + """Tests the from_dict constructor with a keypoint.""" + det_dict = { + "annotation": { + "bounding_box": { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test", + }, + "keypoint": { + "x": 0.5, + "y": 2.25, + }, + }, + "confidence": 0.8, + } + true_det = Detection( + annotation=Keypoint(Point(0.5, 2.25), BoundingBox("Test", 0, 2, 1, 3)), + confidence=0.8, + ) + assert Detection.from_dict(det_dict, Keypoint) == true_det + + def test_from_dict_bounding_box(self): + """Tests the to_dict method with a bounding box.""" + det = Detection( + annotation=BoundingBox("Test", 0, 2, 1, 3), + confidence=0.8, + ) + true_dict = { + "annotation": { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test", + }, + "confidence": 0.8, + } + assert det.to_dict() == true_dict + + def test_to_dict_keypoint(self): + """Tests the to_dict method with a keypoint.""" + det = Detection( + annotation=Keypoint(Point(0.5, 2.25), BoundingBox("Test", 0, 2, 1, 3)), + confidence=0.8, + ) + true_dict = { + "annotation": { + "bounding_box": { + "left": 0, + "right": 1, + "top": 2, + "bottom": 3, + "category": "Test", + }, + "keypoint": { + "x": 0.5, + "y": 2.25, + }, + }, + "confidence": 0.8, + } + + assert det.to_dict() == true_dict +