diff --git a/pkg/blob/azure.go b/pkg/blob/azure.go index 1b160948f..590dd952d 100644 --- a/pkg/blob/azure.go +++ b/pkg/blob/azure.go @@ -103,7 +103,18 @@ func GetCloudProvider(ctx context.Context, kubeClient kubernetes.Interface, node } else { config.UserAgent = userAgent config.CloudProviderBackoff = true - if err = az.InitializeCloudFromConfig(context.TODO(), config, fromSecret, false); err != nil { + // these environment variables are injected by workload identity webhook + if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" { + config.TenantID = tenantID + } + if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" { + config.AADClientID = clientID + } + if federatedTokenFile := os.Getenv("AZURE_FEDERATED_TOKEN_FILE"); federatedTokenFile != "" { + config.AADFederatedTokenFile = federatedTokenFile + config.UseFederatedWorkloadIdentityExtension = true + } + if err = az.InitializeCloudFromConfig(ctx, config, fromSecret, false); err != nil { klog.Warningf("InitializeCloudFromConfig failed with error: %v", err) } } diff --git a/pkg/blob/azure_test.go b/pkg/blob/azure_test.go index e1358e1e2..75889d1a7 100644 --- a/pkg/blob/azure_test.go +++ b/pkg/blob/azure_test.go @@ -80,14 +80,19 @@ users: }() tests := []struct { - desc string - createFakeCredFile bool - createFakeKubeConfig bool - kubeconfig string - nodeID string - userAgent string - allowEmptyCloudConfig bool - expectedErr error + desc string + createFakeCredFile bool + createFakeKubeConfig bool + setFederatedWorkloadIdentityEnv bool + kubeconfig string + nodeID string + userAgent string + allowEmptyCloudConfig bool + expectedErr error + aadFederatedTokenFile string + useFederatedWorkloadIdentityExtension bool + aadClientID string + tenantID string }{ { desc: "out of cluster, no kubeconfig, no credential file", @@ -134,6 +139,20 @@ users: allowEmptyCloudConfig: true, expectedErr: nil, }, + { + desc: "[success] get azure client with workload identity", + createFakeKubeConfig: true, + createFakeCredFile: true, + setFederatedWorkloadIdentityEnv: true, + kubeconfig: fakeKubeConfig, + nodeID: "", + userAgent: "useragent", + useFederatedWorkloadIdentityExtension: true, + aadFederatedTokenFile: "fake-token-file", + aadClientID: "fake-client-id", + tenantID: "fake-tenant-id", + expectedErr: nil, + }, } for _, test := range tests { @@ -142,7 +161,7 @@ users: t.Error(err) } defer func() { - if err := os.Remove(fakeKubeConfig); err != nil { + if err := os.Remove(fakeKubeConfig); err != nil && !os.IsNotExist(err) { t.Error(err) } }() @@ -156,7 +175,7 @@ users: t.Error(err) } defer func() { - if err := os.Remove(fakeCredFile); err != nil { + if err := os.Remove(fakeCredFile); err != nil && !os.IsNotExist(err) { t.Error(err) } }() @@ -176,6 +195,12 @@ users: } continue } + if test.setFederatedWorkloadIdentityEnv { + t.Setenv("AZURE_TENANT_ID", test.tenantID) + t.Setenv("AZURE_CLIENT_ID", test.aadClientID) + t.Setenv("AZURE_FEDERATED_TOKEN_FILE", test.aadFederatedTokenFile) + } + cloud, err := GetCloudProvider(context.Background(), kubeClient, test.nodeID, "", "", test.userAgent, test.allowEmptyCloudConfig) if !reflect.DeepEqual(err, test.expectedErr) && test.expectedErr != nil && !strings.Contains(err.Error(), test.expectedErr.Error()) { t.Errorf("desc: %s,\n input: %q, GetCloudProvider err: %v, expectedErr: %v", test.desc, test.kubeconfig, err, test.expectedErr) @@ -185,6 +210,10 @@ users: } else { assert.Equal(t, cloud.Environment.StorageEndpointSuffix, storage.DefaultBaseURL) assert.Equal(t, cloud.UserAgent, test.userAgent) + assert.Equal(t, cloud.AADFederatedTokenFile, test.aadFederatedTokenFile) + assert.Equal(t, cloud.UseFederatedWorkloadIdentityExtension, test.useFederatedWorkloadIdentityExtension) + assert.Equal(t, cloud.AADClientID, test.aadClientID) + assert.Equal(t, cloud.TenantID, test.tenantID) } } }