// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by consensys/gnark-crypto DO NOT EDIT

package poseidon2

import (
	"testing"

	fr "github.com/consensys/gnark-crypto/field/koalabear"
	"github.com/consensys/gnark-crypto/utils/cpu"
	"github.com/stretchr/testify/require"
)

func TestMulMulInternalInPlaceWidth16(t *testing.T) {
	var input, expected [16]fr.Element
	for i := range input {
		input[i].MustSetRandom()
	}

	expected = input

	h := NewPermutation(16, 6, 21)
	h.matMulInternalInPlace(expected[:])

	var sum fr.Element
	sum.Set(&input[0])
	for i := 1; i < h.params.Width; i++ {
		sum.Add(&sum, &input[i])
	}
	for i := 0; i < h.params.Width; i++ {
		input[i].Mul(&input[i], &diag16[i]).
			Add(&input[i], &sum)
		if !input[i].Equal(&expected[i]) {
			t.Fatal("mat mul internal w/ diagonal doesn't match hand calculated")
		}
	}
}

func TestMulMulInternalInPlaceWidth24(t *testing.T) {
	var input, expected [24]fr.Element
	for i := range input {
		input[i].MustSetRandom()
	}

	expected = input

	h := NewPermutation(24, 6, 21)
	h.matMulInternalInPlace(expected[:])

	var sum fr.Element
	sum.Set(&input[0])
	for i := 1; i < h.params.Width; i++ {
		sum.Add(&sum, &input[i])
	}
	for i := 0; i < h.params.Width; i++ {
		input[i].Mul(&input[i], &diag24[i]).
			Add(&input[i], &sum)
		if !input[i].Equal(&expected[i]) {
			t.Fatal("mat mul internal w/ diagonal doesn't match hand calculated")
		}
	}
}

func TestAVX512Width16(t *testing.T) {
	if !cpu.SupportAVX512 {
		t.Skip("AVX512 not supported")
	}
	assert := require.New(t)
	var input, expected [16]fr.Element
	for i := range input {
		input[i].MustSetRandom()
	}

	expected = input

	h := NewPermutation(16, 6, 21)

	err := h.Permutation(input[:])
	assert.NoError(err)

	h.disableAVX512()
	err = h.Permutation(expected[:])
	assert.NoError(err)

	// compare results
	for i := 0; i < h.params.Width; i++ {
		assert.True(input[i].Equal(&expected[i]), "avx512 result don't match purego")
	}
}

func TestAVX512Width24(t *testing.T) {
	if !cpu.SupportAVX512 {
		t.Skip("AVX512 not supported")
	}
	assert := require.New(t)
	var input, expected [24]fr.Element
	for i := range input {
		input[i].MustSetRandom()
	}

	expected = input

	h := NewPermutation(24, 6, 21)

	err := h.Permutation(input[:])
	assert.NoError(err)

	h.disableAVX512()
	err = h.Permutation(expected[:])
	assert.NoError(err)

	// compare results
	for i := 0; i < h.params.Width; i++ {
		assert.True(input[i].Equal(&expected[i]), "avx512 result don't match purego")
	}
}

func TestAVX512Permutation16x24(t *testing.T) {
	if !cpu.SupportAVX512 {
		t.Skip("AVX512 not supported")
	}
	assert := require.New(t)
	input := make([][16]fr.Element, 24)
	expected := make([][16]fr.Element, 24)

	for i := range input {
		for j := range input[i] {
			input[i][j].MustSetRandom()
		}
	}

	for i := range input {
		copy(expected[i][:], input[i][:])
	}

	h := NewPermutation(24, 6, 21)
	h.Permutation16x24((*[24][16]fr.Element)(input))

	h.disableAVX512()
	h.Permutation16x24((*[24][16]fr.Element)(expected))

	// compare results
	for i := range input {
		for j := range input[i] {
			assert.True(input[i][j].Equal(&expected[i][j]), "avx512 result don't match purego")
		}
	}
}

func BenchmarkPermutation16x24(b *testing.B) {
	input := make([][16]fr.Element, 24)
	for i := range input {
		for j := range input[i] {
			input[i][j].MustSetRandom()
		}
	}
	h := NewPermutation(24, 6, 21)

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		h.Permutation16x24((*[24][16]fr.Element)(input))
	}
}

func (h *Permutation) disableAVX512() {
	h.params.hasFast16_6_21 = false
	h.params.hasFast24_6_21 = false
}

