From ea6a79e2935c301b9c62f6a07723b37f1af73b71 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 18:28:39 +0100 Subject: [PATCH 1/3] ENH: searchsorted: allow python scalars for `x2` cross-ref https://github.com/data-apis/array-api/pull/982 --- array_api_tests/hypothesis_helpers.py | 7 +++- array_api_tests/test_searching_functions.py | 46 +++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index e1df108c..cef157b6 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -454,9 +454,12 @@ def scalars(draw, dtypes, finite=False, **kwds): """ Strategy to generate a scalar that matches a dtype strategy - dtypes should be one of the shared_* dtypes strategies. + dtypes should be one of the shared_* dtypes strategies or a sequence of dtypes. """ - dtype = draw(dtypes) + if isinstance(dtypes, Sequence): + dtype = draw(sampled_from(dtypes)) + else: + dtype = draw(dtypes) mM = kwds.pop('mM', None) if dh.is_int_dtype(dtype): if mM is None: diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 26a348d1..4db1a5ac 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -291,3 +291,49 @@ def test_searchsorted(data): except Exception as exc: ph.add_note(exc, repro_snippet) raise + + +### @pytest.mark.min_version("2025.12") +@given(data=st.data()) +def test_searchsorted_with_scalars(data): + # 1. draw x1, sorter and side exactly the same as in test_searchsorted + x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) + _x1 = data.draw( + st.lists( + xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False), + min_size=1, + unique=True + ), + label="_x1", + ) + x1 = xp.asarray(_x1, dtype=x1_dtype) + if data.draw(st.booleans(), label="use sorter?"): + sorter = xp.argsort(x1) + else: + sorter = None + x1 = xp.sort(x1) + + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) + + # 2. draw x2, a real-valued scalar + # TODO: draw x2 of promotion compatible dtype (int for float x1 etc) -- cf gh-364 + x2 = data.draw(hh.scalars(st.just(x1.dtype), finite=True)) + + # 3. testing: similar to test_searchsorted, modulo `out.shape == ()` + repro_snippet = ph.format_snippet( + f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }" + ) + try: + out = xp.searchsorted(x1, x2, sorter=sorter, **kw) + + ph.assert_dtype( + "searchsorted", + in_dtype=[x1.dtype], #, x2.dtype + out_dtype=out.dtype, + expected=xp.__array_namespace_info__().default_dtypes()["indexing"], + ) + # TODO: values testing + ph.assert_shape("searchsorted", out_shape=out.shape, expected=()) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise From 94861f74d66ee17d7db5438c155260f3815cc936 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 Jan 2026 19:52:03 +0100 Subject: [PATCH 2/3] SQUASH?: restore hh.scalars strategy --- array_api_tests/hypothesis_helpers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index cef157b6..e1df108c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -454,12 +454,9 @@ def scalars(draw, dtypes, finite=False, **kwds): """ Strategy to generate a scalar that matches a dtype strategy - dtypes should be one of the shared_* dtypes strategies or a sequence of dtypes. + dtypes should be one of the shared_* dtypes strategies. """ - if isinstance(dtypes, Sequence): - dtype = draw(sampled_from(dtypes)) - else: - dtype = draw(dtypes) + dtype = draw(dtypes) mM = kwds.pop('mM', None) if dh.is_int_dtype(dtype): if mM is None: From 390b7950c46296c426fdb2ee6f2fcf3773b2a226 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 19 Jan 2026 10:17:16 +0100 Subject: [PATCH 3/3] ENH: draw int scalars too for x2 when x1 is a float array --- array_api_tests/test_searching_functions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 4db1a5ac..b83ff964 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -316,8 +316,9 @@ def test_searchsorted_with_scalars(data): kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) # 2. draw x2, a real-valued scalar - # TODO: draw x2 of promotion compatible dtype (int for float x1 etc) -- cf gh-364 - x2 = data.draw(hh.scalars(st.just(x1.dtype), finite=True)) + # NB: for a float-dtype x1 array, draw python ints or floats + x2 = data.draw(hh.scalars(st.sampled_from([x1.dtype, xp.int32]), finite=True)) + # 3. testing: similar to test_searchsorted, modulo `out.shape == ()` repro_snippet = ph.format_snippet(