오늘의 인기 글
최근 글
최근 댓글
Today
Total
12-31 00:05
관리 메뉴

우노

[Spark] Spark SparseMatrix * DenseMatrix 분석 본문

Data/Spark

[Spark] Spark SparseMatrix * DenseMatrix 분석

운호(Noah) 2021. 3. 18. 14:03

Spark gemm

알고리즘 순서

  • 우측행렬 컬럼 기준으로 좌측행렬 컬럼이 이중포문을 이루며 곱해집니다.
  • 우측행렬 컬럼은 Dense하게 모든 요소를 곱셈에 사용하지만, 좌측행렬 컬럼은 존재하는 요소만 곱셈에 사용합니다.
  • 곱셈 결과값은 아래와 같은 형태로 결과 데이터 배열에 삽입됩니다.
    • 결과데이터배열[결과행렬 데이터배열의 시작 index + 좌측행렬 요소값의 행번호] += (좌측행렬 요소값 * 우측 행렬 요소값)

Code

/**
 * C := alpha * A * B + beta * C
 * For `SparseMatrix` A.
 */
private def gemm(
    alpha: Double,
    A: SparseMatrix,
    B: DenseMatrix,
    beta: Double,
    C: DenseMatrix): Unit = {
  val mA: Int = A.numRows
  val nB: Int = B.numCols
  val kA: Int = A.numCols
  val kB: Int = B.numRows

  require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
  require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
  require(nB == C.numCols,
    s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

    // 좌측 행렬의 데이터 배열
  val Avals = A.values
  // 우측 행렬의 데이터 배열
    val Bvals = B.values
  // 결과 행렬의 데이터 배열 (0으로 초기화 되어있음)
    val Cvals = C.values
  // 좌측 행렬의 rowIndices 배열
    val ArowIndices = A.rowIndices
  // 좌측 행렬의 colPtrs 배열
    val AcolPtrs = A.colPtrs

    // 좌측 행렬이 전치된 경우
  if (A.isTransposed) {
    var colCounterForB = 0
        // 우측 행렬이 전치되지 않은 경우
    if (!B.isTransposed) { 
      while (colCounterForB < nB) {
        var rowCounterForA = 0
        val Cstart = colCounterForB * mA
        val Bstart = colCounterForB * kA
        while (rowCounterForA < mA) {
          var i = AcolPtrs(rowCounterForA)
          val indEnd = AcolPtrs(rowCounterForA + 1)
          var sum = 0.0
          while (i < indEnd) {
            sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
            i += 1
          }
          val Cindex = Cstart + rowCounterForA
          Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
          rowCounterForA += 1
        }
        colCounterForB += 1
      }
    } 
            // 우측 행렬이 전치된 경우
            else {
      while (colCounterForB < nB) {
        var rowCounterForA = 0
        val Cstart = colCounterForB * mA
        while (rowCounterForA < mA) {
          var i = AcolPtrs(rowCounterForA)
          val indEnd = AcolPtrs(rowCounterForA + 1)
          var sum = 0.0
          while (i < indEnd) {
            sum += Avals(i) * B(ArowIndices(i), colCounterForB)
            i += 1
          }
          val Cindex = Cstart + rowCounterForA
          Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
          rowCounterForA += 1
        }
        colCounterForB += 1
      }
    }
  } 
        // 좌측 행렬이 전치되지 않은 경우
        else {
          // 'beta'가 1이 아니라면, 행렬 크기 조정
          if (beta != 1.0) {
            getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1)
          }

          // 우측 행렬의 col index
          var colCounterForB = 0 

          // 우측 행렬이 전치되지 않은 경우
          if (!B.isTransposed) { 

              // 우측 행렬의 col 순서대로 진행
              while (colCounterForB < nB) {

                  // 좌측 행렬의 col index
                  var colCounterForA = 0 

                  // 우측 행렬에서 어떤 col을 다루고 있느냐에 따라 달라지는, 우측 행렬 데이터 배열의 시작 index
                  val Bstart = colCounterForB * kB
                  // 우측 행렬에서 어떤 col을 다루고 있느냐에 따라 달라지는, 결과 데이터 배열의 시작 index
                  val Cstart = colCounterForB * mA

                  // 좌측 행렬의 col 순서대로 진행
                  while (colCounterForA < kA) {

                      // i, indEnd = 좌측 행렬 각 col에 있는 요소의 행과 값을 알 수 있도록, AcolPtrs을 통해 index 제공
                      var i = AcolPtrs(colCounterForA)
                      val indEnd = AcolPtrs(colCounterForA + 1)

                      // Bval = 우측 행렬의 데이터 배열에서 현재 곱셈에 사용될 값
                      // Bstart는 고정 된 상태로 colCounterForA가 증가하므로, 우측 행렬의 각 col에 있는 모든 요소에 접근하게 된다.
                      val Bval = Bvals(Bstart + colCounterForA) * alpha

                      // i, indEnd = 좌측 행렬 각 col에 있는 요소의 행과 값을 알 수 있도록, AcolPtrs을 통해 index 제공
                      // 따라서, 좌측 행렬에선 존재하는 요소에만 접근하게 된다.
                      while (i < indEnd) {

                          // 결과데이터배열(결과데이터배열의 시작 index + 좌측행렬 요소값의 행번호) += (좌측행렬 요소값 * 우측 행렬 요소값)
                          Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
                          i += 1
                      }
                      // 좌측 행렬의 col index 증가
                      colCounterForA += 1
                  }
                  // 우측 행렬의 col index 증가
                  colCounterForB += 1
                }
            } 
            // 우측 행렬이 전치된 경우
            else {
      while (colCounterForB < nB) {
        var colCounterForA = 0 // The column of A to multiply with the row of B
        val Cstart = colCounterForB * mA
        while (colCounterForA < kA) {
          var i = AcolPtrs(colCounterForA)
          val indEnd = AcolPtrs(colCounterForA + 1)
          val Bval = B(colCounterForA, colCounterForB) * alpha
          while (i < indEnd) {
            Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
            i += 1
          }
          colCounterForA += 1
        }
        colCounterForB += 1
      }
    }
  }
}
Comments