Improve example test testPatcher and increase caching speed

This commit is contained in:
oSumAtrIX 2022-03-20 03:06:23 +01:00
parent 81e0220d15
commit 5d146c362f
No known key found for this signature in database
GPG key ID: A9B3094ACDB604B4
6 changed files with 71 additions and 60 deletions

View file

@ -7,7 +7,6 @@ import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Jar2ASM
import java.io.InputStream
import java.io.OutputStream
import java.util.jar.JarOutputStream
/**
* The patcher. (docs WIP)
@ -20,12 +19,12 @@ class Patcher(
private val input: InputStream,
signatures: Array<Signature>,
) {
val cache = Cache()
var cache: Cache
private val patches: MutableList<Patch> = mutableListOf()
init {
cache.classes.putAll(Jar2ASM.jar2asm(input))
cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve())
val classes = Jar2ASM.jar2asm(input);
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
}
fun addPatches(vararg patches: Patch) {

View file

@ -2,14 +2,14 @@ package net.revanced.patcher.cache
import org.objectweb.asm.tree.ClassNode
class Cache {
val classes: MutableMap<String, ClassNode> = mutableMapOf()
val methods: MethodMap = MethodMap()
}
class Cache (
val classes: List<ClassNode>,
val methods: MethodMap
)
class MethodMap : LinkedHashMap<String, PatchData>() {
override fun get(key: String): PatchData {
return super.get(key) ?: throw MethodNotFoundException("Method $key not found in method cache")
return super.get(key) ?: throw MethodNotFoundException("Method $key was not found in the method cache")
}
}

View file

@ -4,12 +4,12 @@ import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.MethodNode
data class PatchData(
val cls: ClassNode,
val declaringClass: ClassNode,
val method: MethodNode,
val sd: ScanData
val scanData: PatternScanData
)
data class ScanData(
data class PatternScanData(
val startIndex: Int,
val endIndex: Int
)

View file

@ -1,8 +1,9 @@
package net.revanced.patcher.resolver
import mu.KotlinLogging
import net.revanced.patcher.cache.MethodMap
import net.revanced.patcher.cache.PatchData
import net.revanced.patcher.cache.ScanData
import net.revanced.patcher.cache.PatternScanData
import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.ExtraTypes
import org.objectweb.asm.Type
@ -13,13 +14,13 @@ import org.objectweb.asm.tree.MethodNode
private val logger = KotlinLogging.logger("MethodResolver")
internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) {
fun resolve(): MutableMap<String, PatchData> {
val patchData = mutableMapOf<String, PatchData>()
fun resolve(): MethodMap {
val methodMap = MethodMap()
for ((classNode, methods) in classList) {
for (method in methods) {
for (signature in signatures) {
if (patchData.containsKey(signature.name)) { // method already found for this sig
if (methodMap.containsKey(signature.name)) { // method already found for this sig
logger.debug { "Sig ${signature.name} already found, skipping." }
continue
}
@ -30,10 +31,10 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
continue
}
logger.debug { "Method for sig ${signature.name} found!" }
patchData[signature.name] = PatchData(
methodMap[signature.name] = PatchData(
classNode,
method,
ScanData(
PatternScanData(
// sadly we cannot create contracts for a data class, so we must assert
sr.startIndex!!,
sr.endIndex!!
@ -44,11 +45,11 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
}
for (signature in signatures) {
if (patchData.containsKey(signature.name)) continue
if (methodMap.containsKey(signature.name)) continue
logger.error { "Could not find method for sig ${signature.name}!" }
}
return patchData
return methodMap
}
private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {

View file

@ -10,21 +10,20 @@ import java.util.jar.JarInputStream
import java.util.jar.JarOutputStream
object Jar2ASM {
fun jar2asm(input: InputStream): Map<String, ClassNode> {
return buildMap {
val jar = JarInputStream(input)
fun jar2asm(input: InputStream) = mutableListOf<ClassNode>().apply {
val jar = JarInputStream(input)
while (true) {
val e = jar.nextJarEntry ?: break
if (e.name.endsWith(".class")) {
val classNode = ClassNode()
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
this[e.name] = classNode
this.add(classNode)
}
jar.closeEntry()
}
}
}
fun asm2jar(input: InputStream, output: OutputStream, structure: Map<String, ClassNode>) {
fun asm2jar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
val jis = JarInputStream(input)
val jos = JarOutputStream(output)
@ -33,10 +32,13 @@ object Jar2ASM {
val next = jis.nextJarEntry ?: break
val e = JarEntry(next) // clone it, to not modify the input (if possible)
jos.putNextEntry(e)
if (structure.containsKey(e.name)) {
val clazz = classes.singleOrNull {
clazz -> clazz.name == e.name
};
if (clazz != null) {
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
val cn = structure[e.name]!!
cn.accept(cw)
clazz.accept(cw)
jos.write(cw.toByteArray())
} else {
jos.write(jis.readAllBytes())

View file

@ -7,11 +7,8 @@ import net.revanced.patcher.util.ExtraTypes
import net.revanced.patcher.writer.ASMWriter.setAt
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.tree.LdcInsnNode
import java.io.ByteArrayOutputStream
import org.objectweb.asm.tree.*
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class PatcherTest {
private val testSigs: Array<Signature> = arrayOf(
@ -46,14 +43,24 @@ internal class PatcherTest {
patcher.addPatches(
Patch ("TestPatch") {
// Get the method from the resolver cache
val main = patcher.cache.methods["mainMethod"]
val mainMethod = patcher.cache.methods["mainMethod"]
// Get the instruction list
val insn = main.method.instructions!!
val instructions = mainMethod.method.instructions!!
// Let's modify it, so it prints "Hello, ReVanced!"
// Get the start index of our signature
// Get the start index of our opcode pattern
// This will be the index of the LDC instruction
val startIndex = main.sd.startIndex
insn.setAt(startIndex, LdcInsnNode("Hello, ReVanced!"))
val startIndex = mainMethod.scanData.startIndex
// Create a new Ldc node and replace the LDC instruction
val stringNode = LdcInsnNode("Hello, ReVanced!");
instructions.setAt(startIndex, stringNode)
// Now lets print our string to the console output
// First create a list of instructions
val printCode = InsnList();
printCode.add(LdcInsnNode("Hello, ReVanced!"))
printCode.add(MethodInsnNode(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V"))
// Add the list after the second instruction by our pattern
instructions.insert(instructions[startIndex + 1], printCode)
// Finally, tell the patcher that this patch was a success.
// You can also return PatchResultError with a message.
// If an exception is thrown inside this function,
@ -62,7 +69,9 @@ internal class PatcherTest {
}
)
// Apply all patches loaded in the patcher
val result = patcher.applyPatches()
// You can check if an error occurred
for ((s, r) in result) {
if (r.isFailure) {
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
@ -70,30 +79,30 @@ internal class PatcherTest {
}
// TODO Doesn't work, needs to be fixed.
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
// )
//
// out.close()
//val out = ByteArrayOutputStream()
//patcher.saveTo(out)
//assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
//)
//
//out.close()
testData.close()
}
// TODO Doesn't work, needs to be fixed.
// @Test
// fun noChanges() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
// val patcher = Patcher(testData, testSigs)
//
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertEquals(available, out.size())
//
// out.close()
// testData.close()
// }
//@Test
//fun noChanges() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
// val patcher = Patcher(testData, testSigs)
//
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertEquals(available, out.size())
//
// out.close()
// testData.close()
//}
}