# Plain cases
--reset
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32
--bia_dt=f32
--bia_mask=2
--batch=shapes_2d_ci
--batch=shapes_2d
--bia_mask=4
--batch=shapes_3d

# Post-ops check for different data types
--reset
--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32
--attr-post-ops=sum+relu:0.5+add:f32
--batch=shapes_2d_ci

--attr-post-ops=sum:0.5,\
                linear:2:1,\
                add:f32,\
                add:u8:per_oc,\
                prelu:per_tensor,\
                prelu:per_oc,\
                sum+relu:0.5+add:f32
--batch=shapes_2d
--batch=shapes_3d

--reset
--dt=bf16
--attr-post-ops=mul:bf16,div:bf16
30x40:40x50_n"bf16_binary_po_special_kinds"

# Different tags
--reset
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32
--stag=ab,ba
--wtag=ab,ba
--dtag=ab,ba
--batch=shapes_2d
--stag=abc,acb
--wtag=abc,acb
--dtag=abc,acb
--batch=shapes_3d

# Sum with different data type
--reset
--dt=f64,f32
--attr-post-ops=sum:0.25:0:s32
--batch=shapes_2d
--batch=shapes_3d
--dt=u8:s8:s8
--attr-post-ops=sum:0.25:0:u8
--batch=shapes_2d
--batch=shapes_3d

# Arg scales check
--reset
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:2,wei:per_oc
--batch=shapes_2d_ci

--attr-scales=src:common:0.25,\
              wei:common:0.5,\
              src:common:0.25+wei:common:0.5,\
              src:common:0.25+dst:common:2,\
              src:common:0.25+wei:common:0.5+dst:common:2
--batch=shapes_2d
--batch=shapes_3d

# Zero-points check
--reset
--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16
--attr-zero-points=src:common:1+wei:common:-1+dst:common:2
--batch=shapes_2d_ci

--attr-zero-points=src:common:1,\
                   wei:common:-1,\
                   dst:common:2,\
                   src:common:-1+wei:common:1,\
                   src:common:1+dst:common:-2,\
                   src:common:-1+wei:common:-1+dst:common:2
--batch=shapes_2d
--batch=shapes_3d

# Run-time dimensions check
--reset
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3
--runtime_dims_masks=15:15
--stag=ab,ba
--wtag=ab,ba
--dtag=ab
--batch=shapes_2d_ci

--stag=abc,acb
--wtag=abc,acb
--dtag=abc,acb
--batch=shapes_3d

--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16
--stag=ab
--wtag=ab
--dtag=ab
--batch=shapes_2d_ci

### int8 wei decomp
--reset
--stag=ab,ba
--dt=bf16:s8:bf16,bf16:u8:bf16
--attr-scales=wei:common:2,wei:per_oc:bf16
--attr-fpmath=bf16:true
--batch=shapes_2d

--reset
--stag=ab,ba
--dt=f16:s8:f16,f16:u8:f16
--attr-scales=wei:common:2,wei:per_oc,wei:per_ocic:f16
--attr-zero-points=,wei:common:2,wei:per_oc,wei:per_ocic:s8
--attr-fpmath=f16:true
--batch=shapes_2d_ci

--stag=abc,acb
--dt=bf16:s8:bf16,bf16:u8:bf16
--attr-scales=wei:per_ocic:bf16:2x1
--attr-zero-points=,wei:per_ocic:u8:4x1
--attr-fpmath=bf16:true
1x5x12:1x12x17
3x5x12:3x12x17
3x5x12:1x12x17

### int4 wei decomp
--reset
--dt=bf16:s4:bf16,bf16:u4:bf16
--stag=abc,acb
--attr-scales=wei:common:2,wei:per_oc:bf16,wei:per_ocic:bf16:32x1
# There is no support for a common u4/s4 zero point
--attr-zero-points=wei:per_oc:u4,wei:per_ocic:s4,wei:per_ocic:s4:32x1
--attr-fpmath=bf16:true
7x6x32:7x32x64
7x6x32:1x32x64
3x6x96:3x96x64
3x6x96:1x96x64

# Test bf32, tf32 data type configuration
--reset
--skip-impl=ref,x64:gemm
--dt=f32
--attr-fpmath=bf16,tf32
77x133:133x117
15x24x16:15x16x32
7x16x24x8:7x16x8x24
--skip-impl=

# test all the supported data type configurations + bias data types
--reset
--dt=f64,f32
--bia_dt=undef,f32
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=bf16,bf16:bf16:f32
--bia_dt=undef,f32,bf16
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=f16,f16:f16:f32
--bia_dt=undef,f32,f16
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=f8_e5m2,f8_e5m2:f8_e5m2:f32
--bia_dt=undef,f32,f8_e5m2
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=f8_e4m3,f8_e4m3:f8_e4m3:f32
--bia_dt=undef,f32,f8_e4m3
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=u8:s8:f32,u8:s8:s32,u8:s8:s8,u8:s8:u8,\
     s8:s8:f32,s8:s8:s32,s8:s8:s8,s8:s8:u8
--bia_dt=undef,f32,u8,s8,s32
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--dt=u8:s8:bf16,u8:s8:f16,\
     s8:s8:bf16,s8:s8:f16
--bia_dt=undef,f32
--bia_mask=2,3  77x133:133x117
--bia_mask=4,6  15x24x16:15x16x32
--bia_mask=8,12 7x16x24x8:7x16x8x24

--reset
--stag=abc --wtag=abc --dtag=abc
--bia_dt=f32 --bia_mask=0,1,2,3,4
2x32x64:1x64x32

# Basic post-ops with runtime dims 2D.
--reset
--dt=f32,s8:s8:s32,bf16
--stag=ab --wtag=ab --dtag=ab
--runtime_dims_masks=,1:0,0:2,2:1,1:2,3:1,2:3,3:3
--attr-post-ops=mul:f32,relu,sum,prelu,prelu:per_oc
3x20:20x4n"postops+runtime_dims_2d"

# Basic post-ops with runtime dims 3D.
--reset
--dt=f32,s8:s8:s32,bf16
--stag=abc --wtag=abc --dtag=abc
--runtime_dims_masks=,2:0,0:4,4:2,2:4,6:2,4:6,6:6
--attr-post-ops=mul:f32,relu,sum,prelu,prelu:per_oc
10x3x20:10x20x4n"postops+runtime_dims_3d"

# Basic post-ops with runtime dims 4D.
--reset
--dt=f32,s8:s8:s32,bf16
--stag=abcd --wtag=abcd --dtag=abcd
--runtime_dims_masks=,4:0,0:8,8:4,4:8,12:4,8:12,12:12
--attr-post-ops=mul:f32,relu,sum,prelu,prelu:per_oc
2x10x3x20:2x10x20x4n"postops+runtime_dims_4d"
