Skip to main content

object_store/
limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! An object store that limits the maximum concurrency of the wrapped implementation
19
20use crate::{
21    BoxStream, CopyOptions, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload,
22    ObjectMeta, ObjectStore, Path, PutMultipartOptions, PutOptions, PutPayload, PutResult,
23    RenameOptions, Result, StreamExt, UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures_util::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34/// Store wrapper that wraps an inner store and limits the maximum number of concurrent
35/// object store operations. Where each call to an [`ObjectStore`] member function is
36/// considered a single operation, even if it may result in more than one network call
37///
38/// ```
39/// # use object_store::memory::InMemory;
40/// # use object_store::limit::LimitStore;
41///
42/// // Create an in-memory `ObjectStore` limited to 20 concurrent requests
43/// let store = LimitStore::new(InMemory::new(), 20);
44/// ```
45///
46#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48    inner: Arc<T>,
49    max_requests: usize,
50    semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54    /// Create new limit store that will limit the maximum
55    /// number of outstanding concurrent requests to
56    /// `max_requests`
57    pub fn new(inner: T, max_requests: usize) -> Self {
58        Self {
59            inner: Arc::new(inner),
60            max_requests,
61            semaphore: Arc::new(Semaphore::new(max_requests)),
62        }
63    }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69    }
70}
71
72#[async_trait]
73#[deny(clippy::missing_trait_methods)]
74impl<T: ObjectStore> ObjectStore for LimitStore<T> {
75    async fn put_opts(
76        &self,
77        location: &Path,
78        payload: PutPayload,
79        opts: PutOptions,
80    ) -> Result<PutResult> {
81        let _permit = self.semaphore.acquire().await.unwrap();
82        self.inner.put_opts(location, payload, opts).await
83    }
84
85    async fn put_multipart_opts(
86        &self,
87        location: &Path,
88        opts: PutMultipartOptions,
89    ) -> Result<Box<dyn MultipartUpload>> {
90        let upload = self.inner.put_multipart_opts(location, opts).await?;
91        Ok(Box::new(LimitUpload {
92            semaphore: Arc::clone(&self.semaphore),
93            upload,
94        }))
95    }
96
97    async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
98        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
99        let r = self.inner.get_opts(location, options).await?;
100        Ok(permit_get_result(r, permit))
101    }
102
103    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> Result<Vec<Bytes>> {
104        let _permit = self.semaphore.acquire().await.unwrap();
105        self.inner.get_ranges(location, ranges).await
106    }
107
108    fn delete_stream(
109        &self,
110        locations: BoxStream<'static, Result<Path>>,
111    ) -> BoxStream<'static, Result<Path>> {
112        let inner = Arc::clone(&self.inner);
113        let fut = Arc::clone(&self.semaphore)
114            .acquire_owned()
115            .map(move |permit| {
116                let s = inner.delete_stream(locations);
117                PermitWrapper::new(s, permit.unwrap())
118            });
119        fut.into_stream().flatten().boxed()
120    }
121
122    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result<ObjectMeta>> {
123        let prefix = prefix.cloned();
124        let inner = Arc::clone(&self.inner);
125        let fut = Arc::clone(&self.semaphore)
126            .acquire_owned()
127            .map(move |permit| {
128                let s = inner.list(prefix.as_ref());
129                PermitWrapper::new(s, permit.unwrap())
130            });
131        fut.into_stream().flatten().boxed()
132    }
133
134    fn list_with_offset(
135        &self,
136        prefix: Option<&Path>,
137        offset: &Path,
138    ) -> BoxStream<'static, Result<ObjectMeta>> {
139        let prefix = prefix.cloned();
140        let offset = offset.clone();
141        let inner = Arc::clone(&self.inner);
142        let fut = Arc::clone(&self.semaphore)
143            .acquire_owned()
144            .map(move |permit| {
145                let s = inner.list_with_offset(prefix.as_ref(), &offset);
146                PermitWrapper::new(s, permit.unwrap())
147            });
148        fut.into_stream().flatten().boxed()
149    }
150
151    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
152        let _permit = self.semaphore.acquire().await.unwrap();
153        self.inner.list_with_delimiter(prefix).await
154    }
155
156    async fn copy_opts(&self, from: &Path, to: &Path, options: CopyOptions) -> Result<()> {
157        let _permit = self.semaphore.acquire().await.unwrap();
158        self.inner.copy_opts(from, to, options).await
159    }
160
161    async fn rename_opts(&self, from: &Path, to: &Path, options: RenameOptions) -> Result<()> {
162        let _permit = self.semaphore.acquire().await.unwrap();
163        self.inner.rename_opts(from, to, options).await
164    }
165}
166
167fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
168    let payload = match r.payload {
169        #[cfg(all(feature = "fs", not(target_arch = "wasm32")))]
170        v @ GetResultPayload::File(_, _) => v,
171        GetResultPayload::Stream(s) => {
172            GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
173        }
174    };
175    GetResult { payload, ..r }
176}
177
178/// Combines an [`OwnedSemaphorePermit`] with some other type
179struct PermitWrapper<T> {
180    inner: T,
181    #[allow(dead_code)]
182    permit: OwnedSemaphorePermit,
183}
184
185impl<T> PermitWrapper<T> {
186    fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
187        Self { inner, permit }
188    }
189}
190
191impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
192    type Item = T::Item;
193
194    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195        Pin::new(&mut self.inner).poll_next(cx)
196    }
197
198    fn size_hint(&self) -> (usize, Option<usize>) {
199        self.inner.size_hint()
200    }
201}
202
203/// An [`MultipartUpload`] wrapper that limits the maximum number of concurrent requests
204#[derive(Debug)]
205pub struct LimitUpload {
206    upload: Box<dyn MultipartUpload>,
207    semaphore: Arc<Semaphore>,
208}
209
210impl LimitUpload {
211    /// Create a new [`LimitUpload`] limiting `upload` to `max_concurrency` concurrent requests
212    pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
213        Self {
214            upload,
215            semaphore: Arc::new(Semaphore::new(max_concurrency)),
216        }
217    }
218}
219
220#[async_trait]
221impl MultipartUpload for LimitUpload {
222    fn put_part(&mut self, data: PutPayload) -> UploadPart {
223        let upload = self.upload.put_part(data);
224        let s = Arc::clone(&self.semaphore);
225        Box::pin(async move {
226            let _permit = s.acquire().await.unwrap();
227            upload.await
228        })
229    }
230
231    async fn complete(&mut self) -> Result<PutResult> {
232        let _permit = self.semaphore.acquire().await.unwrap();
233        self.upload.complete().await
234    }
235
236    async fn abort(&mut self) -> Result<()> {
237        let _permit = self.semaphore.acquire().await.unwrap();
238        self.upload.abort().await
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use crate::ObjectStore;
245    use crate::integration::*;
246    use crate::limit::LimitStore;
247    use crate::memory::InMemory;
248    use futures_util::stream::StreamExt;
249    use std::pin::Pin;
250    use std::time::Duration;
251    use tokio::time::timeout;
252
253    #[tokio::test]
254    async fn limit_test() {
255        let max_requests = 10;
256        let memory = InMemory::new();
257        let integration = LimitStore::new(memory, max_requests);
258
259        put_get_delete_list(&integration).await;
260        get_opts(&integration).await;
261        list_uses_directories_correctly(&integration).await;
262        list_with_delimiter(&integration).await;
263        rename_and_copy(&integration).await;
264        stream_get(&integration).await;
265
266        let mut streams = Vec::with_capacity(max_requests);
267        for _ in 0..max_requests {
268            let mut stream = integration.list(None).peekable();
269            Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired
270            streams.push(stream);
271        }
272
273        let t = Duration::from_millis(20);
274
275        // Expect to not be able to make another request
276        let fut = integration.list(None).collect::<Vec<_>>();
277        assert!(timeout(t, fut).await.is_err());
278
279        // Drop one of the streams
280        streams.pop();
281
282        // Can now make another request
283        integration.list(None).collect::<Vec<_>>().await;
284    }
285}