001    package cpw.mods.fml.relauncher;
002    
003    import java.io.ByteArrayOutputStream;
004    import java.io.IOException;
005    import java.io.InputStream;
006    import java.net.URL;
007    import java.net.URLClassLoader;
008    import java.util.ArrayList;
009    import java.util.Arrays;
010    import java.util.Collections;
011    import java.util.HashMap;
012    import java.util.HashSet;
013    import java.util.List;
014    import java.util.Map;
015    import java.util.Set;
016    import java.util.logging.Level;
017    
018    public class RelaunchClassLoader extends URLClassLoader
019    {
020        // Left behind for CCC/NEI compatibility
021        private static String[] excludedPackages = new String[0];
022        // Left behind for CCC/NEI compatibility
023        private static String[] transformerExclusions = new String[0];
024    
025        private List<URL> sources;
026        private ClassLoader parent;
027    
028        private List<IClassTransformer> transformers;
029        private Map<String, Class> cachedClasses;
030    
031        private Set<String> classLoaderExceptions = new HashSet<String>();
032        private Set<String> transformerExceptions = new HashSet<String>();
033    
034        public RelaunchClassLoader(URL[] sources)
035        {
036            super(sources, null);
037            this.sources = new ArrayList<URL>(Arrays.asList(sources));
038            this.parent = getClass().getClassLoader();
039            this.cachedClasses = new HashMap<String,Class>(1000);
040            this.transformers = new ArrayList<IClassTransformer>(2);
041    //        ReflectionHelper.setPrivateValue(ClassLoader.class, null, this, "scl");
042            Thread.currentThread().setContextClassLoader(this);
043    
044            // standard classloader exclusions
045            addClassLoaderExclusion("java.");
046            addClassLoaderExclusion("sun.");
047            addClassLoaderExclusion("cpw.mods.fml.relauncher.");
048            addClassLoaderExclusion("net.minecraftforge.classloading.");
049    
050            // standard transformer exclusions
051            addTransformerExclusion("javax.");
052            addTransformerExclusion("cpw.mods.fml.");
053            addTransformerExclusion("org.objectweb.asm.");
054            addTransformerExclusion("com.google.common.");
055        }
056    
057        public void registerTransformer(String transformerClassName)
058        {
059            try
060            {
061                transformers.add((IClassTransformer) loadClass(transformerClassName).newInstance());
062            }
063            catch (Exception e)
064            {
065                FMLRelaunchLog.log(Level.SEVERE, e, "A critical problem occured registering the ASM transformer class %s", transformerClassName);
066            }
067        }
068        @Override
069        public Class<?> findClass(String name) throws ClassNotFoundException
070        {
071            // NEI/CCC compatibility code
072            if (excludedPackages.length != 0)
073            {
074                classLoaderExceptions.addAll(Arrays.asList(excludedPackages));
075                excludedPackages = new String[0];
076            }
077            if (transformerExclusions.length != 0)
078            {
079                transformerExceptions.addAll(Arrays.asList(transformerExclusions));
080                transformerExclusions = new String[0];
081            }
082    
083            for (String st : classLoaderExceptions)
084            {
085                if (name.startsWith(st))
086                {
087                    return parent.loadClass(name);
088                }
089            }
090    
091            if (cachedClasses.containsKey(name))
092            {
093                return cachedClasses.get(name);
094            }
095    
096            for (String st : transformerExceptions)
097            {
098                if (name.startsWith(st))
099                {
100                    Class<?> cl = super.findClass(name);
101                    cachedClasses.put(name, cl);
102                    return cl;
103                }
104            }
105    
106            try
107            {
108                int lastDot = name.lastIndexOf('.');
109                if (lastDot > -1)
110                {
111                    String pkgname = name.substring(0, lastDot);
112                    if (getPackage(pkgname)==null)
113                    {
114                        definePackage(pkgname, null, null, null, null, null, null, null);
115                    }
116                }
117                byte[] basicClass = getClassBytes(name);
118                byte[] transformedClass = runTransformers(name, basicClass);
119                Class<?> cl = defineClass(name, transformedClass, 0, transformedClass.length);
120                cachedClasses.put(name, cl);
121                return cl;
122            }
123            catch (Throwable e)
124            {
125                throw new ClassNotFoundException(name, e);
126            }
127        }
128    
129        /**
130         * @param name
131         * @return
132         * @throws IOException
133         */
134        public byte[] getClassBytes(String name) throws IOException
135        {
136            InputStream classStream = null;
137            try
138            {
139                URL classResource = findResource(name.replace('.', '/').concat(".class"));
140                if (classResource == null)
141                {
142                    return null;
143                }
144                classStream = classResource.openStream();
145                return readFully(classStream);
146            }
147            finally
148            {
149                if (classStream != null)
150                {
151                    try
152                    {
153                        classStream.close();
154                    }
155                    catch (IOException e)
156                    {
157                        // Swallow the close exception
158                    }
159                }
160            }
161        }
162    
163        private byte[] runTransformers(String name, byte[] basicClass)
164        {
165            for (IClassTransformer transformer : transformers)
166            {
167                basicClass = transformer.transform(name, basicClass);
168            }
169            return basicClass;
170        }
171    
172        @Override
173        public void addURL(URL url)
174        {
175            super.addURL(url);
176            sources.add(url);
177        }
178    
179        public List<URL> getSources()
180        {
181            return sources;
182        }
183    
184    
185        private byte[] readFully(InputStream stream)
186        {
187            try
188            {
189                ByteArrayOutputStream bos = new ByteArrayOutputStream(stream.available());
190                int r;
191                while ((r = stream.read()) != -1)
192                {
193                    bos.write(r);
194                }
195    
196                return bos.toByteArray();
197            }
198            catch (Throwable t)
199            {
200                /// HMMM
201                return new byte[0];
202            }
203        }
204    
205        public List<IClassTransformer> getTransformers()
206        {
207            return Collections.unmodifiableList(transformers);
208        }
209    
210        private void addClassLoaderExclusion(String toExclude)
211        {
212            classLoaderExceptions.add(toExclude);
213        }
214    
215        void addTransformerExclusion(String toExclude)
216        {
217            transformerExceptions.add(toExclude);
218        }
219    }