fix(Io): JAR loading and saving (#8)

* refactor: Complete rewrite of `Io`

* style: format code

* style: rewrite todos

* fix: use lateinit instead of nonnull assert for zipEntry

* fix: use lateinit instead of nonnull assert for jarEntry & reuse zipEntry

* docs: add docs to `Patcher`

* test: match output of patcher

* chore: add todo to `Io` for removing non-class files

Co-authored-by: Sculas <contact@sculas.xyz>
This commit is contained in:
oSumAtrIX 2022-03-21 18:48:35 +01:00 committed by she11sh0cked
parent 87bbde5e06
commit 4d98cbc9e8
8 changed files with 133 additions and 77 deletions

View file

@ -5,28 +5,49 @@ import net.revanced.patcher.patch.Patch
import net.revanced.patcher.resolver.MethodResolver
import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Io
import org.objectweb.asm.tree.ClassNode
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
/**
* The patcher. (docs WIP)
* The Patcher class.
* ***It is of utmost importance that the input and output streams are NEVER closed.***
*
* @param input the input stream to read from, must be a JAR
* @param output the output stream to write to
* @param signatures the signatures
* @sample net.revanced.patcher.PatcherTest
* @throws IOException if one of the streams are closed
*/
class Patcher(
private val input: InputStream,
private val output: OutputStream,
signatures: Array<Signature>,
) {
var cache: Cache
private val patches: MutableList<Patch> = mutableListOf()
private var io: Io
private val patches = mutableListOf<Patch>()
init {
val classes = Io.readClassesFromJar(input)
val classes = mutableListOf<ClassNode>()
io = Io(input, output, classes)
io.readFromJar()
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
}
/**
* Saves the output to the output stream.
* Calling this method will close the input and output streams,
* meaning this method should NEVER be called after.
*
* @throws IOException if one of the streams are closed
*/
fun save() {
io.saveAsJar()
}
fun addPatches(vararg patches: Patch) {
this.patches.addAll(patches)
}
@ -46,8 +67,4 @@ class Patcher(
}
}
}
fun saveTo(output: OutputStream) {
Io.writeClassesToJar(input, output, cache.classes)
}
}

View file

