/*
 * Copyright 2010-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
 * that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.idea.intentions

import com.intellij.codeInsight.intention.LowPriorityAction
import com.intellij.openapi.editor.Editor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.impl.AnonymousFunctionDescriptor
import org.jetbrains.kotlin.idea.caches.resolve.analyze
import org.jetbrains.kotlin.idea.core.ShortenReferences
import org.jetbrains.kotlin.idea.core.moveInsideParentheses
import org.jetbrains.kotlin.idea.core.replaced
import org.jetbrains.kotlin.idea.intentions.branchedTransformations.BranchedFoldingUtils
import org.jetbrains.kotlin.idea.search.usagesSearch.descriptor
import org.jetbrains.kotlin.idea.util.IdeDescriptorRenderers
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.collectDescendantsOfType
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.getStrictParentOfType
import org.jetbrains.kotlin.resolve.bindingContextUtil.getTargetFunctionDescriptor
import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode
import org.jetbrains.kotlin.types.isFlexible
import org.jetbrains.kotlin.types.typeUtil.isUnit
import org.jetbrains.kotlin.types.typeUtil.makeNotNullable

class LambdaToAnonymousFunctionIntention : SelfTargetingIntention<KtLambdaExpression>(
    KtLambdaExpression::class.java,
    "Convert to anonymous function",
    "Convert lambda expression to anonymous function"
), LowPriorityAction {
    override fun isApplicableTo(element: KtLambdaExpression, caretOffset: Int): Boolean {
        if (element.getStrictParentOfType<KtValueArgument>() == null) return false
        if (element.getStrictParentOfType<KtFunction>()?.hasModifier(KtTokens.INLINE_KEYWORD) == true) return false
        val descriptor = element.functionLiteral.descriptor as? AnonymousFunctionDescriptor ?: return false
        if (descriptor.valueParameters.any { it.name.isSpecial }) return false
        val lastElement = element.functionLiteral.arrow ?: element.functionLiteral.lBrace
        return caretOffset <= lastElement.endOffset
    }

    override fun applyTo(element: KtLambdaExpression, editor: Editor?) {
        val functionDescriptor = element.functionLiteral.descriptor as? AnonymousFunctionDescriptor ?: return
        val resultingFunction = convertLambdaToFunction(element, functionDescriptor) ?: return
        (resultingFunction.parent as? KtLambdaArgument)?.also { it.moveInsideParentheses(it.analyze(BodyResolveMode.PARTIAL)) }
    }

    companion object {
        private val typeSourceCode = IdeDescriptorRenderers.SOURCE_CODE_TYPES

        fun convertLambdaToFunction(
            lambda: KtLambdaExpression,
            functionDescriptor: FunctionDescriptor,
            functionName: String = "",
            replaceElement: (KtNamedFunction) -> KtExpression = { lambda.replaced(it) }
        ): KtExpression? {
            val functionLiteral = lambda.functionLiteral
            val bodyExpression = functionLiteral.bodyExpression ?: return null

            val context = bodyExpression.analyze(BodyResolveMode.PARTIAL)
            val functionLiteralDescriptor by lazy { functionLiteral.descriptor }
            bodyExpression.collectDescendantsOfType<KtReturnExpression>().forEach {
                val targetDescriptor = it.getTargetFunctionDescriptor(context)
                if (targetDescriptor == functionDescriptor || targetDescriptor == functionLiteralDescriptor) it.labeledExpression?.delete()
            }

            val psiFactory = KtPsiFactory(lambda)
            val function = psiFactory.createFunction(
                KtPsiFactory.CallableBuilder(KtPsiFactory.CallableBuilder.Target.FUNCTION).apply {
                    typeParams()
                    functionDescriptor.extensionReceiverParameter?.type?.let {
                        receiver(typeSourceCode.renderType(it))
                    }
                    name(functionName)
                    for (parameter in functionDescriptor.valueParameters) {
                        val type = parameter.type.let { if (it.isFlexible()) it.makeNotNullable() else it }
                        param(parameter.name.asString(), typeSourceCode.renderType(type))
                    }
                    functionDescriptor.returnType?.takeIf { !it.isUnit() }?.let {
                        val lastStatement = bodyExpression.statements.lastOrNull()
                        if (lastStatement != null && lastStatement !is KtReturnExpression) {
                            val foldableReturns = BranchedFoldingUtils.getFoldableReturns(lastStatement)
                            if (foldableReturns == null || foldableReturns.isEmpty()) {
                                lastStatement.replace(psiFactory.createExpressionByPattern("return $0", lastStatement))
                            }
                        }
                        returnType(typeSourceCode.renderType(it))
                    } ?: noReturnType()
                    blockBody(" " + bodyExpression.text)
                }.asString()
            )
            return replaceElement(function).also { ShortenReferences.DEFAULT.process(it) }
        }
    }
}
