diff --git a/TCPBridge/bridge.lua b/TCPBridge/bridge.lua index 180ac6c..694a0e8 100644 --- a/TCPBridge/bridge.lua +++ b/TCPBridge/bridge.lua @@ -3,13 +3,20 @@ local imt = require "interminitel" local clients, coroutines, messages = {}, {}, {} -local function hasValidPacket(s) - local w, pi, pt, ds, sn, po, da = pcall(imt.decodePacket,s) - if w and pi and pt and ds and sn and po and da then - return true - end +local function spawn(f) + coroutines[#coroutines+1] = coroutine.create(function() + while true do + print(pcall(f)) + end + end) end +local function hasValidPacket(s) + local w, res = pcall(imt.decodePacket,s) + if res then return true end +end +hasValidPacket("") + function socketLoop() local server = socket.bind("*", 4096) server:settimeout(0) @@ -17,27 +24,29 @@ function socketLoop() local client,err = server:accept() if client then client:settimeout(0) - clients[#clients+1] = {["conn"]=client,last=os.time()} + clients[#clients+1] = {["conn"]=client,last=os.time(),buffer=""} + print("Gained client: "..client:getsockname()) end coroutine.yield() end end -coroutines[#coroutines+1]=coroutine.create(socketLoop) +spawn(socketLoop) function clientLoop() while true do for _,client in pairs(clients) do - local s=client.conn:receive(16384) + local s=client.conn:receive() if s then client.buffer = client.buffer .. s + print(s) end end coroutine.yield() end end -coroutines[#coroutines+1]=coroutine.create(clientLoop) +spawn(clientLoop) function pushLoop() while true do @@ -52,22 +61,23 @@ function pushLoop() end end -coroutines[#coroutines+1]=coroutine.create(pushLoop) +spawn(pushLoop) function bufferLoop() while true do for _,client in pairs(clients) do - if hasValidPacket(client.buffer) then - local tPacket = {imt.decodePacket(client.buffer)} - client.buffer = table.remove(tPacket,#tPacket) - messages[#messages+1] = imt.encodePacket(table.unpack(tPacket)) + if client.buffer:len() > 0 then + if hasValidPacket(client.buffer) then + messages[#messages+1] = imt.encodePacket(imt.decodePacket(client.buffer)) + client.buffer = imt.getRemainder(client.buffer) + end end end coroutine.yield() end end -coroutines[#coroutines+1]=coroutine.create(bufferLoop) +spawn(bufferLoop) while #coroutines > 0 do for k,v in pairs(coroutines) do diff --git a/TCPBridge/interminitel.lua b/TCPBridge/interminitel.lua index b504695..9dc5bca 100644 --- a/TCPBridge/interminitel.lua +++ b/TCPBridge/interminitel.lua @@ -1,5 +1,11 @@ local imt = {} +imt.ttypes = {} +imt.ttypes.string=1 +imt.ttypes.number=2 + +imt.ftypes = {tostring,tonumber} + function imt.to16bn(n) return string.char(math.floor(n/256))..string.char(math.floor(n%256)) end @@ -7,38 +13,49 @@ function imt.from16bn(s) return (string.byte(s,1,1)*256)+string.byte(s,2,2) end -function imt.encodePacket(packetID, packetType, destination, sender, port, data) - local rs = string.char(packetID:len()%256)..packetID..string.char(packetType) - rs=rs..string.char(destination:len()%256)..destination - rs=rs..string.char(sender:len()%256)..sender - rs=rs..to16bn(port) - rs=rs..to16bn(data:len())..data - return rs +function imt.encodePacket(...) + local tArgs = {...} + local packet = string.char(#tArgs%256) + for _,segment in ipairs(tArgs) do + local segtype = type(segment) + segment = tostring(segment) + packet = packet .. imt.to16bn(segment:len()) .. string.char(imt.ttypes[segtype]) .. tostring(segment) + end + packet = imt.to16bn(packet:len()) .. packet + return packet end function imt.decodePacket(s) - local pidlen, destlen, senderlen, datalen, packetID, packetType, destination, sender, port, data - pidlen = string.byte(s:sub(1,1)) - s=s:sub(2) - packetID = s:sub(1,pidlen) - s=s:sub(pidlen+1) - packetType = string.byte(s:sub(pidlen+1)) - s=s:sub(2) - destlen = string.byte(s:sub(1,1)) - s=s:sub(2) - destination = s:sub(1,destlen) - s=s:sub(destlen+1) - senderlen=string.byte(s:sub(1,1)) - s=s:sub(2) - sender = s:sub(1,senderlen) - s=s:sub(senderlen+1) - port=from16bn(s:sub(1,2)) - s=s:sub(3) - datalen=from16bn(s:sub(1,2)) - s=s:sub(3) - data=s:sub(1,datalen) - s=s:sub(datalen+1) - return packetID, packetType, destination, sender, port, data, s + local function getfirst(n) + local ns = s:sub(1,n) + s=s:sub(n+1) + return ns + end + local plen = imt.from16bn(getfirst(2)) + if s:len() < plen then return false end + local nsegments = string.byte(getfirst(1)) + local tsegments = {} + for i = 1, nsegments do + if s:len() < 1 then return false end + local seglen = imt.from16bn(getfirst(2)) + local segtype = string.byte(getfirst(1)) + local tempseg = getfirst(seglen) + tsegments[#tsegments+1] = imt.ftypes[segtype](tempseg) + end + return table.unpack(tsegments) +end +function imt.getRemainder(s) + local function getfirst(n) + local ns = s:sub(1,n) + s=s:sub(n+1) + return ns + end + local plen = imt.from16bn(getfirst(2)) + if s:len() > plen then + getfirst(plen) + return s + end + return nil end return imt