feat: added protected dial support, removed multi-IO support for simplicity

This commit is contained in:
Toby
2024-04-06 14:42:45 -07:00
parent ae34b4856a
commit 9c0893c512
8 changed files with 88 additions and 72 deletions

View File

@@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
type engine struct {
logger Logger
ioList []io.PacketIO
io io.PacketIO
workers []*worker
}
@@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
}
return &engine{
logger: config.Logger,
ioList: config.IOs,
io: config.IO,
workers: workers,
}, nil
}
@@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs
defer ioCancel() // Stop workers & IO
// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}
// Register callbacks
errChan := make(chan error, len(e.ioList))
for _, i := range e.ioList {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(ioEntry, p)
})
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
return err
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
// Block until IO errors or context is cancelled
@@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
}
// dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs.
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
ipVersion := data[0] >> 4
var layerType gopacket.LayerType
@@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6
} else {
// Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
@@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(),
Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b)
return e.io.SetVerdict(p, v, b)
},
})
return true