diff --git a/airflow/configuration.py b/airflow/configuration.py index 75edc6739cc86..7565e89f1d115 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -488,7 +488,7 @@ def _create_future_warning(name: str, section: str, current_value: Any, new_valu ) def _env_var_name(self, section: str, key: str) -> str: - return f"{ENV_VAR_PREFIX}{section.upper()}__{key.upper()}" + return f"{ENV_VAR_PREFIX}{section.replace('.', '_').upper()}__{key.upper()}" def _get_env_var_option(self, section: str, key: str): # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index b1d754965e0b3..9a202ab156d72 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -30,7 +30,22 @@ class OdbcHook(DbApiHook): """ Interact with odbc data sources using pyodbc. + To configure driver, in addition to supplying as constructor arg, the following are also supported: + * set ``driver`` parameter in ``hook_params`` dictionary when instantiating hook by SQL operators. + * set ``driver`` extra in the connection and set ``allow_driver_in_extra`` to True in + section ``providers.odbc`` section of airflow config. + * patch ``OdbcHook.default_driver`` in ``local_settings.py`` file. + See :doc:`/connections/odbc` for full documentation. + + :param args: passed to DbApiHook + :param database: database to use -- overrides connection ``schema`` + :param driver: name of driver or path to driver. see above for more info + :param dsn: name of DSN to use. overrides DSN supplied in connection ``extra`` + :param connect_kwargs: keyword arguments passed to ``pyodbc.connect`` + :param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pyodbc`` Only used for + ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods. + :param kwargs: passed to DbApiHook """ DEFAULT_SQLALCHEMY_SCHEME = "mssql+pyodbc" @@ -40,6 +55,8 @@ class OdbcHook(DbApiHook): hook_name = "ODBC" supports_autocommit = True + default_driver: str | None = None + def __init__( self, *args, @@ -102,6 +119,19 @@ def connection_extra_lower(self) -> dict: @property def driver(self) -> str | None: """Driver from init param if given; else try to find one in connection extra.""" + extra_driver = self.connection_extra_lower.get("driver") + from airflow.configuration import conf + + if extra_driver and conf.getboolean("providers.odbc", "allow_driver_in_extra", fallback=False): + self._driver = extra_driver + else: + self.log.warning( + "You have supplied 'driver' via connection extra but it will not be used. In order to " + "use 'driver' from extra you must set airflow config setting `allow_driver_in_extra = True` " + "in section `providers.odbc`. Alternatively you may specify driver via 'driver' parameter of " + "the hook constructor or via 'hook_params' dictionary with key 'driver' if using SQL " + "operators." + ) if not self._driver: driver = self.connection_extra_lower.get("driver") if driver: diff --git a/docs/apache-airflow-providers-odbc/connections/odbc.rst b/docs/apache-airflow-providers-odbc/connections/odbc.rst index 176977d8fec04..11d382eeb661b 100644 --- a/docs/apache-airflow-providers-odbc/connections/odbc.rst +++ b/docs/apache-airflow-providers-odbc/connections/odbc.rst @@ -67,6 +67,15 @@ Extra (optional) * This is only used when ``get_uri`` is invoked in :py:meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`. By default, the hook uses scheme ``mssql+pyodbc``. You may pass a string value here to override. + - ``driver`` + * The name of the driver to use on your system. Note that this is only considered if ``allow_driver_in_extra`` + is set to True in airflow config section ``providers.odbc`` (by default it is not considered). Note: if setting + this config from env vars, use ``AIRFLOW__PROVIDERS_ODBC__ALLOW_DRIVER_IN_EXTRA=true``. + + .. note:: + If setting ``allow_driver_extra``to True, this allows users to set the driver via the Airflow Connection's + ``extra`` field. By default this is not allowed. If enabling this functionality, you should make sure + that you trust the users who can edit connections in the UI to not use it maliciously. .. note:: You are responsible for installing an ODBC driver on your system. diff --git a/docs/apache-airflow/howto/set-config.rst b/docs/apache-airflow/howto/set-config.rst index 02fa9bac2dc51..b443cff9c81ae 100644 --- a/docs/apache-airflow/howto/set-config.rst +++ b/docs/apache-airflow/howto/set-config.rst @@ -38,6 +38,19 @@ or by creating a corresponding environment variable: export AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=my_conn_string +Note that when the section name has a dot in it, you must replace it with an underscore when setting the env var. +For example consider the pretend section ``providers.some_provider``: + +.. code-block:: ini + + [providers.some_provider>] + this_param = true + +.. code-block:: bash + + export AIRFLOW__PROVIDERS_SOME_PROVIDER__THIS_PARAM=true + + You can also derive the connection string at run time by appending ``_cmd`` to the key like this: diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index 961bc3c7cc1c2..2afc78bcaf000 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -226,6 +226,24 @@ def test_command_precedence(self): assert "key4" not in cfg_dict["test"] assert "printf key4_result" == cfg_dict["test"]["key4_cmd"] + def test_can_read_dot_section(self): + test_config = """[test.abc] +key1 = true +""" + test_conf = AirflowConfigParser() + test_conf.read_string(test_config) + section = "test.abc" + key = "key1" + assert test_conf.getboolean(section, key) is True + + with mock.patch.dict( + "os.environ", + { + "AIRFLOW__TEST_ABC__KEY1": "false", # note that the '.' is converted to '_' + }, + ): + assert test_conf.getboolean(section, key) is False + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @conf_vars( { @@ -596,7 +614,6 @@ def test_command_from_env(self): @pytest.mark.parametrize("display_sensitive, result", [(True, "OK"), (False, "< hidden >")]) def test_as_dict_display_sensitivewith_command_from_env(self, display_sensitive, result): - test_cmdenv_conf = AirflowConfigParser() test_cmdenv_conf.sensitive_config_values.add(("testcmdenv", "itsacommand")) with mock.patch.dict("os.environ"): diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index 58dcaf10d7b1f..be0c61ccb2dd1 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -177,11 +177,33 @@ def test_driver(self): assert hook.driver == "Blah driver" hook = self.get_hook(hook_params=dict(driver="{Blah driver}")) assert hook.driver == "Blah driver" - hook = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')) + + def test_driver_extra_raises_warning_by_default(self, caplog): + with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): + driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')).driver + assert "You have supplied 'driver' via connection extra but it will not be used" in caplog.text + assert driver is None + + @mock.patch.dict("os.environ", {"AIRFLOW__PROVIDERS_ODBC__ALLOW_DRIVER_IN_EXTRA": "TRUE"}) + def test_driver_extra_works_when_allow_driver_extra(self): + hook = self.get_hook( + conn_params=dict(extra='{"driver": "Blah driver"}'), hook_params=dict(allow_driver_extra=True) + ) assert hook.driver == "Blah driver" hook = self.get_hook(conn_params=dict(extra='{"driver": "{Blah driver}"}')) assert hook.driver == "Blah driver" + def test_driver_none_by_default(self): + hook = self.get_hook() + assert hook.driver is None + + def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, caplog): + with patch.object(OdbcHook, "default_driver", "Blah driver"): + with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): + driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver2"}')).driver + assert "have supplied 'driver' via connection extra but it will not be used" in caplog.text + assert driver == "Blah driver" + def test_database(self): hook = self.get_hook(hook_params=dict(database="abc")) assert hook.database == "abc"