Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Copy with Remote Object Stores in datafusion-cli #9064

Merged
merged 17 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions datafusion-cli/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Execution functions

use std::collections::HashMap;
use std::io::prelude::*;
use std::io::BufReader;
use std::time::Instant;
Expand All @@ -42,6 +43,8 @@ use datafusion::physical_plan::{collect, execute_stream};
use datafusion::prelude::SessionContext;
use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str};

use datafusion::logical_expr::dml::CopyTo;
use datafusion::sql::parser::Statement;
use object_store::http::HttpBuilder;
use object_store::ObjectStore;
use rustyline::error::ReadlineError;
Expand Down Expand Up @@ -221,7 +224,7 @@ async fn exec_and_print(

let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
for statement in statements {
let mut plan = ctx.state().statement_to_plan(statement).await?;
let plan = create_plan(ctx, statement).await?;

// For plans like `Explain` ignore `MaxRows` option and always display all rows
let should_ignore_maxrows = matches!(
Expand All @@ -231,13 +234,6 @@ async fn exec_and_print(
| LogicalPlan::Analyze(_)
);

// Note that cmd is a mutable reference so that create_external_table function can remove all
// datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion
// will raise Configuration errors.
if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan {
create_external_table(ctx, cmd).await?;
}

let df = ctx.execute_logical_plan(plan).await?;
let physical_plan = df.create_physical_plan().await?;

Expand All @@ -260,6 +256,36 @@ async fn exec_and_print(
Ok(())
}

async fn create_plan(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. Thank you @manoj-inukolunu

ctx: &mut SessionContext,
statement: Statement,
) -> Result<LogicalPlan, DataFusionError> {
let mut plan = ctx.state().statement_to_plan(statement).await?;

// Note that cmd is a mutable reference so that create_external_table function can remove all
// datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion
// will raise Configuration errors.
if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan {
create_external_table(ctx, cmd).await?;
}

if let LogicalPlan::Copy(copy_to) = &mut plan {
register_object_store(ctx, copy_to).await?;
}
Ok(plan)
}

async fn register_object_store(
ctx: &SessionContext,
copy_to: &mut CopyTo,
) -> Result<(), DataFusionError> {
let url = ListingTableUrl::parse(copy_to.output_url.as_str())?;
let store =
get_object_store(ctx, &mut HashMap::new(), url.scheme(), url.as_ref()).await?;
ctx.runtime_env().register_object_store(url.as_ref(), store);
Ok(())
}

async fn create_external_table(
ctx: &SessionContext,
cmd: &mut CreateExternalTable,
Expand All @@ -269,17 +295,30 @@ async fn create_external_table(
let url: &Url = table_path.as_ref();

// registering the cloud object store dynamically using cmd.options
let store = get_object_store(ctx, &mut cmd.options, scheme, url).await?;

ctx.runtime_env().register_object_store(url, store);

Ok(())
}

async fn get_object_store(
ctx: &SessionContext,
options: &mut HashMap<String, String>,
scheme: &str,
url: &Url,
) -> Result<Arc<dyn ObjectStore>, DataFusionError> {
let store = match scheme {
"s3" => {
let builder = get_s3_object_store_builder(url, cmd).await?;
let builder = get_s3_object_store_builder(url, options).await?;
Arc::new(builder.build()?) as Arc<dyn ObjectStore>
}
"oss" => {
let builder = get_oss_object_store_builder(url, cmd)?;
let builder = get_oss_object_store_builder(url, options)?;
Arc::new(builder.build()?) as Arc<dyn ObjectStore>
}
"gs" | "gcs" => {
let builder = get_gcs_object_store_builder(url, cmd)?;
let builder = get_gcs_object_store_builder(url, options)?;
Arc::new(builder.build()?) as Arc<dyn ObjectStore>
}
"http" | "https" => Arc::new(
Expand All @@ -297,10 +336,7 @@ async fn create_external_table(
})?
}
};

ctx.runtime_env().register_object_store(url, store);

Ok(())
Ok(store)
}

#[cfg(test)]
Expand All @@ -309,7 +345,9 @@ mod tests {

use super::*;
use datafusion::common::plan_err;
use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions};
use datafusion_common::{
file_options::StatementOptions, FileType, FileTypeWriterOptions,
};

async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -352,6 +390,42 @@ mod tests {

Ok(())
}
#[tokio::test]
async fn copy_to_external_object_store_test() -> Result<()> {
let locations = vec![
"s3://bucket/path/file.parquet",
"oss://bucket/path/file.parquet",
"gcs://bucket/path/file.parquet",
];
let mut ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let dialect = &task_ctx.session_config().options().sql_parser.dialect;
let dialect = dialect_from_str(dialect).ok_or_else(|| {
plan_datafusion_err!(
"Unsupported SQL dialect: {dialect}. Available dialects: \
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
)
})?;
for location in locations {
let sql = format!("copy (values (1,2)) to '{}';", location);
let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
for statement in statements {
//Should not fail
let mut plan = create_plan(&mut ctx, statement).await?;
if let LogicalPlan::Copy(copy_to) = &mut plan {
assert_eq!(copy_to.output_url, location);
assert_eq!(copy_to.file_format, FileType::PARQUET);
ctx.runtime_env()
.object_store_registry
.get_store(&Url::parse(&copy_to.output_url).unwrap())?;
} else {
return plan_err!("LogicalPlan is not a CopyTo");
}
}
}
Ok(())
}

#[tokio::test]
async fn create_object_store_table_s3() -> Result<()> {
Expand Down
45 changes: 22 additions & 23 deletions datafusion-cli/src/object_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,34 @@

use async_trait::async_trait;
use aws_credential_types::provider::ProvideCredentials;
use datafusion::{
error::{DataFusionError, Result},
logical_expr::CreateExternalTable,
};
use datafusion::error::{DataFusionError, Result};
use object_store::aws::AwsCredential;
use object_store::{
aws::AmazonS3Builder, gcp::GoogleCloudStorageBuilder, CredentialProvider,
};
use std::collections::HashMap;
use std::sync::Arc;
use url::Url;

pub async fn get_s3_object_store_builder(
url: &Url,
cmd: &mut CreateExternalTable,
options: &mut HashMap<String, String>,
) -> Result<AmazonS3Builder> {
let bucket_name = get_bucket_name(url)?;
let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name);

if let (Some(access_key_id), Some(secret_access_key)) = (
// These options are datafusion-cli specific and must be removed before passing through to datafusion.
// Otherwise, a Configuration error will be raised.
cmd.options.remove("access_key_id"),
cmd.options.remove("secret_access_key"),
options.remove("access_key_id"),
options.remove("secret_access_key"),
) {
println!("removing secret access key!");
builder = builder
.with_access_key_id(access_key_id)
.with_secret_access_key(secret_access_key);

if let Some(session_token) = cmd.options.remove("session_token") {
if let Some(session_token) = options.remove("session_token") {
builder = builder.with_token(session_token);
}
} else {
Expand All @@ -69,7 +67,7 @@ pub async fn get_s3_object_store_builder(
builder = builder.with_credentials(credentials);
}

if let Some(region) = cmd.options.remove("region") {
if let Some(region) = options.remove("region") {
builder = builder.with_region(region);
}

Expand Down Expand Up @@ -102,7 +100,7 @@ impl CredentialProvider for S3CredentialProvider {

pub fn get_oss_object_store_builder(
url: &Url,
cmd: &mut CreateExternalTable,
cmd: &mut HashMap<String, String>,
) -> Result<AmazonS3Builder> {
let bucket_name = get_bucket_name(url)?;
let mut builder = AmazonS3Builder::from_env()
Expand All @@ -111,16 +109,15 @@ pub fn get_oss_object_store_builder(
// oss don't care about the "region" field
.with_region("do_not_care");

if let (Some(access_key_id), Some(secret_access_key)) = (
cmd.options.remove("access_key_id"),
cmd.options.remove("secret_access_key"),
) {
if let (Some(access_key_id), Some(secret_access_key)) =
(cmd.remove("access_key_id"), cmd.remove("secret_access_key"))
{
builder = builder
.with_access_key_id(access_key_id)
.with_secret_access_key(secret_access_key);
}

if let Some(endpoint) = cmd.options.remove("endpoint") {
if let Some(endpoint) = cmd.remove("endpoint") {
builder = builder.with_endpoint(endpoint);
}

Expand All @@ -129,21 +126,20 @@ pub fn get_oss_object_store_builder(

pub fn get_gcs_object_store_builder(
url: &Url,
cmd: &mut CreateExternalTable,
cmd: &mut HashMap<String, String>,
) -> Result<GoogleCloudStorageBuilder> {
let bucket_name = get_bucket_name(url)?;
let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name);

if let Some(service_account_path) = cmd.options.remove("service_account_path") {
if let Some(service_account_path) = cmd.remove("service_account_path") {
builder = builder.with_service_account_path(service_account_path);
}

if let Some(service_account_key) = cmd.options.remove("service_account_key") {
if let Some(service_account_key) = cmd.remove("service_account_key") {
builder = builder.with_service_account_key(service_account_key);
}

if let Some(application_credentials_path) =
cmd.options.remove("application_credentials_path")
if let Some(application_credentials_path) = cmd.remove("application_credentials_path")
{
builder = builder.with_application_credentials(application_credentials_path);
}
Expand Down Expand Up @@ -186,7 +182,8 @@ mod tests {
let mut plan = ctx.state().create_logical_plan(&sql).await?;

if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan {
let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?;
let builder =
get_s3_object_store_builder(table_url.as_ref(), &mut cmd.options).await?;
// get the actual configuration information, then assert_eq!
let config = [
(AmazonS3ConfigKey::AccessKeyId, access_key_id),
Expand Down Expand Up @@ -218,7 +215,8 @@ mod tests {
let mut plan = ctx.state().create_logical_plan(&sql).await?;

if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan {
let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?;
let builder =
get_oss_object_store_builder(table_url.as_ref(), &mut cmd.options)?;
// get the actual configuration information, then assert_eq!
let config = [
(AmazonS3ConfigKey::AccessKeyId, access_key_id),
Expand Down Expand Up @@ -250,7 +248,8 @@ mod tests {
let mut plan = ctx.state().create_logical_plan(&sql).await?;

if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan {
let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?;
let builder =
get_gcs_object_store_builder(table_url.as_ref(), &mut cmd.options)?;
// get the actual configuration information, then assert_eq!
let config = [
(GoogleConfigKey::ServiceAccount, service_account_path),
Expand Down
Loading