cuda_test_numba.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. From numba documentation:
  3. https://numba.pydata.org/numba-doc/latest/cuda/examples.html#matrix-multiplication
  4. """
  5. from numba import cuda, float32
  6. import numpy as np
  7. from timeit import default_timer as timer
  8. # Controls threads per block and shared memory usage.
  9. # The computation will be done on blocks of TPBxTPB elements.
  10. TPB = 16
  11. @cuda.jit
  12. def fast_matmul(A, B, C):
  13. # Define an array in the shared memory
  14. # The size and type of the arrays must be known at compile time
  15. sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
  16. sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
  17. x, y = cuda.grid(2)
  18. tx = cuda.threadIdx.x
  19. ty = cuda.threadIdx.y
  20. bpg = cuda.gridDim.x # blocks per grid
  21. if x >= C.shape[0] and y >= C.shape[1]:
  22. # Quit if (x, y) is outside of valid C boundary
  23. return
  24. # Each thread computes one element in the result matrix.
  25. # The dot product is chunked into dot products of TPB-long vectors.
  26. tmp = 0.
  27. for i in range(bpg):
  28. # Preload data into shared memory
  29. sA[tx, ty] = A[x, ty + i * TPB]
  30. sB[tx, ty] = B[tx + i * TPB, y]
  31. # Wait until all threads finish preloading
  32. cuda.syncthreads()
  33. # Computes partial product on the shared memory
  34. for j in range(TPB):
  35. tmp += sA[tx, j] * sB[j, ty]
  36. # Wait until all threads finish computing
  37. cuda.syncthreads()
  38. C[x, y] = tmp
  39. # run it
  40. if __name__ == '__main__':
  41. # Initialize the data arrays
  42. A = np.full((TPB*20, TPB*20), 3, np.float32)
  43. B = np.full((TPB*20, TPB*20), 4, np.float32)
  44. # Configure the blocks
  45. threadsperblock = (TPB, TPB)
  46. blockspergrid_x = int(np.ceil(A.shape[0] / threadsperblock[0]))
  47. blockspergrid_y = int(np.ceil(B.shape[1] / threadsperblock[1]))
  48. blockspergrid = (blockspergrid_x, blockspergrid_y)
  49. # Start the kernel
  50. C = np.zeros_like(A)
  51. start = timer()
  52. fast_matmul[blockspergrid, threadsperblock](A, B, C)
  53. cuda.synchronize()
  54. print("Time taken: %f" % (timer() - start))
  55. # Print the result
  56. print(C)