Skip to content

Commit 9b2a5aa

Browse files
committed
rename matrix args in BreezeUtil to upper
1 parent 9ab32c2 commit 9b2a5aa

1 file changed

Lines changed: 17 additions & 16 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,39 @@ import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
2626
private[ann] object BreezeUtil {
2727

2828
// TODO: switch to MLlib BLAS interface
29-
private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N"
29+
private def transposeString(A: BDM[Double]): String = if (A.isTranspose) "T" else "N"
3030

3131
/**
3232
* DGEMM: C := alpha * A * B + beta * C
3333
* @param alpha alpha
34-
* @param a A
35-
* @param b B
34+
* @param A A
35+
* @param B B
3636
* @param beta beta
37-
* @param c C
37+
* @param C C
3838
*/
39-
def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = {
39+
def dgemm(alpha: Double, A: BDM[Double], B: BDM[Double], beta: Double, C: BDM[Double]): Unit = {
4040
// TODO: add code if matrices isTranspose!!!
41-
require(a.cols == b.rows, "A & B Dimension mismatch!")
42-
require(a.rows == c.rows, "A & C Dimension mismatch!")
43-
require(b.cols == c.cols, "A & C Dimension mismatch!")
44-
NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols,
45-
alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride,
46-
beta, c.data, c.offset, c.rows)
41+
require(A.cols == B.rows, "A & B Dimension mismatch!")
42+
require(A.rows == C.rows, "A & C Dimension mismatch!")
43+
require(B.cols == C.cols, "A & C Dimension mismatch!")
44+
NativeBLAS.dgemm(transposeString(A), transposeString(B), C.rows, C.cols, A.cols,
45+
alpha, A.data, A.offset, A.majorStride, B.data, B.offset, B.majorStride,
46+
beta, C.data, C.offset, C.rows)
4747
}
4848

4949
/**
5050
* DGEMV: y := alpha * A * x + beta * y
5151
* @param alpha alpha
52-
* @param a A
52+
* @param A A
5353
* @param x x
5454
* @param beta beta
5555
* @param y y
5656
*/
57-
def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = {
58-
require(a.cols == x.length, "A & x Dimension mismatch!")
59-
NativeBLAS.dgemv(transposeString(a), a.rows, a.cols,
60-
alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride,
57+
def dgemv(alpha: Double, A: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = {
58+
require(A.cols == x.length, "A & x Dimension mismatch!")
59+
require(A.rows == y.length, "A & y Dimension mismatch!")
60+
NativeBLAS.dgemv(transposeString(A), A.rows, A.cols,
61+
alpha, A.data, A.offset, A.majorStride, x.data, x.offset, x.stride,
6162
beta, y.data, y.offset, y.stride)
6263
}
6364
}

0 commit comments

Comments
 (0)