1use crate::{
21 BoxStream, CopyOptions, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload,
22 ObjectMeta, ObjectStore, Path, PutMultipartOptions, PutOptions, PutPayload, PutResult,
23 RenameOptions, Result, StreamExt, UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures_util::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48 inner: Arc<T>,
49 max_requests: usize,
50 semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54 pub fn new(inner: T, max_requests: usize) -> Self {
58 Self {
59 inner: Arc::new(inner),
60 max_requests,
61 semaphore: Arc::new(Semaphore::new(max_requests)),
62 }
63 }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69 }
70}
71
72#[async_trait]
73#[deny(clippy::missing_trait_methods)]
74impl<T: ObjectStore> ObjectStore for LimitStore<T> {
75 async fn put_opts(
76 &self,
77 location: &Path,
78 payload: PutPayload,
79 opts: PutOptions,
80 ) -> Result<PutResult> {
81 let _permit = self.semaphore.acquire().await.unwrap();
82 self.inner.put_opts(location, payload, opts).await
83 }
84
85 async fn put_multipart_opts(
86 &self,
87 location: &Path,
88 opts: PutMultipartOptions,
89 ) -> Result<Box<dyn MultipartUpload>> {
90 let upload = self.inner.put_multipart_opts(location, opts).await?;
91 Ok(Box::new(LimitUpload {
92 semaphore: Arc::clone(&self.semaphore),
93 upload,
94 }))
95 }
96
97 async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
98 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
99 let r = self.inner.get_opts(location, options).await?;
100 Ok(permit_get_result(r, permit))
101 }
102
103 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> Result<Vec<Bytes>> {
104 let _permit = self.semaphore.acquire().await.unwrap();
105 self.inner.get_ranges(location, ranges).await
106 }
107
108 fn delete_stream(
109 &self,
110 locations: BoxStream<'static, Result<Path>>,
111 ) -> BoxStream<'static, Result<Path>> {
112 let inner = Arc::clone(&self.inner);
113 let fut = Arc::clone(&self.semaphore)
114 .acquire_owned()
115 .map(move |permit| {
116 let s = inner.delete_stream(locations);
117 PermitWrapper::new(s, permit.unwrap())
118 });
119 fut.into_stream().flatten().boxed()
120 }
121
122 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result<ObjectMeta>> {
123 let prefix = prefix.cloned();
124 let inner = Arc::clone(&self.inner);
125 let fut = Arc::clone(&self.semaphore)
126 .acquire_owned()
127 .map(move |permit| {
128 let s = inner.list(prefix.as_ref());
129 PermitWrapper::new(s, permit.unwrap())
130 });
131 fut.into_stream().flatten().boxed()
132 }
133
134 fn list_with_offset(
135 &self,
136 prefix: Option<&Path>,
137 offset: &Path,
138 ) -> BoxStream<'static, Result<ObjectMeta>> {
139 let prefix = prefix.cloned();
140 let offset = offset.clone();
141 let inner = Arc::clone(&self.inner);
142 let fut = Arc::clone(&self.semaphore)
143 .acquire_owned()
144 .map(move |permit| {
145 let s = inner.list_with_offset(prefix.as_ref(), &offset);
146 PermitWrapper::new(s, permit.unwrap())
147 });
148 fut.into_stream().flatten().boxed()
149 }
150
151 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
152 let _permit = self.semaphore.acquire().await.unwrap();
153 self.inner.list_with_delimiter(prefix).await
154 }
155
156 async fn copy_opts(&self, from: &Path, to: &Path, options: CopyOptions) -> Result<()> {
157 let _permit = self.semaphore.acquire().await.unwrap();
158 self.inner.copy_opts(from, to, options).await
159 }
160
161 async fn rename_opts(&self, from: &Path, to: &Path, options: RenameOptions) -> Result<()> {
162 let _permit = self.semaphore.acquire().await.unwrap();
163 self.inner.rename_opts(from, to, options).await
164 }
165}
166
167fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
168 let payload = match r.payload {
169 #[cfg(all(feature = "fs", not(target_arch = "wasm32")))]
170 v @ GetResultPayload::File(_, _) => v,
171 GetResultPayload::Stream(s) => {
172 GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
173 }
174 };
175 GetResult { payload, ..r }
176}
177
178struct PermitWrapper<T> {
180 inner: T,
181 #[allow(dead_code)]
182 permit: OwnedSemaphorePermit,
183}
184
185impl<T> PermitWrapper<T> {
186 fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
187 Self { inner, permit }
188 }
189}
190
191impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
192 type Item = T::Item;
193
194 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195 Pin::new(&mut self.inner).poll_next(cx)
196 }
197
198 fn size_hint(&self) -> (usize, Option<usize>) {
199 self.inner.size_hint()
200 }
201}
202
203#[derive(Debug)]
205pub struct LimitUpload {
206 upload: Box<dyn MultipartUpload>,
207 semaphore: Arc<Semaphore>,
208}
209
210impl LimitUpload {
211 pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
213 Self {
214 upload,
215 semaphore: Arc::new(Semaphore::new(max_concurrency)),
216 }
217 }
218}
219
220#[async_trait]
221impl MultipartUpload for LimitUpload {
222 fn put_part(&mut self, data: PutPayload) -> UploadPart {
223 let upload = self.upload.put_part(data);
224 let s = Arc::clone(&self.semaphore);
225 Box::pin(async move {
226 let _permit = s.acquire().await.unwrap();
227 upload.await
228 })
229 }
230
231 async fn complete(&mut self) -> Result<PutResult> {
232 let _permit = self.semaphore.acquire().await.unwrap();
233 self.upload.complete().await
234 }
235
236 async fn abort(&mut self) -> Result<()> {
237 let _permit = self.semaphore.acquire().await.unwrap();
238 self.upload.abort().await
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use crate::ObjectStore;
245 use crate::integration::*;
246 use crate::limit::LimitStore;
247 use crate::memory::InMemory;
248 use futures_util::stream::StreamExt;
249 use std::pin::Pin;
250 use std::time::Duration;
251 use tokio::time::timeout;
252
253 #[tokio::test]
254 async fn limit_test() {
255 let max_requests = 10;
256 let memory = InMemory::new();
257 let integration = LimitStore::new(memory, max_requests);
258
259 put_get_delete_list(&integration).await;
260 get_opts(&integration).await;
261 list_uses_directories_correctly(&integration).await;
262 list_with_delimiter(&integration).await;
263 rename_and_copy(&integration).await;
264 stream_get(&integration).await;
265
266 let mut streams = Vec::with_capacity(max_requests);
267 for _ in 0..max_requests {
268 let mut stream = integration.list(None).peekable();
269 Pin::new(&mut stream).peek().await; streams.push(stream);
271 }
272
273 let t = Duration::from_millis(20);
274
275 let fut = integration.list(None).collect::<Vec<_>>();
277 assert!(timeout(t, fut).await.is_err());
278
279 streams.pop();
281
282 integration.list(None).collect::<Vec<_>>().await;
284 }
285}