Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions sparsity/sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,27 @@ def copy(self, *args, deep=True, **kwargs):

def multiply(self, other, axis='columns'):
"""
To multiply row-wise 'other' should be of shape: (self.shape[0], 1)
To multiply col-wise 'other should be of shape: (1, self.shape[1])
Multiply SparseFrame row-wise or column-wise.

Parameters
----------
other: array-like
Vector of numbers to multiply columns/rows by.
axis: int | str
- 1 or 'columns' to multiply column-wise (default)
- 0 or 'index' to multiply row-wise
"""
try:
other = other.toarray()
except AttributeError:
pass

if axis in [0, 'index']:
other = np.asarray(other).reshape(-1, 1)
elif axis in [1, 'columns']:
other = np.asarray(other).reshape(1, -1)
else:
other = np.asarray(other).reshape(-1, 1)
raise ValueError("Axis should be one of 0, 1, 'index', 'columns'.")

data = self.data.multiply(other)
assert data.shape == self.data.shape, \
Expand Down
44 changes: 34 additions & 10 deletions sparsity/test/test_sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,20 +778,28 @@ def test_multiply_rowwise():
other = np.arange(5)
msg = "Row wise multiplication failed"

# nd.array
other = other.reshape(1, -1)
# list
res = sf.multiply(list(other), axis=0)
assert np.all(res.sum(axis=1).T == 5 * other), msg

# 1D array
res = sf.multiply(other, axis=0)
assert np.all(res.sum(axis=0) == 5 * other), msg
assert np.all(res.sum(axis=1).T == 5 * other), msg

# 2D array
_other = other.reshape(-1, 1)
res = sf.multiply(_other, axis=0)
assert np.all(res.sum(axis=1).T == 5 * other), msg

# SparseFrame
_other = SparseFrame(other)
res = sf.multiply(_other, axis=0)
assert np.all(res.sum(axis=0) == 5 * other), msg
assert np.all(res.sum(axis=1).T == 5 * other), msg

# csr_matrix
_other = _other.data
res = sf.multiply(_other, axis=0)
assert np.all(res.sum(axis=0) == 5 * other), msg
assert np.all(res.sum(axis=1).T == 5 * other), msg


def test_multiply_colwise():
Expand All @@ -800,21 +808,37 @@ def test_multiply_colwise():
other = np.arange(5)
msg = "Column wise multiplication failed"

# nd.array
other = other.reshape(-1, 1)
# list
res = sf.multiply(list(other), axis=1)
assert np.all(res.sum(axis=0) == 5 * other), msg

# 1D array
res = sf.multiply(other, axis=1)
assert np.all(res.sum(axis=1) == 5 * other), msg
assert np.all(res.sum(axis=0) == 5 * other), msg

# 2D array
_other = other.reshape(1, -1)
res = sf.multiply(_other, axis=1)
assert np.all(res.sum(axis=0) == 5 * other), msg

# SparseFrame
_other = SparseFrame(other)
res = sf.multiply(_other, axis=1)
assert np.all(res.sum(axis=1) == 5 * other), msg
assert np.all(res.sum(axis=0) == 5 * other), msg

# csr_matrix
_other = _other.data
_other.toarray()
res = sf.multiply(_other, axis=1)
assert np.all(res.sum(axis=1) == 5 * other), msg
assert np.all(res.sum(axis=0) == 5 * other), msg


def test_multiply_wrong_axis():
sf = SparseFrame(np.ones((5, 5)))
other = np.arange(5)

with pytest.raises(ValueError):
sf.multiply(other, axis=2)


def test_drop_single_label():
Expand Down