/*
 * Decompiled with CFR 0.152.
 */
package net.neoforged.fml.classloading;

import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.module.Configuration;
import java.lang.module.ModuleDescriptor;
import java.lang.module.ModuleReader;
import java.lang.module.ModuleReference;
import java.lang.module.ResolvedModule;
import java.lang.reflect.Field;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.security.AllPermission;
import java.security.CodeSigner;
import java.security.CodeSource;
import java.security.Permissions;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import net.neoforged.fml.classloading.JarModuleFinder;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class ModuleClassLoader
extends ClassLoader
implements AutoCloseable {
    private static final MethodHandle LAYER_BIND_TO_LOADER;
    private final Map<String, ModuleInfo> moduleInfoCache;
    private final Map<String, ModuleInfo> packageLookup;
    private final Map<String, ClassLoader> parentLoaders;
    private final Configuration configuration;
    private ClassLoader fallbackClassLoader;
    private volatile boolean closed = false;

    private static void bindToLayer(ModuleClassLoader classLoader, ModuleLayer layer) {
        try {
            LAYER_BIND_TO_LOADER.invokeExact(layer, classLoader);
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
    }

    public ModuleClassLoader(String name, Configuration configuration, List<ModuleLayer> parentLayers) {
        this(name, configuration, parentLayers, null);
    }

    @VisibleForTesting
    public ModuleClassLoader(String name, Configuration configuration, List<ModuleLayer> parentLayers, @Nullable ClassLoader parentLoader) {
        super(name, parentLoader);
        this.configuration = configuration;
        this.fallbackClassLoader = Objects.requireNonNullElse(parentLoader, ClassLoader.getPlatformClassLoader());
        this.moduleInfoCache = HashMap.newHashMap(configuration.modules().size());
        int packageCount = 0;
        for (ResolvedModule m : configuration.modules()) {
            ModuleReference moduleReference = m.reference();
            if (moduleReference instanceof JarModuleFinder.JarModuleReference) {
                JarModuleFinder.JarModuleReference jarRef = (JarModuleFinder.JarModuleReference)moduleReference;
                String moduleReference2 = m.reference().descriptor().name();
                ModuleInfo moduleInfo = new ModuleInfo(this, moduleReference2, jarRef);
                this.moduleInfoCache.put(moduleReference2, moduleInfo);
                continue;
            }
            throw new IllegalArgumentException("Unsupported module reference type: " + String.valueOf(m.reference().getClass()));
        }
        this.packageLookup = HashMap.newHashMap(packageCount);
        for (ModuleInfo moduleInfo : this.moduleInfoCache.values()) {
            for (String string : moduleInfo.moduleReference.descriptor().packages()) {
                this.packageLookup.put(string, moduleInfo);
            }
        }
        this.parentLoaders = new HashMap<String, ClassLoader>();
        HashSet<ModuleDescriptor> processedAutomaticDescriptors = new HashSet<ModuleDescriptor>();
        HashMap<ResolvedModule, ClassLoader> classLoaderMap = new HashMap<ResolvedModule, ClassLoader>();
        Function<ResolvedModule, ClassLoader> findClassLoader = k -> {
            if (!this.moduleInfoCache.containsKey(k.name())) {
                for (ModuleLayer parentLayer : parentLayers) {
                    ClassLoader loader;
                    if (parentLayer.configuration() != k.configuration() || (loader = parentLayer.findLoader(k.name())) == null) continue;
                    return loader;
                }
                return ClassLoader.getPlatformClassLoader();
            }
            return this;
        };
        for (ResolvedModule rm : configuration.modules()) {
            for (ResolvedModule other : rm.reads()) {
                ClassLoader cl = classLoaderMap.computeIfAbsent(other, findClassLoader);
                ModuleDescriptor descriptor = other.reference().descriptor();
                if (descriptor.isAutomatic()) {
                    if (!processedAutomaticDescriptors.add(descriptor)) continue;
                    descriptor.packages().forEach(pn -> this.parentLoaders.put((String)pn, cl));
                    continue;
                }
                descriptor.exports().stream().filter(e -> !e.isQualified() || e.isQualified() && other.configuration() == configuration && e.targets().contains(rm.name())).map(ModuleDescriptor.Exports::source).forEach(pn -> this.parentLoaders.put((String)pn, cl));
            }
        }
        HashSet hashSet = new HashSet();
        parentLayers.forEach(p -> ModuleClassLoader.forLayerAndParents(p, visitedLayers, l -> ModuleClassLoader.bindToLayer(this, l)));
    }

    private static void forLayerAndParents(ModuleLayer layer, Set<ModuleLayer> visited, Consumer<ModuleLayer> operation) {
        if (visited.contains(layer)) {
            return;
        }
        visited.add(layer);
        operation.accept(layer);
        if (layer != ModuleLayer.boot()) {
            layer.parents().forEach(l -> ModuleClassLoader.forLayerAndParents(l, visited, operation));
        }
    }

    private URL readerToURL(ModuleInfo moduleInfo, String name) throws IOException {
        ModuleReader reader = moduleInfo.getReader();
        return ModuleClassLoader.toURL(reader.find(name));
    }

    private static URL toURL(Optional<URI> uri) {
        if (uri.isPresent()) {
            try {
                return uri.get().toURL();
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException(e);
            }
        }
        return null;
    }

    private static byte[] getClassBytes(ModuleInfo moduleInfo, String name) throws IOException {
        String cname = name.replace('.', '/') + ".class";
        ModuleReader reader = moduleInfo.getReader();
        try (InputStream istream = reader.open(cname).orElse(null);){
            if (istream == null) {
                byte[] byArray = new byte[]{};
                return byArray;
            }
            byte[] byArray = istream.readAllBytes();
            return byArray;
        }
    }

    @Nullable
    private Class<?> readerToClass(ModuleInfo moduleInfo, String name) throws ClassNotFoundException {
        byte[] bytes;
        try {
            bytes = ModuleClassLoader.getClassBytes(moduleInfo, name);
        }
        catch (IOException e) {
            throw new ClassNotFoundException(name, e);
        }
        bytes = this.maybeTransformClassBytes(bytes, name, null);
        if (bytes.length == 0) {
            return null;
        }
        return this.defineClass(name, bytes, 0, bytes.length, moduleInfo.protectionDomain);
    }

    protected byte[] maybeTransformClassBytes(byte[] bytes, String name, @Nullable String context) {
        return bytes;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
        Object object = this.getClassLoadingLock(name);
        synchronized (object) {
            Class<?> c = this.findLoadedClass(name);
            if (c == null) {
                ModuleInfo localModule;
                String packageName = ModuleClassLoader.packageName(name);
                c = packageName != null ? ((localModule = this.packageLookup.get(packageName)) != null ? this.readerToClass(localModule, name) : this.parentLoaders.getOrDefault(packageName, this.fallbackClassLoader).loadClass(name)) : this.fallbackClassLoader.loadClass(name);
            }
            if (c == null) {
                throw new ClassNotFoundException(name);
            }
            if (resolve) {
                this.resolveClass(c);
            }
            return c;
        }
    }

    @Override
    public URL getResource(String name) {
        try {
            Enumeration<URL> reslist = this.enumerateResources(name);
            if (reslist.hasMoreElements()) {
                return reslist.nextElement();
            }
            return this.fallbackClassLoader.getResource(name);
        }
        catch (IOException e) {
            return null;
        }
    }

    @Override
    protected URL findResource(String moduleName, String name) throws IOException {
        ModuleInfo localModule = this.moduleInfoCache.get(moduleName);
        if (localModule == null) {
            return null;
        }
        return this.readerToURL(localModule, name);
    }

    @Override
    public Enumeration<URL> getResources(String name) throws IOException {
        final Enumeration<URL> localUrls = this.enumerateResources(name);
        final Enumeration<URL> parentUrls = this.fallbackClassLoader.getResources(name);
        return new Enumeration<URL>(this){

            @Override
            public boolean hasMoreElements() {
                return localUrls.hasMoreElements() || parentUrls.hasMoreElements();
            }

            @Override
            public URL nextElement() {
                if (localUrls.hasMoreElements()) {
                    return (URL)localUrls.nextElement();
                }
                return (URL)parentUrls.nextElement();
            }
        };
    }

    @Override
    protected Enumeration<URL> findResources(String name) throws IOException {
        return this.enumerateResources(name);
    }

    private Enumeration<URL> enumerateResources(String name) throws IOException {
        int idx = name.lastIndexOf(47);
        String pkgname = idx == -1 || idx == name.length() - 1 ? "" : name.substring(0, idx).replace('/', '.');
        ModuleInfo localModule = this.packageLookup.get(pkgname);
        if (localModule != null) {
            URL url = this.readerToURL(localModule, name);
            return url != null ? ModuleClassLoader.singletonEnumeration(url) : Collections.emptyEnumeration();
        }
        URL firstResult = null;
        ArrayList<URL> multipleResult = null;
        for (ModuleInfo moduleInfo : this.moduleInfoCache.values()) {
            URL url = ModuleClassLoader.toURL(moduleInfo.getReader().find(name));
            if (url == null) continue;
            if (firstResult == null) {
                firstResult = url;
                continue;
            }
            if (multipleResult == null) {
                multipleResult = new ArrayList<URL>();
                multipleResult.add(firstResult);
                multipleResult.add(url);
                continue;
            }
            multipleResult.add(url);
        }
        if (multipleResult != null) {
            return Collections.enumeration(multipleResult);
        }
        if (firstResult != null) {
            return ModuleClassLoader.singletonEnumeration(firstResult);
        }
        return Collections.emptyEnumeration();
    }

    private static Enumeration<URL> singletonEnumeration(final URL url) {
        return new Enumeration<URL>(){
            boolean read = false;

            @Override
            public boolean hasMoreElements() {
                return !this.read;
            }

            @Override
            public URL nextElement() {
                if (this.read) {
                    throw new NoSuchElementException();
                }
                this.read = true;
                return url;
            }
        };
    }

    @Override
    protected Class<?> findClass(String moduleName, String name) {
        ModuleInfo localModule = this.moduleInfoCache.get(moduleName);
        if (localModule != null) {
            try {
                Class<?> c = this.readerToClass(localModule, name);
                if (c != null) {
                    return c;
                }
            }
            catch (ClassNotFoundException classNotFoundException) {
                // empty catch block
            }
        }
        return null;
    }

    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        Class<?> c;
        ModuleInfo localModule;
        String packageName = ModuleClassLoader.packageName(name);
        if (packageName != null && (localModule = this.packageLookup.get(packageName)) != null && (c = this.readerToClass(localModule, name)) != null) {
            return c;
        }
        throw new ClassNotFoundException(name);
    }

    protected byte[] getMaybeTransformedClassBytes(String name, String context) throws ClassNotFoundException {
        IOException suppressed;
        byte[] bytes;
        block12: {
            bytes = new byte[]{};
            suppressed = null;
            try {
                String pname = ModuleClassLoader.packageName(name);
                if (pname == null) break block12;
                ModuleInfo localModule = this.packageLookup.get(pname);
                if (localModule != null) {
                    bytes = ModuleClassLoader.getClassBytes(localModule, name);
                    break block12;
                }
                ClassLoader parentLoader = this.parentLoaders.get(pname);
                if (parentLoader == null) break block12;
                String cname = name.replace('.', '/') + ".class";
                try (InputStream is = parentLoader.getResourceAsStream(cname);){
                    if (is != null) {
                        bytes = is.readAllBytes();
                    }
                }
            }
            catch (IOException e) {
                suppressed = e;
            }
        }
        byte[] maybeTransformedBytes = this.maybeTransformClassBytes(bytes, name, context);
        if (maybeTransformedBytes.length == 0) {
            ClassNotFoundException e = new ClassNotFoundException(name);
            if (suppressed != null) {
                e.addSuppressed(suppressed);
            }
            throw e;
        }
        return maybeTransformedBytes;
    }

    public void setFallbackClassLoader(ClassLoader fallbackClassLoader) {
        this.fallbackClassLoader = fallbackClassLoader;
    }

    @Override
    public void close() throws IOException {
        if (this.closed) {
            return;
        }
        this.closed = true;
        IOException firstException = null;
        for (ModuleInfo moduleInfo : this.moduleInfoCache.values()) {
            try {
                moduleInfo.close();
            }
            catch (IOException e) {
                if (firstException == null) {
                    firstException = e;
                    continue;
                }
                firstException.addSuppressed(e);
            }
        }
        this.moduleInfoCache.clear();
        if (firstException != null) {
            throw firstException;
        }
    }

    @Nullable
    private static String packageName(String className) {
        int lastSeparator = className.lastIndexOf(46);
        if (lastSeparator <= 0) {
            return null;
        }
        return className.substring(0, className.lastIndexOf(46));
    }

    public Configuration getConfiguration() {
        return this.configuration;
    }

    static {
        ClassLoader.registerAsParallelCapable();
        try {
            Field hackfield = MethodHandles.Lookup.class.getDeclaredField("IMPL_LOOKUP");
            hackfield.setAccessible(true);
            MethodHandles.Lookup hack = (MethodHandles.Lookup)hackfield.get(null);
            LAYER_BIND_TO_LOADER = hack.findSpecial(ModuleLayer.class, "bindToLoader", MethodType.methodType(Void.TYPE, ClassLoader.class), ModuleLayer.class);
        }
        catch (IllegalAccessException | NoSuchFieldException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }

    private static final class ModuleInfo
    implements AutoCloseable {
        private final String name;
        private final JarModuleFinder.JarModuleReference moduleReference;
        private final ReentrantLock lock = new ReentrantLock();
        private final ProtectionDomain protectionDomain;
        private volatile ModuleReader cachedReader;
        private volatile boolean closed = false;

        ModuleInfo(ClassLoader classLoader, String name, JarModuleFinder.JarModuleReference moduleReference) {
            this.name = name;
            this.moduleReference = moduleReference;
            CodeSource codeSource = new CodeSource(ModuleClassLoader.toURL(moduleReference.location()), (CodeSigner[])null);
            Permissions perms = new Permissions();
            perms.add(new AllPermission());
            this.protectionDomain = new ProtectionDomain(codeSource, perms, classLoader, null);
        }

        ModuleReader getReader() throws IOException {
            if (this.closed) {
                throw new IOException("Module " + this.name + " has been closed");
            }
            ModuleReader reader = this.cachedReader;
            if (reader != null) {
                return reader;
            }
            this.lock.lock();
            try {
                if (this.closed) {
                    throw new IOException("Module " + this.name + " has been closed");
                }
                reader = this.cachedReader;
                if (reader == null) {
                    this.cachedReader = reader = this.moduleReference.open();
                }
                ModuleReader moduleReader = reader;
                return moduleReader;
            }
            finally {
                this.lock.unlock();
            }
        }

        @Override
        public void close() throws IOException {
            this.lock.lock();
            try {
                if (!this.closed) {
                    this.closed = true;
                    if (this.cachedReader != null) {
                        this.cachedReader.close();
                        this.cachedReader = null;
                    }
                }
            }
            finally {
                this.lock.unlock();
            }
        }
    }
}

