오늘의 인기 글
최근 글
최근 댓글
Today
Total
01-22 00:01
관리 메뉴

우노

[Spark] Breeze CSCMatrix Multiply 함수 구현 (메모리 정적 할당) 본문

Data/Spark

[Spark] Breeze CSCMatrix Multiply 함수 구현 (메모리 정적 할당)

운호(Noah) 2022. 3. 17. 16:18

들어가기 앞서,

  • Breeze 가 제공하는 CSCMatrix Multiply 는, 곱셈전에 결과 Matrix 의 NNZ 를 먼저 구합니다.
  • 결과 Matrix 의 NNZ 를 구한 이후에는, NNZ 에 따라 결과 Matrix 를 표현하기 위한 배열 공간을 정적 할당합니다.

함수 선언

  • 결과 Matrix NNZ 계산 함수

      // 결과 Matrix 의 NNZ 를 미리 계산
      def computeNnz(a: CSCMatrix[Double], b: CSCMatrix[Double], workIndex: Array[Double]) = {
    
          // nnz 를 0 으로 초기화
          var nnz = 0
    
              // col = 우측 행렬의 col index 순서대로 진행
              for (col <- 0 until b.cols){
    
                  // bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
                  for (bOff <- b.colPtrs(col) until b.colPtrs(col + 1)){
    
                      // bRow = 우측 행렬의 각 col에 있는 요소의 row index
                      val bRow = b.rowIndices(bOff)
    
                      // a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
                      // 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
                      for (aOff <- a.colPtrs(bRow) until a.colPtrs(bRow + 1)){
    
                          // aRow = 좌측 행렬의 각 col에 있는 요소의 row index  
                            val aRow = a.rowIndices(aOff)
    
                          // 단순히 결과 행렬의 nnz 를 확인하기 위한 트래킹 용도로 사용된다.
                          // 반복문 시, 우측 행렬의 col index 는 고정된 상태에서, 좌측 행렬의 col 이 움직이며 nnz 가 계산되는데,
                          // 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만 nnz 증가
                          if (workIndex(aRow) < col) {
                              workIndex(aRow) = col
                              nnz += 1
                          }
                      }
                  }
              }
          // nnz 반환
          nnz
      }
  • Multiply 함수

      import java.util.Arrays
    
      def static_multiply(a : breeze.linalg.CSCMatrix[Double], b : breeze.linalg.CSCMatrix[Double]): breeze.linalg.CSCMatrix[Double] = {
    
          // 결과 행렬의 row 길이와 동일한 임시 데이터 배열을 생성
          // 해당 임시 Data 배열은, 각 계산 마다 결과 행렬에 업데이트 됨
          val workData = new Array[Double](a.rows)
    
          // workData 배열의 업데이트 여부를 확인하기 위한, 임시 Index 배열 생성
          val workIndex = new Array[Double](a.rows)
    
          // workIndex 배열요소를 -1로 초기화
          Arrays.fill(workIndex, -1)
    
          // 결과 행렬의 NNZ 계산
          val totalNnz = computeNnz(a, b, workIndex)
    
          // 결과 행렬의 NNZ 계산이 끝난 뒤, workIndex 배열을 다시 -1 로 초기화
          Arrays.fill(workIndex, -1)
    
          // 결과 CSCMatrix 생성
          var res = CSCMatrix.zeros[Double](a.rows, b.cols)
          // 결과 CSCMatrix 의 rowIndices, data 배열을 totalNnz 개수만큼 0 으로 초기화 (colPtrs 은 이미 0으로 초기화 되어있음)
          res.reserve(totalNnz)
    
          // 결과 행렬의 rowIndices 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
          val resRows = res.rowIndices
          // 결과 행렬의 Data 배열 (totalNnz 개수만큼 0 으로 초기화 되어있음)
          val resData = res.data
    
          // 왼쪽 행렬의 rowIndices
          val aRows = a.rowIndices
          // 왼쪽 행렬의 Data
          val aData = a.data
          // 왼쪽 행렬의 colPtrs
          val aPtrs = a.colPtrs
    
          var nnz = 0
    
          // col = 우측 행렬의 col index 순서대로 진행
          for ( col <- 0 until b.cols) {
    
              // bOff = 우측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, B rowIndices 배열의 index 제공
              for (bOff <- b.colPtrs(col) until b.colPtrs(col + 1)) {
    
                  // bRow = 우측 행렬의 각 col에 있는 요소의 row index
                  // bVal = 우측 행렬의 각 col에 있는 요소값
                  val bRow = b.rowIndices(bOff)
                  val bVal = b.data(bOff)
    
                  // a0ff = 우측 행렬의 각 col 에 있는 요소의 bRow 와 동일한 index 를 가지는 좌측 행렬의 col 을 확인하며,
                  // 좌측 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, A rowIndices 배열의 index 제공
                  for (aOff <- aPtrs(bRow) until aPtrs(bRow + 1)) {
    
                      // aRow = 좌측 행렬의 각 col에 있는 요소의 row index  
                      // aVal = 좌측 행렬의 각 col에 있는 요소값
                      val aRow = aRows(aOff)
                      val aVal = aData(aOff)
    
                      // 반복문 시, 우측 행렬의 col index 는 고정된 상태에서 좌측 행렬의 col 이 움직이며 요소를 확인하는데,
                      // 동일한 우측 행렬 col index 기준으로, 좌측 행렬 col 의 내부 요소가 새로운 row 에서 참조될때만, 조건문 실행
                      if (workIndex(aRow) < col) {
    
                          // 임시 Data 배열의 aRow index 값을 0 으로 초기화
                          workData(aRow) = 0
                          // 임시 Index 배열의 aRow index 값을 col index 로 할당
                          workIndex(aRow) = col
                          // 결과 행렬의 rowIndices 배열에서 nnz index 값을 aRow 로 할당 (먼저 계산되는 aRow 순서대로 할당 되며, 정렬은 아래에서 진행)
                          resRows(nnz) = aRow
                          // nnz 증가
                          nnz += 1
                      }    
                      // workData(aRow) += (좌측 행렬의 각 col에 있는 요소 값) * (우측 행렬의 각 col에 있는 요소 값)
                      workData(aRow) += aVal * bVal
                  }
              }
    
              // 임시 Data 배열이 다 채워지면, 결과 행렬의 colPtrs 배열 요소를 nnz 로 할당
              res.colPtrs(col + 1) = nnz
    
              // 결과 행렬의 rowIndices 배열에서 현재 계산한 rowIndices 부분만 오름차순으로 정렬
              Arrays.sort(resRows, res.colPtrs(col), res.colPtrs(col + 1))
    
                // resOff = 결과 행렬의 각 col 에 있는 요소의 row index 를 알 수 있도록, Result rowIndices 배열의 index 제공
                for (resOff <- res.colPtrs(col) until res.colPtrs(col + 1)) {
    
                    // row = 결과 행렬의 각 col에 있는 요소의 row index  
                    val row = resRows(resOff)
                    // 임시 데이터 배열의 row에 위치한 값을, 결과 행렬의 데이터 배열에서 해당하는 위치에 할당
                    resData(resOff) = workData(row)
                }
    
                // 조건을 만족하지 못하면 에러 발생
                assert(nnz <= totalNnz)
        }
    
          res = new CSCMatrix(res.data, res.rows, res.cols, res.colPtrs, res.rowIndices)
    
          return res
      }

함수 실행

// 라이브러리 호출
import org.apache.spark.ml.linalg.SparseMatrix
import breeze.linalg.CSCMatrix
import java.util.Random

//랜덤 설정
val rand = new Random()

// CSCMatrix 생성
val LR = 10000
val LC = 30000
val RC = 10000
val LD = 0.001
val RD = 0.001
val l_sm = SparseMatrix.sprand(LR, LC, LD, rand)
val r_sm = SparseMatrix.sprand(LC, RC, RD, rand)
val l_csc = new CSCMatrix(l_sm.values, l_sm.numRows, l_sm.numCols, l_sm.colPtrs, l_sm.rowIndices)
val r_csc = new CSCMatrix(r_sm.values, r_sm.numRows, r_sm.numCols, r_sm.colPtrs, r_sm.rowIndices)

// 함수 실행
static_multiply(l_csc, r_csc)
Comments