Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/rapidata/rapidata_client/order/rapidata_order_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _create_general_order(
responses_per_datapoint: int = 10,
validation_set_id: str | None = None,
confidence_threshold: float | None = None,
quorum_threshold: int | None = None,
filters: Sequence[RapidataFilter] | None = None,
settings: Sequence[RapidataSetting] | None = None,
selections: Sequence[RapidataSelection] | None = None,
Expand All @@ -70,15 +71,28 @@ def _create_general_order(
if selections is None:
selections = []

if not confidence_threshold:
if confidence_threshold is not None and quorum_threshold is not None:
raise ValueError(
"Cannot set both confidence_threshold and quorum_threshold. Choose one stopping strategy."
)

if confidence_threshold is None and quorum_threshold is None:
from rapidata.rapidata_client.referee._naive_referee import NaiveReferee

referee = NaiveReferee(responses=responses_per_datapoint)
elif quorum_threshold is not None:
from rapidata.rapidata_client.referee._quorum_referee import QuorumReferee

referee = QuorumReferee(
threshold=quorum_threshold,
max_votes=responses_per_datapoint,
)
else:
from rapidata.rapidata_client.referee._early_stopping_referee import (
EarlyStoppingReferee,
)

assert confidence_threshold is not None
referee = EarlyStoppingReferee(
threshold=confidence_threshold,
max_responses=responses_per_datapoint,
Expand Down Expand Up @@ -264,6 +278,7 @@ def create_compare_order(
a_b_names: list[str] | None = None,
validation_set_id: str | None = None,
confidence_threshold: float | None = None,
quorum_threshold: int | None = None,
filters: Sequence[RapidataFilter] | None = None,
settings: Sequence[RapidataSetting] | None = None,
selections: Sequence[RapidataSelection] | None = None,
Expand Down Expand Up @@ -299,6 +314,15 @@ def create_compare_order(
If provided, one validation task will be shown infront of the datapoints that will be labeled.
confidence_threshold (float, optional): The probability threshold for the comparison. Defaults to None.\n
If provided, the comparison datapoint will stop after the threshold is reached or at the number of responses, whatever happens first.
quorum_threshold (int, optional): The number of matching responses required to reach quorum. Defaults to None.\n
If provided, the comparison datapoint will stop when this many responses agree or that quorum can't be reached anymore or after responses_per_datapoint votes.
Cannot be used together with confidence_threshold.
Example:
```python
responses_per_datapoint = 10
quorum_threshold = 7
```
This will stop at 7 responses for one side or if both sides have at least 4 responses.
filters (Sequence[RapidataFilter], optional): The list of filters for the comparison. Defaults to []. Decides who the tasks should be shown to.
settings (Sequence[RapidataSetting], optional): The list of settings for the comparison. Defaults to []. Decides how the tasks should be shown.
selections (Sequence[RapidataSelection], optional): The list of selections for the comparison. Defaults to []. Decides in what order the tasks should be shown.
Expand Down Expand Up @@ -356,6 +380,7 @@ def create_compare_order(
responses_per_datapoint=responses_per_datapoint,
validation_set_id=validation_set_id,
confidence_threshold=confidence_threshold,
quorum_threshold=quorum_threshold,
filters=filters,
selections=selections,
settings=settings,
Expand Down
1 change: 1 addition & 0 deletions src/rapidata/rapidata_client/referee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._base_referee import Referee
from ._naive_referee import NaiveReferee # as MaxVoteReferee
from ._early_stopping_referee import EarlyStoppingReferee
from ._quorum_referee import QuorumReferee
75 changes: 75 additions & 0 deletions src/rapidata/rapidata_client/referee/_quorum_referee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from rapidata.rapidata_client.referee._base_referee import Referee
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from rapidata.api_client.models.i_referee_model import IRefereeModel


class QuorumReferee(Referee):
"""A referee that completes a task when a specified number of responses agree.

This referee implements a quorum-based approach where a task is completed when:
1. A minimum number of responses (threshold) agree on the same answer, OR
2. Quorum becomes mathematically impossible to reach, OR
3. The maximum number of votes is reached

For example, with threshold=7 and max_votes=10:
- Task completes when 7 responses agree (quorum reached)
- Task completes when both options have 4+ responses (quorum impossible: 4+4=8 > 10-7)
- Task completes after 10 total votes if neither condition is met

Args:
threshold (int, optional): The number of matching responses required
to reach quorum. Defaults to 3.
max_votes (int, optional): The maximum number of votes allowed
before stopping. Defaults to 5.

Attributes:
threshold (int): The number of matching responses required to reach quorum.
max_votes (int): The maximum number of votes allowed before stopping.
"""

def __init__(self, threshold: int = 3, max_votes: int = 5):
if threshold < 1:
raise ValueError("The threshold must be greater than 0.")
if max_votes < 1:
raise ValueError("The number of max_votes must be greater than 0.")
if threshold > max_votes:
raise ValueError("The threshold cannot be greater than max_votes.")

super().__init__()
if not isinstance(threshold, int) or not isinstance(max_votes, int):
raise ValueError(
"The the quorum threshold and responses_per_datapoint must be integers."
)
self.threshold = threshold
self.max_votes = max_votes

def _to_dict(self):
return {
"_t": "QuorumRefereeConfig",
"maxVotes": self.max_votes,
"threshold": self.threshold,
}

def _to_model(self) -> IRefereeModel:
from rapidata.api_client.models.i_referee_model_quorum_referee_model import (
IRefereeModelQuorumRefereeModel,
)
from rapidata.api_client.models.i_referee_model import IRefereeModel

return IRefereeModel(
actual_instance=IRefereeModelQuorumRefereeModel(
_t="QuorumReferee",
maxVotes=self.max_votes,
threshold=self.threshold,
)
)

def __str__(self) -> str:
return f"QuorumReferee(threshold={self.threshold}, max_votes={self.max_votes})"

def __repr__(self) -> str:
return self.__str__()