Skip to content

Enables array tests for asarray #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _unittests/onnx-numpy-skips.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# API failures
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
array_api_tests/test_creation_functions.py::test_arange
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
Expand Down
4 changes: 3 additions & 1 deletion _unittests/test_array_api.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
6 changes: 5 additions & 1 deletion _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


class TestOnnxNumpy(ExtTestCase):
def test_abs(self):
def test_empty(self):
c = EagerTensor(np.array([4, 5], dtype=np.int64))
self.assertRaise(lambda: xp.empty(c, dtype=xp.int64), RuntimeError)

def test_zeros(self):
c = EagerTensor(np.array([4, 5], dtype=np.int64))
mat = xp.zeros(c, dtype=xp.int64)
matnp = mat.numpy()
Expand Down
12 changes: 9 additions & 3 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,7 +2501,7 @@ def test_numpy_all(self):
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty(self):
def test_numpy_all_zeros(self):
data = np.zeros((0,), dtype=np.bool_)
y = np.all(data)

Expand All @@ -2513,7 +2513,7 @@ def test_numpy_all_empty(self):
self.assertEqualArray(y, got[0])

@unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0")
def test_numpy_all_empty_axis_0(self):
def test_numpy_all_zeros_axis_0(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=0)

Expand All @@ -2535,7 +2535,13 @@ def test_numpy_all_empty_axis_1(self):
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

@unittest.skipIf(True, reason="Fails to follow Array API")
def test_get_item(self):
a = EagerNumpyTensor(np.array([True], dtype=np.bool_))
i = a[0]
self.assertEqualArray(i.numpy(), a.numpy()[0])


if __name__ == "__main__":
# TestNpx().test_numpy_all_empty_axis_0()
# TestNpx().test_get_item()
unittest.main(verbosity=2)
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ jobs:
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain
displayName: "numpy test_creation_functions.py"
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
cd array-api-tests
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
displayName: "ort test_creation_functions.py"
#- script: |
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
Expand Down
10 changes: 9 additions & 1 deletion onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def template_asarray(
return a.astype(dtype=dtype)

if isinstance(a, int):
v = TEagerTensor(np.array(a, dtype=np.int64))
if a is False:
v = TEagerTensor(np.array(False, dtype=np.bool_))
elif a is True:
v = TEagerTensor(np.array(True, dtype=np.bool_))
else:
try:
v = TEagerTensor(np.asarray(a, dtype=np.int64))
except OverflowError:
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
elif isinstance(a, float):
v = TEagerTensor(np.array(a, dtype=np.float32))
elif isinstance(a, bool):
Expand Down
12 changes: 12 additions & 0 deletions onnx_array_api/array_api/onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"all",
"asarray",
"astype",
"empty",
"equal",
"isdtype",
"isfinite",
Expand Down Expand Up @@ -73,6 +74,17 @@ def ones(
return generic_ones(shape, dtype=dtype, order=order)


def empty(
shape: TensorType[ElemType.int64, "I", (None,)],
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
order: OptParType[str] = "C",
) -> TensorType[ElemType.numerics, "T"]:
raise RuntimeError(
"ONNX assumes there is no inplace implementation. "
"empty function is only used in that case."
)


def zeros(
shape: TensorType[ElemType.int64, "I", (None,)],
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
Expand Down
56 changes: 55 additions & 1 deletion onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
from onnx import ModelProto
from onnx import ModelProto, TensorProto
from onnx.reference import ReferenceEvaluator
from .._helpers import np_dtype_to_tensor_dtype
from .npx_numpy_tensors_ops import ConstantOfShape
Expand Down Expand Up @@ -183,6 +183,60 @@ def __array_namespace__(self, api_version: Optional[str] = None):
f"Unable to return an implementation for api_version={api_version!r}."
)

def __bool__(self):
"Implicit conversion to bool."
if self.dtype != DType(TensorProto.BOOL):
raise TypeError(
f"Conversion to bool only works for bool scalar, not for {self!r}."
)
if self.shape == (0,):
return False
if len(self.shape) != 0:
raise ValueError(
f"Conversion to bool only works for scalar, not for {self!r}."
)
return bool(self._tensor)

def __int__(self):
"Implicit conversion to bool."
if len(self.shape) != 0:
raise ValueError(
f"Conversion to bool only works for scalar, not for {self!r}."
)
if self.dtype not in {
DType(TensorProto.INT64),
DType(TensorProto.INT32),
DType(TensorProto.INT16),
DType(TensorProto.INT8),
DType(TensorProto.UINT64),
DType(TensorProto.UINT32),
DType(TensorProto.UINT16),
DType(TensorProto.UINT8),
}:
raise TypeError(
f"Conversion to int only works for int scalar, "
f"not for dtype={self.dtype}."
)
return int(self._tensor)

def __float__(self):
"Implicit conversion to bool."
if len(self.shape) != 0:
raise ValueError(
f"Conversion to bool only works for scalar, not for {self!r}."
)
if self.dtype not in {
DType(TensorProto.FLOAT),
DType(TensorProto.DOUBLE),
DType(TensorProto.FLOAT16),
DType(TensorProto.BFLOAT16),
}:
raise TypeError(
f"Conversion to int only works for float scalar, "
f"not for dtype={self.dtype}."
)
return float(self._tensor)


class JitNumpyTensor(NumpyTensor, JitTensor):
"""
Expand Down
4 changes: 4 additions & 0 deletions onnx_array_api/npx/npx_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ def __getitem__(self, index: Any) -> "Var":

if isinstance(index, Var):
# scenario 2
# TODO: fix this when index is an integer
new_shape = cst(np.array([-1], dtype=np.int64))
new_self = self.reshape(new_shape)
new_index = index.reshape(new_shape)
Expand All @@ -973,6 +974,9 @@ def __getitem__(self, index: Any) -> "Var":

if not isinstance(index, tuple):
index = (index,)
elif len(index) == 0:
# The array contains a scalar and it needs to be returned.
return var(self, op="Identity")

# only one integer?
ni = None
Expand Down