Skip to content

Commit fc295e3

Browse files
authored
Merge pull request #27 from cceckman/slongfield/start_parse
Create a better start-line parsing module.
2 parents 39a1689 + b0f5a00 commit fc295e3

6 files changed

Lines changed: 397 additions & 25 deletions

File tree

http_server/parse_start.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from amaranth import Module
2+
from amaranth.lib.wiring import In, Out, Component, connect, flipped
3+
from amaranth.lib import stream
4+
5+
from string_contains_match import StringContainsMatch
6+
from stream_demux import StreamDemux
7+
import stream_utils
8+
9+
class ParseStart(Component):
10+
"""
11+
Parser for the start-line of an HTTP request.
12+
13+
Parameters
14+
----------
15+
paths: list[str]
16+
Valid paths to match
17+
18+
Attributes
19+
----------
20+
input: Stream(8), in
21+
Data stream to match
22+
reset: Signal(1), in
23+
Reset and await new input
24+
done: Signal(1), out
25+
Indicates that the "\r\n" end-of-line sequence was seen.
26+
method: list(Signal(1)), out
27+
Bitfield of matched methods. The 0th field indicates no match.
28+
METHOD_* constants can be used for decode.
29+
path: list(Signal(1)), out
30+
Bitfield of matched paths. The 0th field indicates no match.
31+
Other matches are in the order from the paths parameter.
32+
protocol: list(Siganl(1)), out
33+
Bitfield of matched protocol. The 0th field indicates no match.
34+
PROTOCOL_* constants can be used for decode.
35+
"""
36+
37+
METHOD_NO_MATCH = 0
38+
METHOD_GET = 1
39+
METHOD_POST = 2
40+
41+
PROTOCOL_NO_MATCH = 0
42+
PROTOCOL_HTTP1_0 = 1
43+
44+
def __init__(self, paths):
45+
super().__init__({
46+
"input": In(stream.Signature(8)),
47+
"reset": In(1),
48+
"done": Out(1),
49+
"method": Out(3),
50+
"path": Out(len(paths)+1),
51+
"protocol": Out(2),
52+
})
53+
self._paths = paths
54+
55+
def elaborate(self, _platform):
56+
m = Module()
57+
58+
resets = []
59+
60+
# TODO: #4 - Evaluate using StringMatch instead of StringContainsMatch. Here and below.
61+
# Check https://en.wikipedia.org/wiki/HTTP_request_smuggling cases.
62+
method_stream = stream.Signature(8).create()
63+
get_matcher = m.submodules.get_matcher = StringContainsMatch("GET")
64+
resets.append(get_matcher.reset)
65+
m.d.comb += self.method[self.METHOD_GET].eq(get_matcher.accepted)
66+
67+
post_matcher = m.submodules.post_matcher = StringContainsMatch("POST")
68+
resets.append(post_matcher.reset)
69+
m.d.comb += self.method[self.METHOD_POST].eq(post_matcher.accepted)
70+
71+
any_method_match = stream_utils.tree_or(m, [get_matcher.accepted, post_matcher.accepted])
72+
m.d.comb += self.method[0].eq(~any_method_match)
73+
74+
stream_utils.fanout_stream(m, method_stream, [get_matcher.input, post_matcher.input])
75+
76+
path_stream = stream.Signature(8).create()
77+
path_streams = []
78+
path_match = []
79+
for i,path in enumerate(self._paths):
80+
matcher = m.submodules[f"path_matchers_{i}"] = StringContainsMatch(path)
81+
path_streams.append(matcher.input)
82+
resets.append(matcher.reset)
83+
m.d.comb += self.path[i+1].eq(matcher.accepted)
84+
path_match.append(matcher.accepted)
85+
stream_utils.fanout_stream(m, path_stream, path_streams)
86+
any_path_match = stream_utils.tree_or(m, path_match)
87+
m.d.comb += self.path[0].eq(~any_path_match)
88+
89+
# TODO: #4 - If we want to get out of the stone age, should match more than HTTP/1.0
90+
# That being said, silicon is kind of like a stone, right?
91+
protocol_match = m.submodules.protocol_match = StringContainsMatch("HTTP/1.0")
92+
m.d.comb += self.protocol[self.PROTOCOL_NO_MATCH].eq(~protocol_match.accepted)
93+
m.d.comb += self.protocol[self.PROTOCOL_HTTP1_0].eq(protocol_match.accepted)
94+
95+
with m.FSM():
96+
with m.State("reset"):
97+
for r in resets:
98+
m.d.sync += r.eq(1)
99+
m.d.sync += self.done.eq(0)
100+
m.next = "match_method"
101+
with m.State("match_method"):
102+
m.next = "match_method"
103+
for r in resets:
104+
m.d.sync += r.eq(0)
105+
m.d.comb += [
106+
method_stream.valid.eq(self.input.valid),
107+
method_stream.payload.eq(self.input.payload),
108+
self.input.ready.eq(method_stream.ready),
109+
]
110+
with m.If(self.input.valid & (self.input.payload == ord(' '))):
111+
m.d.comb += self.input.ready.eq(1)
112+
m.next = "match_path"
113+
with m.State("match_path"):
114+
m.next = "match_path"
115+
m.d.comb += [
116+
path_stream.valid.eq(self.input.valid),
117+
path_stream.payload.eq(self.input.payload),
118+
self.input.ready.eq(path_stream.ready),
119+
]
120+
with m.If(self.input.valid & (self.input.payload == ord(' '))):
121+
m.d.comb += self.input.ready.eq(1)
122+
m.next = "match_protocol"
123+
with m.State("match_protocol"):
124+
m.next = "match_protocol"
125+
# connect results in warning about combinatorial signals
126+
m.d.comb += [
127+
protocol_match.input.valid.eq(self.input.valid),
128+
protocol_match.input.payload.eq(self.input.payload),
129+
self.input.ready.eq(protocol_match.input.ready),
130+
]
131+
with m.If(self.input.valid & (self.input.payload == ord('\r'))):
132+
m.d.comb += self.input.ready.eq(1)
133+
m.next = "match_end"
134+
with m.State("match_end"):
135+
m.d.comb += self.input.ready.eq(1)
136+
m.next = "match_end"
137+
# TODO: #4 - Should error if this isn't \n, and setup to return a
138+
# HTTP 400 Bad Request error.
139+
with m.If(self.input.valid & (self.input.payload == ord('\n'))):
140+
m.next = "done"
141+
with m.State("done"):
142+
m.next = "done"
143+
m.d.sync += self.done.eq(1)
144+
145+
return m
146+

