001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.security;
019
020import java.io.BufferedInputStream;
021import java.io.BufferedOutputStream;
022import java.io.DataInputStream;
023import java.io.DataOutputStream;
024import java.io.FilterInputStream;
025import java.io.FilterOutputStream;
026import java.io.IOException;
027import java.io.InputStream;
028import java.io.OutputStream;
029import java.net.InetAddress;
030import java.nio.ByteBuffer;
031import javax.security.sasl.Sasl;
032import javax.security.sasl.SaslException;
033import org.apache.hadoop.conf.Configuration;
034import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
035import org.apache.hadoop.hbase.ipc.FallbackDisallowedException;
036import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
037import org.apache.hadoop.io.WritableUtils;
038import org.apache.hadoop.ipc.RemoteException;
039import org.apache.hadoop.security.SaslInputStream;
040import org.apache.hadoop.security.SaslOutputStream;
041import org.apache.hadoop.security.token.Token;
042import org.apache.hadoop.security.token.TokenIdentifier;
043import org.apache.yetus.audience.InterfaceAudience;
044import org.slf4j.Logger;
045import org.slf4j.LoggerFactory;
046
047import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
048
049/**
050 * A utility class that encapsulates SASL logic for RPC client. Copied from
051 * <code>org.apache.hadoop.security</code>
052 */
053@InterfaceAudience.Private
054public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
055
056  private static final Logger LOG = LoggerFactory.getLogger(HBaseSaslRpcClient.class);
057  private boolean cryptoAesEnable;
058  private CryptoAES cryptoAES;
059  private InputStream saslInputStream;
060  private InputStream cryptoInputStream;
061  private OutputStream saslOutputStream;
062  private OutputStream cryptoOutputStream;
063  private boolean initStreamForCrypto;
064
065  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
066    Token<? extends TokenIdentifier> token, InetAddress serverAddr, String servicePrincipal,
067    boolean fallbackAllowed) throws IOException {
068    super(conf, provider, token, serverAddr, servicePrincipal, fallbackAllowed);
069  }
070
071  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
072    Token<? extends TokenIdentifier> token, InetAddress serverAddr, String servicePrincipal,
073    boolean fallbackAllowed, String rpcProtection, boolean initStreamForCrypto) throws IOException {
074    super(conf, provider, token, serverAddr, servicePrincipal, fallbackAllowed, rpcProtection);
075    this.initStreamForCrypto = initStreamForCrypto;
076  }
077
078  private static void readStatus(DataInputStream inStream) throws IOException {
079    int status = inStream.readInt(); // read status
080    if (status != SaslStatus.SUCCESS.state) {
081      throw new RemoteException(WritableUtils.readString(inStream),
082        WritableUtils.readString(inStream));
083    }
084  }
085
086  /**
087   * Do client side SASL authentication with server via the given InputStream and OutputStream
088   * @param inS  InputStream to use
089   * @param outS OutputStream to use
090   * @return true if connection is set up, or false if needs to switch to simple Auth.
091   */
092  public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException {
093    DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
094    DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
095
096    try {
097      byte[] saslToken = getInitialResponse();
098      if (saslToken != null) {
099        outStream.writeInt(saslToken.length);
100        outStream.write(saslToken, 0, saslToken.length);
101        outStream.flush();
102        if (LOG.isDebugEnabled()) {
103          LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
104        }
105      }
106      if (!isComplete()) {
107        readStatus(inStream);
108        int len = inStream.readInt();
109        if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
110          if (!fallbackAllowed) {
111            throw new FallbackDisallowedException();
112          }
113          LOG.debug("Server asks us to fall back to simple auth.");
114          dispose();
115          return false;
116        }
117        saslToken = new byte[len];
118        if (LOG.isDebugEnabled()) {
119          LOG.debug("Will read input token of size " + saslToken.length
120            + " for processing by initSASLContext");
121        }
122        inStream.readFully(saslToken);
123      }
124
125      while (!isComplete()) {
126        saslToken = evaluateChallenge(saslToken);
127        if (saslToken != null) {
128          if (LOG.isDebugEnabled()) {
129            LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
130          }
131          outStream.writeInt(saslToken.length);
132          outStream.write(saslToken, 0, saslToken.length);
133          outStream.flush();
134        }
135        if (!isComplete()) {
136          readStatus(inStream);
137          saslToken = new byte[inStream.readInt()];
138          if (LOG.isDebugEnabled()) {
139            LOG.debug("Will read input token of size " + saslToken.length
140              + " for processing by initSASLContext");
141          }
142          inStream.readFully(saslToken);
143        }
144      }
145
146      if (LOG.isDebugEnabled()) {
147        LOG.debug("SASL client context established. Negotiated QoP: "
148          + saslClient.getNegotiatedProperty(Sasl.QOP));
149      }
150
151      verifyNegotiatedQop();
152
153      // initial the inputStream, outputStream for both Sasl encryption
154      // and Crypto AES encryption if necessary
155      // if Crypto AES encryption enabled, the saslInputStream/saslOutputStream is
156      // only responsible for connection header negotiation,
157      // cryptoInputStream/cryptoOutputStream is responsible for rpc encryption with Crypto AES
158      saslInputStream = new SaslInputStream(inS, saslClient);
159      saslOutputStream = new SaslOutputStream(outS, saslClient);
160      if (initStreamForCrypto) {
161        cryptoInputStream = new WrappedInputStream(inS);
162        cryptoOutputStream = new WrappedOutputStream(outS);
163      }
164
165      return true;
166    } catch (IOException e) {
167      try {
168        saslClient.dispose();
169      } catch (SaslException ignored) {
170        // ignore further exceptions during cleanup
171      }
172      throw e;
173    }
174  }
175
176  public String getSaslQOP() {
177    return (String) saslClient.getNegotiatedProperty(Sasl.QOP);
178  }
179
180  public void initCryptoCipher(RPCProtos.CryptoCipherMeta cryptoCipherMeta, Configuration conf)
181    throws IOException {
182    // create SaslAES for client
183    cryptoAES = EncryptionUtil.createCryptoAES(cryptoCipherMeta, conf);
184    cryptoAesEnable = true;
185  }
186
187  /**
188   * Get a SASL wrapped InputStream. Can be called only after saslConnect() has been called.
189   * @return a SASL wrapped InputStream
190   */
191  public InputStream getInputStream() throws IOException {
192    if (!saslClient.isComplete()) {
193      throw new IOException("Sasl authentication exchange hasn't completed yet");
194    }
195    // If Crypto AES is enabled, return cryptoInputStream which unwrap the data with Crypto AES.
196    if (cryptoAesEnable && cryptoInputStream != null) {
197      return cryptoInputStream;
198    }
199    return saslInputStream;
200  }
201
202  class WrappedInputStream extends FilterInputStream {
203    private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0);
204
205    public WrappedInputStream(InputStream in) throws IOException {
206      super(in);
207    }
208
209    @Override
210    public int read() throws IOException {
211      byte[] b = new byte[1];
212      int n = read(b, 0, 1);
213      return (n != -1) ? b[0] : -1;
214    }
215
216    @Override
217    public int read(byte b[]) throws IOException {
218      return read(b, 0, b.length);
219    }
220
221    @Override
222    public synchronized int read(byte[] buf, int off, int len) throws IOException {
223      // fill the buffer with the next RPC message
224      if (unwrappedRpcBuffer.remaining() == 0) {
225        readNextRpcPacket();
226      }
227      // satisfy as much of the request as possible
228      int readLen = Math.min(len, unwrappedRpcBuffer.remaining());
229      unwrappedRpcBuffer.get(buf, off, readLen);
230      return readLen;
231    }
232
233    // unwrap messages with Crypto AES
234    private void readNextRpcPacket() throws IOException {
235      LOG.debug("reading next wrapped RPC packet");
236      DataInputStream dis = new DataInputStream(in);
237      int rpcLen = dis.readInt();
238      byte[] rpcBuf = new byte[rpcLen];
239      dis.readFully(rpcBuf);
240
241      // unwrap with Crypto AES
242      rpcBuf = cryptoAES.unwrap(rpcBuf, 0, rpcBuf.length);
243      if (LOG.isDebugEnabled()) {
244        LOG.debug("unwrapping token of length:" + rpcBuf.length);
245      }
246      unwrappedRpcBuffer = ByteBuffer.wrap(rpcBuf);
247    }
248  }
249
250  /**
251   * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has been called.
252   * @return a SASL wrapped OutputStream
253   */
254  public OutputStream getOutputStream() throws IOException {
255    if (!saslClient.isComplete()) {
256      throw new IOException("Sasl authentication exchange hasn't completed yet");
257    }
258    // If Crypto AES is enabled, return cryptoOutputStream which wrap the data with Crypto AES.
259    if (cryptoAesEnable && cryptoOutputStream != null) {
260      return cryptoOutputStream;
261    }
262    return saslOutputStream;
263  }
264
265  class WrappedOutputStream extends FilterOutputStream {
266    public WrappedOutputStream(OutputStream out) throws IOException {
267      super(out);
268    }
269
270    @Override
271    public void write(byte[] buf, int off, int len) throws IOException {
272      if (LOG.isDebugEnabled()) {
273        LOG.debug("wrapping token of length:" + len);
274      }
275
276      // wrap with Crypto AES
277      byte[] wrapped = cryptoAES.wrap(buf, off, len);
278      DataOutputStream dob = new DataOutputStream(out);
279      dob.writeInt(wrapped.length);
280      dob.write(wrapped, 0, wrapped.length);
281      dob.flush();
282    }
283  }
284}