feat: added protected dial support, removed multi-IO support for simplicity
This commit is contained in:
@@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
|
||||
|
||||
type engine struct {
|
||||
logger Logger
|
||||
ioList []io.PacketIO
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
|
||||
}
|
||||
return &engine{
|
||||
logger: config.Logger,
|
||||
ioList: config.IOs,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
}, nil
|
||||
}
|
||||
@@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
|
||||
func (e *engine) Run(ctx context.Context) error {
|
||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||
defer ioCancel() // Stop workers & IOs
|
||||
defer ioCancel() // Stop workers & IO
|
||||
|
||||
// Start workers
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
|
||||
// Register callbacks
|
||||
errChan := make(chan error, len(e.ioList))
|
||||
for _, i := range e.ioList {
|
||||
ioEntry := i // Make sure dispatch() uses the correct ioEntry
|
||||
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return false
|
||||
}
|
||||
return e.dispatch(ioEntry, p)
|
||||
})
|
||||
// Register IO callback
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
return err
|
||||
errChan <- err
|
||||
return false
|
||||
}
|
||||
return e.dispatch(p)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Block until IO errors or context is cancelled
|
||||
@@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// dispatch dispatches a packet to a worker.
|
||||
// This must be safe for concurrent use, as it may be called from multiple IOs.
|
||||
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
func (e *engine) dispatch(p io.Packet) bool {
|
||||
data := p.Data()
|
||||
ipVersion := data[0] >> 4
|
||||
var layerType gopacket.LayerType
|
||||
@@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
layerType = layers.LayerTypeIPv6
|
||||
} else {
|
||||
// Unsupported network layer
|
||||
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
return true
|
||||
}
|
||||
// Load balance by stream ID
|
||||
@@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
StreamID: p.StreamID(),
|
||||
Packet: packet,
|
||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||
return ioEntry.SetVerdict(p, v, b)
|
||||
return e.io.SetVerdict(p, v, b)
|
||||
},
|
||||
})
|
||||
return true
|
||||
|
||||
Reference in New Issue
Block a user