Skip to content

Commit

Permalink
Merge pull request #8 from doyensec/torchserve_rce
Browse files Browse the repository at this point in the history
TorchServe Management API RCE
  • Loading branch information
execveat authored Dec 22, 2023
2 parents 422aaa2 + 2ccfd7d commit f9e2da1
Show file tree
Hide file tree
Showing 25 changed files with 2,897 additions and 0 deletions.
48 changes: 48 additions & 0 deletions doyensec/detectors/rce/torchserve/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# TorchServe Management API Detection Plugin
## Overview
This plugin detects and assesses the security risks of TorchServe Management API instances. Inspired by the ShellTorch vulnerability chain (disclosed by [Oligo Security](https://www.oligo.security/blog/shelltorch-torchserve-ssrf-vulnerability-cve-2023-43654)), it addresses the critical risks associated with insecure configurations of TorchServe, a widely used open-source application for serving PyTorch models in production.

## Background
TorchServe, before version 0.8.2, bound to `0.0.0.0` by default, potentially exposing its Management API to the internet. Since PyTorch models allow arbitrary code execution, unrestricted model addition poses significant risks including data leakage and user privacy breaches.

The original ShellTorch attack exploited [CVE-2022-1471](https://nvd.nist.gov/vuln/detail/CVE-2022-1471), a vulnerability fixed in TorchServe 0.8.2. However, the risk of executing arbitrary code in models remains in the latest version (0.9.0).

To mitigate these risks, TorchServe introduced the allow_urls feature, limiting model downloads to specified sources. However, a typical `allow_urls` configuration often includes entire services like GCP and AWS, which can be insecure. It's important to configure `allow_urls` carefully to avoid such vulnerabilities.

## Plugin Description
This plugin detects exposed TorchServe Management API instances, assessing the remote code execution (RCE) risk. It supports multiple detection modes:

### Static Mode
**Description:** Manually host a model file on a web server. Most reliable, particularly effective against lenient `allow_urls` configurations.
**Use case:** Ideal when `allow_urls` includes cloud services, posing a security risk.

```
--torchserve-management-api-mode=static --torchserve-management-api-model-static-url=https://s3.amazonaws.com/model.mar
```

### Local Mode
**Description:** Serve the model via an embedded web server. Quicker setup, but may fail against restrictive `allow_urls`.
**Use case:** Best for environments where `allow_urls` is not a limiting factor.

```
--torchserve-management-api-mode=local --torchserve-management-api-local-bind-host=tsunami --torchserve-management-api-local-bind-port=1234 --torchserve-management-api-local-accessible-url=http://mydomain.com/
```

### SSRF Mode
**Description:** Uses Tsunami's callback server as the model source. Indirect verification of RCE risk.
**Use case:** Selected when direct model serving isn't feasible or as an additional verification layer.

```
--torchserve-management-api-mode=ssrf
```

### Basic Mode
**Description:** Default mode that relies solely on Management API fingerprinting.
**Use case:** Automatically selected when callback server isn't available, useful as a preliminary check.

```
--torchserve-management-api-mode=basic
```

## Testing
Utilize the following testbed for assessing plugin functionality: [TorchServe Security Testbed](https://github.com/google/security-testbeds/tree/main/torchserve).
99 changes: 99 additions & 0 deletions doyensec/detectors/rce/torchserve/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
plugins {
id 'java-library'
}

description = 'Tsunami VulnDetector plugin for TorchServe CVE-2023-43654.'
group = 'com.google.tsunami'
version = '0.0.1-SNAPSHOT'

repositories {
maven { // The google mirror is less flaky than mavenCentral()
url 'https://maven-central.storage-download.googleapis.com/repos/central/data/'
}
mavenCentral()
mavenLocal()
}

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11

jar.manifest {
attributes('Implementation-Title': name,
'Implementation-Version': version,
'Built-By': System.getProperty('user.name'),
'Built-JDK': System.getProperty('java.version'),
'Source-Compatibility': sourceCompatibility,
'Target-Compatibility': targetCompatibility)
}

javadoc.options {
encoding = 'UTF-8'
use = true
links 'https://docs.oracle.com/javase/8/docs/api/'
}

// Log stacktrace to console when test fails.
test {
testLogging {
exceptionFormat = 'full'
showExceptions true
showCauses true
showStackTraces true
}
maxHeapSize = '1500m'
}
}

ext {
tsunamiVersion = '0.0.14'
junitVersion = '4.13'
mockitoVersion = '2.28.2'
truthVersion = '1.0.1'
javaxInjectVersion = '1'
jcommanderVersion = '1.48'
okhttpVersion = '3.12.0'


guavaVersion = '28.2-jre'
guiceVersion = '4.2.3'
tsunamiVersion = '0.0.14'
junitVersion = '4.13'
okhttpVersion = '3.12.0'
truthVersion = '1.0.1'
}

dependencies {
implementation "com.google.tsunami:tsunami-common:${tsunamiVersion}"
implementation "com.google.tsunami:tsunami-plugin:${tsunamiVersion}"
implementation "com.google.tsunami:tsunami-proto:${tsunamiVersion}"

implementation "javax.inject:javax.inject:${javaxInjectVersion}"
implementation "com.beust:jcommander:${jcommanderVersion}"
implementation "com.squareup.okhttp3:okhttp:${okhttpVersion}"

testImplementation "junit:junit:${junitVersion}"
testImplementation "org.mockito:mockito-core:${mockitoVersion}"
testImplementation "com.google.truth:truth:${truthVersion}"
testImplementation "com.google.truth.extensions:truth-java8-extension:${truthVersion}"
testImplementation "com.google.truth.extensions:truth-proto-extension:${truthVersion}"
testImplementation "com.squareup.okhttp3:mockwebserver:${okhttpVersion}"

testImplementation "junit:junit:${junitVersion}"
testImplementation "com.google.guava:guava-testlib:${guavaVersion}"
testImplementation "com.google.inject.extensions:guice-testlib:${guiceVersion}"
testImplementation "com.google.truth:truth:${truthVersion}"
testImplementation "com.google.truth.extensions:truth-java8-extension:${truthVersion}"
testImplementation "com.google.truth.extensions:truth-proto-extension:${truthVersion}"
testImplementation "com.squareup.okhttp3:mockwebserver:${okhttpVersion}"
}

// Generate model.zip file and include it in the jar file.
task createModelsZip(type: Zip) {
from 'src/main/resources/model'
into '/'
destinationDirectory = file("$buildDir/resources/main")
archiveFileName = 'model.mar'
}

processResources.dependsOn createModelsZip
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Loading

0 comments on commit f9e2da1

Please sign in to comment.