오늘의 인기 글
최근 글
최근 댓글
Today
Total
05-06 00:00
관리 메뉴

우노

[Spark] Breeze CSCMatrix * CSCMatrix 분석 본문

Data/Spark

[Spark] Breeze CSCMatrix * CSCMatrix 분석

운호(Noah) 2021. 3. 17. 14:45

Breeze canMulM_M

Code

implicit def canMulM_M[@expand.args(Int, Float, Double, Long) T]
    : breeze.linalg.operators.OpMulMatrix.Impl2[CSCMatrix[T], CSCMatrix[T], CSCMatrix[T]] =
    new breeze.linalg.operators.OpMulMatrix.Impl2[CSCMatrix[T], CSCMatrix[T], CSCMatrix[T]] {

      def apply(a: CSCMatrix[T], b: CSCMatrix[T]): CSCMatrix[T] = {

                // 양쪽 행렬의 차원이 안 맞을 경우
                require(a.cols == b.rows, "Dimension Mismatch")

                // 결과 행렬의 row 길이와 동일한 임시 데이터 배열을 생성
                // 해당 임시 Data 배열은, 각 계산 마다 결과 행렬에 업데이트 됨
                val workData = new Array[T](a.rows)

                // workData 배열의 업데이트 여부를 확인하기 위한, 임시 Index 배열 생성
                val workIndex = new Array[Int](a.rows)
                // workIndex 배열요소를 -1로 초기화
                util.Arrays.fill(workIndex, -1)

                // 결과 행렬의 NNZ 계산
                val totalNnz = computeNnz(a, b, workIndex)
                // 결과 행렬의 NNZ 계산이 끝난 뒤, workIndex 배열을 다시 -1 로 초기화
                util.Arrays.fill(workIndex, -1)

                // 결과 CSCMatrix 생성
                val res = CSCMatrix.zeros[T](a.rows, b.cols)
                // 결과 CSCMatrix 의 rowIndices, data 배열을 totalNnz 개수만큼 0 으로 초기화 (colPtrs 은 이미 0으로 초기화 되어있음)
                res.reserve(totalNnz)

                // 결과 행렬의 rowIndices 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
                val resRows = res.rowIndices
                // 결과 행렬의 Data 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
                val resData = res.data

                // 왼쪽 행렬의 rowIndices
                val aRows = a.rowIndices
                // 왼쪽 행렬의 Data
                val aData = a.data
                // 왼쪽 행렬의 colPtrs
                val aPtrs = a.colPtrs

                // col = 우측 행렬의 col index 순서대로 진행
                cforRange(0 until b.cols) { col =>

                    // nnz = 결과 행렬의 NNZ
                    var nnz = res.used

                    // bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
                    cforRange(b.colPtrs(col) until b.colPtrs(col + 1)) { bOff =>

                        // bRow = 우측 행렬의 각 col에 있는 요소의 row index
                        // bVal = 우측 행렬의 각 col에 있는 요소값
                        val bRow = b.rowIndices(bOff)
                        val bVal = b.data(bOff)

                        // a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
                        // 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
                        cforRange(aPtrs(bRow) until aPtrs(bRow + 1)) { aOff =>

                            // aRow = 좌측 행렬의 각 col에 있는 요소의 row index  
                            // aVal = 좌측 행렬의 각 col에 있는 요소값
                            val aRow = aRows(aOff)
                            val aVal = aData(aOff)

                            // 반복문 시, 우측 행렬의 col index 는 고정된 상태에서 좌측 행렬의 col 이 움직이며 요소를 확인하는데,
                            // 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만, 조건문 실행
                            if (workIndex(aRow) < col) {

                                // 임시 Data 배열의 aRow index 값을 0 으로 초기화
                                workData(aRow) = 0
                                // 임시 Index 배열의 aRow index 값을 col index 로 할당
                                workIndex(aRow) = col
                                // 결과 행렬의 rowIndices 배열에서 nnz index 값을 aRow 로 할당 (먼저 계산되는 aRow 순서대로 할당 되며, 정렬은 아래에서 진행)
                                resRows(nnz) = aRow
                                // nnz 증가
                                nnz += 1
                            }

                            // workData(aRow) += (좌측 행렬의 각 col에 있는 요소 값) * (우측 행렬의 각 col에 있는 요소 값)
                            workData(aRow) += aVal * bVal

                        }
                    }

                    // 임시 Data 배열이 다 채워지면, 결과 행렬의 colPtrs 배열 요소를 nnz 로 할당
                    res.colPtrs(col + 1) = nnz
                    // 결과 행렬의 nnz 값을, 현재까지 계산한 nnz 값으로 수정
                    res.used = nnz

                    // 결과 행렬의 rowIndices 배열에서 현재 계산한 rowIndices 부분만 오름차순으로 정렬
                    util.Arrays.sort(resRows, res.colPtrs(col), res.colPtrs(col + 1))

                    // resOff = 결과 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, Result rowIndices 배열의 index 제공
                    cforRange(res.colPtrs(col) until res.colPtrs(col + 1)) { resOff =>

                        // row = 결과 행렬의 각 col에 있는 요소의 row index  
                        val row = resRows(resOff)
                        // 임시 데이터 배열의 row에 위치한 값을, 결과 행렬의 데이터 배열에서 해당하는 위치에 할당
                        resData(resOff) = workData(row)
                    }

                    // 조건을 만족하지 못하면 에러 발생
                    assert(nnz <= totalNnz)
                }

                // 결과 CSCMatrix 반환
                res.compact()
                res
            }

            // 결과 Matrix 의 NNZ 를 미리 계산
            private def computeNnz(a: CSCMatrix[T], b: CSCMatrix[T], workIndex: Array[Int]) = {

                // nnz 를 0 으로 초기화
                var nnz = 0

                // col = 우측 행렬의 col index 순서대로 진행
                cforRange(0 until b.cols) { col =>

                    // bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
                    cforRange(b.colPtrs(col) until b.colPtrs(col + 1)) { bOff =>

                        // bRow = 우측 행렬의 각 col에 있는 요소의 row index
                        val bRow = b.rowIndices(bOff)

                        // a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
                        // 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
                        cforRange(a.colPtrs(bRow) until a.colPtrs(bRow + 1)) { aOff =>

                            // aRow = 좌측 행렬의 각 col에 있는 요소의 row index  
                            val aRow = a.rowIndices(aOff)

                            // 단순히 결과 행렬의 nnz 를 확인하기 위한 트래킹 용도로 사용된다.
                            // 반복문 시, 우측 행렬의 col index 는 고정된 상태에서, 좌측 행렬의 col 이 움직이며 nnz 가 계산되는데,
                            // 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만 nnz 증가
                            if (workIndex(aRow) < col) {
                                workIndex(aRow) = col
                                nnz += 1
                            }
                        }
                    }
                }
                // nnz 반환
                nnz
            }

      implicitly[BinaryRegistry[Matrix[T], Matrix[T], OpMulMatrix.type, Matrix[T]]].register(this)

    }
Comments