# BERT bfloat16 training problems for parameter batch = 56, seq_len = 512
# incuding all the relevant post-ops and data types propagation
#
# In total, there are 24 identical Encoder fragments in the topology:
#         ____|____ -----.
#        /    |    \     :
#      MM_1  MM_2  MM_3  :
#       |     |   /      :
#       |    MM_4 -------:
#        \   /           :
#         MM_5           :
#           |            :
#         MM_6 ----------`
#           |
#     Layer_norm ---.
#           |       :
#         MM_7      :
#           |       :
#         MM_8 -----`
#           |
#     Layer_norm
#
#
# Plus training-specific part
#

# =============================================================================
# FWD
# =============================================================================

# -----------------------------------------------------------------------------
# Encoder:
#   2d problems - M = seq_len * batch
#   4d problems - B = batch x 16, M = seq_len
# -----------------------------------------------------------------------------
--reset
--skip-impl=ref
--dt=bf16 --stag=ab --wtag=ab --dtag=ab
--bia_dt=bf16 --bia_mask=2
# MM_2, MM_3 and MM_6 are the same, but MM_6 with binary post-ops by default
28672x1024:1024x1024n"BERT:FWD,Encoder_MM_1*96"

--reset
--skip-impl=ref
--dt=bf16 --stag=abcd --wtag=abdc --dtag=abcd
#--attr-post-ops=add:bf16:13
56x16x384x64:56x16x64x384n"BERT:FWD,Encoder_MM_4*24"

--reset
--skip-impl=ref
--dt=bf16 --stag=abcd --wtag=abcd --dtag=abcd
56x16x384x384:56x16x384x64n"BERT:FWD,Encoder_MM_5*24"

#--reset
#--skip-impl=ref
#--dt=bf16 --stag=ab --wtag=ab --dtag=ab
#--bia_dt=bf16 --bia_mask=2
#--attr-post-ops=add:bf16:per_tensor
#28672x1024:1024x1024n"BERT:FWD,Encoder_MM_6"

--reset
--skip-impl=ref
--dt=bf16 --stag=ab --wtag=ab --dtag=ab
--bia_dt=bf16 --bia_mask=2
28672x1024:1024x4096n"BERT:FWD,Encoder_MM_7*24"

--reset
--skip-impl=ref
--dt=bf16 --stag=ab --wtag=ab --dtag=ab
--bia_dt=bf16 --bia_mask=2
--attr-post-ops=add:bf16:per_tensor
28672x4096:4096x1024n"BERT:FWD,Encoder_MM_8*24"

# -----------------------------------------------------------------------------
# Training-specific part:
#    Masked -  M = batch * num_mask_tokens, num_mask_tokens = 20
#    Pooler & prediction - M = batch
#    Embedding - M = seq_len * batch
# -----------------------------------------------------------------------------
--reset
--skip-impl=ref
--dt=bf16  --bia_dt=bf16 --bia_mask=2

--stag=ab --wtag=ab --dtag=ab 1120x1024:1024x1024n"BERT-L:FWD,Masked_1*1"
--stag=ab --wtag=ba --dtag=ab 1120x1024:1024x30522n"BERT-L:FWD,Masked_2*1"

--stag=ab --wtag=ab --dtag=ab
56x1024:1024x1024n"BERT-L:FWD,Pooler*1"
56x1024:1024x2n"BERT-L:FWD,Prediction*1"
28672x2:2x1024n"BERT-L:FWD,Embedding*1"

# =============================================================================
# BWD/D
# =============================================================================
# -----------------------------------------------------------------------------
# Encoder:
#   2d problems - M = seq_len * batch
#   4d problems - B = batch x 16, M = seq_len
# -----------------------------------------------------------------------------

--reset
--skip-impl=ref
--dt=bf16
--stag=ab --wtag=ba --dtag=ab # A - plain, B - transformed
# MM_2, MM_3 and MM_6 are the same
28672x1024:1024x1024n"BERT:BWD_D,Encoder_MM_1*96"

--stag=abcd --wtag=abcd --dtag=abcd # A - plain, B - plain
56x16x384x384:56x16x384x64n"BERT:BWD_D,Encoder_MM_4_A*24"

--stag=abdc --wtag=abcd --dtag=abcd # A - transformed, B - plain
# MM_5 B gradient is the same
56x16x384x384:56x16x384x64n"BERT:BWD_D,Encoder_MM_4_B*48"

--stag=abcd --wtag=abdc --dtag=abcd # A - plain, B - transformed
56x16x384x64:56x16x64x384n"BERT:BWD_D,Encoder_MM_5_A*24"

--stag=ab --wtag=ba --dtag=ab # A - plain, B - transformed
28672x4096:4096x1024n"BERT:BWD_D,Encoder_MM_7*24"
28672x1024:1024x4096n"BERT:BWD_D,Encoder_MM_8*24"

# -----------------------------------------------------------------------------
# Training-specific part:
#    Masked -  M = batch * num_mask_tokens, num_mask_tokens = 20
#    Pooler & prediction - M = batch
#    Embedding - M = seq_len * batch
# -----------------------------------------------------------------------------
--stag=ab --wtag=ba --dtag=ab
1120x1024:1024x1024n"BERT-L:BWD_D,Masked_1*1"
--stag=ab --wtag=ab --dtag=ab
1120x30522:30522x1024n"BERT-L:BWD_D,Masked_2*1"
--stag=ab --wtag=ba --dtag=ab
56x1024:1024x1024n"BERT-L:BWD_D,Pooler*1"
56x2:2x1024n"BERT-L:BWD_D,Prediction*1"
28672x1024:1024x2n"BERT-L:BWD_D,Embedding*1"

# =============================================================================
# BWD/W
# =============================================================================
# -----------------------------------------------------------------------------
# Encoder:
#   2d problems - K = seq_len * batch
# -----------------------------------------------------------------------------

--dt=bf16:bf16:f32
--stag=ba --wtag=ab --dtag=ab # A - transformed, B - plain
# MM_2, MM_3 and MM_6 are the same
1024x28672:28672x1024n"BERT:BWD_W,Encoder_MM_1*96"
1024x28672:28672x4096n"BERT:BWD_W,Encoder_MM_7*24"
4096x28672:28672x1024n"BERT:BWD_W,Encoder_MM_8*24"

# -----------------------------------------------------------------------------
# Training-specific part:
#    Masked -  K = batch * num_mask_tokens, num_mask_tokens = 20
#    Pooler & prediction - K = batch
#    Embedding - K = seq_len * batch
# -----------------------------------------------------------------------------
1024x1120:1120x1024n"BERT-L:BWD_W,Masked_1*1"
1024x56:56x1024n"BERT-L:BWD_W,Pooler*1"
1024x56:56x2n"BERT-L:BWD_W,Prediction*1"
2x28672:28672x1024n"BERT-L:BWD_W,Embedding*1"
