Skip to content

Commit

Permalink
[AMD] Add missing i16 for wmma and disable some tests (#4843)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexAUT authored Oct 3, 2024
1 parent 33c0c1c commit 1495116
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3330,7 +3330,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
if is_hip():
# hip does not support tf32 precision, so use ieee for all tests
input_precision = "ieee"
if "gfx11" in triton.runtime.driver.active.get_current_target().arch:
arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in arch or "gfx12" in arch:
if in_dtype_str == "float32":
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
if out_dtype_str == "float16":
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ std::string getTypeStr(Type ty) {
scalarName = "bf16";
} else if (ty.isInteger(32)) {
scalarName = "i32";
} else if (ty.isInteger(16)) {
scalarName = "i16";
} else if (ty.isInteger(8)) {
scalarName = "iu8";
} else if (ty.isInteger(4)) {
Expand Down

0 comments on commit 1495116

Please sign in to comment.