001/*
002 * CDDL HEADER START
003 *
004 * The contents of this file are subject to the terms of the
005 * Common Development and Distribution License, Version 1.0 only
006 * (the "License").  You may not use this file except in compliance
007 * with the License.
008 *
009 * You can obtain a copy of the license at legal-notices/CDDLv1_0.txt
010 * or http://forgerock.org/license/CDDLv1.0.html.
011 * See the License for the specific language governing permissions
012 * and limitations under the License.
013 *
014 * When distributing Covered Code, include this CDDL HEADER in each
015 * file and include the License file at legal-notices/CDDLv1_0.txt.
016 * If applicable, add the following below this CDDL HEADER, with the
017 * fields enclosed by brackets "[]" replaced with your own identifying
018 * information:
019 *      Portions Copyright [yyyy] [name of copyright owner]
020 *
021 * CDDL HEADER END
022 *
023 *
024 *      Copyright 2008 Sun Microsystems, Inc.
025 *      Portions Copyright 2015 ForgeRock AS.
026 */
027
028package org.opends.admin.ads.util;
029
030import java.io.IOException;
031import java.net.Socket;
032import java.net.InetAddress;
033import java.util.Map;
034import java.util.HashMap;
035
036import java.security.GeneralSecurityException;
037
038import javax.net.SocketFactory;
039import javax.net.ssl.KeyManager;
040import javax.net.ssl.SSLContext;
041import javax.net.ssl.SSLSocketFactory;
042import javax.net.ssl.SSLKeyException;
043import javax.net.ssl.TrustManager;
044
045/**
046 * An implementation of SSLSocketFactory.
047 */
048public class TrustedSocketFactory extends SSLSocketFactory
049{
050  private static Map<Thread, TrustManager> hmTrustManager = new HashMap<>();
051  private static Map<Thread, KeyManager> hmKeyManager = new HashMap<>();
052
053  private static Map<TrustManager, SocketFactory> hmDefaultFactoryTm = new HashMap<>();
054  private static Map<KeyManager, SocketFactory> hmDefaultFactoryKm = new HashMap<>();
055
056  private SSLSocketFactory innerFactory;
057  private TrustManager trustManager;
058  private KeyManager   keyManager;
059
060  /**
061   * Constructor of the TrustedSocketFactory.
062   * @param trustManager the trust manager to use.
063   * @param keyManager   the key manager to use.
064   */
065  public TrustedSocketFactory(TrustManager trustManager, KeyManager keyManager)
066  {
067    this.trustManager = trustManager;
068    this.keyManager   = keyManager;
069  }
070
071  /**
072   * Sets the provided trust and key manager for the operations in the
073   * current thread.
074   *
075   * @param trustManager
076   *          the trust manager to use.
077   * @param keyManager
078   *          the key manager to use.
079   */
080  public static synchronized void setCurrentThreadTrustManager(
081      TrustManager trustManager, KeyManager keyManager)
082  {
083    setThreadTrustManager(trustManager, Thread.currentThread());
084    setThreadKeyManager  (keyManager, Thread.currentThread());
085  }
086
087  /**
088   * Sets the provided trust manager for the operations in the provided thread.
089   * @param trustManager the trust manager to use.
090   * @param thread the thread where we want to use the provided trust manager.
091   */
092  public static synchronized void setThreadTrustManager(
093      TrustManager trustManager, Thread thread)
094  {
095    TrustManager currentTrustManager = hmTrustManager.get(thread);
096    if (currentTrustManager != null) {
097      hmDefaultFactoryTm.remove(currentTrustManager);
098      hmTrustManager.remove(thread);
099    }
100    if (trustManager != null) {
101      hmTrustManager.put(thread, trustManager);
102    }
103  }
104
105  /**
106   * Sets the provided key manager for the operations in the provided thread.
107   * @param keyManager the key manager to use.
108   * @param thread the thread where we want to use the provided key manager.
109   */
110  public static synchronized void setThreadKeyManager(
111      KeyManager keyManager, Thread thread)
112  {
113    KeyManager currentKeyManager = hmKeyManager.get(thread);
114    if (currentKeyManager != null) {
115      hmDefaultFactoryKm.remove(currentKeyManager);
116      hmKeyManager.remove(thread);
117    }
118    if (keyManager != null) {
119      hmKeyManager.put(thread, keyManager);
120    }
121  }
122
123  //
124  // SocketFactory implementation
125  //
126  /**
127   * Returns the default SSL socket factory. The default
128   * implementation can be changed by setting the value of the
129   * "ssl.SocketFactory.provider" security property (in the Java
130   * security properties file) to the desired class. If SSL has not
131   * been configured properly for this virtual machine, the factory
132   * will be inoperative (reporting instantiation exceptions).
133   *
134   * @return the default SocketFactory
135   */
136  public static synchronized SocketFactory getDefault()
137  {
138    Thread currentThread = Thread.currentThread();
139    TrustManager trustManager = hmTrustManager.get(currentThread);
140    KeyManager   keyManager   = hmKeyManager.get(currentThread);
141    SocketFactory result;
142
143    if (trustManager == null)
144    {
145      if (keyManager == null)
146      {
147        result = new TrustedSocketFactory(null,null);
148      }
149      else
150      {
151        result = hmDefaultFactoryKm.get(keyManager);
152        if (result == null)
153        {
154          result = new TrustedSocketFactory(null,keyManager);
155          hmDefaultFactoryKm.put(keyManager, result);
156        }
157      }
158    }
159    else
160    {
161      if (keyManager == null)
162      {
163        result = hmDefaultFactoryTm.get(trustManager);
164        if (result == null)
165        {
166          result = new TrustedSocketFactory(trustManager, null);
167          hmDefaultFactoryTm.put(trustManager, result);
168        }
169      }
170      else
171      {
172        SocketFactory tmsf = hmDefaultFactoryTm.get(trustManager);
173        SocketFactory kmsf = hmDefaultFactoryKm.get(keyManager);
174        if ( tmsf == null || kmsf == null)
175        {
176          result = new TrustedSocketFactory(trustManager, keyManager);
177          hmDefaultFactoryTm.put(trustManager, result);
178          hmDefaultFactoryKm.put(keyManager, result);
179        }
180        else
181        if ( !tmsf.equals(kmsf) )
182        {
183          result = new TrustedSocketFactory(trustManager, keyManager);
184          hmDefaultFactoryTm.put(trustManager, result);
185          hmDefaultFactoryKm.put(keyManager, result);
186        }
187        else
188        {
189          result = tmsf ;
190        }
191      }
192    }
193
194    return result;
195  }
196
197  /** {@inheritDoc} */
198  public Socket createSocket(InetAddress address, int port) throws IOException {
199    return getInnerFactory().createSocket(address, port);
200  }
201
202  /** {@inheritDoc} */
203  public Socket createSocket(InetAddress address, int port,
204      InetAddress clientAddress, int clientPort) throws IOException
205  {
206    return getInnerFactory().createSocket(address, port, clientAddress,
207        clientPort);
208  }
209
210  /** {@inheritDoc} */
211  public Socket createSocket(String host, int port) throws IOException
212  {
213    return getInnerFactory().createSocket(host, port);
214  }
215
216  /** {@inheritDoc} */
217  public Socket createSocket(String host, int port, InetAddress clientHost,
218      int clientPort) throws IOException
219  {
220    return getInnerFactory().createSocket(host, port, clientHost, clientPort);
221  }
222
223  /** {@inheritDoc} */
224  public Socket createSocket(Socket s, String host, int port, boolean autoClose)
225  throws IOException
226  {
227    return getInnerFactory().createSocket(s, host, port, autoClose);
228  }
229
230  /** {@inheritDoc} */
231  public String[] getDefaultCipherSuites()
232  {
233    try
234    {
235      return getInnerFactory().getDefaultCipherSuites();
236    }
237    catch(IOException x)
238    {
239      return new String[0];
240    }
241  }
242
243  /** {@inheritDoc} */
244  public String[] getSupportedCipherSuites()
245  {
246    try
247    {
248      return getInnerFactory().getSupportedCipherSuites();
249    }
250    catch(IOException x)
251    {
252      return new String[0];
253    }
254  }
255
256  private SSLSocketFactory getInnerFactory() throws IOException {
257    if (innerFactory == null)
258    {
259      String algorithm = "TLSv1";
260      SSLKeyException xx;
261      KeyManager[] km = null;
262      TrustManager[] tm = null;
263
264      try {
265        SSLContext sslCtx = SSLContext.getInstance(algorithm);
266        if (trustManager != null)
267        {
268          tm = new TrustManager[] { trustManager };
269        }
270        if (keyManager != null)
271        {
272          km = new KeyManager[] { keyManager };
273        }
274        sslCtx.init(km, tm, new java.security.SecureRandom() );
275        innerFactory = sslCtx.getSocketFactory();
276      }
277      catch(GeneralSecurityException x) {
278        xx = new SSLKeyException("Failed to create SSLContext for " +
279            algorithm);
280        xx.initCause(x);
281        throw xx;
282      }
283    }
284    return innerFactory;
285  }
286}
287