우노
[Spark] Spark SparseMatrix * DenseMatrix 분석 본문
Spark gemm
- Spark의 gemm 함수는 SparseMatrix * DenseMatrix를 지원합니다.
- 즉, Sparse Matrix와 Dense Matirx 간 곱셈을 지원하는 함수입니다.
- github
알고리즘 순서
- 우측행렬 컬럼 기준으로 좌측행렬 컬럼이 이중포문을 이루며 곱해집니다.
- 우측행렬 컬럼은 Dense하게 모든 요소를 곱셈에 사용하지만, 좌측행렬 컬럼은 존재하는 요소만 곱셈에 사용합니다.
- 곱셈 결과값은 아래와 같은 형태로 결과 데이터 배열에 삽입됩니다.
- 결과데이터배열[결과행렬 데이터배열의 시작 index + 좌측행렬 요소값의 행번호] += (좌측행렬 요소값 * 우측 행렬 요소값)
Code
/**
* C := alpha * A * B + beta * C
* For `SparseMatrix` A.
*/
private def gemm(
alpha: Double,
A: SparseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
val mA: Int = A.numRows
val nB: Int = B.numCols
val kA: Int = A.numCols
val kB: Int = B.numRows
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
require(nB == C.numCols,
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
// 좌측 행렬의 데이터 배열
val Avals = A.values
// 우측 행렬의 데이터 배열
val Bvals = B.values
// 결과 행렬의 데이터 배열 (0으로 초기화 되어있음)
val Cvals = C.values
// 좌측 행렬의 rowIndices 배열
val ArowIndices = A.rowIndices
// 좌측 행렬의 colPtrs 배열
val AcolPtrs = A.colPtrs
// 좌측 행렬이 전치된 경우
if (A.isTransposed) {
var colCounterForB = 0
// 우측 행렬이 전치되지 않은 경우
if (!B.isTransposed) {
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
val Bstart = colCounterForB * kA
while (rowCounterForA < mA) {
var i = AcolPtrs(rowCounterForA)
val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
}
}
// 우측 행렬이 전치된 경우
else {
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
while (rowCounterForA < mA) {
var i = AcolPtrs(rowCounterForA)
val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B(ArowIndices(i), colCounterForB)
i += 1
}
val Cindex = Cstart + rowCounterForA
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
}
}
}
// 좌측 행렬이 전치되지 않은 경우
else {
// 'beta'가 1이 아니라면, 행렬 크기 조정
if (beta != 1.0) {
getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1)
}
// 우측 행렬의 col index
var colCounterForB = 0
// 우측 행렬이 전치되지 않은 경우
if (!B.isTransposed) {
// 우측 행렬의 col 순서대로 진행
while (colCounterForB < nB) {
// 좌측 행렬의 col index
var colCounterForA = 0
// 우측 행렬에서 어떤 col을 다루고 있느냐에 따라 달라지는, 우측 행렬 데이터 배열의 시작 index
val Bstart = colCounterForB * kB
// 우측 행렬에서 어떤 col을 다루고 있느냐에 따라 달라지는, 결과 데이터 배열의 시작 index
val Cstart = colCounterForB * mA
// 좌측 행렬의 col 순서대로 진행
while (colCounterForA < kA) {
// i, indEnd = 좌측 행렬 각 col에 있는 요소의 행과 값을 알 수 있도록, AcolPtrs을 통해 index 제공
var i = AcolPtrs(colCounterForA)
val indEnd = AcolPtrs(colCounterForA + 1)
// Bval = 우측 행렬의 데이터 배열에서 현재 곱셈에 사용될 값
// Bstart는 고정 된 상태로 colCounterForA가 증가하므로, 우측 행렬의 각 col에 있는 모든 요소에 접근하게 된다.
val Bval = Bvals(Bstart + colCounterForA) * alpha
// i, indEnd = 좌측 행렬 각 col에 있는 요소의 행과 값을 알 수 있도록, AcolPtrs을 통해 index 제공
// 따라서, 좌측 행렬에선 존재하는 요소에만 접근하게 된다.
while (i < indEnd) {
// 결과데이터배열(결과데이터배열의 시작 index + 좌측행렬 요소값의 행번호) += (좌측행렬 요소값 * 우측 행렬 요소값)
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
// 좌측 행렬의 col index 증가
colCounterForA += 1
}
// 우측 행렬의 col index 증가
colCounterForB += 1
}
}
// 우측 행렬이 전치된 경우
else {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
var i = AcolPtrs(colCounterForA)
val indEnd = AcolPtrs(colCounterForA + 1)
val Bval = B(colCounterForA, colCounterForB) * alpha
while (i < indEnd) {
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
}
colCounterForB += 1
}
}
}
}
'Data > Spark' 카테고리의 다른 글
[Spark] spark-shell 에서 HttpClient를 사용해 post request 보내기 (0) | 2021.05.28 |
---|---|
[Spark] Spark 실행 시간 측정 (0) | 2021.03.29 |
[Spark] Breeze CSCMatrix * DenseMatrix 분석 (0) | 2021.03.18 |
[Spark] Breeze CSCMatrix * CSCMatrix 분석 (0) | 2021.03.17 |
[Spark] Spark SparseMatrix를 Breeze CSCMatrix로 변환하는 방법 (0) | 2021.03.04 |
Comments