001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *   http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 */
018
019package org.apache.commons.compress.utils;
020
021import java.io.File;
022import java.io.IOException;
023import java.nio.ByteBuffer;
024import java.nio.channels.ClosedChannelException;
025import java.nio.channels.NonWritableChannelException;
026import java.nio.channels.SeekableByteChannel;
027import java.nio.file.Files;
028import java.nio.file.StandardOpenOption;
029import java.util.ArrayList;
030import java.util.Arrays;
031import java.util.Collections;
032import java.util.List;
033import java.util.Objects;
034
035/**
036 * Read-Only Implementation of {@link SeekableByteChannel} that
037 * concatenates a collection of other {@link SeekableByteChannel}s.
038 *
039 * <p>This is a lose port of <a
040 * href="https://github.com/frugalmechanic/fm-common/blob/master/jvm/src/main/scala/fm/common/MultiReadOnlySeekableByteChannel.scala">MultiReadOnlySeekableByteChannel</a>
041 * by Tim Underwood.</p>
042 *
043 * @since 1.19
044 */
045public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
046
047    private final List<SeekableByteChannel> channels;
048    private long globalPosition;
049    private int currentChannelIdx;
050
051    /**
052     * Concatenates the given channels.
053     *
054     * @param channels the channels to concatenate
055     * @throws NullPointerException if channels is null
056     */
057    public MultiReadOnlySeekableByteChannel(List<SeekableByteChannel> channels) {
058        this.channels = Collections.unmodifiableList(new ArrayList<>(
059            Objects.requireNonNull(channels, "channels must not be null")));
060    }
061
062    @Override
063    public synchronized int read(ByteBuffer dst) throws IOException {
064        if (!isOpen()) {
065            throw new ClosedChannelException();
066        }
067        if (!dst.hasRemaining()) {
068            return 0;
069        }
070
071        int totalBytesRead = 0;
072        while (dst.hasRemaining() && currentChannelIdx < channels.size()) {
073            final SeekableByteChannel currentChannel = channels.get(currentChannelIdx);
074            final int newBytesRead = currentChannel.read(dst);
075            if (newBytesRead == -1) {
076                // EOF for this channel -- advance to next channel idx
077                currentChannelIdx += 1;
078                continue;
079            }
080            if (currentChannel.position() >= currentChannel.size()) {
081                // we are at the end of the current channel
082                currentChannelIdx++;
083            }
084            totalBytesRead += newBytesRead;
085        }
086        if (totalBytesRead > 0) {
087            globalPosition += totalBytesRead;
088            return totalBytesRead;
089        }
090        return -1;
091    }
092
093    @Override
094    public void close() throws IOException {
095        IOException first = null;
096        for (SeekableByteChannel ch : channels) {
097            try {
098                ch.close();
099            } catch (IOException ex) {
100                if (first == null) {
101                    first = ex;
102                }
103            }
104        }
105        if (first != null) {
106            throw new IOException("failed to close wrapped channel", first);
107        }
108    }
109
110    @Override
111    public boolean isOpen() {
112        for (SeekableByteChannel ch : channels) {
113            if (!ch.isOpen()) {
114                return false;
115            }
116        }
117        return true;
118    }
119
120    @Override
121    public long position() {
122        return globalPosition;
123    }
124
125    @Override
126    public long size() throws IOException {
127        long acc = 0;
128        for (SeekableByteChannel ch : channels) {
129            acc += ch.size();
130        }
131        return acc;
132    }
133
134    /**
135     * @throws NonWritableChannelException since this implementation is read-only.
136     */
137    @Override
138    public SeekableByteChannel truncate(long size) {
139        throw new NonWritableChannelException();
140    }
141
142    /**
143     * @throws NonWritableChannelException since this implementation is read-only.
144     */
145    @Override
146    public int write(ByteBuffer src) {
147        throw new NonWritableChannelException();
148    }
149
150    @Override
151    public synchronized SeekableByteChannel position(long newPosition) throws IOException {
152        if (newPosition < 0) {
153            throw new IllegalArgumentException("Negative position: " + newPosition);
154        }
155        if (!isOpen()) {
156            throw new ClosedChannelException();
157        }
158
159        globalPosition = newPosition;
160
161        long pos = newPosition;
162
163        for (int i = 0; i < channels.size(); i++) {
164            final SeekableByteChannel currentChannel = channels.get(i);
165            final long size = currentChannel.size();
166
167            final long newChannelPos;
168            if (pos == -1L) {
169                // Position is already set for the correct channel,
170                // the rest of the channels get reset to 0
171                newChannelPos = 0;
172            } else if (pos <= size) {
173                // This channel is where we want to be
174                currentChannelIdx = i;
175                long tmp = pos;
176                pos = -1L; // Mark pos as already being set
177                newChannelPos = tmp;
178            } else {
179                // newPosition is past this channel.  Set channel
180                // position to the end and substract channel size from
181                // pos
182                pos -= size;
183                newChannelPos = size;
184            }
185
186            currentChannel.position(newChannelPos);
187        }
188        return this;
189    }
190
191    /**
192     * Concatenates the given channels.
193     *
194     * @param channels the channels to concatenate
195     * @throws NullPointerException if channels is null
196     * @return SeekableByteChannel that concatenates all provided channels
197     */
198    public static SeekableByteChannel forSeekableByteChannels(SeekableByteChannel... channels) {
199        if (Objects.requireNonNull(channels, "channels must not be null").length == 1) {
200            return channels[0];
201        }
202        return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
203    }
204
205    /**
206     * Concatenates the given files.
207     *
208     * @param files the files to concatenate
209     * @throws NullPointerException if files is null
210     * @throws IOException if opening a channel for one of the files fails
211     * @return SeekableByteChannel that concatenates all provided files
212     */
213    public static SeekableByteChannel forFiles(File... files) throws IOException {
214        List<SeekableByteChannel> channels = new ArrayList<>();
215        for (File f : Objects.requireNonNull(files, "files must not be null")) {
216            channels.add(Files.newByteChannel(f.toPath(), StandardOpenOption.READ));
217        }
218        if (channels.size() == 1) {
219            return channels.get(0);
220        }
221        return new MultiReadOnlySeekableByteChannel(channels);
222    }
223
224}