diff --git a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt index c7b814a4299f6f..42a321560f500e 100644 --- a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt +++ b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt @@ -26,7 +26,6 @@ import okhttp3.Headers import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response -import okio.Buffer import okio.BufferedSource import okio.Okio import org.json.JSONException @@ -120,14 +119,10 @@ public class BundleDownloader public constructor(private val client: OkHttpClien downloadBundleFromURLCall = null val url = resp.request().url().toString() - // Make sure the result is a multipart response and parse the boundary. - var contentType = resp.header("content-type") - if (contentType == null) { - // fallback to empty string for nullability - contentType = "" - } + val contentType = resp.header("content-type") ?: "" val regex = Pattern.compile("multipart/mixed;.*boundary=\"([^\"]+)\"") val match = regex.matcher(contentType) + if (contentType.isNotEmpty() && match.find()) { val boundary = Assertions.assertNotNull(match.group(1)) processMultipartResponse(url, resp, boundary, outputFile, bundleInfo, callback) @@ -135,18 +130,23 @@ public class BundleDownloader public constructor(private val client: OkHttpClien // In case the server doesn't support multipart/mixed responses, fallback to // normal // download. - resp.body().use { body -> - if (body != null) { + val body = resp.body() + if (body != null) { + body.use { processBundleResult( url, resp.code(), resp.headers(), - body.source(), + it.source(), outputFile, bundleInfo, callback, ) } + } else { + callback.onFailure( + makeGeneric(url, "Development server response body was empty.", "URL: $url", null) + ) } } } @@ -164,44 +164,40 @@ public class BundleDownloader public constructor(private val client: OkHttpClien bundleInfo: BundleInfo?, callback: DevBundleDownloadListener, ) { - if (response.body() == null) { + val responseBody = response.body() + if (responseBody == null) { callback.onFailure( DebugServerException( (""" Error while reading multipart response. - + Response body was empty: ${response.code()} - + URL: $url - - + + """ .trimIndent()) ) ) return } - val source = checkNotNull(response.body()?.source()) + + val source = responseBody.source() val bodyReader = MultipartStreamReader(source, boundary) val completed = bodyReader.readAllParts( object : ChunkListener { - @Throws(IOException::class) override fun onChunkComplete( headers: Map, - body: Buffer, + body: BufferedSource, isLastChunk: Boolean, ) { - // This will get executed for every chunk of the multipart response. The last chunk - // (isLastChunk = true) will be the JS bundle, the other ones will be progress - // events - // encoded as JSON. if (isLastChunk) { - // The http status code for each separate chunk is in the X-Http-Status header. - var status = response.code() - if (headers.containsKey("X-Http-Status")) { - status = headers.getOrDefault("X-Http-Status", "0").toInt() - } + val status = + headers["X-Http-Status"]?.toIntOrNull() + ?: response.code() + processBundleResult( url, status, @@ -211,39 +207,29 @@ public class BundleDownloader public constructor(private val client: OkHttpClien bundleInfo, callback, ) - } else { - if ( - !headers.containsKey("Content-Type") || - headers["Content-Type"] != "application/json" - ) { - return - } + return + } - try { - val progress = JSONObject(body.readUtf8()) - val status = - if (progress.has("status")) progress.getString("status") else "Bundling" - var done: Int? = null - if (progress.has("done")) { - done = progress.getInt("done") - } - var total: Int? = null - if (progress.has("total")) { - total = progress.getInt("total") - } - callback.onProgress(status, done, total) - } catch (e: JSONException) { - FLog.e(ReactConstants.TAG, "Error parsing progress JSON. $e") - } + val contentType = headers["Content-Type"] ?: return + if (!isJsonContentType(contentType)) { + return + } + + try { + // Body is already bounded to this part; safe to read fully. + val progress = JSONObject(body.readUtf8()) + val status = if (progress.has("status")) progress.getString("status") else "Bundling" + val done: Int? = if (progress.has("done")) progress.getInt("done") else null + val total: Int? = if (progress.has("total")) progress.getInt("total") else null + callback.onProgress(status, done, total) + } catch (e: JSONException) { + FLog.e(ReactConstants.TAG, "Error parsing progress JSON.", e) } } - override fun onChunkProgress( - headers: Map, - loaded: Long, - total: Long, - ) { - if ("application/javascript" == headers["Content-Type"]) { + override fun onChunkProgress(headers: Map, loaded: Long, total: Long) { + val contentType = headers["Content-Type"] ?: return + if (isJavaScriptContentType(contentType)) { callback.onProgress( "Downloading", (loaded / 1024).toInt(), @@ -253,17 +239,18 @@ public class BundleDownloader public constructor(private val client: OkHttpClien } } ) + if (!completed) { callback.onFailure( DebugServerException( (""" Error while reading multipart response. - + Response code: ${response.code()} - + URL: $url - - + + """ .trimIndent()) ) @@ -309,7 +296,6 @@ public class BundleDownloader public constructor(private val client: OkHttpClien val tmpFile = File(outputFile.path + ".tmp") if (storePlainJSInFile(body, tmpFile)) { - // If we have received a new bundle from the server, move it to its final destination. if (!tmpFile.renameTo(outputFile)) { throw IOException("Couldn't rename $tmpFile to $outputFile") } @@ -326,7 +312,9 @@ public class BundleDownloader public constructor(private val client: OkHttpClien @Throws(IOException::class) private fun storePlainJSInFile(body: BufferedSource, outputFile: File): Boolean { - Okio.sink(outputFile).use { it -> body.readAll(it) } + Okio.sink(outputFile).use { sink -> + body.readAll(sink) + } return true } @@ -343,5 +331,11 @@ public class BundleDownloader public constructor(private val client: OkHttpClien } } } + + private fun isJsonContentType(value: String): Boolean = + value.startsWith("application/json") + + private fun isJavaScriptContentType(value: String): Boolean = + value.startsWith("application/javascript") } -} +} \ No newline at end of file diff --git a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt index 5ff3dc94532bae..e0a63989195f01 100644 --- a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt +++ b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt @@ -14,6 +14,10 @@ import kotlin.math.max import okio.Buffer import okio.BufferedSource import okio.ByteString +import okio.Source +import okio.Timeout +import okio.Okio +import java.util.TreeMap /** Utility class to parse the body of a response of type multipart/mixed. */ internal class MultipartStreamReader( @@ -25,7 +29,7 @@ internal class MultipartStreamReader( interface ChunkListener { /** Invoked when a chunk of a multipart response is fully downloaded. */ @Throws(IOException::class) - fun onChunkComplete(headers: Map, body: Buffer, isLastChunk: Boolean) + fun onChunkComplete(headers: Map, body: BufferedSource, isLastChunk: Boolean) /** Invoked as bytes of the current chunk are read. */ @Throws(IOException::class) @@ -48,8 +52,9 @@ internal class MultipartStreamReader( var chunkStart: Long = 0 var bytesSeen: Long = 0 val content = Buffer() + var currentHeaders: Map? = null - var currentHeadersLength: Long = 0 + var currentBodyStartIndexInContent: Long = -1 while (true) { var isCloseDelimiter = false @@ -58,6 +63,7 @@ internal class MultipartStreamReader( // to allow for the edge case when the delimiter is cut by read call. val searchStart = max((bytesSeen - closeDelimiter.size()).toDouble(), chunkStart.toDouble()).toLong() + var indexOfDelimiter = content.indexOf(delimiter, searchStart) if (indexOfDelimiter == -1L) { isCloseDelimiter = true @@ -68,16 +74,16 @@ internal class MultipartStreamReader( bytesSeen = content.size() if (currentHeaders == null) { - val indexOfHeaders = content.indexOf(headersDelimiter, searchStart) - if (indexOfHeaders >= 0) { - source.read(content, indexOfHeaders) + val indexOfHeadersDelimiter = content.indexOf(headersDelimiter, searchStart) + if (indexOfHeadersDelimiter >= 0) { val headers = Buffer() - content.copyTo(headers, searchStart, indexOfHeaders - searchStart) - currentHeadersLength = headers.size() + headersDelimiter.size() + content.copyTo(headers, searchStart, indexOfHeadersDelimiter - searchStart) currentHeaders = parseHeaders(headers) + currentBodyStartIndexInContent = indexOfHeadersDelimiter + headersDelimiter.size().toLong() } } else { - emitProgress(currentHeaders, content.size() - currentHeadersLength, false, listener) + val loaded = max(0L, content.size() - currentBodyStartIndexInContent) + emitProgress(currentHeaders, loaded, false, listener) } val bytesRead = source.read(content, bufferLen.toLong()) @@ -92,26 +98,30 @@ internal class MultipartStreamReader( // Ignore preamble if (chunkStart > 0) { - val chunk = Buffer() + if (currentHeaders != null && currentBodyStartIndexInContent >= 0) { + val loadedFinal = max(0L, chunkEnd - currentBodyStartIndexInContent) + emitProgress(currentHeaders, loadedFinal, true, listener) + } content.skip(chunkStart) - content.read(chunk, length) - emitProgress(currentHeaders, chunk.size() - currentHeadersLength, true, listener) - emitChunk(chunk, isCloseDelimiter, listener) + emitChunk(content, length, isCloseDelimiter, listener) + currentHeaders = null - currentHeadersLength = 0 + currentBodyStartIndexInContent = -1 } else { content.skip(chunkEnd) } if (isCloseDelimiter) { return true } + chunkStart = delimiter.size().toLong() bytesSeen = chunkStart } } private fun parseHeaders(data: Buffer): Map { - val headers: MutableMap = mutableMapOf() + // Header names are case-insensitive + val headers: MutableMap = TreeMap(String.CASE_INSENSITIVE_ORDER) val text = data.readUtf8() val lines = text.split(CRLF.toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() for (line in lines) { @@ -126,20 +136,81 @@ internal class MultipartStreamReader( return headers } + /** + * Emits a chunk to the listener. The `body` passed to the listener is bounded to the chunk body + * bytes, so the listener cannot accidentally read into the next boundary. + * + * Also drains any unread body bytes after the callback to keep parsing in sync. + */ @Throws(IOException::class) - private fun emitChunk(chunk: Buffer, done: Boolean, listener: ChunkListener) { + private fun emitChunk( + content: Buffer, + chunkLength: Long, + done: Boolean, + listener: ChunkListener, + ) { val marker: ByteString = ByteString.encodeUtf8(CRLF + CRLF) - val indexOfMarker = chunk.indexOf(marker) - if (indexOfMarker == -1L) { - listener.onChunkComplete(emptyMap(), chunk, done) - } else { - val headers = Buffer() - val body = Buffer() - chunk.read(headers, indexOfMarker) - chunk.skip(marker.size().toLong()) - chunk.readAll(body) - listener.onChunkComplete(parseHeaders(headers), body, done) + val indexOfMarker = content.indexOf(marker, 0) + + if (indexOfMarker == -1L || indexOfMarker >= chunkLength) { + // No headers marker found inside the chunk. Treat the entire chunk as body. + val bodyLength = chunkLength + val body = Okio.buffer(FixedLengthSource(content, bodyLength)) + try { + listener.onChunkComplete(emptyMap(), body, done) + } finally { + drainFully(body) + } + return } + + // Headers exist. + val headersBuf = Buffer() + content.read(headersBuf, indexOfMarker) + content.skip(marker.size().toLong()) + val headers = parseHeaders(headersBuf) + + val maxBodyLength = chunkLength - indexOfMarker - marker.size().toLong() + val body = Okio.buffer(FixedLengthSource(content, maxBodyLength)) + try { + listener.onChunkComplete(headers, body, done) + } finally { + drainFully(body) + } + } + + private fun drainFully(body: BufferedSource) { + // Drain remaining bytes from this part body (if listener didn't). + // Use small reusable buffer to avoid unbounded memory. + val tmp = Buffer() + try { + while (true) { + val r = body.read(tmp, 8 * 1024L) + if (r == -1L) break + tmp.clear() + } + } catch (_: IOException) { + // Best-effort drain; parsing will likely fail upstream anyway. + } + } + + private class FixedLengthSource( + private val upstream: Buffer, + private var remaining: Long, + ) : Source { + override fun read(sink: Buffer, byteCount: Long): Long { + if (byteCount == 0L) return 0L + if (remaining == 0L) return -1L + val toRead = minOf(byteCount, remaining) + val read = upstream.read(sink, toRead) + if (read == -1L) return -1L + remaining -= read + return read + } + + override fun timeout(): Timeout = Timeout.NONE + + override fun close() = Unit } @Throws(IOException::class) diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt index c0753a45ca04e5..76dcc8fff67c2e 100644 --- a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt @@ -8,6 +8,7 @@ package com.facebook.react.devsupport import okio.Buffer +import okio.BufferedSource import okio.ByteString import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -34,18 +35,19 @@ class MultipartStreamReaderTest { val callback: CallCountTrackingChunkCallback = object : CallCountTrackingChunkCallback() { - override fun onChunkComplete(headers: Map, body: Buffer, done: Boolean) { + override fun onChunkComplete(headers: Map, body: BufferedSource, done: Boolean) { super.onChunkComplete(headers, body, done) - assertThat(done).isTrue + assertThat(done).isTrue() assertThat(headers["Content-Type"]).isEqualTo("application/json; charset=utf-8") assertThat(body.readUtf8()).isEqualTo("{}") } } + val success = reader.readAllParts(callback) assertThat(callback.callCount).isEqualTo(1) - assertThat(success).isTrue + assertThat(success).isTrue() } @Test @@ -70,7 +72,7 @@ class MultipartStreamReaderTest { val callback: CallCountTrackingChunkCallback = object : CallCountTrackingChunkCallback() { - override fun onChunkComplete(headers: Map, body: Buffer, done: Boolean) { + override fun onChunkComplete(headers: Map, body: BufferedSource, done: Boolean) { super.onChunkComplete(headers, body, done) assertThat(done).isEqualTo(callCount == 3) @@ -80,7 +82,7 @@ class MultipartStreamReaderTest { val success = reader.readAllParts(callback) assertThat(callback.callCount).isEqualTo(3) - assertThat(success).isTrue + assertThat(success).isTrue() } @Test @@ -96,7 +98,7 @@ class MultipartStreamReaderTest { val success = reader.readAllParts(callback) assertThat(callback.callCount).isEqualTo(0) - assertThat(success).isFalse + assertThat(success).isFalse() } @Test @@ -120,15 +122,89 @@ class MultipartStreamReaderTest { val callback = CallCountTrackingChunkCallback() val success = reader.readAllParts(callback) + // First part was complete, then stream ended without a close delimiter. + assertThat(callback.callCount).isEqualTo(1) + assertThat(success).isFalse() + } + + @Test + fun testListenerDoesNotNeedToFullyReadBody() { + val response: ByteString = + encodeUtf8( + "preamble\r\n" + + "--sample_boundary\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 4\r\n\r\n" + + "ABCD\r\n" + + "--sample_boundary\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 1\r\n\r\n" + + "Z\r\n" + + "--sample_boundary--\r\n" + ) + + val source = Buffer().apply { write(response) } + val reader = MultipartStreamReader(source, "sample_boundary") + + val parts = mutableListOf() + val callback = + object : MultipartStreamReader.ChunkListener { + override fun onChunkComplete(headers: Map, body: BufferedSource, isLastChunk: Boolean) { + if (parts.isEmpty()) { + // Intentionally only read 1 byte from the first part. + parts.add(body.readUtf8(1)) + return + } + parts.add(body.readUtf8()) + } + + override fun onChunkProgress(headers: Map, loaded: Long, total: Long) = Unit + } + + val success = reader.readAllParts(callback) + + assertThat(success).isTrue() + assertThat(parts).containsExactly("A", "Z") + } + + @Test + fun testHeaderNamesAreCaseInsensitive() { + val response: ByteString = + encodeUtf8( + "preamble\r\n" + + "--sample_boundary\r\n" + + "content-type: application/json\r\n" + + "content-length: 2\r\n\r\n" + + "{}\r\n" + + "--sample_boundary--\r\n" + ) + + val source = Buffer().apply { write(response) } + val reader = MultipartStreamReader(source, "sample_boundary") + + val callback = + object : CallCountTrackingChunkCallback() { + override fun onChunkComplete(headers: Map, body: BufferedSource, done: Boolean) { + super.onChunkComplete(headers, body, done) + + // Lookup using canonical case should still work. + assertThat(headers["Content-Type"]).isEqualTo("application/json") + assertThat(headers["Content-Length"]).isEqualTo("2") + assertThat(body.readUtf8()).isEqualTo("{}") + } + } + + val success = reader.readAllParts(callback) + + assertThat(success).isTrue() assertThat(callback.callCount).isEqualTo(1) - assertThat(success).isFalse } internal open class CallCountTrackingChunkCallback : MultipartStreamReader.ChunkListener { var callCount = 0 private set - override fun onChunkComplete(headers: Map, body: Buffer, isLastChunk: Boolean) { + override fun onChunkComplete(headers: Map, body: BufferedSource, isLastChunk: Boolean) { callCount++ }