func TestPoseidon2Width16(t *testing.T) {
	var input, expected [16]fr.Element
	// these are random values generated by MustSetRandom()
	input[0].SetUint64(595602690)
	input[1].SetUint64(847709907)
	input[2].SetUint64(543464918)
	input[3].SetUint64(2007411168)
	input[4].SetUint64(388763785)
	input[5].SetUint64(1476043928)
	input[6].SetUint64(1217186791)
	input[7].SetUint64(1009172579)
	input[8].SetUint64(1702185369)
	input[9].SetUint64(831063788)
	input[10].SetUint64(1937176007)
	input[11].SetUint64(1631695539)
	input[12].SetUint64(1955714534)
	input[13].SetUint64(1387220004)
	input[14].SetUint64(567062513)
	input[15].SetUint64(331325971)

	expected[0].SetUint64(1693177489)
	expected[1].SetUint64(50767021)
	expected[2].SetUint64(1825750786)
	expected[3].SetUint64(1570512031)
	expected[4].SetUint64(874586144)
	expected[5].SetUint64(1526919721)
	expected[6].SetUint64(2107355180)
	expected[7].SetUint64(1922897603)
	expected[8].SetUint64(1518961114)
	expected[9].SetUint64(141284986)
	expected[10].SetUint64(900994878)
	expected[11].SetUint64(115984755)
	expected[12].SetUint64(756527509)
	expected[13].SetUint64(1386241908)
	expected[14].SetUint64(525644973)
	expected[15].SetUint64(1531957077)

	h := NewPermutation(16, 6, 21)
	h.Permutation(input[:])
	for i := 0; i < h.params.Width; i++ {
		if !input[i].Equal(&expected[i]) {
			t.Fatal("mismatch error")
		}
	}
}

func TestPoseidon2Width24(t *testing.T) {
	var input, expected [24]fr.Element
	// these are random values generated by MustSetRandom()
	input[0].SetUint64(568554527)
	input[1].SetUint64(1037389773)
	input[2].SetUint64(974985042)
	input[3].SetUint64(693745454)
	input[4].SetUint64(445115978)
	input[5].SetUint64(247489969)
	input[6].SetUint64(1800921402)
	input[7].SetUint64(380223487)
	input[8].SetUint64(1663707776)
	input[9].SetUint64(542110938)
	input[10].SetUint64(1156833323)
	input[11].SetUint64(2007942824)
	input[12].SetUint64(2068171589)
	input[13].SetUint64(386387355)
	input[14].SetUint64(407453015)
	input[15].SetUint64(806215973)
	input[16].SetUint64(141351644)
	input[17].SetUint64(129559919)
	input[18].SetUint64(1565876180)
	input[19].SetUint64(257799181)
	input[20].SetUint64(1038008269)
	input[21].SetUint64(1353553525)
	input[22].SetUint64(410540253)
	input[23].SetUint64(1602372302)

	expected[0].SetUint64(1053460531)
	expected[1].SetUint64(1671312670)
	expected[2].SetUint64(214628630)
	expected[3].SetUint64(1942298267)
	expected[4].SetUint64(60214972)
	expected[5].SetUint64(347747608)
	expected[6].SetUint64(1401560933)
	expected[7].SetUint64(1851418915)
	expected[8].SetUint64(1066873794)
	expected[9].SetUint64(544902884)
	expected[10].SetUint64(2129748883)
	expected[11].SetUint64(329899943)
	expected[12].SetUint64(696093037)
	expected[13].SetUint64(1845838180)
	expected[14].SetUint64(932334704)
	expected[15].SetUint64(1648959581)
	expected[16].SetUint64(1988761311)
	expected[17].SetUint64(1694101983)
	expected[18].SetUint64(2032844528)
	expected[19].SetUint64(1961776557)
	expected[20].SetUint64(1649176607)
	expected[21].SetUint64(1828834386)
	expected[22].SetUint64(352206058)
	expected[23].SetUint64(1826445122)

	h := NewPermutation(24, 6, 21)
	h.Permutation(input[:])
	for i := 0; i < h.params.Width; i++ {
		if !input[i].Equal(&expected[i]) {
			t.Fatal("mismatch error")
		}
	}
}

func BenchmarkPoseidon2Width16(b *testing.B) {
	h := NewPermutation(16, 6, 21)

	var tmp [16]fr.Element
	for i := range tmp {
		tmp[i].MustSetRandom()
	}
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		h.Permutation(tmp[:])
	}
}

func BenchmarkPoseidon2Width24(b *testing.B) {
	h := NewPermutation(24, 6, 21)

	var tmp [24]fr.Element
	for i := range tmp {
		tmp[i].MustSetRandom()
	}
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		h.Permutation(tmp[:])
	}
}
