diff --git a/cmd/check.go b/cmd/check.go index e7296231c4..4414ab8999 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -242,17 +242,19 @@ func initialiseCheck() *checkInitData { initData.result.AddWarnings(refreshResult.Warnings...) // setup the session data - prepared statements and introspection tables - err = workspace.EnsureServiceState(context.Background(), initData.workspace.GetResourceMaps(), initData.client) + // create session data source - for check command, we create prepared statements for ALL queries + sessionDataSource := workspace.NewSessionDataSource(initData.workspace.GetResourceMaps()) + err = workspace.EnsureSessionData(context.Background(), sessionDataSource, initData.client) if err != nil { initData.result.Error = err return initData } - // register EnsureServiceState as a callback on the client. + // register EnsureSessionData as a callback on the client. // if the underlying SQL client has certain errors (for example context expiry) it will reset the session // so our client object calls this callback to restore the session data - initData.client.SetEnsureSessionStateFunc(func(ctx context.Context, client db_common.Client) error { - return workspace.EnsureServiceState(ctx, initData.workspace.GetResourceMaps(), client) + initData.client.SetEnsureSessionDataFunc(func(ctx context.Context, client db_common.Client) error { + return workspace.EnsureSessionData(ctx, sessionDataSource, client) }) return initData diff --git a/cmd/query.go b/cmd/query.go index b15d440f61..ffad8791da 100644 --- a/cmd/query.go +++ b/cmd/query.go @@ -257,16 +257,19 @@ func getQueryInitDataAsync(ctx context.Context, w *workspace.Workspace, initData initData.Result.AddWarnings(res.Warnings...) // setup the session data - prepared statements and introspection tables - err = workspace.EnsureServiceState(context.Background(), preparedStatementProviders, initData.Client) + // create session data source + sessionDataSource := workspace.NewSessionDataSource(w.GetResourceMaps(), preparedStatementProviders) + + err = workspace.EnsureSessionData(context.Background(), sessionDataSource, initData.Client) if err != nil { initData.Result.Error = err return } - // register EnsureServiceState as a callback on the client. + // register EnsureSessionData as a callback on the client. // if the underlying SQL client has certain errors (for example context expiry) it will reset the session // so our client object calls this callback to restore the session data - initData.Client.SetEnsureSessionStateFunc(func(ctx context.Context, client db_common.Client) error { - return workspace.EnsureServiceState(ctx, preparedStatementProviders, client) + initData.Client.SetEnsureSessionDataFunc(func(ctx context.Context, client db_common.Client) error { + return workspace.EnsureSessionData(ctx, sessionDataSource, client) }) }() } diff --git a/db/db_client/db_client.go b/db/db_client/db_client.go index 476d72feb6..b8dabdcafc 100644 --- a/db/db_client/db_client.go +++ b/db/db_client/db_client.go @@ -52,7 +52,7 @@ func establishConnection(connStr string) (*sql.DB, error) { return nil, fmt.Errorf("could not establish connection") } -func (c *DbClient) SetEnsureSessionStateFunc(f db_common.EnsureSessionStateCallback) { +func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallback) { c.ensureSessionFunc = f } diff --git a/db/db_common/client.go b/db/db_common/client.go index ef137222bc..3fcb88a90f 100644 --- a/db/db_common/client.go +++ b/db/db_common/client.go @@ -28,7 +28,7 @@ type Client interface { CacheOff() error CacheClear() error - SetEnsureSessionStateFunc(EnsureSessionStateCallback) + SetEnsureSessionDataFunc(EnsureSessionStateCallback) // remote client will have empty implementation diff --git a/db/db_local/local_db_client.go b/db/db_local/local_db_client.go index 8cce558363..b092d7570a 100644 --- a/db/db_local/local_db_client.go +++ b/db/db_local/local_db_client.go @@ -75,8 +75,8 @@ func (c *LocalDbClient) Close() error { } // EnsureSessionState implements Client -func (c *LocalDbClient) SetEnsureSessionStateFunc(f db_common.EnsureSessionStateCallback) { - c.client.SetEnsureSessionStateFunc(f) +func (c *LocalDbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallback) { + c.client.SetEnsureSessionDataFunc(f) } // SchemaMetadata implements Client diff --git a/workspace/service_state.go b/workspace/session_data.go similarity index 54% rename from workspace/service_state.go rename to workspace/session_data.go index c08778b635..bdd293374c 100644 --- a/workspace/service_state.go +++ b/workspace/session_data.go @@ -5,30 +5,30 @@ import ( "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/query/queryresult" - "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/utils" ) -// EnsureServiceState queries the database and makes sure that workspace temp tables -// and prepared statements are available in the database -func EnsureServiceState(ctx context.Context, preparedStatementProviders *modconfig.WorkspaceResourceMaps, client db_common.Client) error { - utils.LogTime("workspace.EnsureServiceState start") - defer utils.LogTime("workspace.EnsureServiceState end") +// EnsureSessionData determines whether session scoped data (introspection tables and prepared statements) +// exists for this session, and if not, creates it +func EnsureSessionData(ctx context.Context, source *SessionDataSource, client db_common.Client) error { + utils.LogTime("workspace.EnsureSessionData start") + defer utils.LogTime("workspace.EnsureSessionData end") // check for introspection tables result, err := client.ExecuteSync(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema LIKE 'pg_temp%' AND table_name='steampipe_mod' ", true) if err != nil { return err } - // since we are quering with a 'select count...', we will always have exactly one cell with the value count := result.Rows[0].(*queryresult.RowResult).Data[0].(int64) + + // if the steampipe_mod table is missing, assume we have no session data - go ahead and create it if count == 0 { - err = db_common.CreatePreparedStatements(context.Background(), preparedStatementProviders, client) + err = db_common.CreatePreparedStatements(context.Background(), source.preparedStatementSource, client) if err != nil { return err } - if err = db_common.CreateIntrospectionTables(ctx, preparedStatementProviders, client); err != nil { + if err = db_common.CreateIntrospectionTables(ctx, source.introspectionTableSource, client); err != nil { return err } } diff --git a/workspace/session_data_source.go b/workspace/session_data_source.go new file mode 100644 index 0000000000..908cfd4c5b --- /dev/null +++ b/workspace/session_data_source.go @@ -0,0 +1,32 @@ +package workspace + +import "github.com/turbot/steampipe/steampipeconfig/modconfig" + +type SessionDataSource struct { + preparedStatementSource, introspectionTableSource *modconfig.WorkspaceResourceMaps +} + +// NewSessionDataSource creates a new SessionDataSource object +// if a single parameter is poassed, this map is used for both prepared statements and introspection tables +// if a second parameter is passed, it will be a minimal set of resources for which we need to create prepared statements +// this will be populated for batch mode querying +func NewSessionDataSource(items ...*modconfig.WorkspaceResourceMaps) *SessionDataSource { + if len(items) == 0 { + panic("NewSessionStateSource called with no parameters") + } + if len(items) > 2 { + panic("NewSessionStateSource called with more than 2 parameters") + } + // default to initialising introspectionTableSource AND preparedStatementSource from the first param, + // which is expected to be the full map of workspace resources + res := &SessionDataSource{ + introspectionTableSource: items[0], + preparedStatementSource: items[0], + } + // is the preparedStatementSource explicitly provided? + if len(items) == 2 { + res.preparedStatementSource = items[1] + } + return res + +}