Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions .claude/settings.local.json

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,4 @@ customers/
*.ipynb

.idea/
.claude/
132 changes: 1 addition & 131 deletions src/rapidata/rapidata_client/order/_rapidata_order_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Optional, Sequence, get_args
import random
import secrets

from rapidata.rapidata_client.datapoints._datapoint import Datapoint
Expand Down Expand Up @@ -30,11 +29,7 @@
from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
from rapidata.rapidata_client.settings import RapidataSetting
from rapidata.rapidata_client.workflow import (
Workflow,
CompareWorkflow,
ClassifyWorkflow,
)
from rapidata.rapidata_client.workflow import Workflow
from rapidata.rapidata_client.selection import (
ConditionalValidationSelection,
LabelingSelection,
Expand Down Expand Up @@ -75,7 +70,6 @@ def __init__(
self._priority: int | None = None
self._datapoints: list[Datapoint] = []
self._sticky_state_value: StickyStateLiteral | None = None
self._temporary_sticky_enabled: bool = False
self._validation_set_manager: ValidationSetManager = ValidationSetManager(
self._openapi_service
)
Expand All @@ -98,11 +92,6 @@ def _to_model(self) -> CreateOrderModel:
self._referee = NaiveReferee()

sticky_state = self._sticky_state_value
if not sticky_state and self._temporary_sticky_enabled:
sticky_state = "Temporary"
logger.debug(
"Setting sticky state to Temporary due to temporary sticky enabled."
)

validation_set_id = (
self._validation_set.id
Expand Down Expand Up @@ -137,112 +126,6 @@ def _generate_id(self, length=9):
alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
return "".join(secrets.choice(alphabet) for _ in range(length))

def _attach_validation_set_id(self) -> None:
"""
Sets the validation set for the order.
"""
assert self._workflow is not None

required_amount = min(int(len(self._datapoints) * 0.01) or 3, 15)

if self._validation_set is None:
try:
with suppress_rapidata_error_logging():
val_set_id = (
(
self._openapi_service.validation_api.validation_set_recommended_get(
asset_type=[self._datapoints[0].get_asset_type()],
modality=[self._workflow.modality],
instruction=self._workflow._get_instruction(),
prompt_type=[
t.value
for t in self._datapoints[0].get_prompt_type()
],
)
)
.validation_sets[0]
.id
)
self._validation_set = (
self._validation_set_manager.get_validation_set_by_id(
val_set_id
)
)

except Exception as e:
logger.debug("No recommended validation set found, error: %s", e)

sufficient_rapids_count = False
if self._validation_set is not None:
sufficient_rapids_count = (
self._validation_set_manager._get_total_and_labeled_rapids_count(
self._validation_set.id
)[1]
>= required_amount
)

if self._validation_set is None or not sufficient_rapids_count:
if (
len(self._datapoints)
< rapidata_config.order.minOrderDatapointsForValidation
):
logger.debug(
"No recommended validation set found, dataset too small to create one."
)
return

logger.info("No recommended validation set found, creating new one.")
managed_print()
managed_print(
f"No recommended validation set found, new one will be created."
)

new_dimension = self._generate_id()
logger.debug("New dimension created: %s", new_dimension)
rng = random.Random(42)
self._validation_set = (
self._validation_set_manager._create_order_validation_set(
workflow=self._workflow,
name=self._name,
datapoints=rng.sample(
self._datapoints, len(self._datapoints)
), # shuffle the datapoints with a specific seed
required_amount=required_amount,
settings=self._settings,
dimensions=[new_dimension],
)
)

self._validation_set.update_should_alert(False)
self._validation_set.update_can_be_flagged(False)

logger.debug("New validation set created for order: %s", self._validation_set)

self._selections = [
CappedSelection(
selections=[
ConditionalValidationSelection(
validation_set_id=self._validation_set.id,
dimensions=self._validation_set.dimensions,
thresholds=[0, 0.5, 0.7],
chances=[1, 1, 0.2],
rapid_counts=[10, 1, 1],
),
LabelingSelection(amount=1),
],
max_rapids=3,
)
]

for dimension in self._validation_set.dimensions:
self._user_filters.append(
UserScoreFilter(
lower_bound=0.3,
upper_bound=1,
dimension=dimension,
)
)

def _create(self) -> RapidataOrder:
"""
Create the Rapidata order by making the necessary API calls based on the builder's configuration.
Expand All @@ -254,22 +137,9 @@ def _create(self) -> RapidataOrder:
Returns:
RapidataOrder: The created RapidataOrder instance.
"""
if (
rapidata_config.order.autoValidationSetCreation
and isinstance(self._workflow, (CompareWorkflow, ClassifyWorkflow))
and not self._selections
and rapidata_config.enableBetaFeatures
):
self._attach_validation_set_id()
self._temporary_sticky_enabled = True
logger.debug("Temporary sticky enabled for order creation.")

order_model = self._to_model()
logger.debug("Creating order with model: %s", order_model)

self._temporary_sticky_enabled = False
logger.debug("Disabling temporary sticky after order creation.")

result = self._openapi_service.order_api.order_post(
create_order_model=order_model
)
Expand Down
104 changes: 0 additions & 104 deletions src/rapidata/rapidata_client/validation/validation_set_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,110 +630,6 @@ def _submit(

return validation_set

def _create_order_validation_set(
self,
workflow: Workflow,
name: str,
datapoints: list[Datapoint],
required_amount: int,
settings: Sequence[RapidataSetting] | None = None,
dimensions: list[str] = [],
) -> RapidataValidationSet:
with tracer.start_as_current_span(
"ValidationSetManager._create_order_validation_set"
):
rapids: list[Rapid] = []
for datapoint in workflow._format_datapoints(datapoints):
rapids.append(
Rapid(
asset=datapoint.asset,
payload=workflow._to_payload(datapoint),
context=datapoint.context,
media_context=datapoint.media_context,
data_type=datapoint.data_type,
settings=settings,
)
)
validation_set = RapidataValidationSet(
validation_set_id=self._openapi_service.validation_api.validation_set_post(
create_validation_set_model=CreateValidationSetModel(name=name)
).validation_set_id,
name=name,
dimensions=dimensions,
openapi_service=self._openapi_service,
)

managed_print()
managed_print(
Fore.YELLOW
+ f"A new validation set was created. Please annotate {required_amount} datapoint{('s' if required_amount != 1 else '')} before the order can run."
+ Fore.RESET
)

link = f"https://app.{self._openapi_service.environment}/validation-set/detail/{validation_set.id}/annotate?maxSize={len(datapoints)}&required={required_amount}"
could_open_browser = webbrowser.open(link)
if not could_open_browser:
encoded_url = urllib.parse.quote(link, safe="%/:=&?~#+!$,;'@()*[]")
managed_print(
Fore.RED
+ f"Please open this URL in your browser to annotate the validation set: '{encoded_url}'"
+ Fore.RESET
)
else:
managed_print(
Fore.YELLOW
+ f"Please annotate the validation set. \n'{link}'"
+ Fore.RESET
)

with tracer.start_as_current_span("Annotating validation set"):
progress_bar = tqdm(
total=required_amount,
desc="Annotate the validation set",
disable=rapidata_config.logging.silent_mode,
)

rapid_index = 0
while True:
total_rapids, labeled_rapids = (
self._get_total_and_labeled_rapids_count(validation_set.id)
)

progress_bar.n = labeled_rapids
progress_bar.refresh()

if labeled_rapids >= required_amount:
break

if total_rapids < required_amount and rapid_index >= len(rapids):
managed_print(
Fore.RED
+ f"""Warning: An order can only be started with at least {required_amount} annotated validation tasks. But only {labeled_rapids}/{required_amount} were annotated.
Either add clearer examples or turn off the 'autoValidationSetCreation' with:

from rapidata import rapidata_config
rapidata_config.order.autoValidationSetCreation = False"""
+ Fore.RESET
)
raise RuntimeError(
f"Not enough rapids annotated. Required: {required_amount}, Annotated: {labeled_rapids}"
)

if (
rapid_index < len(rapids)
and total_rapids - labeled_rapids <= required_amount * 2
):
validation_set.add_rapid(rapids[rapid_index])
rapid_index += 1

time.sleep(2)

progress_bar.close()

validation_set.update_dimensions(dimensions)

return validation_set

def get_validation_set_by_id(self, validation_set_id: str) -> RapidataValidationSet:
"""Get a validation set by ID.

Expand Down