diff --git a/connector.go b/connector.go index 3cef7963f..29b6dc1b7 100644 --- a/connector.go +++ b/connector.go @@ -80,7 +80,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { dialsLock.RLock() dial, ok := dials[mc.cfg.Net] dialsLock.RUnlock() - if ok { + + if c.cfg.DialFunc != nil { + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr) + } else if ok { dctx := ctx if mc.cfg.Timeout > 0 { var cancel context.CancelFunc diff --git a/dsn.go b/dsn.go index ef0608636..3e986fca9 100644 --- a/dsn.go +++ b/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/rsa" "crypto/tls" "errors" @@ -65,6 +66,15 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + + // DialFunc specifies the dial function for creating connections. + // If DialFunc is nil, the connector will attempt to find a dial function from the global registry (registered with RegisterDialContext). + // If no dial function is found even after checking the global registry, the net.Dialer will be used as a fallback. + // + // The dial function is responsible for establishing connections. By providing a custom dial function, + // users can flexibly control the process of connection establishment. Custom dial functions can be registered in the global registry + // to tailor connection behavior according to specific requirements. + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) } // NewConfig creates a new Config and sets default values.