http_server/parse_start_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import sys
2+
from amaranth.sim import Simulator
3+
4+
from parse_start import ParseStart
5+
from stream_fixtures import StreamSender
6+
7+
paths = ["/led", "/count"]
8+
PATH_NO_MATCH = 0
9+
PATH_LED = 1
10+
PATH_COUNT = 2
11+
12+
def run_test(send_line, check):
13+
dut = ParseStart(["/led", "/count"])
14+
sender = StreamSender(stream=dut.input)
15+
sim = Simulator(dut)
16+
sim.add_clock(1e-6)
17+
checked = False
18+
19+
async def driver(ctx):
20+
nonlocal checked
21+
ctx.set(dut.reset, 1)
22+
ctx.tick()
23+
ctx.set(dut.reset, 0)
24+
25+
while not sender.done:
26+
await ctx.tick()
27+
check(ctx, dut)
28+
checked = True
29+
30+
sim.add_testbench(driver)
31+
sim.add_process(sender.send_passive(map(ord,send_line)))
32+
33+
with sim.write_vcd(sys.stdout):
34+
sim.run_until(0.0001)
35+
assert checked
36+
37+
def test_good_parse():
38+
def check(ctx, dut):
39+
assert ctx.get(dut.method[dut.METHOD_NO_MATCH])==0
40+
assert ctx.get(dut.method[dut.METHOD_GET])==0
41+
assert ctx.get(dut.method[dut.METHOD_POST])==1
42+
assert ctx.get(dut.path[PATH_NO_MATCH])==0
43+
assert ctx.get(dut.path[PATH_LED])==1
44+
assert ctx.get(dut.path[PATH_COUNT])==0
45+
assert ctx.get(dut.protocol[dut.PROTOCOL_NO_MATCH])==0
46+
assert ctx.get(dut.protocol[dut.PROTOCOL_HTTP1_0])==1
47+
48+
run_test("POST /led HTTP/1.0\r\n", check)
49+
50+
51+
def test_no_method_match():
52+
def check(ctx, dut):
53+
assert ctx.get(dut.method[dut.METHOD_NO_MATCH])==1
54+
assert ctx.get(dut.method[dut.METHOD_GET])==0
55+
assert ctx.get(dut.method[dut.METHOD_POST])==0
56+
assert ctx.get(dut.path[PATH_NO_MATCH])==0
57+
assert ctx.get(dut.path[PATH_LED])==1
58+
assert ctx.get(dut.path[PATH_COUNT])==0
59+
assert ctx.get(dut.protocol[dut.PROTOCOL_NO_MATCH])==0
60+
assert ctx.get(dut.protocol[dut.PROTOCOL_HTTP1_0])==1
61+
62+
run_test("REQUEST /led HTTP/1.0\r\n", check)
63+
64+
def test_no_protocol_match():
65+
def check(ctx, dut):
66+
assert ctx.get(dut.method[dut.METHOD_NO_MATCH])==0
67+
assert ctx.get(dut.method[dut.METHOD_GET])==1
68+
assert ctx.get(dut.method[dut.METHOD_POST])==0
69+
assert ctx.get(dut.path[PATH_NO_MATCH])==0
70+
assert ctx.get(dut.path[PATH_LED])==0
71+
assert ctx.get(dut.path[PATH_COUNT])==1
72+
assert ctx.get(dut.protocol[dut.PROTOCOL_NO_MATCH])==1
73+
assert ctx.get(dut.protocol[dut.PROTOCOL_HTTP1_0])==0
74+
75+
run_test("GET /count HTTP/3.0\r\n", check)
76+
77+
78+
def test_no_path_match():
79+
def check(ctx, dut):
80+
assert ctx.get(dut.method[dut.METHOD_NO_MATCH])==0
81+
assert ctx.get(dut.method[dut.METHOD_GET])==1
82+
assert ctx.get(dut.method[dut.METHOD_POST])==0
83+
assert ctx.get(dut.path[PATH_NO_MATCH])==1
84+
assert ctx.get(dut.path[PATH_LED])==0
85+
assert ctx.get(dut.path[PATH_COUNT])==0
86+
assert ctx.get(dut.protocol[dut.PROTOCOL_NO_MATCH])==0
87+
assert ctx.get(dut.protocol[dut.PROTOCOL_HTTP1_0])==1
88+
89+
run_test("GET /index.html HTTP/1.0\r\n", check)
90+
91+
def test_double_start_line():
92+
def check(ctx, dut):
93+
assert ctx.get(dut.method[dut.METHOD_NO_MATCH])==0
94+
assert ctx.get(dut.method[dut.METHOD_GET])==1
95+
assert ctx.get(dut.method[dut.METHOD_POST])==0
96+
assert ctx.get(dut.path[PATH_NO_MATCH])==1
97+
assert ctx.get(dut.path[PATH_LED])==0
98+
assert ctx.get(dut.path[PATH_COUNT])==0
99+
assert ctx.get(dut.protocol[dut.PROTOCOL_NO_MATCH])==0
100+
assert ctx.get(dut.protocol[dut.PROTOCOL_HTTP1_0])==1
101+
102+
# TODO: #4 - This should raise an error so the HTTP responder can return a 400 Bad Request.
103+
run_test("GET /index.html HTTP/1.0\rPOST /help HTTP/1.1\r\n", check)
104+
105+
if __name__ == "__main__":
106+
test_good_parse()
107+
test_no_method_match()
108+
test_no_path_match()
109+
test_no_protocol_match()
110+
test_double_start_line()

