diff --git a/src/main/kotlin/net/revanced/patcher/Patcher.kt b/src/main/kotlin/net/revanced/patcher/Patcher.kt index 6c0d3b3..a52e263 100644 --- a/src/main/kotlin/net/revanced/patcher/Patcher.kt +++ b/src/main/kotlin/net/revanced/patcher/Patcher.kt @@ -7,7 +7,6 @@ import net.revanced.patcher.signature.Signature import net.revanced.patcher.util.Jar2ASM import java.io.InputStream import java.io.OutputStream -import java.util.jar.JarOutputStream /** * The patcher. (docs WIP) @@ -20,12 +19,12 @@ class Patcher( private val input: InputStream, signatures: Array, ) { - val cache = Cache() + var cache: Cache private val patches: MutableList = mutableListOf() init { - cache.classes.putAll(Jar2ASM.jar2asm(input)) - cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve()) + val classes = Jar2ASM.jar2asm(input); + cache = Cache(classes, MethodResolver(classes, signatures).resolve()) } fun addPatches(vararg patches: Patch) { diff --git a/src/main/kotlin/net/revanced/patcher/cache/Cache.kt b/src/main/kotlin/net/revanced/patcher/cache/Cache.kt index fbff7f5..3b995ea 100644 --- a/src/main/kotlin/net/revanced/patcher/cache/Cache.kt +++ b/src/main/kotlin/net/revanced/patcher/cache/Cache.kt @@ -2,14 +2,14 @@ package net.revanced.patcher.cache import org.objectweb.asm.tree.ClassNode -class Cache { - val classes: MutableMap = mutableMapOf() - val methods: MethodMap = MethodMap() -} +class Cache ( + val classes: List, + val methods: MethodMap +) class MethodMap : LinkedHashMap() { override fun get(key: String): PatchData { - return super.get(key) ?: throw MethodNotFoundException("Method $key not found in method cache") + return super.get(key) ?: throw MethodNotFoundException("Method $key was not found in the method cache") } } diff --git a/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt b/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt index 5cc2c16..93e94c1 100644 --- a/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt +++ b/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt @@ -4,12 +4,12 @@ import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.MethodNode data class PatchData( - val cls: ClassNode, + val declaringClass: ClassNode, val method: MethodNode, - val sd: ScanData + val scanData: PatternScanData ) -data class ScanData( +data class PatternScanData( val startIndex: Int, val endIndex: Int ) diff --git a/src/main/kotlin/net/revanced/patcher/resolver/MethodResolver.kt b/src/main/kotlin/net/revanced/patcher/resolver/MethodResolver.kt index aeaad48..6533da1 100644 --- a/src/main/kotlin/net/revanced/patcher/resolver/MethodResolver.kt +++ b/src/main/kotlin/net/revanced/patcher/resolver/MethodResolver.kt @@ -1,8 +1,9 @@ package net.revanced.patcher.resolver import mu.KotlinLogging +import net.revanced.patcher.cache.MethodMap import net.revanced.patcher.cache.PatchData -import net.revanced.patcher.cache.ScanData +import net.revanced.patcher.cache.PatternScanData import net.revanced.patcher.signature.Signature import net.revanced.patcher.util.ExtraTypes import org.objectweb.asm.Type @@ -13,13 +14,13 @@ import org.objectweb.asm.tree.MethodNode private val logger = KotlinLogging.logger("MethodResolver") internal class MethodResolver(private val classList: List, private val signatures: Array) { - fun resolve(): MutableMap { - val patchData = mutableMapOf() + fun resolve(): MethodMap { + val methodMap = MethodMap() for ((classNode, methods) in classList) { for (method in methods) { for (signature in signatures) { - if (patchData.containsKey(signature.name)) { // method already found for this sig + if (methodMap.containsKey(signature.name)) { // method already found for this sig logger.debug { "Sig ${signature.name} already found, skipping." } continue } @@ -30,10 +31,10 @@ internal class MethodResolver(private val classList: List, private va continue } logger.debug { "Method for sig ${signature.name} found!" } - patchData[signature.name] = PatchData( + methodMap[signature.name] = PatchData( classNode, method, - ScanData( + PatternScanData( // sadly we cannot create contracts for a data class, so we must assert sr.startIndex!!, sr.endIndex!! @@ -44,11 +45,11 @@ internal class MethodResolver(private val classList: List, private va } for (signature in signatures) { - if (patchData.containsKey(signature.name)) continue + if (methodMap.containsKey(signature.name)) continue logger.error { "Could not find method for sig ${signature.name}!" } } - return patchData + return methodMap } private fun cmp(method: MethodNode, signature: Signature): Pair { diff --git a/src/main/kotlin/net/revanced/patcher/util/Jar2ASM.kt b/src/main/kotlin/net/revanced/patcher/util/Jar2ASM.kt index d9b59ee..271d9ad 100644 --- a/src/main/kotlin/net/revanced/patcher/util/Jar2ASM.kt +++ b/src/main/kotlin/net/revanced/patcher/util/Jar2ASM.kt @@ -10,21 +10,20 @@ import java.util.jar.JarInputStream import java.util.jar.JarOutputStream object Jar2ASM { - fun jar2asm(input: InputStream): Map { - return buildMap { - val jar = JarInputStream(input) + fun jar2asm(input: InputStream) = mutableListOf().apply { + val jar = JarInputStream(input) while (true) { val e = jar.nextJarEntry ?: break if (e.name.endsWith(".class")) { val classNode = ClassNode() ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) - this[e.name] = classNode + this.add(classNode) } jar.closeEntry() } - } } - fun asm2jar(input: InputStream, output: OutputStream, structure: Map) { + + fun asm2jar(input: InputStream, output: OutputStream, classes: List) { val jis = JarInputStream(input) val jos = JarOutputStream(output) @@ -33,10 +32,13 @@ object Jar2ASM { val next = jis.nextJarEntry ?: break val e = JarEntry(next) // clone it, to not modify the input (if possible) jos.putNextEntry(e) - if (structure.containsKey(e.name)) { + + val clazz = classes.singleOrNull { + clazz -> clazz.name == e.name + }; + if (clazz != null) { val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES) - val cn = structure[e.name]!! - cn.accept(cw) + clazz.accept(cw) jos.write(cw.toByteArray()) } else { jos.write(jis.readAllBytes()) diff --git a/src/test/kotlin/net/revanced/patcher/PatcherTest.kt b/src/test/kotlin/net/revanced/patcher/PatcherTest.kt index 4428962..d199f48 100644 --- a/src/test/kotlin/net/revanced/patcher/PatcherTest.kt +++ b/src/test/kotlin/net/revanced/patcher/PatcherTest.kt @@ -7,11 +7,8 @@ import net.revanced.patcher.util.ExtraTypes import net.revanced.patcher.writer.ASMWriter.setAt import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type -import org.objectweb.asm.tree.LdcInsnNode -import java.io.ByteArrayOutputStream +import org.objectweb.asm.tree.* import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertTrue internal class PatcherTest { private val testSigs: Array = arrayOf( @@ -46,14 +43,24 @@ internal class PatcherTest { patcher.addPatches( Patch ("TestPatch") { // Get the method from the resolver cache - val main = patcher.cache.methods["mainMethod"] + val mainMethod = patcher.cache.methods["mainMethod"] // Get the instruction list - val insn = main.method.instructions!! + val instructions = mainMethod.method.instructions!! // Let's modify it, so it prints "Hello, ReVanced!" - // Get the start index of our signature + // Get the start index of our opcode pattern // This will be the index of the LDC instruction - val startIndex = main.sd.startIndex - insn.setAt(startIndex, LdcInsnNode("Hello, ReVanced!")) + val startIndex = mainMethod.scanData.startIndex + // Create a new Ldc node and replace the LDC instruction + val stringNode = LdcInsnNode("Hello, ReVanced!"); + instructions.setAt(startIndex, stringNode) + // Now lets print our string to the console output + // First create a list of instructions + val printCode = InsnList(); + printCode.add(LdcInsnNode("Hello, ReVanced!")) + printCode.add(MethodInsnNode(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V")) + // Add the list after the second instruction by our pattern + instructions.insert(instructions[startIndex + 1], printCode) + // Finally, tell the patcher that this patch was a success. // You can also return PatchResultError with a message. // If an exception is thrown inside this function, @@ -62,7 +69,9 @@ internal class PatcherTest { } ) + // Apply all patches loaded in the patcher val result = patcher.applyPatches() + // You can check if an error occurred for ((s, r) in result) { if (r.isFailure) { throw Exception("Patch $s failed", r.exceptionOrNull()!!) @@ -70,30 +79,30 @@ internal class PatcherTest { } // TODO Doesn't work, needs to be fixed. -// val out = ByteArrayOutputStream() -// patcher.saveTo(out) -// assertTrue( -// // 8 is a random value, it's just weird if it's any lower than that -// out.size() > 8, -// "Output must be at least 8 bytes" -// ) -// -// out.close() + //val out = ByteArrayOutputStream() + //patcher.saveTo(out) + //assertTrue( + // // 8 is a random value, it's just weird if it's any lower than that + // out.size() > 8, + // "Output must be at least 8 bytes" + //) + // + //out.close() testData.close() } // TODO Doesn't work, needs to be fixed. -// @Test -// fun noChanges() { -// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! -// val available = testData.available() -// val patcher = Patcher(testData, testSigs) -// -// val out = ByteArrayOutputStream() -// patcher.saveTo(out) -// assertEquals(available, out.size()) -// -// out.close() -// testData.close() -// } + //@Test + //fun noChanges() { + // val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! + // val available = testData.available() + // val patcher = Patcher(testData, testSigs) + // + // val out = ByteArrayOutputStream() + // patcher.saveTo(out) + // assertEquals(available, out.size()) + // + // out.close() + // testData.close() + //} } \ No newline at end of file