How to use AWS SAML Authentication in Java

Wu Jiaojiao & Li Yuansan
Lazy Tech Leader Series
3 min readJul 13, 2022

--

If you work in a company, and want to move data from on prem to public cloud, e.g. AWS S3 service, programmatically. Before you can code, let us try to understand the infrastructure setup about the connectivity.

Network Accessibility

In most of the cases, how on-prem network connecting to AWS public cloud are setup in the diagram below. From your organization perspective, the customer network represents the on-prem Network.

In order for you to be able to connect to services sitting/routing from AWS VPC, the AWS client should know the “customer router/firewall” information. This is normally regarded as “Proxy Server”. It is the starting point to reach AWS network.

Authentication

If you are hosting a public available information, and no authentication are required, e.g. http connection. You should be able to use curl command with proxy setting to be able to establish the connection.

However, if you want to connect to AWS services, e.g. S3. and it requires specific role to be able to read and write. How does authentication work in this case?

I paste one of the common used authentication flow as below. The Enterprise section represents on Prem identify Provider. Most of companies use windows Active Directory.

If you are using browser/aws cli/aws java client to connect to AWS, there are two main steps:

  • Authenticate via your AD.
  • Get the SAML Assertion from your AD, and send to AWS Sts to authorize.

Code

After understanding the network and authentication flow, now it is the time for coding:

  • Step 1 — prepare the pom.xml
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-iam</artifactId>
<version>1.11.444</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
<version>1.11.444</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-ec2</artifactId>
<version>1.11.444</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-s3</artifactId>
<version>1.11.444</version>
</dependency>
<dependency>
<groupId>org.jsoup</groupId>
<artifactId>jsoup</artifactId>
<version>1.9.2</version>
</dependency>
  • Step 2 — Example code to get credentials:
import com.amazonaws.ClientConfiguration;
import com.amazonaws.Protocol;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import org.jsoup.Connection;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.jsoup.parser.Parser;
import org.jsoup.select.Elements;

import java.io.IOException;
import java.net.URL;
import java.util.Base64;
import java.util.Map;

public class AWSClientWithADAuth {

public static final String ADFS_URL = "https://adfs.example.com/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices";
public static final String ADFS_HOST = "adfs.example.com";

public BasicSessionCredentials authorizeWithSamlAssertion(String region, String roleARN) throws Exception {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard().withRegion(region)
.withCredentials(new AWSStaticCredentialsProvider(new BasicAWSCredentials("", "")))
.withClientConfiguration(
new ClientConfiguration().withProxyHost("proxy.example.net")
.withProxyPort(8443).withProtocol(Protocol.HTTPS)
).build();

String samlAssertion = getSamlAssertion("AD", "username", "password");

if (samlAssertion != null) {
AssumeRoleWithSAMLRequest request = new AssumeRoleWithSAMLRequest()
.withRoleArn(roleARN).withPrincipalArn(extractPrincipalArn(samlAssertion,roleARN))
.withDurationSeconds(3600)
.withSAMLAssertion(samlAssertion);

AssumeRoleWithSAMLResult assumeRoleWithSAMLResult = stsClient.assumeRoleWithSAML(request);
Credentials credentials = assumeRoleWithSAMLResult.getCredentials();
return new BasicSessionCredentials(credentials.getAccessKeyId(),
credentials.getSecretAccessKey(), credentials.getSessionToken());
} else {
throw new Exception("Unauthorized Access");
}
}

private String getSamlAssertion(String domain, String principal, String credential)
throws IOException {
String userName = domain + "\\" + principal;

Connection.Response initialResp = Jsoup.connect(ADFS_URL).execute();
Document loginForm = initialResp.parse();
Map<String, String> cookies = initialResp.cookies();
Element loginElement = loginForm.getElementById("loginForm");

if (loginElement != null){
URL loginPostUrl = new URL("https", ADFS_HOST, loginElement.attr("action"));
cookies.clear();
Connection conn = Jsoup.connect(loginPostUrl.toString());
Document samlDoc =
conn.data("UserName", userName).data("Kmsi", "false").data("Password", credential)
.method(Connection.Method.POST).execute().parse();
Elements samlElements = samlDoc.getElementsByAttributeValue("name", "SAMLResponse");
if (samlElements.isEmpty()){
throw new IOException("Invalid SAML Response.");
} else {
Element samlElement = samlElements.first();
return samlElement.val();
}
}
return null;
}
private String extractPrincipalArn(String samlAssertion, String roleARN) throws Exception {
String assertion = new String(Base64.getDecoder().decode(samlAssertion), "UTF-8");
Document assertionDoc = Jsoup.parse(assertion, "UTF-8", Parser.xmlParser());
Elements samlAttribute = assertionDoc.getElementsByAttributeValue("name", "https://aws.amazon.com/SAML/Attributes/Role");

for (Element attribue: samlAttribute){
Elements samlAttrValues = attribue.getElementsByTag("AttributeValue");

for (Element samlRoleAttr: samlAttrValues){
String samlRole = samlRoleAttr.text();
if (samlRole.contains(roleARN)){
return samlRole.split(",")[0];
}
}
}
throw new Exception("You are not authorized to assume role: "+ roleARN);
}
}

--

--