http_server/simple_led_http.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from amaranth.lib.wiring import In, Out, Component, connect
33

44
from printer import Printer
5+
from parse_start import ParseStart
56
from stream_mux import StreamMux
67
from stream_demux import StreamDemux
78
from string_match import StringMatch
@@ -43,14 +44,11 @@ def elaborate(self, _platform):
4344
parser_demux = m.submodules.parser_demux = StreamDemux(mux_width=4, stream_width=8)
4445
connect(m, self.session.inbound.data, parser_demux.input)
4546

46-
# Match the header for an HTTP/1.0 request to the LED path.
47-
# TODO: If we want to match more than one path, could probably have some common
48-
# matching for the method and protocol. Also, if we want to get out of the
49-
# stone age, this could be HTTP/1.1.
50-
led_start = "POST /led HTTP/1.0\r\n"
51-
led_start_matcher = m.submodules.led_start_matcher = StringMatch(led_start)
52-
HTTP_PARSER_LED = 0
53-
connect(m, led_start_matcher.input, parser_demux.outs[HTTP_PARSER_LED])
47+
# TODO: #4 - Add packet count and RFC2324 endpoints
48+
MATCHED_LED_PATH = 1 # start_matcher path match is in the order the paths are connected.
49+
start_matcher = m.submodules.start_matcher = ParseStart(["/led"])
50+
HTTP_PARSER_START = 0
51+
connect(m, start_matcher.input, parser_demux.outs[HTTP_PARSER_START])
5452

5553
HTTP_PARSER_HEADERS = 1
5654
skip_headers = m.submodules.end_of_header_matcher = StringContainsMatch("\r\n\r\n")
@@ -70,7 +68,7 @@ def elaborate(self, _platform):
7068
m.d.comb += parser_demux.outs[HTTP_PARSER_SINK].ready.eq(1)
7169

7270
## Responders
73-
response_mux = m.submodules.response_mux = StreamMux(mux_width=2, stream_width=8)
71+
response_mux = m.submodules.response_mux = StreamMux(mux_width=3, stream_width=8)
7472
connect(m, response_mux.out, self.session.outbound.data)
7573

