Skip to content

Commit

Permalink
Merging fp8_gemm_tuner.py to gemm_tuner.py (#66)
Browse files Browse the repository at this point in the history
* adding input type

* merge gradlib_fp8 to gradlib

* using fp8

* fix lint

* fix lint
  • Loading branch information
charlifu authored Jun 25, 2024
1 parent 3200953 commit 367aa5a
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 329 deletions.
2 changes: 1 addition & 1 deletion ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ To obtain all the shapes of gemms during the execution of the model, set the env
Next, run gradlib to obtain the best solutions of these shapes:

```
python3 gradlib/gradlib/fp8_gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv
python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv --indtype fp8 --outdtype f16
```
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.

Expand Down
18 changes: 12 additions & 6 deletions gradlib/csrc/hipbsolgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,18 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC)));
if (scaleA != nullptr) {
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA)));
}
if (scaleB != nullptr) {
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB)));
}
if (scaleC != nullptr) {
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC)));
}
// nvtxRangePop();
// if heuristic does not exist in the map, do search and push into the map
// auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } };
Expand Down
70 changes: 41 additions & 29 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@

rtol = 1e-5
atol = 1
dtype = torch.float16


class Gemm:

def __init__(self, m, n, k, dtype, rocblas_decode=False):
def __init__(self, m, n, k, indtype, outdtype, rocblas_decode=False):
self.m = m
self.k = k
self.n = n
self.dtype = dtype
self.indtype = indtype
self.outdtype = outdtype
self.use_rocblas = (indtype == outdtype
and indtype is not torch.float8_e4m3fnuz)
self.nb = 37
self.inp = torch.randn((self.n, self.k),
dtype=self.dtype,
device='cuda')
device='cuda').to(self.indtype)
self.weights = torch.randn((self.m, self.k),
dtype=self.dtype,
device='cuda')
device='cuda').to(self.indtype)
# weights2 is used in measurement/warm iters to ensure
# HBM fetch for weight tensors
self.weights2 = torch.randn((self.nb, self.m, self.k),
dtype=self.dtype,
device='cuda')
device='cuda').to(self.indtype)
self.blob = torch.ones(128 * 1024 * 1024,
dtype=torch.float32,
device='cuda')
self.topn = 20 #number of top solutions from each source
self.hipb_sols = []
self.rocb_sols = []
self.rtol = 1e-5
self.atol = 1
self.start = torch.cuda.Event(enable_timing=True)
Expand All @@ -49,7 +49,8 @@ def __init__(self, m, n, k, dtype, rocblas_decode=False):
self.rocblas_decode = rocblas_decode

def find_hipblas_sols(self):
sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t())
sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(),
self.outdtype)
print('M N K',
self.m,
self.n,
Expand All @@ -61,34 +62,37 @@ def find_hipblas_sols(self):
self.hipb_sols = sols

def check_gemm_ref(self, libtype, solidx):
ref = F.linear(self.inp, self.weights)
ref = F.linear(self.inp.to(torch.float32),
self.weights.to(torch.float32)).to(self.outdtype)
if libtype == 'hipblaslt':
c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx)
c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx,
self.outdtype)
elif libtype == 'rocblas':
c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx)
if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol):
#print('>>>',libtype,'Solidx',solidx,'passed reference test')
return True
else:
print('>>>',
libtype,
'Solidx',
solidx,
'FAILED reference test',
flush=True)
print(ref, flush=True)
print(c, flush=True)
return False

print('>>>',
libtype,
'Solidx',
solidx,
'FAILED reference test',
flush=True)
print(ref, flush=True)
print(c, flush=True)
return False

def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10):
#print('>>>hipbtime',solidx)
for i in range(cold_iters):
hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx)
hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx,
self.outdtype)
self.start.record()
for i in range(warm_iters):
hipbsolidxgemm.hipb_mm(
self.inp, self.weights2[random.randint(0, self.nb - 1)].t(),
solidx)
solidx, self.outdtype)
self.end.record()
torch.cuda.synchronize()
gtime = self.start.elapsed_time(self.end) / warm_iters
Expand Down Expand Up @@ -178,7 +182,8 @@ def functional_check_topn_fastest(self):
self.hipb_top_sols = hipb_topn

def find_fastest_solution(self):
self.find_rocblas_sols()
if self.use_rocblas:
self.find_rocblas_sols()
if not (self.rocblas_decode and self.n == 1):
self.find_hipblas_sols()
self.warmup()
Expand Down Expand Up @@ -228,9 +233,14 @@ def find_fastest_solution(self):

class GemmTuner:

def __init__(self, dtype, tuned_file=None, rocblas_decode=False):
def __init__(self,
indtype,
outdtype,
tuned_file=None,
rocblas_decode=False):
self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K'])
self.dtype = dtype
self.indtype = indtype
self.outdtype = outdtype
self.rocblas_decode = rocblas_decode
self.tuned_file = tuned_file
if Path(tuned_file).is_file():
Expand Down Expand Up @@ -259,13 +269,15 @@ def find_best_sols(self):
gemmobj = Gemm(ds['M'],
ds['N'],
ds['K'],
dtype=self.dtype,
indtype=self.indtype,
outdtype=self.outdtype,
rocblas_decode=self.rocblas_decode)
gemmobj.find_fastest_solution()
soldf.loc[i, 'libtype'] = gemmobj.best_libtype
soldf.loc[i, 'solidx'] = gemmobj.best_solidx
soldf.loc[i, 'soltimems'] = gemmobj.best_soltime
soldf['dtype'] = self.dtype
soldf['indtype'] = self.indtype
soldf['outdtype'] = self.outdtype
finaldf = pd.concat([self.gemm_problems, soldf], axis=1)
finaldf = pd.concat([finaldf, self.gdf])
finaldf.to_csv(self.tuned_file, index=False)
Expand Down
Loading

0 comments on commit 367aa5a

Please sign in to comment.