001package cpw.mods.fml.common.asm.transformers;
002
003import java.io.BufferedOutputStream;
004import java.io.BufferedReader;
005import java.io.ByteArrayOutputStream;
006import java.io.DataInputStream;
007import java.io.File;
008import java.io.FileInputStream;
009import java.io.FileNotFoundException;
010import java.io.FileOutputStream;
011import java.io.IOException;
012import java.io.InputStream;
013import java.io.InputStreamReader;
014import java.util.ArrayList;
015import java.util.Collections;
016import java.util.Enumeration;
017import java.util.HashSet;
018import java.util.Hashtable;
019import java.util.LinkedHashSet;
020import java.util.List;
021import java.util.Map.Entry;
022import java.util.zip.ZipEntry;
023import java.util.zip.ZipFile;
024import java.util.zip.ZipOutputStream;
025
026import org.objectweb.asm.ClassReader;
027import org.objectweb.asm.ClassWriter;
028import org.objectweb.asm.Type;
029import org.objectweb.asm.tree.AnnotationNode;
030import org.objectweb.asm.tree.ClassNode;
031import org.objectweb.asm.tree.FieldNode;
032import org.objectweb.asm.tree.MethodNode;
033
034import com.google.common.base.Objects;
035import com.google.common.collect.Lists;
036import com.google.common.collect.Sets;
037
038import cpw.mods.fml.relauncher.Side;
039import cpw.mods.fml.relauncher.SideOnly;
040
041public class MCPMerger
042{
043    private static Hashtable<String, ClassInfo> clients = new Hashtable<String, ClassInfo>();
044    private static Hashtable<String, ClassInfo> shared  = new Hashtable<String, ClassInfo>();
045    private static Hashtable<String, ClassInfo> servers = new Hashtable<String, ClassInfo>();
046    private static HashSet<String> copyToServer = new HashSet<String>();
047    private static HashSet<String> copyToClient = new HashSet<String>();
048    private static HashSet<String> dontAnnotate = new HashSet<String>();
049    private static final boolean DEBUG = false;
050
051    public static void main(String[] args)
052    {
053        if (args.length != 3)
054        {
055            System.out.println("Usage: MCPMerger <MapFile> <minecraft.jar> <minecraft_server.jar>");
056            System.exit(1);
057        }
058
059        File map_file = new File(args[0]);
060        File client_jar = new File(args[1]);
061        File server_jar = new File(args[2]);
062        File client_jar_tmp = new File(args[1] + ".MergeBack");
063        File server_jar_tmp = new File(args[2] + ".MergeBack");
064
065
066        if (client_jar_tmp.exists() && !client_jar_tmp.delete())
067        {
068            System.out.println("Could not delete temp file: " + client_jar_tmp);
069        }
070
071        if (server_jar_tmp.exists() && !server_jar_tmp.delete())
072        {
073            System.out.println("Could not delete temp file: " + server_jar_tmp);
074        }
075
076        if (!client_jar.exists())
077        {
078            System.out.println("Could not find minecraft.jar: " + client_jar);
079            System.exit(1);
080        }
081
082        if (!server_jar.exists())
083        {
084            System.out.println("Could not find minecraft_server.jar: " + server_jar);
085            System.exit(1);
086        }
087
088        if (!client_jar.renameTo(client_jar_tmp))
089        {
090            System.out.println("Could not rename file: " + client_jar + " -> " + client_jar_tmp);
091            System.exit(1);
092        }
093
094        if (!server_jar.renameTo(server_jar_tmp))
095        {
096            System.out.println("Could not rename file: " + server_jar + " -> " + server_jar_tmp);
097            System.exit(1);
098        }
099
100        if (!readMapFile(map_file))
101        {
102            System.out.println("Could not read map file: " + map_file);
103            System.exit(1);
104        }
105
106        try
107        {
108            processJar(client_jar_tmp, server_jar_tmp, client_jar, server_jar);
109        }
110        catch (IOException e)
111        {
112            e.printStackTrace();
113            System.exit(1);
114        }
115
116        if (!client_jar_tmp.delete())
117        {
118            System.out.println("Could not delete temp file: " + client_jar_tmp);
119        }
120
121        if (!server_jar_tmp.delete())
122        {
123            System.out.println("Could not delete temp file: " + server_jar_tmp);
124        }
125    }
126
127    private static boolean readMapFile(File mapFile)
128    {
129        try
130        {
131            FileInputStream fstream = new FileInputStream(mapFile);
132            DataInputStream in = new DataInputStream(fstream);
133            BufferedReader br = new BufferedReader(new InputStreamReader(in));
134
135            String line;
136            while ((line = br.readLine()) != null)
137            {
138                line = line.split("#")[0];
139                char cmd = line.charAt(0);
140                line = line.substring(1).trim();
141                
142                switch (cmd)
143                {
144                    case '!': dontAnnotate.add(line); break;
145                    case '<': copyToClient.add(line); break;
146                    case '>': copyToServer.add(line); break; 
147                }
148            }
149
150            in.close();
151            return true;
152        }
153        catch (Exception e)
154        {
155            System.err.println("Error: " + e.getMessage());
156            return false;
157        }
158    }
159
160    public static void processJar(File clientInFile, File serverInFile, File clientOutFile, File serverOutFile) throws IOException
161    {
162        ZipFile cInJar = null;
163        ZipFile sInJar = null;
164        ZipOutputStream cOutJar = null;
165        ZipOutputStream sOutJar = null;
166
167        try
168        {
169            try
170            {
171                cInJar = new ZipFile(clientInFile);
172                sInJar = new ZipFile(serverInFile);
173            }
174            catch (FileNotFoundException e)
175            {
176                throw new FileNotFoundException("Could not open input file: " + e.getMessage());
177            }
178            try
179            {
180                cOutJar = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(clientOutFile)));
181                sOutJar = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(serverOutFile)));
182            }
183            catch (FileNotFoundException e)
184            {
185                throw new FileNotFoundException("Could not open output file: " + e.getMessage());
186            }
187            Hashtable<String, ZipEntry> cClasses = getClassEntries(cInJar, cOutJar);
188            Hashtable<String, ZipEntry> sClasses = getClassEntries(sInJar, sOutJar);
189            HashSet<String> cAdded = new HashSet<String>();
190            HashSet<String> sAdded = new HashSet<String>();
191
192            for (Entry<String, ZipEntry> entry : cClasses.entrySet())
193            {
194                String name = entry.getKey();
195                ZipEntry cEntry = entry.getValue();
196                ZipEntry sEntry = sClasses.get(name);
197
198                if (sEntry == null)
199                {
200                    if (!copyToServer.contains(name))
201                    {
202                        copyClass(cInJar, cEntry, cOutJar, null, true);
203                        cAdded.add(name);
204                    }
205                    else
206                    {
207                        if (DEBUG)
208                        {
209                            System.out.println("Copy class c->s : " + name);
210                        }
211                        copyClass(cInJar, cEntry, cOutJar, sOutJar, true);
212                        cAdded.add(name);
213                        sAdded.add(name);
214                    }
215                    continue;
216                }
217
218                sClasses.remove(name);
219                ClassInfo info = new ClassInfo(name);
220                shared.put(name, info);
221
222                byte[] cData = readEntry(cInJar, entry.getValue());
223                byte[] sData = readEntry(sInJar, sEntry);
224                byte[] data = processClass(cData, sData, info);
225
226                ZipEntry newEntry = new ZipEntry(cEntry.getName());
227                cOutJar.putNextEntry(newEntry);
228                cOutJar.write(data);
229                sOutJar.putNextEntry(newEntry);
230                sOutJar.write(data);
231                cAdded.add(name);
232                sAdded.add(name);
233            }
234
235            for (Entry<String, ZipEntry> entry : sClasses.entrySet())
236            {
237                if (DEBUG)
238                {
239                    System.out.println("Copy class s->c : " + entry.getKey());
240                }
241                copyClass(sInJar, entry.getValue(), cOutJar, sOutJar, false);
242            }
243
244            for (String name : new String[]{SideOnly.class.getName(), Side.class.getName()})
245            {
246                String eName = name.replace(".", "/");
247                byte[] data = getClassBytes(name);
248                ZipEntry newEntry = new ZipEntry(name.replace(".", "/").concat(".class"));
249                if (!cAdded.contains(eName))
250                {
251                    cOutJar.putNextEntry(newEntry);
252                    cOutJar.write(data);
253                }
254                if (!sAdded.contains(eName))
255                {
256                    sOutJar.putNextEntry(newEntry);
257                    sOutJar.write(data);
258                }
259            }
260
261        }
262        finally
263        {
264            if (cInJar != null)
265            {
266                try { cInJar.close(); } catch (IOException e){}
267            }
268
269            if (sInJar != null)
270            {
271                try { sInJar.close(); } catch (IOException e) {}
272            }
273            if (cOutJar != null)
274            {
275                try { cOutJar.close(); } catch (IOException e){}
276            }
277
278            if (sOutJar != null)
279            {
280                try { sOutJar.close(); } catch (IOException e) {}
281            }
282        }
283    }
284
285    private static void copyClass(ZipFile inJar, ZipEntry entry, ZipOutputStream outJar, ZipOutputStream outJar2, boolean isClientOnly) throws IOException
286    {
287        ClassReader reader = new ClassReader(readEntry(inJar, entry));
288        ClassNode classNode = new ClassNode();
289
290        reader.accept(classNode, 0);
291
292        if (!dontAnnotate.contains(classNode.name))
293        {
294            if (classNode.visibleAnnotations == null) classNode.visibleAnnotations = new ArrayList<AnnotationNode>();
295            classNode.visibleAnnotations.add(getSideAnn(isClientOnly));
296        }
297
298        ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_MAXS);
299        classNode.accept(writer);
300        byte[] data = writer.toByteArray();
301
302        ZipEntry newEntry = new ZipEntry(entry.getName());
303        if (outJar != null)
304        {
305            outJar.putNextEntry(newEntry);
306            outJar.write(data);
307        }
308        if (outJar2 != null)
309        {
310            outJar2.putNextEntry(newEntry);
311            outJar2.write(data);
312        }
313    }
314
315    private static AnnotationNode getSideAnn(boolean isClientOnly)
316    {
317        AnnotationNode ann = new AnnotationNode(Type.getDescriptor(SideOnly.class));
318        ann.values = new ArrayList<Object>();
319        ann.values.add("value");
320        ann.values.add(new String[]{ Type.getDescriptor(Side.class), (isClientOnly ? "CLIENT" : "SERVER")});
321        return ann;
322    }
323
324    @SuppressWarnings("unchecked")
325    private static Hashtable<String, ZipEntry> getClassEntries(ZipFile inFile, ZipOutputStream outFile) throws IOException
326    {
327        Hashtable<String, ZipEntry> ret = new Hashtable<String, ZipEntry>();
328        for (ZipEntry entry : Collections.list((Enumeration<ZipEntry>)inFile.entries()))
329        {
330            if (entry.isDirectory())
331            {
332                outFile.putNextEntry(entry);
333                continue;
334            }
335            String entryName = entry.getName();
336            if (!entryName.endsWith(".class") || entryName.startsWith("."))
337            {
338                ZipEntry newEntry = new ZipEntry(entry.getName());
339                outFile.putNextEntry(newEntry);
340                outFile.write(readEntry(inFile, entry));
341            }
342            else
343            {
344                ret.put(entryName.replace(".class", ""), entry);
345            }
346        }
347        return ret;
348    }
349    private static byte[] readEntry(ZipFile inFile, ZipEntry entry) throws IOException
350    {
351        return readFully(inFile.getInputStream(entry));
352    }
353    private static byte[] readFully(InputStream stream) throws IOException
354    {
355        byte[] data = new byte[4096];
356        ByteArrayOutputStream entryBuffer = new ByteArrayOutputStream();
357        int len;
358        do
359        {
360            len = stream.read(data);
361            if (len > 0)
362            {
363                entryBuffer.write(data, 0, len);
364            }
365        } while (len != -1);
366
367        return entryBuffer.toByteArray();
368    }
369    private static class ClassInfo
370    {
371        public String name;
372        public ArrayList<FieldNode> cField = new ArrayList<FieldNode>();
373        public ArrayList<FieldNode> sField = new ArrayList<FieldNode>();
374        public ArrayList<MethodNode> cMethods = new ArrayList<MethodNode>();
375        public ArrayList<MethodNode> sMethods = new ArrayList<MethodNode>();
376        public ClassInfo(String name){ this.name = name; }
377        public boolean isSame() { return (cField.size() == 0 && sField.size() == 0 && cMethods.size() == 0 && sMethods.size() == 0); }
378    }
379
380    public static byte[] processClass(byte[] cIn, byte[] sIn, ClassInfo info)
381    {
382        ClassNode cClassNode = getClassNode(cIn);
383        ClassNode sClassNode = getClassNode(sIn);
384
385        processFields(cClassNode, sClassNode, info);
386        processMethods(cClassNode, sClassNode, info);
387
388        ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_MAXS);
389        cClassNode.accept(writer);
390        return writer.toByteArray();
391    }
392
393    private static ClassNode getClassNode(byte[] data)
394    {
395        ClassReader reader = new ClassReader(data);
396        ClassNode classNode = new ClassNode();
397        reader.accept(classNode, 0);
398        return classNode;
399    }
400
401    @SuppressWarnings("unchecked")
402    private static void processFields(ClassNode cClass, ClassNode sClass, ClassInfo info)
403    {
404        List<FieldNode> cFields = cClass.fields;
405        List<FieldNode> sFields = sClass.fields;
406
407        int sI = 0;
408        for (int x = 0; x < cFields.size(); x++)
409        {
410            FieldNode cF = cFields.get(x);
411            if (sI < sFields.size())
412            {
413                if (!cF.name.equals(sFields.get(sI).name))
414                {
415                    boolean serverHas = false;
416                    for (int y = sI + 1; y < sFields.size(); y++)
417                    {
418                        if (cF.name.equals(sFields.get(y).name))
419                        {
420                            serverHas = true;
421                            break;
422                        }
423                    }
424                    if (serverHas)
425                    {
426                        boolean clientHas = false;
427                        FieldNode sF = sFields.get(sI);
428                        for (int y = x + 1; y < cFields.size(); y++)
429                        {
430                            if (sF.name.equals(cFields.get(y).name))
431                            {
432                                clientHas = true;
433                                break;
434                            }
435                        }
436                        if (!clientHas)
437                        {
438                            if  (sF.visibleAnnotations == null) sF.visibleAnnotations = new ArrayList<AnnotationNode>();
439                            sF.visibleAnnotations.add(getSideAnn(false));
440                            cFields.add(x++, sF);
441                            info.sField.add(sF);
442                        }
443                    }
444                    else
445                    {
446                        if  (cF.visibleAnnotations == null) cF.visibleAnnotations = new ArrayList<AnnotationNode>();
447                        cF.visibleAnnotations.add(getSideAnn(true));
448                        sFields.add(sI, cF);
449                        info.cField.add(cF);
450                    }
451                }
452            }
453            else
454            {
455                if  (cF.visibleAnnotations == null) cF.visibleAnnotations = new ArrayList<AnnotationNode>();
456                cF.visibleAnnotations.add(getSideAnn(true));
457                sFields.add(sI, cF);
458                info.cField.add(cF);
459            }
460            sI++;
461        }
462        if (sFields.size() != cFields.size())
463        {
464            for (int x = cFields.size(); x < sFields.size(); x++)
465            {
466                FieldNode sF = sFields.get(x);
467                if  (sF.visibleAnnotations == null) sF.visibleAnnotations = new ArrayList<AnnotationNode>();
468                sF.visibleAnnotations.add(getSideAnn(true));
469                cFields.add(x++, sF);
470                info.sField.add(sF);
471            }
472        }
473    }
474
475    private static class MethodWrapper
476    {
477        private MethodNode node;
478        public boolean client;
479        public boolean server;
480        public MethodWrapper(MethodNode node)
481        {
482            this.node = node;
483        }
484        @Override
485        public boolean equals(Object obj)
486        {
487            if (obj == null || !(obj instanceof MethodWrapper)) return false;
488            MethodWrapper mw = (MethodWrapper) obj;
489            boolean eq = Objects.equal(node.name, mw.node.name) && Objects.equal(node.desc, mw.node.desc);
490            if (eq)
491            {
492                mw.client = this.client | mw.client;
493                mw.server = this.server | mw.server;
494                this.client = this.client | mw.client;
495                this.server = this.server | mw.server;
496                if (DEBUG)
497                {
498                    System.out.printf(" eq: %s %s\n", this, mw);
499                }
500            }
501            return eq;
502        }
503
504        @Override
505        public int hashCode()
506        {
507            return Objects.hashCode(node.name, node.desc);
508        }
509        @Override
510        public String toString()
511        {
512            return Objects.toStringHelper(this).add("name", node.name).add("desc",node.desc).add("server",server).add("client",client).toString();
513        }
514    }
515    @SuppressWarnings("unchecked")
516    private static void processMethods(ClassNode cClass, ClassNode sClass, ClassInfo info)
517    {
518        List<MethodNode> cMethods = (List<MethodNode>)cClass.methods;
519        List<MethodNode> sMethods = (List<MethodNode>)sClass.methods;
520        LinkedHashSet<MethodWrapper> allMethods = Sets.newLinkedHashSet();
521
522        int cPos = 0;
523        int sPos = 0;
524        int cLen = cMethods.size();
525        int sLen = sMethods.size();
526        String clientName = "";
527        String lastName = clientName;
528        String serverName = "";
529        while (cPos < cLen || sPos < sLen)
530        {
531            do
532            {
533                if (sPos>=sLen)
534                {
535                    break;
536                }
537                MethodNode sM = sMethods.get(sPos);
538                serverName = sM.name;
539                if (!serverName.equals(lastName) && cPos != cLen)
540                {
541                    if (DEBUG)
542                    {
543                        System.out.printf("Server -skip : %s %s %d (%s %d) %d [%s]\n", sClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
544                    }
545                    break;
546                }
547                MethodWrapper mw = new MethodWrapper(sM);
548                mw.server = true;
549                allMethods.add(mw);
550                if (DEBUG)
551                {
552                    System.out.printf("Server *add* : %s %s %d (%s %d) %d [%s]\n", sClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
553                }
554                sPos++;
555            }
556            while (sPos < sLen);
557            do
558            {
559                if (cPos>=cLen)
560                {
561                    break;
562                }
563                MethodNode cM = cMethods.get(cPos);
564                lastName = clientName;
565                clientName = cM.name;
566                if (!clientName.equals(lastName) && sPos != sLen)
567                {
568                    if (DEBUG)
569                    {
570                        System.out.printf("Client -skip : %s %s %d (%s %d) %d [%s]\n", cClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
571                    }
572                    break;
573                }
574                MethodWrapper mw = new MethodWrapper(cM);
575                mw.client = true;
576                allMethods.add(mw);
577                if (DEBUG)
578                {
579                    System.out.printf("Client *add* : %s %s %d (%s %d) %d [%s]\n", cClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
580                }
581                cPos++;
582            }
583            while (cPos < cLen);
584        }
585
586        cMethods.clear();
587        sMethods.clear();
588
589        for (MethodWrapper mw : allMethods)
590        {
591            if (DEBUG)
592            {
593                System.out.println(mw);
594            }
595            cMethods.add(mw.node);
596            sMethods.add(mw.node);
597            if (mw.server && mw.client)
598            {
599                // no op
600            }
601            else
602            {
603                if (mw.node.visibleAnnotations == null) mw.node.visibleAnnotations = Lists.newArrayListWithExpectedSize(1);
604                mw.node.visibleAnnotations.add(getSideAnn(mw.client));
605                if (mw.client)
606                {
607                    info.sMethods.add(mw.node);
608                }
609                else
610                {
611                    info.cMethods.add(mw.node);
612                }
613            }
614        }
615    }
616
617    public static byte[] getClassBytes(String name) throws IOException
618    {
619        InputStream classStream = null;
620        try
621        {
622            classStream = MCPMerger.class.getResourceAsStream("/" + name.replace('.', '/').concat(".class"));
623            return readFully(classStream);
624        }
625        finally
626        {
627            if (classStream != null)
628            {
629                try
630                {
631                    classStream.close();
632                }
633                catch (IOException e){}
634            }
635        }
636    }
637}