7674
ok_response = "\r\n".join(
@@ -107,46 +105,67 @@ def elaborate(self, _platform):
107105
not_found_printer.en.eq(1),
108106
]
109107

108+
not_allowed_response = "\r\n".join(
109+
["HTTP/1.0 405 Method Not Allowed",
110+
"Host: Fomu",
111+
"Content-Type: text/plain; charset=utf-8",
112+
"",
113+
"",
114+
'🛑']) + "\r\n"
115+
not_allowed_response = not_allowed_response.encode("utf-8")
116+
not_allowed_printer = m.submodules.not_allowed_printer = Printer(not_allowed_response)
117+
RESPONSE_405 = 2
118+
connect(m, not_allowed_printer.output, response_mux.input[RESPONSE_405])
119+
send_405 = [
120+
response_mux.select.eq(RESPONSE_405),
121+
parser_demux.select.eq(HTTP_PARSER_SINK),
122+
not_allowed_printer.en.eq(1),
123+
]
124+
125+
110126
with m.FSM():
111127
with m.State("reset"):
112128
m.d.comb += [
113-
led_start_matcher.reset.eq(1),
129+
start_matcher.reset.eq(1),
114130
skip_headers.reset.eq(1),
115131
led_body_handler.reset.eq(1),
116132
]
117133
m.next = "idle"
118134
with m.State("idle"):
119-
m.d.comb += led_start_matcher.reset.eq(0)
120-
m.d.sync += parser_demux.select.eq(HTTP_PARSER_LED)
135+
m.d.comb += start_matcher.reset.eq(0)
136+
m.d.sync += parser_demux.select.eq(HTTP_PARSER_START)
121137
m.d.sync += response_mux.select.eq(RESPONSE_OK)
122138
m.next = "idle"
123139
with m.If(self.session.inbound.active):
124140
m.next = "parsing_start"
125141
m.d.sync += self.session.outbound.active.eq(1)
126142
with m.State("parsing_start"):
127143
m.next = "parsing_start"
128-
# Input finished before header matched, or header failed to match
129-
with m.If(~self.session.inbound.active | led_start_matcher.rejected):
130-
m.next = "writing"
131-
m.d.sync += send_404
132144
# start line matched successfully
133-
with m.Elif(led_start_matcher.accepted):
134-
145+
with m.If(start_matcher.done):
135146
m.next = "parsing_header"
136147
m.d.sync += parser_demux.select.eq(HTTP_PARSER_HEADERS)
137148
with m.State("parsing_header"):
138149
m.next = "parsing_header"
139150
with m.If(skip_headers.accepted):
140-
m.next = "parsing_body"
141-
# TODO: #4 - Should pick the body parser based on the method+path from start
142-
m.d.sync += parser_demux.select.eq(HTTP_PARSER_LED_BODY)
151+
with m.If(start_matcher.method[start_matcher.METHOD_POST] &
152+
start_matcher.path[MATCHED_LED_PATH]):
153+
m.next = "parsing_led_body"
154+
m.d.sync += parser_demux.select.eq(HTTP_PARSER_LED_BODY)
155+
with m.Elif(start_matcher.method[start_matcher.METHOD_GET] &
156+
start_matcher.path[MATCHED_LED_PATH]):
157+
m.next = "writing"
158+
m.d.sync += send_405
159+
with m.Else():
160+
m.next = "writing"
161+
m.d.sync += send_404
143162
with m.Elif(~self.session.inbound.active):
144163
m.next = "writing"
145164
# TODO: #4 - Should send a different error code besides 404 if the
146165
# headers fail to parse before end-of-session.
147166
m.d.sync += send_404
148-
with m.State("parsing_body"): # TODO: #4 - Make the specific body depend on the path
149-
m.next = "parsing_body"
167+
with m.State("parsing_led_body"): # TODO: #4 - Make body parsing state more generic.
168+
m.next = "parsing_led_body"
150169
with m.If(led_body_handler.accepted):
151170
m.next = "writing"
152171
m.d.sync += send_ok
@@ -160,10 +179,12 @@ def elaborate(self, _platform):
160179
m.d.sync += [
161180
ok_printer.en.eq(0),
162181
not_found_printer.en.eq(0),
182+
not_allowed_printer.en.eq(0),
163183
self.session.outbound.active.eq(1),
164184
]
165185
with m.If( ((response_mux.select == RESPONSE_OK) & ok_printer.done)
166-
| ((response_mux.select == RESPONSE_404) & not_found_printer.done)):
186+
| ((response_mux.select == RESPONSE_404) & not_found_printer.done)
187+
| ((response_mux.select == RESPONSE_405) & not_allowed_printer.done)):
167188
m.d.sync += self.session.outbound.active.eq(0)
168189
# Can finish writing before all the input is collected,
169190
# since a bad request migh trigger an early 404. Wait

0 commit comments

Comments
 (0)