为多租户 Spark 集群构建基于 Tonic 的即时 IAM 凭证颁发服务


在管理一个共享的、多租户的 Apache Spark 集群时,数据访问的权限控制是一个无法回避的核心挑战。当不同的租户作业需要在隔离的 S3 存储桶中读写数据时,如何安全、动态地为其授予最小权限,同时避免凭证泄露,成为架构设计的关键。一个常见的错误是将长期有效的 AWS Access Key 硬编码在 Spark 配置中,或者为整个集群配置一个过于宽泛的 IAM Instance Profile。这两种方式都严重违反了最小权限原则,在生产环境中是不可接受的。

架构决策的十字路口摆着两个看似可行的方案:利用云原生生态的原生集成,或是构建一个专用的凭证管理微服务。

方案 A: EKS IAM Roles for Service Accounts (IRSA)

在 Kubernetes 上运行 Spark 的场景中,IRSA 是一个常见的首选方案。它通过将 Kubernetes Service Account 与 AWS IAM Role 关联,允许 Pod 内的应用程序通过 OIDC 联邦认证,自动获取临时 IAM 凭证。

graph TD
    subgraph Kubernetes Cluster
        A[Spark Driver Pod] -- "K8s SA Token" --> B(OIDC Provider);
        A -- "Uses ServiceAccount `spark-job-sa`" --> C{K8s Service Account};
        C -- "Annotated with `eks.amazonaws.com/role-arn`" --> D((IAM Role));
    end
    B -- "Validates Token" --> E(AWS STS);
    E -- "AssumeRoleWithWebIdentity" --> F[Temporary IAM Credentials];
    F -- "Returns Credentials" --> A;
    A -- "Uses Credentials" --> G[ Tenant S3 Bucket];

优势分析:

  1. 原生集成: 与 EKS 紧密集成,配置相对直接,无需额外部署服务。
  2. 无长期凭证: 遵循最佳实践,使用短期的临时凭证,自动轮换。

