package connector

import (
	"bytes"
	"context"
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"net"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/fleeting/fleeting/provider"
	"golang.org/x/crypto/ssh"
)

func TestConnectionFailed(t *testing.T) {
	listener, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	defer listener.Close()

	t.Parallel()

	t.Run("reach deadline", func(t *testing.T) {
		_, err := DialSSH(context.Background(), provider.ConnectInfo{
			ConnectorConfig: provider.ConnectorConfig{
				Timeout: time.Second,
			},
			InternalAddr: listener.Addr().String(),
		}, DialOptions{})

		require.ErrorContains(t, err, "ssh: handshake failed: read tcp")
	})

	t.Run("context deadline exceeded", func(t *testing.T) {
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()

		_, err := DialSSH(ctx, provider.ConnectInfo{
			ConnectorConfig: provider.ConnectorConfig{
				Timeout: time.Hour,
			},
			InternalAddr: listener.Addr().String(),
		}, DialOptions{})

		require.ErrorIs(t, err, context.DeadlineExceeded)
	})
}

func TestSSHServerConnection(t *testing.T) {
	// Generate a test host key
	hostPrivateKeySigner, err := generateHostKey()
	require.NoError(t, err)

	// Generate a test user key pair
	userPrivateKey, userPublicKey, err := generateUserKeyPair()
	require.NoError(t, err)

	// Set up SSH server config
	serverConfig := &ssh.ServerConfig{
		PublicKeyCallback: func(conn ssh.ConnMetadata, auth ssh.PublicKey) (*ssh.Permissions, error) {
			if bytes.Equal(auth.Marshal(), userPublicKey.Marshal()) {
				return &ssh.Permissions{}, nil
			}
			return nil, fmt.Errorf("unknown public key for %q", conn.User())
		},
	}
	serverConfig.AddHostKey(hostPrivateKeySigner)

	// Start SSH server
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)
	defer listener.Close()

	serverDone := make(chan struct{})
	go func() {
		defer close(serverDone)
		conn, err := listener.Accept()
		if err != nil {
			return
		}
		defer conn.Close()

		// Perform SSH handshake
		sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
		if err != nil {
			return
		}
		defer sshConn.Close()

		// Handle incoming channels and requests
		go ssh.DiscardRequests(reqs)
		go func() {
			for newChannel := range chans {
				newChannel.Reject(ssh.UnknownChannelType, "test server")
			}
		}()

		// Keep connection alive briefly for test
		time.Sleep(100 * time.Millisecond)
	}()

	t.Parallel()

	t.Run("successful connection to SSH server", func(t *testing.T) {
		keyBytes := encodePrivateKey(userPrivateKey)

		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
		defer cancel()

		client, err := DialSSH(ctx, provider.ConnectInfo{
			ConnectorConfig: provider.ConnectorConfig{
				Timeout:  5 * time.Second,
				Username: "testuser",
				Key:      keyBytes,
			},
			InternalAddr: listener.Addr().String(),
		}, DialOptions{})

		require.NoError(t, err)
		require.NotNil(t, client)

		// Verify we can close the connection
		err = client.Close()
		require.NoError(t, err)
	})

	// Wait for server goroutine to finish
	select {
	case <-serverDone:
	case <-time.After(15 * time.Second):
		t.Fatal("server did not finish in time")
	}
}

// Helper function to generate a host key for the test SSH server
func generateHostKey() (ssh.Signer, error) {
	key, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return nil, err
	}
	return ssh.NewSignerFromKey(key)
}

// Helper function to generate user key pair for authentication
func generateUserKeyPair() (crypto.PrivateKey, ssh.PublicKey, error) {
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return nil, nil, err
	}

	publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
	if err != nil {
		return nil, nil, err
	}

	return privateKey, publicKey, nil
}

// Helper function to encode private key to the format expected by ConnectInfo
func encodePrivateKey(privateKey crypto.PrivateKey) []byte {
	rsaKey, ok := privateKey.(*rsa.PrivateKey)
	if !ok {
		return nil
	}

	privateKeyPEM := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
	}

	return pem.EncodeToMemory(privateKeyPEM)
}
