001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.io.compress.zstd;
019
020import com.github.luben.zstd.ZstdDecompressCtx;
021import com.github.luben.zstd.ZstdDictDecompress;
022import edu.umd.cs.findbugs.annotations.Nullable;
023import java.io.IOException;
024import java.nio.ByteBuffer;
025import org.apache.hadoop.conf.Configuration;
026import org.apache.hadoop.hbase.io.compress.BlockDecompressorHelper;
027import org.apache.hadoop.hbase.io.compress.ByteBuffDecompressor;
028import org.apache.hadoop.hbase.io.compress.CanReinit;
029import org.apache.hadoop.hbase.nio.ByteBuff;
030import org.apache.hadoop.hbase.nio.SingleByteBuff;
031import org.apache.yetus.audience.InterfaceAudience;
032
033/**
034 * Glue for ByteBuffDecompressor on top of zstd-jni
035 */
036@InterfaceAudience.Private
037public class ZstdByteBuffDecompressor implements ByteBuffDecompressor, CanReinit {
038
039  protected int dictId;
040  @Nullable
041  protected ZstdDictDecompress dict;
042  protected ZstdDecompressCtx ctx;
043  // Intended to be set to false by some unit tests
044  private boolean allowByteBuffDecompression;
045
046  ZstdByteBuffDecompressor(@Nullable byte[] dictionary) {
047    ctx = new ZstdDecompressCtx();
048    if (dictionary != null) {
049      this.dictId = ZstdCodec.getDictionaryId(dictionary);
050      this.dict = new ZstdDictDecompress(dictionary);
051      this.ctx.loadDict(this.dict);
052    }
053    allowByteBuffDecompression = true;
054  }
055
056  @Override
057  public boolean canDecompress(ByteBuff output, ByteBuff input) {
058    if (!allowByteBuffDecompression) {
059      return false;
060    }
061    if (output instanceof SingleByteBuff && input instanceof SingleByteBuff) {
062      ByteBuffer nioOutput = output.nioByteBuffers()[0];
063      ByteBuffer nioInput = input.nioByteBuffers()[0];
064      if (nioOutput.isDirect() && nioInput.isDirect()) {
065        return true;
066      } else if (!nioOutput.isDirect() && !nioInput.isDirect()) {
067        return true;
068      }
069    }
070
071    return false;
072  }
073
074  @Override
075  public int decompress(ByteBuff output, ByteBuff input, int inputLen) throws IOException {
076    return BlockDecompressorHelper.decompress(output, input, inputLen, this::decompressRaw);
077  }
078
079  private int decompressRaw(ByteBuff output, ByteBuff input, int inputLen) throws IOException {
080    if (output instanceof SingleByteBuff && input instanceof SingleByteBuff) {
081      ByteBuffer nioOutput = output.nioByteBuffers()[0];
082      ByteBuffer nioInput = input.nioByteBuffers()[0];
083      if (nioOutput.isDirect() && nioInput.isDirect()) {
084        return decompressDirectByteBuffers(nioOutput, nioInput, inputLen);
085      } else if (!nioOutput.isDirect() && !nioInput.isDirect()) {
086        return decompressHeapByteBuffers(nioOutput, nioInput, inputLen);
087      }
088    }
089
090    throw new IllegalStateException("One buffer is direct and the other is not, "
091      + "or one or more not SingleByteBuffs. This is not supported");
092  }
093
094  private int decompressDirectByteBuffers(ByteBuffer output, ByteBuffer input, int inputLen) {
095    int origOutputPos = output.position();
096
097    int n = ctx.decompressDirectByteBuffer(output, output.position(),
098      output.limit() - output.position(), input, input.position(), inputLen);
099
100    output.position(origOutputPos + n);
101    return n;
102  }
103
104  private int decompressHeapByteBuffers(ByteBuffer output, ByteBuffer input, int inputLen) {
105    int origOutputPos = output.position();
106
107    int n = ctx.decompressByteArray(output.array(), output.arrayOffset() + output.position(),
108      output.limit() - output.position(), input.array(), input.arrayOffset() + input.position(),
109      inputLen);
110
111    output.position(origOutputPos + n);
112    return n;
113  }
114
115  @Override
116  public void close() {
117    ctx.close();
118    if (dict != null) {
119      dict.close();
120    }
121  }
122
123  @Override
124  public void reinit(Configuration conf) {
125    if (conf != null) {
126      // Dictionary may have changed
127      byte[] b = ZstdCodec.getDictionary(conf);
128      if (b != null) {
129        // Don't casually create dictionary objects; they consume native memory
130        int thisDictId = ZstdCodec.getDictionaryId(b);
131        if (dict == null || dictId != thisDictId) {
132          dictId = thisDictId;
133          ZstdDictDecompress oldDict = dict;
134          dict = new ZstdDictDecompress(b);
135          ctx.loadDict(dict);
136          if (oldDict != null) {
137            oldDict.close();
138          }
139        }
140      } else {
141        ZstdDictDecompress oldDict = dict;
142        dict = null;
143        dictId = 0;
144        // loadDict((byte[]) accepts null to clear the dictionary
145        ctx.loadDict((byte[]) null);
146        if (oldDict != null) {
147          oldDict.close();
148        }
149      }
150
151      // unit test helper
152      this.allowByteBuffDecompression =
153        conf.getBoolean("hbase.io.compress.zstd.allowByteBuffDecompression", true);
154    }
155  }
156}