001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.security.provider;
019
020import java.io.IOException;
021import java.lang.reflect.InvocationTargetException;
022import java.util.HashMap;
023import java.util.Optional;
024import java.util.ServiceLoader;
025import java.util.concurrent.atomic.AtomicReference;
026import java.util.stream.Collectors;
027import org.apache.hadoop.conf.Configuration;
028import org.apache.yetus.audience.InterfaceAudience;
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032@InterfaceAudience.Private
033public final class SaslServerAuthenticationProviders {
034  private static final Logger LOG =
035    LoggerFactory.getLogger(SaslClientAuthenticationProviders.class);
036
037  public static final String EXTRA_PROVIDERS_KEY = "hbase.server.sasl.provider.extras";
038  private static final AtomicReference<SaslServerAuthenticationProviders> holder =
039    new AtomicReference<>();
040
041  private final HashMap<Byte, SaslServerAuthenticationProvider> providers;
042
043  private SaslServerAuthenticationProviders(Configuration conf,
044    HashMap<Byte, SaslServerAuthenticationProvider> providers) {
045    this.providers = providers;
046  }
047
048  /**
049   * Returns the number of registered providers.
050   */
051  public int getNumRegisteredProviders() {
052    return providers.size();
053  }
054
055  /**
056   * Returns a singleton instance of {@link SaslServerAuthenticationProviders}.
057   */
058  public static SaslServerAuthenticationProviders getInstance(Configuration conf) {
059    SaslServerAuthenticationProviders providers = holder.get();
060    if (null == providers) {
061      synchronized (holder) {
062        // Someone else beat us here
063        providers = holder.get();
064        if (null != providers) {
065          return providers;
066        }
067
068        providers = createProviders(conf);
069        holder.set(providers);
070      }
071    }
072    return providers;
073  }
074
075  /**
076   * Removes the cached singleton instance of {@link SaslServerAuthenticationProviders}.
077   */
078  public static void reset() {
079    synchronized (holder) {
080      holder.set(null);
081    }
082  }
083
084  /**
085   * Adds the given provider into the map of providers if a mapping for the auth code does not
086   * already exist in the map.
087   */
088  static void addProviderIfNotExists(SaslServerAuthenticationProvider provider,
089    HashMap<Byte, SaslServerAuthenticationProvider> providers) {
090    final byte newProviderAuthCode = provider.getSaslAuthMethod().getCode();
091    final SaslServerAuthenticationProvider alreadyRegisteredProvider =
092      providers.get(newProviderAuthCode);
093    if (alreadyRegisteredProvider != null) {
094      throw new RuntimeException("Trying to load SaslServerAuthenticationProvider "
095        + provider.getClass() + ", but " + alreadyRegisteredProvider.getClass()
096        + " is already registered with the same auth code");
097    }
098    providers.put(newProviderAuthCode, provider);
099  }
100
101  /**
102   * Adds any providers defined in the configuration.
103   */
104  static void addExtraProviders(Configuration conf,
105    HashMap<Byte, SaslServerAuthenticationProvider> providers) {
106    for (String implName : conf.getStringCollection(EXTRA_PROVIDERS_KEY)) {
107      Class<?> clz;
108      try {
109        clz = Class.forName(implName);
110      } catch (ClassNotFoundException e) {
111        LOG.warn("Failed to find SaslServerAuthenticationProvider class {}", implName, e);
112        continue;
113      }
114
115      if (!SaslServerAuthenticationProvider.class.isAssignableFrom(clz)) {
116        LOG.warn("Server authentication class {} is not an instance of "
117          + "SaslServerAuthenticationProvider", clz);
118        continue;
119      }
120
121      try {
122        SaslServerAuthenticationProvider provider =
123          (SaslServerAuthenticationProvider) clz.getConstructor().newInstance();
124        addProviderIfNotExists(provider, providers);
125      } catch (InstantiationException | IllegalAccessException | NoSuchMethodException
126        | InvocationTargetException e) {
127        LOG.warn("Failed to instantiate {}", clz, e);
128      }
129    }
130  }
131
132  /**
133   * Loads server authentication providers from the classpath and configuration, and then creates
134   * the SaslServerAuthenticationProviders instance.
135   */
136  static SaslServerAuthenticationProviders createProviders(Configuration conf) {
137    ServiceLoader<SaslServerAuthenticationProvider> loader =
138      ServiceLoader.load(SaslServerAuthenticationProvider.class);
139    HashMap<Byte, SaslServerAuthenticationProvider> providers = new HashMap<>();
140    for (SaslServerAuthenticationProvider provider : loader) {
141      addProviderIfNotExists(provider, providers);
142    }
143
144    addExtraProviders(conf, providers);
145
146    if (LOG.isTraceEnabled()) {
147      String loadedProviders = providers.values().stream()
148        .map((provider) -> provider.getClass().getName()).collect(Collectors.joining(", "));
149      if (loadedProviders.isEmpty()) {
150        loadedProviders = "None!";
151      }
152      LOG.trace("Found SaslServerAuthenticationProviders {}", loadedProviders);
153    }
154
155    // Initialize the providers once, before we get into the RPC path.
156    providers.forEach((b, provider) -> {
157      try {
158        // Give them a copy, just to make sure there is no funny-business going on.
159        provider.init(new Configuration(conf));
160      } catch (IOException e) {
161        LOG.error("Failed to initialize {}", provider.getClass(), e);
162        throw new RuntimeException("Failed to initialize " + provider.getClass().getName(), e);
163      }
164    });
165
166    return new SaslServerAuthenticationProviders(conf, providers);
167  }
168
169  /**
170   * Selects the appropriate SaslServerAuthenticationProvider from those available. If there is no
171   * matching provider for the given {@code authByte}, this method will return null.
172   */
173  public SaslServerAuthenticationProvider selectProvider(byte authByte) {
174    return providers.get(Byte.valueOf(authByte));
175  }
176
177  /**
178   * Extracts the SIMPLE authentication provider.
179   */
180  public SaslServerAuthenticationProvider getSimpleProvider() {
181    Optional<SaslServerAuthenticationProvider> opt = providers.values().stream()
182      .filter((p) -> p instanceof SimpleSaslServerAuthenticationProvider).findFirst();
183    if (!opt.isPresent()) {
184      throw new RuntimeException("SIMPLE authentication provider not available when it should be");
185    }
186    return opt.get();
187  }
188}