우노
[Spark] Breeze CSCMatrix * CSCMatrix 분석 본문
Breeze canMulM_M
- Breeze의 canMulM_M 함수는 CSCMatrix * CSCMatrix를 지원합니다.
- 즉, Sparse Matrix 간 곱셈을 지원하는 함수입니다.
- github
Code
implicit def canMulM_M[@expand.args(Int, Float, Double, Long) T]
: breeze.linalg.operators.OpMulMatrix.Impl2[CSCMatrix[T], CSCMatrix[T], CSCMatrix[T]] =
new breeze.linalg.operators.OpMulMatrix.Impl2[CSCMatrix[T], CSCMatrix[T], CSCMatrix[T]] {
def apply(a: CSCMatrix[T], b: CSCMatrix[T]): CSCMatrix[T] = {
// 양쪽 행렬의 차원이 안 맞을 경우
require(a.cols == b.rows, "Dimension Mismatch")
// 결과 행렬의 row 길이와 동일한 임시 데이터 배열을 생성
// 해당 임시 Data 배열은, 각 계산 마다 결과 행렬에 업데이트 됨
val workData = new Array[T](a.rows)
// workData 배열의 업데이트 여부를 확인하기 위한, 임시 Index 배열 생성
val workIndex = new Array[Int](a.rows)
// workIndex 배열요소를 -1로 초기화
util.Arrays.fill(workIndex, -1)
// 결과 행렬의 NNZ 계산
val totalNnz = computeNnz(a, b, workIndex)
// 결과 행렬의 NNZ 계산이 끝난 뒤, workIndex 배열을 다시 -1 로 초기화
util.Arrays.fill(workIndex, -1)
// 결과 CSCMatrix 생성
val res = CSCMatrix.zeros[T](a.rows, b.cols)
// 결과 CSCMatrix 의 rowIndices, data 배열을 totalNnz 개수만큼 0 으로 초기화 (colPtrs 은 이미 0으로 초기화 되어있음)
res.reserve(totalNnz)
// 결과 행렬의 rowIndices 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
val resRows = res.rowIndices
// 결과 행렬의 Data 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
val resData = res.data
// 왼쪽 행렬의 rowIndices
val aRows = a.rowIndices
// 왼쪽 행렬의 Data
val aData = a.data
// 왼쪽 행렬의 colPtrs
val aPtrs = a.colPtrs
// col = 우측 행렬의 col index 순서대로 진행
cforRange(0 until b.cols) { col =>
// nnz = 결과 행렬의 NNZ
var nnz = res.used
// bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
cforRange(b.colPtrs(col) until b.colPtrs(col + 1)) { bOff =>
// bRow = 우측 행렬의 각 col에 있는 요소의 row index
// bVal = 우측 행렬의 각 col에 있는 요소값
val bRow = b.rowIndices(bOff)
val bVal = b.data(bOff)
// a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
// 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
cforRange(aPtrs(bRow) until aPtrs(bRow + 1)) { aOff =>
// aRow = 좌측 행렬의 각 col에 있는 요소의 row index
// aVal = 좌측 행렬의 각 col에 있는 요소값
val aRow = aRows(aOff)
val aVal = aData(aOff)
// 반복문 시, 우측 행렬의 col index 는 고정된 상태에서 좌측 행렬의 col 이 움직이며 요소를 확인하는데,
// 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만, 조건문 실행
if (workIndex(aRow) < col) {
// 임시 Data 배열의 aRow index 값을 0 으로 초기화
workData(aRow) = 0
// 임시 Index 배열의 aRow index 값을 col index 로 할당
workIndex(aRow) = col
// 결과 행렬의 rowIndices 배열에서 nnz index 값을 aRow 로 할당 (먼저 계산되는 aRow 순서대로 할당 되며, 정렬은 아래에서 진행)
resRows(nnz) = aRow
// nnz 증가
nnz += 1
}
// workData(aRow) += (좌측 행렬의 각 col에 있는 요소 값) * (우측 행렬의 각 col에 있는 요소 값)
workData(aRow) += aVal * bVal
}
}
// 임시 Data 배열이 다 채워지면, 결과 행렬의 colPtrs 배열 요소를 nnz 로 할당
res.colPtrs(col + 1) = nnz
// 결과 행렬의 nnz 값을, 현재까지 계산한 nnz 값으로 수정
res.used = nnz
// 결과 행렬의 rowIndices 배열에서 현재 계산한 rowIndices 부분만 오름차순으로 정렬
util.Arrays.sort(resRows, res.colPtrs(col), res.colPtrs(col + 1))
// resOff = 결과 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, Result rowIndices 배열의 index 제공
cforRange(res.colPtrs(col) until res.colPtrs(col + 1)) { resOff =>
// row = 결과 행렬의 각 col에 있는 요소의 row index
val row = resRows(resOff)
// 임시 데이터 배열의 row에 위치한 값을, 결과 행렬의 데이터 배열에서 해당하는 위치에 할당
resData(resOff) = workData(row)
}
// 조건을 만족하지 못하면 에러 발생
assert(nnz <= totalNnz)
}
// 결과 CSCMatrix 반환
res.compact()
res
}
// 결과 Matrix 의 NNZ 를 미리 계산
private def computeNnz(a: CSCMatrix[T], b: CSCMatrix[T], workIndex: Array[Int]) = {
// nnz 를 0 으로 초기화
var nnz = 0
// col = 우측 행렬의 col index 순서대로 진행
cforRange(0 until b.cols) { col =>
// bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
cforRange(b.colPtrs(col) until b.colPtrs(col + 1)) { bOff =>
// bRow = 우측 행렬의 각 col에 있는 요소의 row index
val bRow = b.rowIndices(bOff)
// a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
// 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
cforRange(a.colPtrs(bRow) until a.colPtrs(bRow + 1)) { aOff =>
// aRow = 좌측 행렬의 각 col에 있는 요소의 row index
val aRow = a.rowIndices(aOff)
// 단순히 결과 행렬의 nnz 를 확인하기 위한 트래킹 용도로 사용된다.
// 반복문 시, 우측 행렬의 col index 는 고정된 상태에서, 좌측 행렬의 col 이 움직이며 nnz 가 계산되는데,
// 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만 nnz 증가
if (workIndex(aRow) < col) {
workIndex(aRow) = col
nnz += 1
}
}
}
}
// nnz 반환
nnz
}
implicitly[BinaryRegistry[Matrix[T], Matrix[T], OpMulMatrix.type, Matrix[T]]].register(this)
}
'Data > Spark' 카테고리의 다른 글
[Spark] Spark SparseMatrix * DenseMatrix 분석 (0) | 2021.03.18 |
---|---|
[Spark] Breeze CSCMatrix * DenseMatrix 분석 (0) | 2021.03.18 |
[Spark] Spark SparseMatrix를 Breeze CSCMatrix로 변환하는 방법 (0) | 2021.03.04 |
[Spark] Spark 로그 확인 라이브러리 (0) | 2021.02.25 |
[Spark] Matrix를 일정 단위로 slice 한 뒤 각각의 nnz, density 구하기 (0) | 2020.11.30 |
Comments