Skip to content

Add user-defined types #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,3 @@ m = gb.io.to_scipy_sparse_matrix(m, format='csr')
A = gb.io.from_networkx(g)
g = gb.io.to_networkx(A)
```

## Attribution
This library borrows some great ideas from [pygraphblas](https://github.com/michelp/pygraphblas),
especially around parsing operator names from SuiteSparse and the concept of a Scalar which the backend
implementation doesn't need to know about.

8 changes: 0 additions & 8 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,6 @@ Import/Export connectors to the Python ecosystem
A = gb.io.from_networkx(g)
g = gb.io.to_networkx(A)

Attribution
-----------

This library borrows some great ideas from `pygraphblas <https://github.com/michelp/pygraphblas>`_,
especially around parsing operator names from SuiteSparse and the concept of a Scalar which the backend
implementation doesn't need to know about.


Indices and tables
==================

Expand Down
44 changes: 29 additions & 15 deletions grblas/_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np

from . import agg, binary, monoid, semiring, unary
from .dtypes import lookup_dtype, unify
from .operator import _normalize_type
from .dtypes import INT64, lookup_dtype
from .operator import get_typed_op
from .scalar import Scalar


Expand All @@ -14,8 +14,8 @@ def _get_types(ops, initdtype):
if initdtype is None:
prev = dict(ops[0].types)
else:
initdtype = lookup_dtype(initdtype)
prev = {key: unify(lookup_dtype(val), initdtype).name for key, val in ops[0].types.items()}
op = ops[0]
prev = {key: get_typed_op(op, key, initdtype).return_type for key in op.types}
for op in ops[1:]:
cur = {}
types = op.types
Expand Down Expand Up @@ -43,11 +43,12 @@ def __init__(
composite=None,
custom=None,
types=None,
any_dtype=None,
):
self.name = name
self._initval_orig = initval
self._initval = False if initval is None else initval
self._initdtype = lookup_dtype(type(self._initval))
self._initdtype = lookup_dtype(type(self._initval), self._initval)
self._monoid = monoid
self._semiring = semiring
self._semiring2 = semiring2
Expand All @@ -68,6 +69,7 @@ def __init__(
self._types_orig = types
self._types = None
self._typed_ops = {}
self._any_dtype = any_dtype

@property
def types(self):
Expand All @@ -82,16 +84,16 @@ def types(self):
return self._types

def __getitem__(self, dtype):
dtype = _normalize_type(dtype)
if dtype not in self.types:
dtype = lookup_dtype(dtype)
if not self._any_dtype and dtype not in self.types:
raise KeyError(f"{self.name} does not work with {dtype}")
if dtype not in self._typed_ops:
self._typed_ops[dtype] = TypedAggregator(self, dtype)
return self._typed_ops[dtype]

def __contains__(self, dtype):
dtype = _normalize_type(dtype)
return dtype in self.types
dtype = lookup_dtype(dtype)
return self._any_dtype or dtype in self.types

def __repr__(self):
return f"agg.{self.name}"
Expand All @@ -107,7 +109,12 @@ def __init__(self, agg, dtype):
self.name = agg.name
self.parent = agg
self.type = dtype
self.return_type = agg.types[dtype]
if dtype in agg.types:
self.return_type = agg.types[dtype]
elif agg._any_dtype is True:
self.return_type = dtype
else:
self.return_type = agg._any_dtype

def __repr__(self):
return f"agg.{self.name}[{self.type}]"
Expand Down Expand Up @@ -160,8 +167,7 @@ def _new(self, updater, expr, *, in_composite=False):
if agg._custom is not None:
return agg._custom(self, updater, expr, in_composite=in_composite)

dtype = unify(lookup_dtype(self.type), lookup_dtype(agg._initdtype))
semiring = agg._semiring[dtype]
semiring = get_typed_op(agg._semiring, self.type, agg._initdtype)
if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator":
# Matrix -> Vector
A = expr.args[0]
Expand Down Expand Up @@ -242,13 +248,15 @@ def __reduce__(self):
agg.any = Aggregator("any", monoid=monoid.lor)
agg.min = Aggregator("min", monoid=monoid.min)
agg.max = Aggregator("max", monoid=monoid.max)
agg.any_value = Aggregator("any_value", monoid=monoid.any)
agg.any_value = Aggregator("any_value", monoid=monoid.any, any_dtype=True)
agg.bitwise_all = Aggregator("bitwise_all", monoid=monoid.band)
agg.bitwise_any = Aggregator("bitwise_any", monoid=monoid.bor)
# Other monoids: bxnor bxor eq lxnor lxor

# Semiring-only
agg.count = Aggregator("count", semiring=semiring.plus_pair, semiring2=semiring.plus_first)
agg.count = Aggregator(
"count", semiring=semiring.plus_pair, semiring2=semiring.plus_first, any_dtype=INT64
)
agg.count_nonzero = Aggregator(
"count_nonzero", semiring=semiring.plus_isne, semiring2=semiring.plus_first
)
Expand All @@ -264,7 +272,9 @@ def __reduce__(self):
semiring=semiring.plus_pow,
semiring2=semiring.plus_first,
)
agg.exists = Aggregator("exists", semiring=semiring.any_pair, semiring2=semiring.any_pair)
agg.exists = Aggregator(
"exists", semiring=semiring.any_pair, semiring2=semiring.any_pair, any_dtype=INT64
)

