/*
 * Copyright (c) Forge Development LLC
 * SPDX-License-Identifier: LGPL-2.1-only
 */
package net.minecraftforge.testing.aggregate;

import org.codehaus.groovy.runtime.StringGroovyMethods;
import org.gradle.api.Project;
import org.gradle.api.file.Directory;
import org.gradle.api.file.ProjectLayout;
import org.gradle.api.model.ObjectFactory;
import org.gradle.api.plugins.jvm.JvmTestSuite;
import org.gradle.api.provider.Provider;
import org.gradle.api.tasks.TaskProvider;
import org.gradle.api.tasks.testing.Test;
import org.gradle.jvm.toolchain.JavaLanguageVersion;
import org.gradle.jvm.toolchain.JavaToolchainService;
import org.gradle.jvm.toolchain.JvmImplementation;
import org.gradle.jvm.toolchain.JvmVendorSpec;
import org.gradle.testing.base.TestingExtension;
import org.jetbrains.annotations.UnmodifiableView;

import javax.inject.Inject;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

abstract class AggregateTestExtensionImpl implements AggregateTestExtensionInternal {
    private final Map<String, List<Integer>> jvms = new TreeMap<>(Comparator.naturalOrder());

    private final AggregateTestProblems problems = this.getObjects().newInstance(AggregateTestProblems.class);

    protected abstract @Inject ObjectFactory getObjects();

    protected abstract @Inject ProjectLayout getLayout();

    protected abstract @Inject JavaToolchainService getJavaToolchains();

    @Inject
    public AggregateTestExtensionImpl(Project project) {
        project.afterEvaluate(this::finish);
    }

    @Override
    public @UnmodifiableView Map<String, List<Integer>> getJvms() {
        return Collections.unmodifiableMap(this.jvms);
    }

    @Override
    public void jvms(Map<? extends CharSequence, List<Integer>> jvms) {
        for (var entry : jvms.entrySet()) {
            this.jvms.put(entry.getKey().toString(), entry.getValue());
        }
    }

    @Override
    public void setJvms(Map<? extends CharSequence, List<Integer>> jvms) {
        this.jvms.clear();
        this.jvms(jvms);
    }

    private void finish(Project project) {
        if (project.getPluginManager().hasPlugin("jvm-test-suite")) {
            var testing = project.getExtensions().getByType(TestingExtension.class);
            var suites = testing.getSuites().withType(JvmTestSuite.class);

            for (var suite : suites) {
                for (var target : suite.getTargets()) {
                    this.configure(project, target.getTestTask().get());
                }
            }
        } else {
            for (var test : project.getTasks().withType(Test.class)) {
                this.configure(project, test);
            }
        }
    }

    private void configure(Project project, Test test) {
        if (this.jvms.isEmpty()) return;

        TaskProvider<? extends AggregateTest> testAll = project.getTasks().register("aggregate" + StringGroovyMethods.capitalize(test.getName()), AggregateTestImpl.class);

        for (var entry : this.jvms.entrySet()) {
            var vendor = entry.getKey();
            var versions = entry.getValue();

            for (int version : versions) {
                var output = this.getLayout().getBuildDirectory().dir(AggregateTest.TEST_RESULTS_DIRECTORY + "/%s-%d".formatted(vendor, version)).map(this.problems.ensureFileLocation());
                var task = this.register(project, test, vendor, version, output);

                testAll.configure(testAllTask -> {
                    testAllTask.getInputs().dir(output);
                    testAllTask.dependsOn(task);
                });
            }
        }
    }

    private TaskProvider<Test> register(Project project, Test test, String vendor, int version, Provider<Directory> output) {
        var vendorSpec = JvmVendorSpec.of(vendor);

        return project.getTasks().register("testUsing" + StringGroovyMethods.capitalize(vendor) + version, Test.class, task -> {
            task.setGroup("Aggregate Testing");
            var description = test.getDescription();
            task.setDescription(description != null
                ? "%s (using %s %s)".formatted(description, vendorSpec, version)
                : "Runs '%s' using %s %s".formatted(test.getName(), vendorSpec, version));
            task.notCompatibleWithConfigurationCache(
                "JVM-specific tests need the test framework from the parent test task, which cannot be serialized."
            );

            task.getTestFrameworkProperty().set(test.getTestFrameworkProperty());
            task.setClasspath(test.getClasspath());
            task.setTestClassesDirs(test.getTestClassesDirs());
            task.getJavaLauncher().set(this.getJavaToolchains().launcherFor(spec -> {
                spec.getVendor().set(vendorSpec);
                spec.getLanguageVersion().set(JavaLanguageVersion.of(version));
                spec.getImplementation().set(JvmImplementation.VENDOR_SPECIFIC);
            }));
            task.reports(reports -> {
                reports.getHtml().getOutputLocation().set(output);
                reports.getJunitXml().getOutputLocation().set(output);
            });
        });
    }
}
