From e426b3ccbd34724429be701f84367806d5d0becf Mon Sep 17 00:00:00 2001 From: Zlatin Balevsky Date: Wed, 12 Jun 2019 17:33:43 +0100 Subject: [PATCH] refactoring to enable hashlist uploads --- .../core/connection/ConnectionAcceptor.groovy | 14 ++- .../muwire/core/upload/ContentRequest.groovy | 5 + .../muwire/core/upload/ContentUploader.groovy | 54 ++++++++ .../muwire/core/upload/HashListRequest.groovy | 4 + .../core/upload/HashListUploader.groovy | 35 ++++++ .../com/muwire/core/upload/Request.groovy | 115 ++++++++++-------- .../muwire/core/upload/UploadManager.groovy | 47 ++++++- .../com/muwire/core/upload/Uploader.groovy | 45 +------ 8 files changed, 226 insertions(+), 93 deletions(-) create mode 100644 core/src/main/groovy/com/muwire/core/upload/ContentRequest.groovy create mode 100644 core/src/main/groovy/com/muwire/core/upload/ContentUploader.groovy create mode 100644 core/src/main/groovy/com/muwire/core/upload/HashListRequest.groovy create mode 100644 core/src/main/groovy/com/muwire/core/upload/HashListUploader.groovy diff --git a/core/src/main/groovy/com/muwire/core/connection/ConnectionAcceptor.groovy b/core/src/main/groovy/com/muwire/core/connection/ConnectionAcceptor.groovy index f6eb620f..6ee13984 100644 --- a/core/src/main/groovy/com/muwire/core/connection/ConnectionAcceptor.groovy +++ b/core/src/main/groovy/com/muwire/core/connection/ConnectionAcceptor.groovy @@ -108,6 +108,9 @@ class ConnectionAcceptor { case (byte)'G': processGET(e) break + case (byte)'H': + processHashList(e) + break case (byte)'P': processPOST(e) break @@ -178,9 +181,18 @@ class ConnectionAcceptor { dis.readFully(et) if (et != "ET ".getBytes(StandardCharsets.US_ASCII)) throw new IOException("Invalid GET connection") - uploadManager.processEndpoint(e) + uploadManager.processGET(e) } + private void processHashList(Endpoint e) { + byte[] ashList = new byte[8] + final DataInputStream dis = new DataInputStream(e.getInputStream()) + dis.readFully(ashList) + if (ashList != "ASHLIST ".getBytes(StandardCharsets.US_ASCII)) + throw new IOException("Invalid HASHLIST connection") + uploadManager.processHashList(e) + } + private void processPOST(final Endpoint e) throws IOException { byte [] ost = new byte[4] final DataInputStream dis = new DataInputStream(e.getInputStream()) diff --git a/core/src/main/groovy/com/muwire/core/upload/ContentRequest.groovy b/core/src/main/groovy/com/muwire/core/upload/ContentRequest.groovy new file mode 100644 index 00000000..45f03336 --- /dev/null +++ b/core/src/main/groovy/com/muwire/core/upload/ContentRequest.groovy @@ -0,0 +1,5 @@ +package com.muwire.core.upload + +class ContentRequest extends Request { + Range range +} diff --git a/core/src/main/groovy/com/muwire/core/upload/ContentUploader.groovy b/core/src/main/groovy/com/muwire/core/upload/ContentUploader.groovy new file mode 100644 index 00000000..1b5d13cb --- /dev/null +++ b/core/src/main/groovy/com/muwire/core/upload/ContentUploader.groovy @@ -0,0 +1,54 @@ +package com.muwire.core.upload + +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.StandardOpenOption + +import com.muwire.core.connection.Endpoint + +class ContentUploader extends Uploader { + + private final File file + private final ContentRequest request + + ContentUploader(File file, ContentRequest request, Endpoint endpoint) { + super(endpoint) + this.file = file + this.request = request + } + + @Override + void respond() { + OutputStream os = endpoint.getOutputStream() + Range range = request.getRange() + if (range.start >= file.length() || range.end >= file.length()) { + os.write("416 Range Not Satisfiable\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) + os.flush() + return + } + + os.write("200 OK\r\n".getBytes(StandardCharsets.US_ASCII)) + os.write("Content-Range: $range.start-$range.end\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) + + FileChannel channel + try { + channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ)) + mapped = channel.map(FileChannel.MapMode.READ_ONLY, range.start, range.end - range.start + 1) + byte [] tmp = new byte[0x1 << 13] + while(mapped.hasRemaining()) { + int start = mapped.position() + synchronized(this) { + mapped.get(tmp, 0, Math.min(tmp.length, mapped.remaining())) + } + int read = mapped.position() - start + endpoint.getOutputStream().write(tmp, 0, read) + } + } finally { + try {channel?.close() } catch (IOException ignored) {} + endpoint.getOutputStream().flush() + } + } + +} diff --git a/core/src/main/groovy/com/muwire/core/upload/HashListRequest.groovy b/core/src/main/groovy/com/muwire/core/upload/HashListRequest.groovy new file mode 100644 index 00000000..f68baf19 --- /dev/null +++ b/core/src/main/groovy/com/muwire/core/upload/HashListRequest.groovy @@ -0,0 +1,4 @@ +package com.muwire.core.upload + +class HashListRequest extends Request { +} diff --git a/core/src/main/groovy/com/muwire/core/upload/HashListUploader.groovy b/core/src/main/groovy/com/muwire/core/upload/HashListUploader.groovy new file mode 100644 index 00000000..ac0d3e03 --- /dev/null +++ b/core/src/main/groovy/com/muwire/core/upload/HashListUploader.groovy @@ -0,0 +1,35 @@ +package com.muwire.core.upload + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import com.muwire.core.InfoHash +import com.muwire.core.connection.Endpoint + +class HashListUploader extends Uploader { + private final InfoHash infoHash + private final HashListRequest request + + HashListUploader(Endpoint endpoint, InfoHash infoHash, HashListRequest request) { + super(endpoint) + this.infoHash = infoHash + mapped = ByteBuffer.wrap(infoHash.getHashList()) + } + + void respond() { + OutputStream os = endpoint.getOutputStream() + os.write("200 OK\r\n".getBytes(StandardCharsets.US_ASCII)) + os.write("Content-Range: 0-${mapped.remaining()}") + + byte[]tmp = new byte[0x1 << 13] + while(mapped.hasRemaining()) { + int start = mapped.position() + synchronized(this) { + mapped.get(tmp, 0, Math.min(tmp.length, mapped.remaining())) + } + int read = mapped.position() - start + endpoint.getOutputStream().write(tmp, 0, read) + } + endpoint.getOutputStream().flush() + } +} diff --git a/core/src/main/groovy/com/muwire/core/upload/Request.groovy b/core/src/main/groovy/com/muwire/core/upload/Request.groovy index c9a87954..a38b7d9d 100644 --- a/core/src/main/groovy/com/muwire/core/upload/Request.groovy +++ b/core/src/main/groovy/com/muwire/core/upload/Request.groovy @@ -16,57 +16,12 @@ class Request { private static final byte N = "\n".getBytes(StandardCharsets.US_ASCII)[0] InfoHash infoHash - Range range Persona downloader Map headers - static Request parse(InfoHash infoHash, InputStream is) throws IOException { - Map headers = new HashMap<>() - byte [] tmp = new byte[Constants.MAX_HEADER_SIZE] - while(headers.size() < Constants.MAX_HEADERS) { - boolean r = false - boolean n = false - int idx = 0 - while (true) { - byte read = is.read() - if (read == -1) - throw new IOException("Stream closed") - - if (!r && read == N) - throw new IOException("Received N before R") - if (read == R) { - if (r) - throw new IOException("double R") - r = true - continue - } - - if (r && !n) { - if (read != N) - throw new IOException("R not followed by N") - n = true - break - } - if (idx == 0x1 << 14) - throw new IOException("Header too long") - tmp[idx++] = read - } - - if (idx == 0) - break - - String header = new String(tmp, 0, idx, StandardCharsets.US_ASCII) - log.fine("Read header $header") - - int keyIdx = header.indexOf(":") - if (keyIdx < 1) - throw new IOException("Header key not found") - if (keyIdx == header.length()) - throw new IOException("Header value not found") - String key = header.substring(0, keyIdx) - String value = header.substring(keyIdx + 1) - headers.put(key, value) - } + static Request parseContentRequest(InfoHash infoHash, InputStream is) throws IOException { + + Map headers = parseHeaders(is) if (!headers.containsKey("Range")) throw new IOException("Range header not found") @@ -93,7 +48,69 @@ class Request { def decoded = Base64.decode(encoded) downloader = new Persona(new ByteArrayInputStream(decoded)) } - new Request( infoHash : infoHash, range : new Range(start, end), headers : headers, downloader : downloader) + new ContentRequest( infoHash : infoHash, range : new Range(start, end), + headers : headers, downloader : downloader) + } + + static Request parseHashListRequest(InfoHash infoHash, InputStream is) throws IOException { + Map headers = parseHeaders(is) + Persona downloader = null + if (headers.containsKey("X-Persona")) { + def encoded = headers["X-Persona"].trim() + def decoded = Base64.decode(encoded) + downloader = new Persona(new ByteArrayInputStream(decoded)) + } + new HashListRequest(infoHash : infoHash, headers : headers, downloader : downloader) + } + + private static Map parseHeaders(InputStream is) { + Map headers = new HashMap<>() + byte [] tmp = new byte[Constants.MAX_HEADER_SIZE] + while(headers.size() < Constants.MAX_HEADERS) { + boolean r = false + boolean n = false + int idx = 0 + while (true) { + byte read = is.read() + if (read == -1) + throw new IOException("Stream closed") + + if (!r && read == N) + throw new IOException("Received N before R") + if (read == R) { + if (r) + throw new IOException("double R") + r = true + continue + } + + if (r && !n) { + if (read != N) + throw new IOException("R not followed by N") + n = true + break + } + if (idx == 0x1 << 14) + throw new IOException("Header too long") + tmp[idx++] = read + } + + if (idx == 0) + break + + String header = new String(tmp, 0, idx, StandardCharsets.US_ASCII) + log.fine("Read header $header") + + int keyIdx = header.indexOf(":") + if (keyIdx < 1) + throw new IOException("Header key not found") + if (keyIdx == header.length()) + throw new IOException("Header value not found") + String key = header.substring(0, keyIdx) + String value = header.substring(keyIdx + 1) + headers.put(key, value) + } + headers } } diff --git a/core/src/main/groovy/com/muwire/core/upload/UploadManager.groovy b/core/src/main/groovy/com/muwire/core/upload/UploadManager.groovy index 9cda302b..177cd9e8 100644 --- a/core/src/main/groovy/com/muwire/core/upload/UploadManager.groovy +++ b/core/src/main/groovy/com/muwire/core/upload/UploadManager.groovy @@ -23,9 +23,9 @@ public class UploadManager { this.fileManager = fileManager } - public void processEndpoint(Endpoint e) throws IOException { + public void processGET(Endpoint e) throws IOException { byte [] infoHashStringBytes = new byte[44] - DataInputStream dis = new DataInputStream(e.getInputStream()) + DataInputStream dis = new DataInputStream(e.getInputStream()) boolean first = true while(true) { if (first) @@ -61,13 +61,13 @@ public class UploadManager { return } - Request request = Request.parse(new InfoHash(infoHashRoot), e.getInputStream()) + Request request = Request.parseContentRequest(new InfoHash(infoHashRoot), e.getInputStream()) if (request.downloader != null && request.downloader.destination != e.destination) { log.info("Downloader persona doesn't match their destination") e.close() return } - Uploader uploader = new Uploader(sharedFiles.iterator().next().file, request, e) + Uploader uploader = new ContentUploader(sharedFiles.iterator().next().file, request, e) eventBus.publish(new UploadEvent(uploader : uploader)) try { uploader.respond() @@ -75,7 +75,46 @@ public class UploadManager { eventBus.publish(new UploadFinishedEvent(uploader : uploader)) } } + } + + public void processHashList(Endpoint e) { + byte [] infoHashStringBytes = new byte[44] + DataInputStream dis = new DataInputStream(e.getInputStream()) + dis.readFully(infoHashStringBytes) + String infoHashString = new String(infoHashStringBytes, StandardCharsets.US_ASCII) + log.info("Responding to hashlist request for root $infoHashString") + byte [] infoHashRoot = Base64.decode(infoHashString) + Set sharedFiles = fileManager.getSharedFiles(infoHashRoot) + if (sharedFiles == null || sharedFiles.isEmpty()) { + log.info "file not found" + e.getOutputStream().write("404 File Not Found\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) + e.getOutputStream().flush() + e.close() + return + } + + byte [] rn = new byte[2] + dis.readFully(rn) + if (rn != "\r\n".getBytes(StandardCharsets.US_ASCII)) { + log.warning("Malformed HASHLIST header") + e.close() + return + } + + Request request = Request.parseHashListRequest(new InfoHash(infoHashRoot), e.getInputStream()) + if (request.downloader != null && request.downloader.destination != e.destination) { + log.info("Downloader persona doesn't match their destination") + e.close() + return + } + Uploader uploader = new HashListUploader(e, sharedFiles.iterator().next().infoHash, request, request) + eventBus.publish(new UploadEvent(uploader : uploader)) + try { + uploader.respond() + } finally { + eventBus.publish(new UploadFinishedEvent(uploader : uploader)) + } } } diff --git a/core/src/main/groovy/com/muwire/core/upload/Uploader.groovy b/core/src/main/groovy/com/muwire/core/upload/Uploader.groovy index ea231ddd..d2e3b6df 100644 --- a/core/src/main/groovy/com/muwire/core/upload/Uploader.groovy +++ b/core/src/main/groovy/com/muwire/core/upload/Uploader.groovy @@ -8,49 +8,16 @@ import java.nio.file.StandardOpenOption import com.muwire.core.connection.Endpoint -class Uploader { - private final File file - private final Request request - private final Endpoint endpoint - private ByteBuffer mapped +abstract class Uploader { + protected final Endpoint endpoint + protected ByteBuffer mapped - Uploader(File file, Request request, Endpoint endpoint) { - this.file = file - this.request = request + Uploader(Endpoint endpoint) { this.endpoint = endpoint } - void respond() { - OutputStream os = endpoint.getOutputStream() - Range range = request.getRange() - if (range.start >= file.length() || range.end >= file.length()) { - os.write("416 Range Not Satisfiable\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) - os.flush() - return - } - - os.write("200 OK\r\n".getBytes(StandardCharsets.US_ASCII)) - os.write("Content-Range: $range.start-$range.end\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) - - FileChannel channel - try { - channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ)) - mapped = channel.map(FileChannel.MapMode.READ_ONLY, range.start, range.end - range.start + 1) - byte [] tmp = new byte[0x1 << 13] - while(mapped.hasRemaining()) { - int start = mapped.position() - synchronized(this) { - mapped.get(tmp, 0, Math.min(tmp.length, mapped.remaining())) - } - int read = mapped.position() - start - endpoint.getOutputStream().write(tmp, 0, read) - } - } finally { - try {channel?.close() } catch (IOException ignored) {} - endpoint.getOutputStream().flush() - } - } - + abstract void respond() + public synchronized int getPosition() { if (mapped == null) return -1