@ -2,7 +2,7 @@ package net.revanced.patcher.cache
import org.objectweb.asm.tree.ClassNode
class Cache (
class Cache(
val classes: List<ClassNode>,
val methods: MethodMap
)

View file

@ -10,8 +10,9 @@ data class PatchData(
val method: MethodNode,
val scanData: PatternScanData
) {
@Suppress("Unused") // TODO(Sculas): remove this when we have coverage for this method.
fun findParentMethod(signature: Signature): PatchData? {
return MethodResolver.resolveMethod(declaringClass, signature)
return MethodResolver.resolveMethod(declaringClass, signature)
}
}

View file

@ -3,47 +3,91 @@ package net.revanced.patcher.util
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.tree.ClassNode
import java.io.BufferedInputStream
import java.io.InputStream
import java.io.OutputStream
import java.util.jar.JarEntry
import java.util.jar.JarInputStream
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import java.util.zip.ZipInputStream
import java.util.zip.ZipOutputStream
object Io {
fun readClassesFromJar(input: InputStream) = mutableListOf<ClassNode>().apply {
val jar = JarInputStream(input)
while (true) {
val e = jar.nextJarEntry ?: break
if (e.name.endsWith(".class")) {
val classNode = ClassNode()
ClassReader(jar.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
this.add(classNode)
}
jar.closeEntry()
internal class Io(
private val input: InputStream,
private val output: OutputStream,
private val classes: MutableList<ClassNode>
) {
private val bufferedInputStream = BufferedInputStream(input)
fun readFromJar() {
bufferedInputStream.mark(0)
// create a BufferedInputStream in order to read the input stream again when calling saveAsJar(..)
val jis = JarInputStream(bufferedInputStream)
// read all entries from the input stream
// we use JarEntry because we only read .class files
lateinit var jarEntry: JarEntry
while (jis.nextJarEntry.also { if (it != null) jarEntry = it } != null) {
// if the current entry ends with .class (indicating a java class file), add it to our list of classes to return
if (jarEntry.name.endsWith(".class")) {
// create a new ClassNode
val classNode = ClassNode()
// read the bytes with a ClassReader into the ClassNode
ClassReader(jis.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
// add it to our list
classes.add(classNode)
}
// finally, close the entry
jis.closeEntry()
}
// at last reset the buffered input stream
bufferedInputStream.reset()
}
fun writeClassesToJar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
val jis = JarInputStream(input)
val jos = JarOutputStream(output)
fun saveAsJar() {
val jis = ZipInputStream(bufferedInputStream)
val jos = ZipOutputStream(output)
// TODO: Add support for adding new/custom classes
while (true) {
val next = jis.nextJarEntry ?: break
val e = JarEntry(next) // clone it, to not modify the input (if possible)
jos.putNextEntry(e)
// first write all non .class zip entries from the original input stream to the output stream
// we read it first to close the input stream as fast as possible
// TODO(oSumAtrIX): There is currently no way to remove non .class files.
lateinit var zipEntry: ZipEntry
while (jis.nextEntry.also { if (it != null) zipEntry = it } != null) {
// skip all class files because we added them in the loop above
// TODO(oSumAtrIX): Check for zipEntry.isDirectory
if (zipEntry.name.endsWith(".class")) continue
val clazz = classes.singleOrNull {
clazz -> clazz.name+".class" == e.name // clazz.name is the class name only while e.name is the full filename with extension
};
if (clazz != null) {
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
clazz.accept(cw)
jos.write(cw.toByteArray())
} else {
jos.write(jis.readBytes())
}
// create a new zipEntry and write the contents of the zipEntry to the output stream
jos.putNextEntry(ZipEntry(zipEntry))
jos.write(jis.readBytes())
// close the newly created zipEntry
jos.closeEntry()
}
// finally, close the input stream
jis.close()
bufferedInputStream.close()
input.close()
// now write all the patched classes to the output stream
for (patchedClass in classes) {
// create a new entry of the patched class
jos.putNextEntry(JarEntry(patchedClass.name + ".class"))
// parse the patched class to a byte array and write it to the output stream
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
patchedClass.accept(cw)
jos.write(cw.toByteArray())
// close the newly created jar entry
jos.closeEntry()
}
// finally, close the rest of the streams
jos.close()
output.close()
}
}

View file

@ -7,6 +7,7 @@ object ASMWriter {
fun InsnList.setAt(index: Int, node: AbstractInsnNode) {
this[this.get(index)] = node
}
fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) {
this.insert(this.get(index), nodes.toInsnList())
}

View file

@ -12,13 +12,16 @@ import net.revanced.patcher.writer.ASMWriter.setAt
import org.junit.jupiter.api.assertDoesNotThrow
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
import org.objectweb.asm.tree.FieldInsnNode
import org.objectweb.asm.tree.LdcInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import kotlin.test.Test
internal class PatcherTest {
companion object {
val testSigs: Array<Signature> = arrayOf(
val testSignatures: Array<Signature> = arrayOf(
// Java:
// public static void main(String[] args) {
// System.out.println("Hello, world!");
@ -45,8 +48,11 @@ internal class PatcherTest {
@Test
fun testPatcher() {
val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
val patcher = Patcher(testData, testSigs)
val patcher = Patcher(
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
ByteArrayOutputStream(),
testSignatures
)
patcher.addPatches(
object : Patch("TestPatch") {
@ -74,9 +80,9 @@ internal class PatcherTest {
startIndex + 1,
FieldInsnNode(
GETSTATIC,
Type.getInternalName(System::class.java), // "java/io/System"
Type.getInternalName(System::class.java), // "java/lang/System"
"out",
Type.getInternalName(PrintStream::class.java) // "java.io.PrintStream"
"L" + Type.getInternalName(PrintStream::class.java) // "Ljava/io/PrintStream"
),
LdcInsnNode("Hello, ReVanced! Adding bytecode."),
MethodInsnNode(
@ -111,41 +117,27 @@ internal class PatcherTest {
)
// Apply all patches loaded in the patcher
val result = patcher.applyPatches()
val patchResult = patcher.applyPatches()
// You can check if an error occurred
for ((s, r) in result) {
if (r.isFailure) {
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
for ((patchName, result) in patchResult) {
if (result.isFailure) {
throw Exception("Patch $patchName failed", result.exceptionOrNull()!!)
}
}
// TODO Doesn't work, needs to be fixed.
//val out = ByteArrayOutputStream()
//patcher.saveTo(out)
//assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
//)
//
//out.close()
testData.close()
patcher.save()
}
// TODO Doesn't work, needs to be fixed.
//@Test
//fun `test patcher with no changes`() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
// val patcher = Patcher(testData, testSigs)
//
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertEquals(available, out.size())
//
// out.close()
// testData.close()
//}
@Test
fun `test patcher with no changes`() {
val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
val out = ByteArrayOutputStream()
Patcher(testData, out, testSignatures).save()
// FIXME(Sculas): There seems to be a 1-byte difference, not sure what it is.
// assertEquals(available, out.size())
out.close()
}
@Test()
fun `should not raise an exception if any signature member except the name is missing`() {
@ -154,6 +146,7 @@ internal class PatcherTest {
assertDoesNotThrow("Should raise an exception because opcodes is empty") {
Patcher(
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
ByteArrayOutputStream(),
arrayOf(
Signature(
sigName,

View file

@ -1,12 +1,12 @@
package net.revanced.patcher
import java.io.ByteArrayOutputStream
import kotlin.test.Test
internal class ReaderTest {
@Test
fun `read jar containing multiple classes`() {
val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!!
Patcher(testData, PatcherTest.testSigs) // reusing test sigs from PatcherTest
testData.close()
Patcher(testData, ByteArrayOutputStream(), PatcherTest.testSignatures) // reusing test sigs from PatcherTest
}
}

View file

@ -17,7 +17,7 @@ object TestUtil {
private fun AbstractInsnNode.nodeString(): String {
val sb = NodeStringBuilder()
when (this) {
// TODO: Add more types
// TODO(Sculas): Add more types
is LdcInsnNode -> sb
.addType("cst", cst)
is FieldInsnNode -> sb