diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index e1df108c..25d68324 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -457,13 +457,12 @@ def scalars(draw, dtypes, finite=False, **kwds): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - mM = kwds.pop('mM', None) if dh.is_int_dtype(dtype): - if mM is None: - m, M = dh.dtype_ranges[dtype] - else: - m, M = mM - return draw(integers(m, M)) + m, M = dh.dtype_ranges[dtype] + min_value = kwds.get('min_value', m) + max_value = kwds.get('max_value', M) + + return draw(integers(min_value, max_value)) elif dtype == bool_dtype: return draw(booleans()) elif dtype == float64: @@ -593,20 +592,32 @@ def two_mutual_arrays( @composite -def array_and_py_scalar(draw, dtypes, mM=None, positive=False): +def array_and_py_scalar(draw, dtypes, **kwds): """Draw a pair: (array, scalar) or (scalar, array).""" dtype = draw(sampled_from(dtypes)) - scalar_var = draw(scalars(just(dtype), finite=True, mM=mM)) - if positive: - assume (scalar_var > 0) + # draw the scalar: for float arrays, draw a float or an int + if dtype in dh.real_float_dtypes: + scalar_strategy = sampled_from([xp.int32, dtype]) + else: + scalar_strategy = just(dtype) + scalar_var = draw(scalars(scalar_strategy, finite=True, **kwds)) + # draw the array. + # XXX artificially limit the range of values for floats, otherwise value testing is flaky elements={} if dtype in dh.real_float_dtypes: - elements = {'allow_nan': False, 'allow_infinity': False, - 'min_value': 1.0 / (2<<5), 'max_value': 2<<5} - if positive: - elements = {'min_value': 0} + elements = { + 'allow_nan': False, + 'allow_infinity': False, + 'min_value': kwds.get('min_value', 1.0 / (2<<5)), + 'max_value': kwds.get('max_value', 2<<5) + } + elif dtype in dh.int_dtypes: + elements = { + 'min_value': kwds.get('min_value', None), + 'max_value': kwds.get('max_value', None) + } array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements)) if draw(booleans()): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index f074bd8e..1652a9d9 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1067,14 +1067,20 @@ def test_clip(x, data): base_shape=x.shape), label="min.shape, max.shape") + # for min,max being scalars: clip(float_array, min=int_scalar, max=int_scalar) + if x.dtype in dh.real_float_dtypes: + scalar_strategy = st.sampled_from([xp.int32, x.dtype]) + else: + scalar_strategy = st.just(x.dtype) + min = data.draw(st.one_of( st.none(), - hh.scalars(dtypes=st.just(x.dtype)), + hh.scalars(dtypes=scalar_strategy), hh.arrays(dtype=st.just(x.dtype), shape=shape1), ), label="min") max = data.draw(st.one_of( st.none(), - hh.scalars(dtypes=st.just(x.dtype)), + hh.scalars(dtypes=scalar_strategy), hh.arrays(dtype=st.just(x.dtype), shape=shape2), ), label="max") @@ -2246,7 +2252,7 @@ def test_binary_with_scalars_bitwise(func_data, x1x2): ], ids=lambda func_data: func_data[0] # use names for test IDs ) -@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3))) +@given(x1x2=hh.array_and_py_scalar([xp.int32], min_value=1, max_value=3)) def test_binary_with_scalars_bitwise_shifts(func_data, x1x2): func_name, refimpl, kwargs, expected = func_data # repack the refimpl