劣势分析:

  1. 粒度问题: IRSA 的权限是绑定到 Service Account 上的。在一个多租户 Spark 平台,如果为每个租户的每个作业都动态创建一个独立的 Service Account,会带来巨大的管理开销。如果多个租户作业共享同一个 Service Account,它们就会获得相同的 IAM Role,这又回到了权限过于宽泛的问题。我们无法做到“作业级别”的精细化权限控制。
  2. 灵活性不足: 权限策略是静态地附加在 IAM Role 上的。如果需要根据作业的特定参数(例如,只允许访问 s3://tenant-bucket/input-path/*)动态生成更细粒度的策略,IRSA 无法直接满足。
  3. Spark Executor 问题: Spark 的 Driver 和 Executor 可能运行在不同的 Pod 中。虽然可以通过 spark.kubernetes.executor.serviceAccountName 为 Executor 指定 Service Account,但这进一步加剧了管理复杂性。

在真实的多租户场景下,租户 A 的作业绝不应该有任何可能性访问租户 B 的数据。IRSA 在这种严格隔离模型下显得力不从心。

方案 B: 构建独立的凭证颁发服务 (Token Vending Machine)

此方案的核心是创建一个独立的、高可用的微服务,我们称之为“凭证贩卖机”(Token Vending Machine, TVM)。这个服务本身拥有一个权限相对较大的 IAM Role,使其能够调用 AWS STS 的 AssumeRole API。Spark 作业启动时,其 Driver 和 Executor 会通过一个安全的 gRPC 通道向 TVM 发起请求,获取仅限于自身任务所需范围的、生命周期极短的临时凭证。

graph TD
    subgraph Spark Application
        SD[Spark Driver] -->|1. gRPC Request| TVM;
        SE[Spark Executor] -->|1. gRPC Request| TVM;
    end
    
    subgraph Secure Infrastructure
        TVM(Credential Service - TVM) -- "Has powerful IAM Role" --> STS;
    end
    
    TVM -- "2. Authenticates Request (e.g., via JWT)" --> TVM;
    TVM -- "3. sts:AssumeRole with dynamic policy" --> STS(AWS STS);
    STS -- "4. Returns Scoped Temporary Credentials" --> TVM;
    TVM -- "5. Returns Credentials via gRPC" --> SD;
    TVM -- "5. Returns Credentials via gRPC" --> SE;

    SD -- "6. Uses Credentials" --> S3[ Tenant S3 Bucket];
    SE -- "6. Uses Credentials" --> S3;

优势分析:

  1. 极致的权限控制: TVM 可以在运行时,根据请求方(租户ID、作业ID、目标路径等)动态生成一个内联的 IAM Policy。这意味着可以为每个 Spark Task 颁发“恰好够用”的权限,完美实现最小权限原则。
  2. 逻辑集中与解耦: 权限生成的复杂逻辑被封装在 TVM 内部。Spark 作业本身无需关心 IAM 的复杂性,只需知道如何调用 TVM 即可。这种解耦使得权限策略的迭代和审计变得非常简单。
  3. 技术栈无关性: 任何需要访问 AWS 资源的应用都可以复用 TVM,不仅仅是 Spark。

劣势分析:

  1. 引入新组件: 需要额外开发、部署和维护一个高可用的服务。TVM 成为系统的关键路径,其稳定性和性能至关重要。
  2. 实现复杂度: 需要自行处理认证、授权、调用 STS、缓存、错误处理等逻辑。

考虑到我们平台对安全性和租户隔离的严苛要求,方案 B 尽管增加了初始开发成本,但其提供的精细化控制和长期可维护性是 IRSA 无法比拟的。我们最终选择基于 Rust 和 Tonic 构建这个高性能、内存安全的 TVM 服务。Rust 的安全保障和 Tonic 框架的性能表现,使其成为构建这类安全核心组件的理想选择。

核心实现:Tonic 凭证服务与 Spark 自定义 Provider

整个实现分为两大部分:使用 Tonic (Rust) 构建的 gRPC 服务端,以及一个 Spark 可以加载的自定义 AWSCredentialsProvider (Scala) 作为客户端。

1. 定义 gRPC 服务接口 (credentials.proto)

首先,我们用 Protobuf 定义服务接口。请求中包含必要的上下文信息,用于动态生成策略。响应则直接映射到 AWS 的临时凭证结构。

syntax = "proto3";

package credentials.v1;

// The core credential vending service.
service CredentialService {
  // Requests short-lived, scoped-down AWS credentials.
  rpc GetScopedCredentials(GetScopedCredentialsRequest) returns (GetScopedCredentialsResponse) {}
}

message GetScopedCredentialsRequest {
  // Identifier for the tenant, extracted from the request's authentication token (e.g., JWT).
  // This field is for context, the primary identity is verified from the token.
  string tenant_id = 1;

  // The specific S3 path the job needs to access.
  // Example: "s3://my-tenant-bucket/data/2023-10-27/"
  string resource_path = 2;

  // The duration for which the credentials should be valid, in seconds.
  // The service will enforce a maximum duration.
  int32 duration_seconds = 3;
}

message GetScopedCredentialsResponse {
  string access_key_id = 1;
  string secret_access_key = 2;
  string session_token = 3;
  // Expiration timestamp in UTC seconds since epoch.
  int64 expiration = 4;
}

2. Tonic 服务端实现 (tvm-server/src/main.rs)

这是 TVM 的核心。它监听 gRPC 请求,验证调用方身份,然后调用 AWS STS API 生成凭证。

// File: tvm-server/src/main.rs

use tonic::{transport::Server, Request, Response, Status};
use credentials::v1::{
    credential_service_server::{CredentialService, CredentialServiceServer},
    GetScopedCredentialsRequest, GetScopedCredentialsResponse,
};
use aws_sdk_sts::Client as StsClient;
use aws_config::meta::region::RegionProviderChain;
use std::time::Duration;
use tracing::{info, warn, error};
use tracing_subscriber;

// Proto-generated module
pub mod credentials {
    pub mod v1 {
        tonic::include_proto!("credentials.v1");
    }
}

// Configuration loaded from environment variables or a config file
#[derive(Debug)]
struct Config {
    aws_role_arn_to_assume: String,
    max_duration_seconds: u32,
    listen_addr: String,
}

#[derive(Debug)]
pub struct MyCredentialService {
    sts_client: StsClient,
    config: Config,
}

impl MyCredentialService {
    // Helper function to create a dynamically scoped IAM policy.
    // In a real project, this should use a proper JSON builder to avoid injection risks.
    fn create_scoped_policy(&self, resource_path: &str) -> Result<String, Status> {
        // Basic validation: ensure it's a valid S3 path format.
        if !resource_path.starts_with("s3://") {
            warn!("Invalid resource_path format: {}", resource_path);
            return Err(Status::invalid_argument("resource_path must start with s3://"));
        }
        
        let parts: Vec<&str> = resource_path.trim_start_matches("s3://").splitn(2, '/').collect();
        let bucket = parts.get(0).ok_or_else(|| Status::invalid_argument("Invalid S3 path: missing bucket"))?;
        let key_pattern = parts.get(1).map_or("*", |p| if p.is_empty() { "*" } else { &format!("{}*", p) });

        let policy = serde_json::json!({
            "Version": "2012-10-17",
            "Statement": [
                {
                    "Effect": "Allow",
                    "Action": [
                        "s3:GetObject",
                        "s3:PutObject",
                        "s3:DeleteObject",
                        "s3:ListBucket" // List is on the bucket, not the object
                    ],
                    "Resource": [
                        format!("arn:aws:s3:::{}/*", bucket), // Grant access to objects
                        format!("arn:aws:s3:::{}", bucket)     // Grant access to the bucket itself for ListBucket
                    ]
                }
            ]
        });
        
        Ok(policy.to_string())
    }
}

#[tonic::async_trait]
impl CredentialService for MyCredentialService {
    async fn get_scoped_credentials(
        &self,
        request: Request<GetScopedCredentialsRequest>,
    ) -> Result<Response<GetScopedCredentialsResponse>, Status> {
        info!("Received credential request");

        // --- 1. Authentication ---
        // In a real system, you'd extract a JWT from request metadata and validate it.
        // The JWT's claims would contain the tenant_id, preventing spoofing.
        // For this example, we trust the tenant_id from the request payload.
        // let tenant_id = self.authenticate(request.metadata())?;
        let inner_req = request.into_inner();
        let tenant_id = inner_req.tenant_id;

        if tenant_id.is_empty() {
             return Err(Status::unauthenticated("Missing tenant_id"));
        }

        // --- 2. Authorization & Validation ---
        let duration = inner_req.duration_seconds.min(self.config.max_duration_seconds as i32);
        if duration < 900 { // AWS STS minimum is 900 seconds
            return Err(Status::invalid_argument("Duration must be at least 900 seconds"));
        }

        // --- 3. Generate Dynamic Policy ---
        let policy_json = self.create_scoped_policy(&inner_req.resource_path)?;
        info!("Generated policy for tenant '{}' and path '{}'", tenant_id, inner_req.resource_path);

        // --- 4. Assume Role with AWS STS ---
        let role_session_name = format!("spark-job-{}-{}", tenant_id, chrono::Utc::now().timestamp_millis());
        
        // Ensure session name is valid
        let role_session_name: String = role_session_name.chars().filter(|c| c.is_ascii_alphanumeric() || *c == '_' || *c == '.' || *c == '@' || *c == '-').collect();
        if role_session_name.len() > 64 {
            role_session_name.truncate(64);
        }

        let assume_role_output = self.sts_client
            .assume_role()
            .role_arn(&self.config.aws_role_arn_to_assume)
            .role_session_name(role_session_name)
            .policy(policy_json)
            .duration_seconds(duration)
            .send()
            .await
            .map_err(|e| {
                error!("STS AssumeRole failed: {}", e);
                Status::internal("Failed to assume AWS role")
            })?;

        // --- 5. Construct and Send Response ---
        if let Some(creds) = assume_role_output.credentials {
            let expiration = creds.expiration.unwrap().as_secs_f64() as i64;
            let response = GetScopedCredentialsResponse {
                access_key_id: creds.access_key_id.unwrap_or_default(),
                secret_access_key: creds.secret_access_key.unwrap_or_default(),
                session_token: creds.session_token.unwrap_or_default(),
                expiration,
            };
            info!("Successfully issued credentials for tenant '{}'", tenant_id);
            Ok(Response::new(response))
        } else {
            error!("STS response was successful but contained no credentials.");
            Err(Status::internal("Invalid response from AWS STS"))
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    tracing_subscriber::fmt::init();
    
    // In production, load from a secure source.
    let config = Config {
        aws_role_arn_to_assume: std::env::var("AWS_ROLE_ARN_TO_ASSUME")
            .expect("AWS_ROLE_ARN_TO_ASSUME must be set"),
        max_duration_seconds: 3600, // 1 hour max
        listen_addr: "0.0.0.0:50051".to_string(),
    };
    
    let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
    let aws_config = aws_config::from_env().region(region_provider).load().await;
    let sts_client = StsClient::new(&aws_config);

    let addr = config.listen_addr.parse()?;
    let service = MyCredentialService { sts_client, config };

    info!("Credential Service listening on {}", addr);

    Server::builder()
        .add_service(CredentialServiceServer::new(service))
        .serve(addr)
        .await?;

    Ok(())
}

3. Spark 自定义凭证 Provider (TvmCredentialProvider.scala)

这个 Scala 类将作为 Spark 作业的一部分被打包。它实现了 Hadoop 的 AWSCredentialsProvider 接口,在需要凭证时,它会调用我们的 TVM 服务。

// File: TvmCredentialProvider.scala

package com.mycompany.spark.auth

import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicSessionCredentials}
import com.mycompany.credentials.v1.CredentialServiceGrpc
import com.mycompany.credentials.v1.GetScopedCredentialsRequest
import io.grpc.{ManagedChannel, ManagedChannelBuilder}
import org.apache.hadoop.conf.Configuration
import java.net.URI
import java.util.concurrent.TimeUnit
import org.slf4j.LoggerFactory

class TvmCredentialProvider(uri: URI, conf: Configuration) extends AWSCredentialsProvider {

  private val logger = LoggerFactory.getLogger(classOf[TvmCredentialProvider])

  private val TVM_ADDRESS_KEY = "fs.s3a.custom.tvm.address"
  private val TVM_TENANT_ID_KEY = "fs.s3a.custom.tvm.tenant.id"
  private val TVM_RESOURCE_PATH_KEY = "fs.s3a.custom.tvm.resource.path"

  private val tvmAddress: String = conf.get(TVM_ADDRESS_KEY, "localhost:50051")
  
  // A real implementation would cache credentials until they are close to expiring.
  // This simple version fetches new credentials on every call to getCredentials().
  // Using `Option` and `synchronized` is crucial for thread safety.
  @volatile private var cachedCredentials: Option[(BasicSessionCredentials, Long)] = None

  private def createGrpcChannel(): ManagedChannel = {
    ManagedChannelBuilder.forTarget(tvmAddress).usePlaintext().build()
  }

  override def getCredentials: AWSCredentials = {
    synchronized {
      val now = System.currentTimeMillis() / 1000
      // Check if cache is valid (exists and not expiring within the next 5 minutes)
      cachedCredentials match {
        case Some((creds, expiration)) if expiration > now + 300 =>
          logger.debug("Using cached credentials")
          return creds
        case _ =>
          logger.info("Credentials expired or not found, fetching new credentials from TVM.")
          // Fall through to fetch new credentials
      }

      val channel = createGrpcChannel()
      try {
        val tenantId = conf.get(TVM_TENANT_ID_KEY, "")
        // We get the resource path from the URI being accessed.
        // This is a simplification; a more robust solution might need more context.
        val resourcePath = s"s3://${uri.getHost}${uri.getPath}"

        if (tenantId.isEmpty) {
          throw new RuntimeException(s"Hadoop config key '$TVM_TENANT_ID_KEY' must be set.")
        }
        
        logger.info(s"Requesting credentials for tenant '$tenantId' and path '$resourcePath'")

        val stub = CredentialServiceGrpc.newBlockingStub(channel)
        val request = GetScopedCredentialsRequest.newBuilder()
          .setTenantId(tenantId)
          .setResourcePath(resourcePath)
          .setDurationSeconds(3600) // Request 1-hour credentials
          .build()

        val response = stub.getScopedCredentials(request)
        
        val newCreds = new BasicSessionCredentials(
          response.getAccessKeyId,
          response.getSecretAccessKey,
          response.getSessionToken
        )

        // Cache the new credentials with their expiration time
        cachedCredentials = Some((newCreds, response.getExpiration))
        newCreds
        
      } catch {
        case e: Exception =>
          logger.error("Failed to get credentials from TVM", e)
          throw new RuntimeException("Failed to retrieve credentials from TVM service", e)
      } finally {
        channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS)
      }
    }
  }

  override def refresh(): Unit = {
    // Force invalidation of the cache
    synchronized {
      cachedCredentials = None
    }
    logger.info("Credentials cache invalidated by refresh() call.")
  }
}

最后,在提交 Spark 作业时,通过配置来启用这个 Provider:

spark-submit \
  --class com.mycompany.Main \
  --conf "spark.hadoop.fs.s3a.aws.credentials.provider=com.mycompany.spark.auth.TvmCredentialProvider" \
  --conf "spark.hadoop.fs.s3a.custom.tvm.address=tvm-service.my-namespace:50051" \
  --conf "spark.hadoop.fs.s3a.custom.tvm.tenant.id=tenant-a" \
  --jars my-spark-app.jar,grpc-stub.jar,custom-provider.jar \
  my-spark-app.jar

这里的核心在于 fs.s3a.aws.credentials.provider 配置,它告诉 Hadoop S3A 文件系统使用我们的自定义类来获取凭证。

架构的局限性与未来展望

这个 TVM 架构虽然解决了精细化权限控制的核心问题,但并非银弹。首先,TVM 服务自身的高可用性至关重要。生产环境必须部署多个 TVM 实例,并通过负载均衡器对外提供服务。其次,TVM 与 Spark Executor 之间的网络延迟会轻微增加首次数据访问的耗时,尽管 gRPC 的性能和凭证缓存机制可以很大程度上缓解这个问题。

未来的优化路径可以集中在以下几个方面:

  1. 凭证缓存策略:TvmCredentialProvider 中实现更智能的缓存和预取逻辑。例如,在一个凭证过期前的几分钟,异步地去获取下一个凭证,避免作业在关键时刻因等待新凭证而阻塞。
  2. 双向 TLS (mTLS): 为了增强安全性,TVM 与客户端之间的通信应启用 mTLS,确保只有受信任的 Spark 作业才能请求凭证。
  3. 支持多云: 将 TVM 的核心逻辑抽象化,使其可以作为一个插件式平台,不仅支持 AWS STS,还能对接 Google Cloud IAM 和 Azure Active Directory,为跨云数据处理提供统一的凭证管理入口。
  4. 审计与监控: 对每一次凭证请求和颁发都生成详细的、结构化的日志,并接入监控系统。这对于安全审计和排查问题至关重要。

  目录