# Semiring and finalize
agg.hypot = Aggregator(
Expand Down Expand Up @@ -564,11 +574,13 @@ def _first_last(agg, updater, expr, *, in_composite, semiring_):
"first",
custom=partial(_first_last, semiring_=semiring.min_secondi),
types=[binary.first],
any_dtype=True,
)
agg.last = Aggregator(
"last",
custom=partial(_first_last, semiring_=semiring.max_secondi),
types=[binary.second],
any_dtype=True,
)


Expand Down Expand Up @@ -601,9 +613,11 @@ def _first_last_index(agg, updater, expr, *, in_composite, semiring):
"first_index",
custom=partial(_first_last_index, semiring=semiring.min_secondi),
types=[semiring.min_secondi],
any_dtype=INT64,
)
agg.last_index = Aggregator(
"last_index",
custom=partial(_first_last_index, semiring=semiring.max_secondi),
types=[semiring.min_secondi],
any_dtype=INT64,
)
32 changes: 18 additions & 14 deletions grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"M_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

nrows = parent._nrows
Expand Down Expand Up @@ -847,25 +847,27 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent.gb_obj = ffi.NULL
else:
parent.clear()
return rv
elif format == "coor":
info = self._export(
rv = self._export(
"csr", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["rows"] = indptr_to_indices(info.pop("indptr"))
info["cols"] = info.pop("col_indices")
info["sorted_rows"] = True
info["format"] = "coor"
return info
rv["rows"] = indptr_to_indices(rv.pop("indptr"))
rv["cols"] = rv.pop("col_indices")
rv["sorted_rows"] = True
rv["format"] = "coor"
elif format == "cooc":
info = self._export(
rv = self._export(
"csc", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["cols"] = indptr_to_indices(info.pop("indptr"))
info["rows"] = info.pop("row_indices")
info["sorted_cols"] = True
info["format"] = "cooc"
return info
rv["cols"] = indptr_to_indices(rv.pop("indptr"))
rv["rows"] = rv.pop("row_indices")
rv["sorted_cols"] = True
rv["format"] = "cooc"
else:
raise ValueError(f"Invalid format: {format}")
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

if method == "export":
mhandle = ffi_new("GrB_Matrix*", parent._carg)
Expand Down Expand Up @@ -1175,6 +1177,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
rv["values"] = values
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"v_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

if format is None:
Expand Down Expand Up @@ -481,6 +481,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
)
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions grblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ def update(self, expr):
return self._update(expr)

def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None):
# TODO: check expected output type (now included in Expression object)
if not isinstance(expr, BaseExpression):
if isinstance(expr, AmbiguousAssignOrExtract):
if expr._is_scalar and self._is_scalar:
Expand Down Expand Up @@ -557,7 +556,7 @@ def __init__(
raise ValueError(f"No default expr_repr for len(args) == {len(args)}")
self.expr_repr = expr_repr
if dtype is None:
self.dtype = lookup_dtype(op.return_type)
self.dtype = op.return_type
else:
self.dtype = dtype
self._value = None
Expand Down
1 change: 1 addition & 0 deletions grblas/binary/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"subtract": "minus",
"true_divide": "truediv",
}
# _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon...
# Not included: maximum, minimum, gcd, hypot, logaddexp, logaddexp2
# lcm, left_shift, nextafter, right_shift

Expand Down
Loading