Skip to content

Core Module

gigaspatial.core

io

DataStore

Bases: ABC

Abstract base class defining the interface for data store implementations. This class serves as a parent for both local and cloud-based storage solutions.

Source code in gigaspatial/core/io/data_store.py
class DataStore(ABC):
    """
    Abstract base class defining the interface for data store implementations.
    This class serves as a parent for both local and cloud-based storage solutions.
    """

    @abstractmethod
    def read_file(self, path: str) -> Any:
        """
        Read contents of a file from the data store.

        Args:
            path: Path to the file to read

        Returns:
            Contents of the file

        Raises:
            IOError: If file cannot be read
        """
        pass

    @abstractmethod
    def write_file(self, path: str, data: Any) -> None:
        """
        Write data to a file in the data store.

        Args:
            path: Path where to write the file
            data: Data to write to the file

        Raises:
            IOError: If file cannot be written
        """
        pass

    @abstractmethod
    def file_exists(self, path: str) -> bool:
        """
        Check if a file exists in the data store.

        Args:
            path: Path to check

        Returns:
            True if file exists, False otherwise
        """
        pass

    @abstractmethod
    def list_files(self, path: str) -> List[str]:
        """
        List all files in a directory.

        Args:
            path: Directory path to list

        Returns:
            List of file paths in the directory
        """
        pass

    @abstractmethod
    def walk(self, top: str) -> Generator:
        """
        Walk through directory tree, similar to os.walk().

        Args:
            top: Starting directory for the walk

        Returns:
            Generator yielding tuples of (dirpath, dirnames, filenames)
        """
        pass

    @abstractmethod
    def open(self, file: str, mode: str = "r") -> Union[str, bytes]:
        """
        Context manager for file operations.

        Args:
            file: Path to the file
            mode: File mode ('r', 'w', 'rb', 'wb')

        Yields:
            File-like object

        Raises:
            IOError: If file cannot be opened
        """
        pass

    @abstractmethod
    def is_file(self, path: str) -> bool:
        """
        Check if path points to a file.

        Args:
            path: Path to check

        Returns:
            True if path is a file, False otherwise
        """
        pass

    @abstractmethod
    def is_dir(self, path: str) -> bool:
        """
        Check if path points to a directory.

        Args:
            path: Path to check

        Returns:
            True if path is a directory, False otherwise
        """
        pass

    @abstractmethod
    def remove(self, path: str) -> None:
        """
        Remove a file.

        Args:
            path: Path to the file to remove

        Raises:
            IOError: If file cannot be removed
        """
        pass

    @abstractmethod
    def rmdir(self, dir: str) -> None:
        """
        Remove a directory and all its contents.

        Args:
            dir: Path to the directory to remove

        Raises:
            IOError: If directory cannot be removed
        """
        pass
file_exists(path) abstractmethod

Check if a file exists in the data store.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if file exists, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def file_exists(self, path: str) -> bool:
    """
    Check if a file exists in the data store.

    Args:
        path: Path to check

    Returns:
        True if file exists, False otherwise
    """
    pass
is_dir(path) abstractmethod

Check if path points to a directory.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if path is a directory, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def is_dir(self, path: str) -> bool:
    """
    Check if path points to a directory.

    Args:
        path: Path to check

    Returns:
        True if path is a directory, False otherwise
    """
    pass
is_file(path) abstractmethod

Check if path points to a file.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if path is a file, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def is_file(self, path: str) -> bool:
    """
    Check if path points to a file.

    Args:
        path: Path to check

    Returns:
        True if path is a file, False otherwise
    """
    pass
list_files(path) abstractmethod

List all files in a directory.

Parameters:

Name Type Description Default
path str

Directory path to list

required

Returns:

Type Description
List[str]

List of file paths in the directory

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def list_files(self, path: str) -> List[str]:
    """
    List all files in a directory.

    Args:
        path: Directory path to list

    Returns:
        List of file paths in the directory
    """
    pass
open(file, mode='r') abstractmethod

Context manager for file operations.

Parameters:

Name Type Description Default
file str

Path to the file

required
mode str

File mode ('r', 'w', 'rb', 'wb')

'r'

Yields:

Type Description
Union[str, bytes]

File-like object

Raises:

Type Description
IOError

If file cannot be opened

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def open(self, file: str, mode: str = "r") -> Union[str, bytes]:
    """
    Context manager for file operations.

    Args:
        file: Path to the file
        mode: File mode ('r', 'w', 'rb', 'wb')

    Yields:
        File-like object

    Raises:
        IOError: If file cannot be opened
    """
    pass
read_file(path) abstractmethod

Read contents of a file from the data store.

Parameters:

Name Type Description Default
path str

Path to the file to read

required

Returns:

Type Description
Any

Contents of the file

Raises:

Type Description
IOError

If file cannot be read

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def read_file(self, path: str) -> Any:
    """
    Read contents of a file from the data store.

    Args:
        path: Path to the file to read

    Returns:
        Contents of the file

    Raises:
        IOError: If file cannot be read
    """
    pass
remove(path) abstractmethod

Remove a file.

Parameters:

Name Type Description Default
path str

Path to the file to remove

required

Raises:

Type Description
IOError

If file cannot be removed

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def remove(self, path: str) -> None:
    """
    Remove a file.

    Args:
        path: Path to the file to remove

    Raises:
        IOError: If file cannot be removed
    """
    pass
rmdir(dir) abstractmethod

Remove a directory and all its contents.

Parameters:

Name Type Description Default
dir str

Path to the directory to remove

required

Raises:

Type Description
IOError

If directory cannot be removed

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def rmdir(self, dir: str) -> None:
    """
    Remove a directory and all its contents.

    Args:
        dir: Path to the directory to remove

    Raises:
        IOError: If directory cannot be removed
    """
    pass
walk(top) abstractmethod

Walk through directory tree, similar to os.walk().

Parameters:

Name Type Description Default
top str

Starting directory for the walk

required

Returns:

Type Description
Generator

Generator yielding tuples of (dirpath, dirnames, filenames)

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def walk(self, top: str) -> Generator:
    """
    Walk through directory tree, similar to os.walk().

    Args:
        top: Starting directory for the walk

    Returns:
        Generator yielding tuples of (dirpath, dirnames, filenames)
    """
    pass
write_file(path, data) abstractmethod

Write data to a file in the data store.

Parameters:

Name Type Description Default
path str

Path where to write the file

required
data Any

Data to write to the file

required

Raises:

Type Description
IOError

If file cannot be written

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def write_file(self, path: str, data: Any) -> None:
    """
    Write data to a file in the data store.

    Args:
        path: Path where to write the file
        data: Data to write to the file

    Raises:
        IOError: If file cannot be written
    """
    pass

read_dataset(data_store, path, compression=None, **kwargs)

Read data from various file formats stored in both local and cloud-based storage.

Parameters:

data_store : DataStore Instance of DataStore for accessing data storage. path : str, Path Path to the file in data storage. **kwargs : dict Additional arguments passed to the specific reader function.

Returns:

pandas.DataFrame or geopandas.GeoDataFrame The data read from the file.

Raises:

FileNotFoundError If the file doesn't exist in blob storage. ValueError If the file type is unsupported or if there's an error reading the file.

Source code in gigaspatial/core/io/readers.py
def read_dataset(data_store: DataStore, path: str, compression: str = None, **kwargs):
    """
    Read data from various file formats stored in both local and cloud-based storage.

    Parameters:
    ----------
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    path : str, Path
        Path to the file in data storage.
    **kwargs : dict
        Additional arguments passed to the specific reader function.

    Returns:
    -------
    pandas.DataFrame or geopandas.GeoDataFrame
        The data read from the file.

    Raises:
    ------
    FileNotFoundError
        If the file doesn't exist in blob storage.
    ValueError
        If the file type is unsupported or if there's an error reading the file.
    """

    # Define supported file formats and their readers
    BINARY_FORMATS = {
        ".shp",
        ".zip",
        ".parquet",
        ".gpkg",
        ".xlsx",
        ".xls",
        ".kmz",
        ".gz",
    }

    PANDAS_READERS = {
        ".csv": pd.read_csv,
        ".xlsx": lambda f, **kw: pd.read_excel(f, engine="openpyxl", **kw),
        ".xls": lambda f, **kw: pd.read_excel(f, engine="xlrd", **kw),
        ".json": pd.read_json,
        # ".gz": lambda f, **kw: pd.read_csv(f, compression="gzip", **kw),
    }

    GEO_READERS = {
        ".shp": gpd.read_file,
        ".zip": gpd.read_file,
        ".geojson": gpd.read_file,
        ".gpkg": gpd.read_file,
        ".parquet": gpd.read_parquet,
        ".kmz": read_kmz,
    }

    COMPRESSION_FORMATS = {
        ".gz": "gzip",
        ".bz2": "bz2",
        ".zip": "zip",
        ".xz": "xz",
    }

    try:
        # Check if file exists
        if not data_store.file_exists(path):
            raise FileNotFoundError(f"File '{path}' not found in blob storage")

        path_obj = Path(path)
        suffixes = path_obj.suffixes
        file_extension = suffixes[-1].lower() if suffixes else ""

        if compression is None and file_extension in COMPRESSION_FORMATS:
            compression_format = COMPRESSION_FORMATS[file_extension]

            # if file has multiple extensions (e.g., .csv.gz), get the inner format
            if len(suffixes) > 1:
                inner_extension = suffixes[-2].lower()

                if inner_extension == ".tar":
                    raise ValueError(
                        "Tar archives (.tar.gz) are not directly supported"
                    )

                if inner_extension in PANDAS_READERS:
                    try:
                        with data_store.open(path, "rb") as f:
                            return PANDAS_READERS[inner_extension](
                                f, compression=compression_format, **kwargs
                            )
                    except Exception as e:
                        raise ValueError(f"Error reading compressed file: {str(e)}")
                elif inner_extension in GEO_READERS:
                    try:
                        with data_store.open(path, "rb") as f:
                            if compression_format == "gzip":
                                import gzip

                                decompressed_data = gzip.decompress(f.read())
                                import io

                                return GEO_READERS[inner_extension](
                                    io.BytesIO(decompressed_data), **kwargs
                                )
                            else:
                                raise ValueError(
                                    f"Compression format {compression_format} not supported for geo data"
                                )
                    except Exception as e:
                        raise ValueError(f"Error reading compressed geo file: {str(e)}")
            else:
                # if just .gz without clear inner type, assume csv
                try:
                    with data_store.open(path, "rb") as f:
                        return pd.read_csv(f, compression=compression_format, **kwargs)
                except Exception as e:
                    raise ValueError(
                        f"Error reading compressed file as CSV: {str(e)}. "
                        f"If not a CSV, specify the format in the filename (e.g., .json.gz)"
                    )

        # Special handling for compressed files
        if file_extension == ".zip":
            # For zip files, we need to use binary mode
            with data_store.open(path, "rb") as f:
                return gpd.read_file(f)

        # Determine if we need binary mode based on file type
        mode = "rb" if file_extension in BINARY_FORMATS else "r"

        # Try reading with appropriate reader
        if file_extension in PANDAS_READERS:
            try:
                with data_store.open(path, mode) as f:
                    return PANDAS_READERS[file_extension](f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error reading file with pandas: {str(e)}")

        if file_extension in GEO_READERS:
            try:
                with data_store.open(path, "rb") as f:
                    return GEO_READERS[file_extension](f, **kwargs)
            except Exception as e:
                # For parquet files, try pandas reader if geopandas fails
                if file_extension == ".parquet":
                    try:
                        with data_store.open(path, "rb") as f:
                            return pd.read_parquet(f, **kwargs)
                    except Exception as e2:
                        raise ValueError(
                            f"Failed to read parquet with both geopandas ({str(e)}) "
                            f"and pandas ({str(e2)})"
                        )
                raise ValueError(f"Error reading file with geopandas: {str(e)}")

        # If we get here, the file type is unsupported
        supported_formats = sorted(set(PANDAS_READERS.keys()) | set(GEO_READERS.keys()))
        supported_compressions = sorted(COMPRESSION_FORMATS.keys())
        raise ValueError(
            f"Unsupported file type: {file_extension}\n"
            f"Supported formats: {', '.join(supported_formats)}"
            f"Supported compressions: {', '.join(supported_compressions)}"
        )

    except Exception as e:
        if isinstance(e, (FileNotFoundError, ValueError)):
            raise
        raise RuntimeError(f"Unexpected error reading dataset: {str(e)}")

read_datasets(data_store, paths, **kwargs)

Read multiple datasets from data storage at once.

Parameters:

data_store : DataStore Instance of DataStore for accessing data storage. paths : list of str Paths to files in data storage. **kwargs : dict Additional arguments passed to read_dataset.

Returns:

dict Dictionary mapping paths to their corresponding DataFrames/GeoDataFrames.

Source code in gigaspatial/core/io/readers.py
def read_datasets(data_store: DataStore, paths, **kwargs):
    """
    Read multiple datasets from data storage at once.

    Parameters:
    ----------
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    paths : list of str
        Paths to files in data storage.
    **kwargs : dict
        Additional arguments passed to read_dataset.

    Returns:
    -------
    dict
        Dictionary mapping paths to their corresponding DataFrames/GeoDataFrames.
    """
    results = {}
    errors = {}

    for path in paths:
        try:
            results[path] = read_dataset(data_store, path, **kwargs)
        except Exception as e:
            errors[path] = str(e)

    if errors:
        error_msg = "\n".join(f"- {path}: {error}" for path, error in errors.items())
        raise ValueError(f"Errors reading datasets:\n{error_msg}")

    return results

read_gzipped_json_or_csv(file_path, data_store)

Reads a gzipped file, attempting to parse it as JSON (lines=True) or CSV.

Source code in gigaspatial/core/io/readers.py
def read_gzipped_json_or_csv(file_path, data_store):
    """Reads a gzipped file, attempting to parse it as JSON (lines=True) or CSV."""

    with data_store.open(file_path, "rb") as f:
        g = gzip.GzipFile(fileobj=f)
        text = g.read().decode("utf-8")
        try:
            df = pd.read_json(io.StringIO(text), lines=True)
            return df
        except json.JSONDecodeError:
            try:
                df = pd.read_csv(io.StringIO(text))
                return df
            except pd.errors.ParserError:
                print(f"Error: Could not parse {file_path} as JSON or CSV.")
                return None

read_kmz(file_obj, **kwargs)

Helper function to read KMZ files and return a GeoDataFrame.

Source code in gigaspatial/core/io/readers.py
def read_kmz(file_obj, **kwargs):
    """Helper function to read KMZ files and return a GeoDataFrame."""
    try:
        with zipfile.ZipFile(file_obj) as kmz:
            # Find the KML file in the archive (usually doc.kml)
            kml_filename = next(
                name for name in kmz.namelist() if name.endswith(".kml")
            )

            # Read the KML content
            kml_content = io.BytesIO(kmz.read(kml_filename))

            gdf = gpd.read_file(kml_content)

            # Validate the GeoDataFrame
            if gdf.empty:
                raise ValueError(
                    "The KML file is empty or does not contain valid geospatial data."
                )

        return gdf

    except zipfile.BadZipFile:
        raise ValueError("The provided file is not a valid KMZ file.")
    except StopIteration:
        raise ValueError("No KML file found in the KMZ archive.")
    except Exception as e:
        raise RuntimeError(f"An error occurred: {e}")

write_dataset(data, data_store, path, **kwargs)

Write DataFrame, GeoDataFrame, or a generic object (for JSON) to various file formats in DataStore.

Parameters:

data : pandas.DataFrame, geopandas.GeoDataFrame, or any object The data to write to data storage. data_store : DataStore Instance of DataStore for accessing data storage. path : str Path where the file will be written in data storage. **kwargs : dict Additional arguments passed to the specific writer function.

Raises:

ValueError If the file type is unsupported or if there's an error writing the file. TypeError If input data is not a DataFrame, GeoDataFrame, AND not a generic object intended for a .json file.

Source code in gigaspatial/core/io/writers.py
def write_dataset(data, data_store: DataStore, path, **kwargs):
    """
    Write DataFrame, GeoDataFrame, or a generic object (for JSON)
    to various file formats in DataStore.

    Parameters:
    ----------
    data : pandas.DataFrame, geopandas.GeoDataFrame, or any object
        The data to write to data storage.
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    path : str
        Path where the file will be written in data storage.
    **kwargs : dict
        Additional arguments passed to the specific writer function.

    Raises:
    ------
    ValueError
        If the file type is unsupported or if there's an error writing the file.
    TypeError
            If input data is not a DataFrame, GeoDataFrame, AND not a generic object
            intended for a .json file.
    """

    # Define supported file formats and their writers
    BINARY_FORMATS = {".shp", ".zip", ".parquet", ".gpkg", ".xlsx", ".xls"}

    PANDAS_WRITERS = {
        ".csv": lambda df, buf, **kw: df.to_csv(buf, **kw),
        ".xlsx": lambda df, buf, **kw: df.to_excel(buf, engine="openpyxl", **kw),
        ".json": lambda df, buf, **kw: df.to_json(buf, **kw),
        ".parquet": lambda df, buf, **kw: df.to_parquet(buf, **kw),
    }

    GEO_WRITERS = {
        ".geojson": lambda gdf, buf, **kw: gdf.to_file(buf, driver="GeoJSON", **kw),
        ".gpkg": lambda gdf, buf, **kw: gdf.to_file(buf, driver="GPKG", **kw),
        ".parquet": lambda gdf, buf, **kw: gdf.to_parquet(buf, **kw),
    }

    try:
        # Get file suffix and ensure it's lowercase
        suffix = Path(path).suffix.lower()

        # 1. Handle generic JSON data
        is_dataframe_like = isinstance(data, (pd.DataFrame, gpd.GeoDataFrame))
        if not is_dataframe_like:
            if suffix == ".json":
                try:
                    # Pass generic data directly to the write_json function
                    write_json(data, data_store, path, **kwargs)
                    return  # Successfully wrote JSON, so exit
                except Exception as e:
                    raise ValueError(f"Error writing generic JSON data: {str(e)}")
            else:
                # Raise an error if it's not a DataFrame/GeoDataFrame and not a .json file
                raise TypeError(
                    "Input data must be a pandas DataFrame or GeoDataFrame, "
                    "or a generic object destined for a '.json' file."
                )

        # 2. Handle DataFrame/GeoDataFrame
        # Determine if we need binary mode based on file type
        mode = "wb" if suffix in BINARY_FORMATS else "w"

        # Handle different data types and formats
        if isinstance(data, gpd.GeoDataFrame):
            if suffix not in GEO_WRITERS:
                supported_formats = sorted(GEO_WRITERS.keys())
                raise ValueError(
                    f"Unsupported file type for GeoDataFrame: {suffix}\n"
                    f"Supported formats: {', '.join(supported_formats)}"
                )

            try:
                with data_store.open(path, "wb") as f:
                    GEO_WRITERS[suffix](data, f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error writing GeoDataFrame: {str(e)}")

        else:  # pandas DataFrame
            if suffix not in PANDAS_WRITERS:
                supported_formats = sorted(PANDAS_WRITERS.keys())
                raise ValueError(
                    f"Unsupported file type for DataFrame: {suffix}\n"
                    f"Supported formats: {', '.join(supported_formats)}"
                )

            try:
                with data_store.open(path, mode) as f:
                    PANDAS_WRITERS[suffix](data, f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error writing DataFrame: {str(e)}")

    except Exception as e:
        if isinstance(e, (TypeError, ValueError)):
            raise
        raise RuntimeError(f"Unexpected error writing dataset: {str(e)}")

write_datasets(data_dict, data_store, **kwargs)

Write multiple datasets to data storage at once.

Parameters:

data_dict : dict Dictionary mapping paths to DataFrames/GeoDataFrames. data_store : DataStore Instance of DataStore for accessing data storage. **kwargs : dict Additional arguments passed to write_dataset.

Raises:

ValueError If there are any errors writing the datasets.

Source code in gigaspatial/core/io/writers.py
def write_datasets(data_dict, data_store: DataStore, **kwargs):
    """
    Write multiple datasets to data storage at once.

    Parameters:
    ----------
    data_dict : dict
        Dictionary mapping paths to DataFrames/GeoDataFrames.
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    **kwargs : dict
        Additional arguments passed to write_dataset.

    Raises:
    ------
    ValueError
        If there are any errors writing the datasets.
    """
    errors = {}

    for path, data in data_dict.items():
        try:
            write_dataset(data, data_store, path, **kwargs)
        except Exception as e:
            errors[path] = str(e)

    if errors:
        error_msg = "\n".join(f"- {path}: {error}" for path, error in errors.items())
        raise ValueError(f"Errors writing datasets:\n{error_msg}")

adls_data_store

ADLSDataStore

Bases: DataStore

An implementation of DataStore for Azure Data Lake Storage.

Source code in gigaspatial/core/io/adls_data_store.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class ADLSDataStore(DataStore):
    """
    An implementation of DataStore for Azure Data Lake Storage.
    """

    def __init__(
        self,
        container: str = config.ADLS_CONTAINER_NAME,
        connection_string: str = config.ADLS_CONNECTION_STRING,
        account_url: str = config.ADLS_ACCOUNT_URL,
        sas_token: str = config.ADLS_SAS_TOKEN,
    ):
        """
        Create a new instance of ADLSDataStore
        :param container: The name of the container in ADLS to interact with.
        """
        if connection_string:
            self.blob_service_client = BlobServiceClient.from_connection_string(
                connection_string
            )
        elif account_url and sas_token:
            self.blob_service_client = BlobServiceClient(
                account_url=account_url, credential=sas_token
            )
        else:
            raise ValueError(
                "Either connection_string or account_url and sas_token must be provided."
            )

        self.container_client = self.blob_service_client.get_container_client(
            container=container
        )
        self.container = container

    def read_file(self, path: str, encoding: Optional[str] = None) -> Union[str, bytes]:
        """
        Read file with flexible encoding support.

        :param path: Path to the file in blob storage
        :param encoding: File encoding (optional)
        :return: File contents as string or bytes
        """
        try:
            blob_client = self.container_client.get_blob_client(path)
            blob_data = blob_client.download_blob().readall()

            # If no encoding specified, return raw bytes
            if encoding is None:
                return blob_data

            # If encoding is specified, decode the bytes
            return blob_data.decode(encoding)

        except Exception as e:
            raise IOError(f"Error reading file {path}: {e}")

    def write_file(self, path: str, data) -> None:
        """
        Write file with support for content type and improved type handling.

        :param path: Destination path in blob storage
        :param data: File contents
        """
        blob_client = self.blob_service_client.get_blob_client(
            container=self.container, blob=path, snapshot=None
        )

        if isinstance(data, str):
            binary_data = data.encode()
        elif isinstance(data, bytes):
            binary_data = data
        else:
            raise Exception(f'Unsupported data type. Only "bytes" or "string" accepted')

        blob_client.upload_blob(binary_data, overwrite=True)

    def upload_file(self, file_path, blob_path):
        """Uploads a single file to Azure Blob Storage."""
        try:
            blob_client = self.container_client.get_blob_client(blob_path)
            with open(file_path, "rb") as data:
                blob_client.upload_blob(data, overwrite=True)
            print(f"Uploaded {file_path} to {blob_path}")
        except Exception as e:
            print(f"Failed to upload {file_path}: {e}")

    def upload_directory(self, dir_path, blob_dir_path):
        """Uploads all files from a directory to Azure Blob Storage."""
        for root, dirs, files in os.walk(dir_path):
            for file in files:
                local_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(local_file_path, dir_path)
                blob_file_path = os.path.join(blob_dir_path, relative_path).replace(
                    "\\", "/"
                )

                self.upload_file(local_file_path, blob_file_path)

    def download_directory(self, blob_dir_path: str, local_dir_path: str):
        """Downloads all files from a directory in Azure Blob Storage to a local directory."""
        try:
            # Ensure the local directory exists
            os.makedirs(local_dir_path, exist_ok=True)

            # List all files in the blob directory
            blob_items = self.container_client.list_blobs(
                name_starts_with=blob_dir_path
            )

            for blob_item in blob_items:
                # Get the relative path of the blob file
                relative_path = os.path.relpath(blob_item.name, blob_dir_path)
                # Construct the local file path
                local_file_path = os.path.join(local_dir_path, relative_path)
                # Create directories if needed
                os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                # Download the blob to the local file
                blob_client = self.container_client.get_blob_client(blob_item.name)
                with open(local_file_path, "wb") as file:
                    file.write(blob_client.download_blob().readall())

            print(f"Downloaded directory {blob_dir_path} to {local_dir_path}")
        except Exception as e:
            print(f"Failed to download directory {blob_dir_path}: {e}")

    def copy_directory(self, source_dir: str, destination_dir: str):
        """
        Copies all files from a source directory to a destination directory within the same container.

        :param source_dir: The source directory path in the blob storage
        :param destination_dir: The destination directory path in the blob storage
        """
        try:
            # Ensure source directory path ends with a trailing slash
            source_dir = source_dir.rstrip("/") + "/"
            destination_dir = destination_dir.rstrip("/") + "/"

            # List all blobs in the source directory
            source_blobs = self.container_client.list_blobs(name_starts_with=source_dir)

            for blob in source_blobs:
                # Get the relative path of the blob
                relative_path = os.path.relpath(blob.name, source_dir)

                # Construct the new blob path
                new_blob_path = os.path.join(destination_dir, relative_path).replace(
                    "\\", "/"
                )

                # Use copy_file method to copy each file
                self.copy_file(blob.name, new_blob_path, overwrite=True)

            print(f"Copied directory from {source_dir} to {destination_dir}")
        except Exception as e:
            print(f"Failed to copy directory {source_dir}: {e}")

    def copy_file(
        self, source_path: str, destination_path: str, overwrite: bool = False
    ):
        """
        Copies a single file from source to destination within the same container.

        :param source_path: The source file path in the blob storage
        :param destination_path: The destination file path in the blob storage
        :param overwrite: If True, overwrite the destination file if it already exists
        """
        try:
            if not self.file_exists(source_path):
                raise FileNotFoundError(f"Source file not found: {source_path}")

            if self.file_exists(destination_path) and not overwrite:
                raise FileExistsError(
                    f"Destination file already exists and overwrite is False: {destination_path}"
                )

            # Create source and destination blob clients
            source_blob_client = self.container_client.get_blob_client(source_path)
            destination_blob_client = self.container_client.get_blob_client(
                destination_path
            )

            # Start the server-side copy operation
            destination_blob_client.start_copy_from_url(source_blob_client.url)

            print(f"Copied file from {source_path} to {destination_path}")
        except Exception as e:
            print(f"Failed to copy file {source_path}: {e}")
            raise

    def exists(self, path: str) -> bool:
        blob_client = self.blob_service_client.get_blob_client(
            container=self.container, blob=path, snapshot=None
        )
        return blob_client.exists()

    def file_exists(self, path: str) -> bool:
        return self.exists(path) and not self.is_dir(path)

    def file_size(self, path: str) -> float:
        blob_client = self.blob_service_client.get_blob_client(
            container=self.container, blob=path, snapshot=None
        )
        properties = blob_client.get_blob_properties()

        # The size is in bytes, convert it to kilobytes
        size_in_bytes = properties.size
        size_in_kb = size_in_bytes / 1024.0
        return size_in_kb

    def list_files(self, path: str):
        blob_items = self.container_client.list_blobs(name_starts_with=path)
        return [item["name"] for item in blob_items]

    def walk(self, top: str):
        top = top.rstrip("/") + "/"
        blob_items = self.container_client.list_blobs(name_starts_with=top)
        blobs = [item["name"] for item in blob_items]
        for blob in blobs:
            dirpath, filename = os.path.split(blob)
            yield (dirpath, [], [filename])

    def list_directories(self, path: str) -> list:
        """List only directory names (not files) from a given path in ADLS."""
        search_path = path.rstrip("/") + "/" if path else ""

        blob_items = self.container_client.list_blobs(name_starts_with=search_path)

        directories = set()

        for blob_item in blob_items:
            # Get the relative path from the search path
            relative_path = blob_item.name[len(search_path) :]

            # Skip if it's empty (shouldn't happen but just in case)
            if not relative_path:
                continue

            # If there's a "/" in the relative path, it means there's a subdirectory
            if "/" in relative_path:
                # Get the first directory name
                dir_name = relative_path.split("/")[0]
                directories.add(dir_name)

        return sorted(list(directories))

    @contextlib.contextmanager
    def open(self, path: str, mode: str = "r"):
        """
        Context manager for file operations with enhanced mode support.

        :param path: File path in blob storage
        :param mode: File open mode (r, rb, w, wb)
        """
        if mode == "w":
            file = io.StringIO()
            yield file
            self.write_file(path, file.getvalue())

        elif mode == "wb":
            file = io.BytesIO()
            yield file
            self.write_file(path, file.getvalue())

        elif mode == "r":
            data = self.read_file(path, encoding="UTF-8")
            file = io.StringIO(data)
            yield file

        elif mode == "rb":
            data = self.read_file(path)
            file = io.BytesIO(data)
            yield file

    def get_file_metadata(self, path: str) -> dict:
        """
        Retrieve comprehensive file metadata.

        :param path: File path in blob storage
        :return: File metadata dictionary
        """
        blob_client = self.container_client.get_blob_client(path)
        properties = blob_client.get_blob_properties()

        return {
            "name": path,
            "size_bytes": properties.size,
            "content_type": properties.content_settings.content_type,
            "last_modified": properties.last_modified,
            "etag": properties.etag,
        }

    def is_file(self, path: str) -> bool:
        return self.file_exists(path)

    def is_dir(self, path: str) -> bool:
        dir_path = path.rstrip("/") + "/"

        existing_blobs = self.list_files(dir_path)

        if len(existing_blobs) > 1:
            return True
        elif len(existing_blobs) == 1:
            if existing_blobs[0] != path.rstrip("/"):
                return True

        return False

    def rmdir(self, dir: str) -> None:
        # Normalize directory path to ensure it targets all children
        dir_path = dir.rstrip("/") + "/"

        # Azure Blob batch delete has a hard limit on number of sub-requests
        # per batch (currently 256). Delete in chunks to avoid
        # ExceedsMaxBatchRequestCount errors.
        blobs = list(self.list_files(dir_path))
        if not blobs:
            return

        BATCH_LIMIT = 256
        for start_idx in range(0, len(blobs), BATCH_LIMIT):
            batch = blobs[start_idx : start_idx + BATCH_LIMIT]
            self.container_client.delete_blobs(*batch)

    def mkdir(self, path: str, exist_ok: bool = False) -> None:
        """
        Create a directory in Azure Blob Storage.

        In ADLS, directories are conceptual and created by adding a placeholder blob.

        :param path: Path of the directory to create
        :param exist_ok: If False, raise an error if the directory already exists
        """
        dir_path = path.rstrip("/") + "/"

        existing_blobs = list(self.list_files(dir_path))

        if existing_blobs and not exist_ok:
            raise FileExistsError(f"Directory {path} already exists")

        # Create a placeholder blob to represent the directory
        placeholder_blob_path = os.path.join(dir_path, ".placeholder")

        # Only create placeholder if it doesn't already exist
        if not self.file_exists(placeholder_blob_path):
            placeholder_content = (
                b"This is a placeholder blob to represent a directory."
            )
            blob_client = self.blob_service_client.get_blob_client(
                container=self.container, blob=placeholder_blob_path
            )
            blob_client.upload_blob(placeholder_content, overwrite=True)

    def remove(self, path: str) -> None:
        blob_client = self.blob_service_client.get_blob_client(
            container=self.container, blob=path, snapshot=None
        )
        if blob_client.exists():
            blob_client.delete_blob()

    def rename(
        self,
        source_path: str,
        destination_path: str,
        overwrite: bool = False,
        delete_source: bool = True,
        wait: bool = True,
        timeout_seconds: int = 300,
        poll_interval_seconds: int = 1,
    ) -> None:
        """
        Rename (move) a single file by copying to the new path and deleting the source.

        :param source_path: Existing blob path
        :param destination_path: Target blob path
        :param overwrite: Overwrite destination if it already exists
        :param delete_source: Delete original after successful copy
        :param wait: Wait for the copy operation to complete
        :param timeout_seconds: Max time to wait for copy to succeed
        :param poll_interval_seconds: Polling interval while waiting
        """

        if not self.file_exists(source_path):
            raise FileNotFoundError(f"Source file not found: {source_path}")

        if self.file_exists(destination_path) and not overwrite:
            raise FileExistsError(
                f"Destination already exists and overwrite is False: {destination_path}"
            )

        # Use copy_file method to copy the file
        self.copy_file(source_path, destination_path, overwrite=overwrite)

        if wait:
            # Wait for copy to complete if requested
            dest_client = self.container_client.get_blob_client(destination_path)
            deadline = time.time() + timeout_seconds
            while True:
                props = dest_client.get_blob_properties()
                status = getattr(props.copy, "status", None)
                if status == "success":
                    break
                if status in {"aborted", "failed"}:
                    raise IOError(
                        f"Copy failed with status {status} from {source_path} to {destination_path}"
                    )
                if time.time() > deadline:
                    raise TimeoutError(
                        f"Timed out waiting for copy to complete for {destination_path}"
                    )
                time.sleep(poll_interval_seconds)

        if delete_source:
            self.remove(source_path)
__init__(container=config.ADLS_CONTAINER_NAME, connection_string=config.ADLS_CONNECTION_STRING, account_url=config.ADLS_ACCOUNT_URL, sas_token=config.ADLS_SAS_TOKEN)

Create a new instance of ADLSDataStore :param container: The name of the container in ADLS to interact with.

Source code in gigaspatial/core/io/adls_data_store.py
def __init__(
    self,
    container: str = config.ADLS_CONTAINER_NAME,
    connection_string: str = config.ADLS_CONNECTION_STRING,
    account_url: str = config.ADLS_ACCOUNT_URL,
    sas_token: str = config.ADLS_SAS_TOKEN,
):
    """
    Create a new instance of ADLSDataStore
    :param container: The name of the container in ADLS to interact with.
    """
    if connection_string:
        self.blob_service_client = BlobServiceClient.from_connection_string(
            connection_string
        )
    elif account_url and sas_token:
        self.blob_service_client = BlobServiceClient(
            account_url=account_url, credential=sas_token
        )
    else:
        raise ValueError(
            "Either connection_string or account_url and sas_token must be provided."
        )

    self.container_client = self.blob_service_client.get_container_client(
        container=container
    )
    self.container = container
copy_directory(source_dir, destination_dir)

Copies all files from a source directory to a destination directory within the same container.

:param source_dir: The source directory path in the blob storage :param destination_dir: The destination directory path in the blob storage

Source code in gigaspatial/core/io/adls_data_store.py
def copy_directory(self, source_dir: str, destination_dir: str):
    """
    Copies all files from a source directory to a destination directory within the same container.

    :param source_dir: The source directory path in the blob storage
    :param destination_dir: The destination directory path in the blob storage
    """
    try:
        # Ensure source directory path ends with a trailing slash
        source_dir = source_dir.rstrip("/") + "/"
        destination_dir = destination_dir.rstrip("/") + "/"

        # List all blobs in the source directory
        source_blobs = self.container_client.list_blobs(name_starts_with=source_dir)

        for blob in source_blobs:
            # Get the relative path of the blob
            relative_path = os.path.relpath(blob.name, source_dir)

            # Construct the new blob path
            new_blob_path = os.path.join(destination_dir, relative_path).replace(
                "\\", "/"
            )

            # Use copy_file method to copy each file
            self.copy_file(blob.name, new_blob_path, overwrite=True)

        print(f"Copied directory from {source_dir} to {destination_dir}")
    except Exception as e:
        print(f"Failed to copy directory {source_dir}: {e}")
copy_file(source_path, destination_path, overwrite=False)

Copies a single file from source to destination within the same container.

:param source_path: The source file path in the blob storage :param destination_path: The destination file path in the blob storage :param overwrite: If True, overwrite the destination file if it already exists

Source code in gigaspatial/core/io/adls_data_store.py
def copy_file(
    self, source_path: str, destination_path: str, overwrite: bool = False
):
    """
    Copies a single file from source to destination within the same container.

    :param source_path: The source file path in the blob storage
    :param destination_path: The destination file path in the blob storage
    :param overwrite: If True, overwrite the destination file if it already exists
    """
    try:
        if not self.file_exists(source_path):
            raise FileNotFoundError(f"Source file not found: {source_path}")

        if self.file_exists(destination_path) and not overwrite:
            raise FileExistsError(
                f"Destination file already exists and overwrite is False: {destination_path}"
            )

        # Create source and destination blob clients
        source_blob_client = self.container_client.get_blob_client(source_path)
        destination_blob_client = self.container_client.get_blob_client(
            destination_path
        )

        # Start the server-side copy operation
        destination_blob_client.start_copy_from_url(source_blob_client.url)

        print(f"Copied file from {source_path} to {destination_path}")
    except Exception as e:
        print(f"Failed to copy file {source_path}: {e}")
        raise
download_directory(blob_dir_path, local_dir_path)

Downloads all files from a directory in Azure Blob Storage to a local directory.

Source code in gigaspatial/core/io/adls_data_store.py
def download_directory(self, blob_dir_path: str, local_dir_path: str):
    """Downloads all files from a directory in Azure Blob Storage to a local directory."""
    try:
        # Ensure the local directory exists
        os.makedirs(local_dir_path, exist_ok=True)

        # List all files in the blob directory
        blob_items = self.container_client.list_blobs(
            name_starts_with=blob_dir_path
        )

        for blob_item in blob_items:
            # Get the relative path of the blob file
            relative_path = os.path.relpath(blob_item.name, blob_dir_path)
            # Construct the local file path
            local_file_path = os.path.join(local_dir_path, relative_path)
            # Create directories if needed
            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

            # Download the blob to the local file
            blob_client = self.container_client.get_blob_client(blob_item.name)
            with open(local_file_path, "wb") as file:
                file.write(blob_client.download_blob().readall())

        print(f"Downloaded directory {blob_dir_path} to {local_dir_path}")
    except Exception as e:
        print(f"Failed to download directory {blob_dir_path}: {e}")
get_file_metadata(path)

Retrieve comprehensive file metadata.

:param path: File path in blob storage :return: File metadata dictionary

Source code in gigaspatial/core/io/adls_data_store.py
def get_file_metadata(self, path: str) -> dict:
    """
    Retrieve comprehensive file metadata.

    :param path: File path in blob storage
    :return: File metadata dictionary
    """
    blob_client = self.container_client.get_blob_client(path)
    properties = blob_client.get_blob_properties()

    return {
        "name": path,
        "size_bytes": properties.size,
        "content_type": properties.content_settings.content_type,
        "last_modified": properties.last_modified,
        "etag": properties.etag,
    }
list_directories(path)

List only directory names (not files) from a given path in ADLS.

Source code in gigaspatial/core/io/adls_data_store.py
def list_directories(self, path: str) -> list:
    """List only directory names (not files) from a given path in ADLS."""
    search_path = path.rstrip("/") + "/" if path else ""

    blob_items = self.container_client.list_blobs(name_starts_with=search_path)

    directories = set()

    for blob_item in blob_items:
        # Get the relative path from the search path
        relative_path = blob_item.name[len(search_path) :]

        # Skip if it's empty (shouldn't happen but just in case)
        if not relative_path:
            continue

        # If there's a "/" in the relative path, it means there's a subdirectory
        if "/" in relative_path:
            # Get the first directory name
            dir_name = relative_path.split("/")[0]
            directories.add(dir_name)

    return sorted(list(directories))
mkdir(path, exist_ok=False)

Create a directory in Azure Blob Storage.

In ADLS, directories are conceptual and created by adding a placeholder blob.

:param path: Path of the directory to create :param exist_ok: If False, raise an error if the directory already exists

Source code in gigaspatial/core/io/adls_data_store.py
def mkdir(self, path: str, exist_ok: bool = False) -> None:
    """
    Create a directory in Azure Blob Storage.

    In ADLS, directories are conceptual and created by adding a placeholder blob.

    :param path: Path of the directory to create
    :param exist_ok: If False, raise an error if the directory already exists
    """
    dir_path = path.rstrip("/") + "/"

    existing_blobs = list(self.list_files(dir_path))

    if existing_blobs and not exist_ok:
        raise FileExistsError(f"Directory {path} already exists")

    # Create a placeholder blob to represent the directory
    placeholder_blob_path = os.path.join(dir_path, ".placeholder")

    # Only create placeholder if it doesn't already exist
    if not self.file_exists(placeholder_blob_path):
        placeholder_content = (
            b"This is a placeholder blob to represent a directory."
        )
        blob_client = self.blob_service_client.get_blob_client(
            container=self.container, blob=placeholder_blob_path
        )
        blob_client.upload_blob(placeholder_content, overwrite=True)
open(path, mode='r')

Context manager for file operations with enhanced mode support.

:param path: File path in blob storage :param mode: File open mode (r, rb, w, wb)

Source code in gigaspatial/core/io/adls_data_store.py
@contextlib.contextmanager
def open(self, path: str, mode: str = "r"):
    """
    Context manager for file operations with enhanced mode support.

    :param path: File path in blob storage
    :param mode: File open mode (r, rb, w, wb)
    """
    if mode == "w":
        file = io.StringIO()
        yield file
        self.write_file(path, file.getvalue())

    elif mode == "wb":
        file = io.BytesIO()
        yield file
        self.write_file(path, file.getvalue())

    elif mode == "r":
        data = self.read_file(path, encoding="UTF-8")
        file = io.StringIO(data)
        yield file

    elif mode == "rb":
        data = self.read_file(path)
        file = io.BytesIO(data)
        yield file
read_file(path, encoding=None)

Read file with flexible encoding support.

:param path: Path to the file in blob storage :param encoding: File encoding (optional) :return: File contents as string or bytes

Source code in gigaspatial/core/io/adls_data_store.py
def read_file(self, path: str, encoding: Optional[str] = None) -> Union[str, bytes]:
    """
    Read file with flexible encoding support.

    :param path: Path to the file in blob storage
    :param encoding: File encoding (optional)
    :return: File contents as string or bytes
    """
    try:
        blob_client = self.container_client.get_blob_client(path)
        blob_data = blob_client.download_blob().readall()

        # If no encoding specified, return raw bytes
        if encoding is None:
            return blob_data

        # If encoding is specified, decode the bytes
        return blob_data.decode(encoding)

    except Exception as e:
        raise IOError(f"Error reading file {path}: {e}")
rename(source_path, destination_path, overwrite=False, delete_source=True, wait=True, timeout_seconds=300, poll_interval_seconds=1)

Rename (move) a single file by copying to the new path and deleting the source.

:param source_path: Existing blob path :param destination_path: Target blob path :param overwrite: Overwrite destination if it already exists :param delete_source: Delete original after successful copy :param wait: Wait for the copy operation to complete :param timeout_seconds: Max time to wait for copy to succeed :param poll_interval_seconds: Polling interval while waiting

Source code in gigaspatial/core/io/adls_data_store.py
def rename(
    self,
    source_path: str,
    destination_path: str,
    overwrite: bool = False,
    delete_source: bool = True,
    wait: bool = True,
    timeout_seconds: int = 300,
    poll_interval_seconds: int = 1,
) -> None:
    """
    Rename (move) a single file by copying to the new path and deleting the source.

    :param source_path: Existing blob path
    :param destination_path: Target blob path
    :param overwrite: Overwrite destination if it already exists
    :param delete_source: Delete original after successful copy
    :param wait: Wait for the copy operation to complete
    :param timeout_seconds: Max time to wait for copy to succeed
    :param poll_interval_seconds: Polling interval while waiting
    """

    if not self.file_exists(source_path):
        raise FileNotFoundError(f"Source file not found: {source_path}")

    if self.file_exists(destination_path) and not overwrite:
        raise FileExistsError(
            f"Destination already exists and overwrite is False: {destination_path}"
        )

    # Use copy_file method to copy the file
    self.copy_file(source_path, destination_path, overwrite=overwrite)

    if wait:
        # Wait for copy to complete if requested
        dest_client = self.container_client.get_blob_client(destination_path)
        deadline = time.time() + timeout_seconds
        while True:
            props = dest_client.get_blob_properties()
            status = getattr(props.copy, "status", None)
            if status == "success":
                break
            if status in {"aborted", "failed"}:
                raise IOError(
                    f"Copy failed with status {status} from {source_path} to {destination_path}"
                )
            if time.time() > deadline:
                raise TimeoutError(
                    f"Timed out waiting for copy to complete for {destination_path}"
                )
            time.sleep(poll_interval_seconds)

    if delete_source:
        self.remove(source_path)
upload_directory(dir_path, blob_dir_path)

Uploads all files from a directory to Azure Blob Storage.

Source code in gigaspatial/core/io/adls_data_store.py
def upload_directory(self, dir_path, blob_dir_path):
    """Uploads all files from a directory to Azure Blob Storage."""
    for root, dirs, files in os.walk(dir_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            relative_path = os.path.relpath(local_file_path, dir_path)
            blob_file_path = os.path.join(blob_dir_path, relative_path).replace(
                "\\", "/"
            )

            self.upload_file(local_file_path, blob_file_path)
upload_file(file_path, blob_path)

Uploads a single file to Azure Blob Storage.

Source code in gigaspatial/core/io/adls_data_store.py
def upload_file(self, file_path, blob_path):
    """Uploads a single file to Azure Blob Storage."""
    try:
        blob_client = self.container_client.get_blob_client(blob_path)
        with open(file_path, "rb") as data:
            blob_client.upload_blob(data, overwrite=True)
        print(f"Uploaded {file_path} to {blob_path}")
    except Exception as e:
        print(f"Failed to upload {file_path}: {e}")
write_file(path, data)

Write file with support for content type and improved type handling.

:param path: Destination path in blob storage :param data: File contents

Source code in gigaspatial/core/io/adls_data_store.py
def write_file(self, path: str, data) -> None:
    """
    Write file with support for content type and improved type handling.

    :param path: Destination path in blob storage
    :param data: File contents
    """
    blob_client = self.blob_service_client.get_blob_client(
        container=self.container, blob=path, snapshot=None
    )

    if isinstance(data, str):
        binary_data = data.encode()
    elif isinstance(data, bytes):
        binary_data = data
    else:
        raise Exception(f'Unsupported data type. Only "bytes" or "string" accepted')

    blob_client.upload_blob(binary_data, overwrite=True)

data_api

GigaDataAPI
Source code in gigaspatial/core/io/data_api.py
class GigaDataAPI:

    def __init__(
        self,
        profile_file: Union[str, Path] = config.API_PROFILE_FILE_PATH,
        share_name: str = config.API_SHARE_NAME,
        schema_name: str = config.API_SCHEMA_NAME,
    ):
        """
        Initialize the GigaDataAPI class with the profile file, share name, and schema name.

        profile_file: Path to the delta-sharing profile file.
        share_name: Name of the share (e.g., "gold").
        schema_name: Name of the schema (e.g., "school-master").
        """
        self.profile_file = profile_file
        self.share_name = share_name
        self.schema_name = schema_name
        self.client = delta_sharing.SharingClient(profile_file)

        self._cache = {}

    def get_country_list(self, sort=True):
        """
        Retrieve a list of available countries in the dataset.

        :param sort: Whether to sort the country list alphabetically (default is True).
        """
        country_list = [
            t.name
            for t in self.client.list_all_tables()
            if t.schema == self.schema_name
        ]
        if sort:
            country_list.sort()
        return country_list

    def load_country_data(self, country, filters=None, use_cache=True):
        """
        Load the dataset for the specified country with optional filtering and caching.

        country: The country code (e.g., "MWI").
        filters: A dictionary with column names as keys and filter values as values.
        use_cache: Whether to use cached data if available (default is True).
        """
        # Check if data is cached
        if use_cache and country in self._cache:
            df_country = self._cache[country]
        else:
            # Load data from the API
            table_url = (
                f"{self.profile_file}#{self.share_name}.{self.schema_name}.{country}"
            )
            df_country = delta_sharing.load_as_pandas(table_url)
            self._cache[country] = df_country  # Cache the data

        # Apply filters if provided
        if filters:
            for column, value in filters.items():
                df_country = df_country[df_country[column] == value]

        return df_country

    def load_multiple_countries(self, countries):
        """
        Load data for multiple countries and combine them into a single DataFrame.

        countries: A list of country codes.
        """
        df_list = []
        for country in countries:
            df_list.append(self.load_country_data(country))
        return pd.concat(df_list, ignore_index=True)

    def get_country_metadata(self, country):
        """
        Retrieve metadata (e.g., column names and data types) for a country's dataset.

        country: The country code (e.g., "MWI").
        """
        df_country = self.load_country_data(country)
        metadata = {
            "columns": df_country.columns.tolist(),
            "data_types": df_country.dtypes.to_dict(),
            "num_records": len(df_country),
        }
        return metadata

    def get_all_cached_data_as_dict(self):
        """
        Retrieve all cached data in a dictionary format, where each key is a country code,
        and the value is the DataFrame of that country.
        """
        return self._cache if self._cache else {}

    def get_all_cached_data_as_json(self):
        """
        Retrieve all cached data in a JSON-like format. Each country is represented as a key,
        and the value is a list of records (i.e., the DataFrame's `to_dict(orient='records')` format).
        """
        if not self._cache:
            return {}

        # Convert each DataFrame in the cache to a JSON-like format (list of records)
        return {
            country: df.to_dict(orient="records") for country, df in self._cache.items()
        }
__init__(profile_file=config.API_PROFILE_FILE_PATH, share_name=config.API_SHARE_NAME, schema_name=config.API_SCHEMA_NAME)

Initialize the GigaDataAPI class with the profile file, share name, and schema name.

profile_file: Path to the delta-sharing profile file. share_name: Name of the share (e.g., "gold"). schema_name: Name of the schema (e.g., "school-master").

Source code in gigaspatial/core/io/data_api.py
def __init__(
    self,
    profile_file: Union[str, Path] = config.API_PROFILE_FILE_PATH,
    share_name: str = config.API_SHARE_NAME,
    schema_name: str = config.API_SCHEMA_NAME,
):
    """
    Initialize the GigaDataAPI class with the profile file, share name, and schema name.

    profile_file: Path to the delta-sharing profile file.
    share_name: Name of the share (e.g., "gold").
    schema_name: Name of the schema (e.g., "school-master").
    """
    self.profile_file = profile_file
    self.share_name = share_name
    self.schema_name = schema_name
    self.client = delta_sharing.SharingClient(profile_file)

    self._cache = {}
get_all_cached_data_as_dict()

Retrieve all cached data in a dictionary format, where each key is a country code, and the value is the DataFrame of that country.

Source code in gigaspatial/core/io/data_api.py
def get_all_cached_data_as_dict(self):
    """
    Retrieve all cached data in a dictionary format, where each key is a country code,
    and the value is the DataFrame of that country.
    """
    return self._cache if self._cache else {}
get_all_cached_data_as_json()

Retrieve all cached data in a JSON-like format. Each country is represented as a key, and the value is a list of records (i.e., the DataFrame's to_dict(orient='records') format).

Source code in gigaspatial/core/io/data_api.py
def get_all_cached_data_as_json(self):
    """
    Retrieve all cached data in a JSON-like format. Each country is represented as a key,
    and the value is a list of records (i.e., the DataFrame's `to_dict(orient='records')` format).
    """
    if not self._cache:
        return {}

    # Convert each DataFrame in the cache to a JSON-like format (list of records)
    return {
        country: df.to_dict(orient="records") for country, df in self._cache.items()
    }
get_country_list(sort=True)

Retrieve a list of available countries in the dataset.

:param sort: Whether to sort the country list alphabetically (default is True).

Source code in gigaspatial/core/io/data_api.py
def get_country_list(self, sort=True):
    """
    Retrieve a list of available countries in the dataset.

    :param sort: Whether to sort the country list alphabetically (default is True).
    """
    country_list = [
        t.name
        for t in self.client.list_all_tables()
        if t.schema == self.schema_name
    ]
    if sort:
        country_list.sort()
    return country_list
get_country_metadata(country)

Retrieve metadata (e.g., column names and data types) for a country's dataset.

country: The country code (e.g., "MWI").

Source code in gigaspatial/core/io/data_api.py
def get_country_metadata(self, country):
    """
    Retrieve metadata (e.g., column names and data types) for a country's dataset.

    country: The country code (e.g., "MWI").
    """
    df_country = self.load_country_data(country)
    metadata = {
        "columns": df_country.columns.tolist(),
        "data_types": df_country.dtypes.to_dict(),
        "num_records": len(df_country),
    }
    return metadata
load_country_data(country, filters=None, use_cache=True)

Load the dataset for the specified country with optional filtering and caching.

country: The country code (e.g., "MWI"). filters: A dictionary with column names as keys and filter values as values. use_cache: Whether to use cached data if available (default is True).

Source code in gigaspatial/core/io/data_api.py
def load_country_data(self, country, filters=None, use_cache=True):
    """
    Load the dataset for the specified country with optional filtering and caching.

    country: The country code (e.g., "MWI").
    filters: A dictionary with column names as keys and filter values as values.
    use_cache: Whether to use cached data if available (default is True).
    """
    # Check if data is cached
    if use_cache and country in self._cache:
        df_country = self._cache[country]
    else:
        # Load data from the API
        table_url = (
            f"{self.profile_file}#{self.share_name}.{self.schema_name}.{country}"
        )
        df_country = delta_sharing.load_as_pandas(table_url)
        self._cache[country] = df_country  # Cache the data

    # Apply filters if provided
    if filters:
        for column, value in filters.items():
            df_country = df_country[df_country[column] == value]

    return df_country
load_multiple_countries(countries)

Load data for multiple countries and combine them into a single DataFrame.

countries: A list of country codes.

Source code in gigaspatial/core/io/data_api.py
def load_multiple_countries(self, countries):
    """
    Load data for multiple countries and combine them into a single DataFrame.

    countries: A list of country codes.
    """
    df_list = []
    for country in countries:
        df_list.append(self.load_country_data(country))
    return pd.concat(df_list, ignore_index=True)

data_store

DataStore

Bases: ABC

Abstract base class defining the interface for data store implementations. This class serves as a parent for both local and cloud-based storage solutions.

Source code in gigaspatial/core/io/data_store.py
class DataStore(ABC):
    """
    Abstract base class defining the interface for data store implementations.
    This class serves as a parent for both local and cloud-based storage solutions.
    """

    @abstractmethod
    def read_file(self, path: str) -> Any:
        """
        Read contents of a file from the data store.

        Args:
            path: Path to the file to read

        Returns:
            Contents of the file

        Raises:
            IOError: If file cannot be read
        """
        pass

    @abstractmethod
    def write_file(self, path: str, data: Any) -> None:
        """
        Write data to a file in the data store.

        Args:
            path: Path where to write the file
            data: Data to write to the file

        Raises:
            IOError: If file cannot be written
        """
        pass

    @abstractmethod
    def file_exists(self, path: str) -> bool:
        """
        Check if a file exists in the data store.

        Args:
            path: Path to check

        Returns:
            True if file exists, False otherwise
        """
        pass

    @abstractmethod
    def list_files(self, path: str) -> List[str]:
        """
        List all files in a directory.

        Args:
            path: Directory path to list

        Returns:
            List of file paths in the directory
        """
        pass

    @abstractmethod
    def walk(self, top: str) -> Generator:
        """
        Walk through directory tree, similar to os.walk().

        Args:
            top: Starting directory for the walk

        Returns:
            Generator yielding tuples of (dirpath, dirnames, filenames)
        """
        pass

    @abstractmethod
    def open(self, file: str, mode: str = "r") -> Union[str, bytes]:
        """
        Context manager for file operations.

        Args:
            file: Path to the file
            mode: File mode ('r', 'w', 'rb', 'wb')

        Yields:
            File-like object

        Raises:
            IOError: If file cannot be opened
        """
        pass

    @abstractmethod
    def is_file(self, path: str) -> bool:
        """
        Check if path points to a file.

        Args:
            path: Path to check

        Returns:
            True if path is a file, False otherwise
        """
        pass

    @abstractmethod
    def is_dir(self, path: str) -> bool:
        """
        Check if path points to a directory.

        Args:
            path: Path to check

        Returns:
            True if path is a directory, False otherwise
        """
        pass

    @abstractmethod
    def remove(self, path: str) -> None:
        """
        Remove a file.

        Args:
            path: Path to the file to remove

        Raises:
            IOError: If file cannot be removed
        """
        pass

    @abstractmethod
    def rmdir(self, dir: str) -> None:
        """
        Remove a directory and all its contents.

        Args:
            dir: Path to the directory to remove

        Raises:
            IOError: If directory cannot be removed
        """
        pass
file_exists(path) abstractmethod

Check if a file exists in the data store.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if file exists, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def file_exists(self, path: str) -> bool:
    """
    Check if a file exists in the data store.

    Args:
        path: Path to check

    Returns:
        True if file exists, False otherwise
    """
    pass
is_dir(path) abstractmethod

Check if path points to a directory.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if path is a directory, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def is_dir(self, path: str) -> bool:
    """
    Check if path points to a directory.

    Args:
        path: Path to check

    Returns:
        True if path is a directory, False otherwise
    """
    pass
is_file(path) abstractmethod

Check if path points to a file.

Parameters:

Name Type Description Default
path str

Path to check

required

Returns:

Type Description
bool

True if path is a file, False otherwise

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def is_file(self, path: str) -> bool:
    """
    Check if path points to a file.

    Args:
        path: Path to check

    Returns:
        True if path is a file, False otherwise
    """
    pass
list_files(path) abstractmethod

List all files in a directory.

Parameters:

Name Type Description Default
path str

Directory path to list

required

Returns:

Type Description
List[str]

List of file paths in the directory

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def list_files(self, path: str) -> List[str]:
    """
    List all files in a directory.

    Args:
        path: Directory path to list

    Returns:
        List of file paths in the directory
    """
    pass
open(file, mode='r') abstractmethod

Context manager for file operations.

Parameters:

Name Type Description Default
file str

Path to the file

required
mode str

File mode ('r', 'w', 'rb', 'wb')

'r'

Yields:

Type Description
Union[str, bytes]

File-like object

Raises:

Type Description
IOError

If file cannot be opened

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def open(self, file: str, mode: str = "r") -> Union[str, bytes]:
    """
    Context manager for file operations.

    Args:
        file: Path to the file
        mode: File mode ('r', 'w', 'rb', 'wb')

    Yields:
        File-like object

    Raises:
        IOError: If file cannot be opened
    """
    pass
read_file(path) abstractmethod

Read contents of a file from the data store.

Parameters:

Name Type Description Default
path str

Path to the file to read

required

Returns:

Type Description
Any

Contents of the file

Raises:

Type Description
IOError

If file cannot be read

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def read_file(self, path: str) -> Any:
    """
    Read contents of a file from the data store.

    Args:
        path: Path to the file to read

    Returns:
        Contents of the file

    Raises:
        IOError: If file cannot be read
    """
    pass
remove(path) abstractmethod

Remove a file.

Parameters:

Name Type Description Default
path str

Path to the file to remove

required

Raises:

Type Description
IOError

If file cannot be removed

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def remove(self, path: str) -> None:
    """
    Remove a file.

    Args:
        path: Path to the file to remove

    Raises:
        IOError: If file cannot be removed
    """
    pass
rmdir(dir) abstractmethod

Remove a directory and all its contents.

Parameters:

Name Type Description Default
dir str

Path to the directory to remove

required

Raises:

Type Description
IOError

If directory cannot be removed

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def rmdir(self, dir: str) -> None:
    """
    Remove a directory and all its contents.

    Args:
        dir: Path to the directory to remove

    Raises:
        IOError: If directory cannot be removed
    """
    pass
walk(top) abstractmethod

Walk through directory tree, similar to os.walk().

Parameters:

Name Type Description Default
top str

Starting directory for the walk

required

Returns:

Type Description
Generator

Generator yielding tuples of (dirpath, dirnames, filenames)

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def walk(self, top: str) -> Generator:
    """
    Walk through directory tree, similar to os.walk().

    Args:
        top: Starting directory for the walk

    Returns:
        Generator yielding tuples of (dirpath, dirnames, filenames)
    """
    pass
write_file(path, data) abstractmethod

Write data to a file in the data store.

Parameters:

Name Type Description Default
path str

Path where to write the file

required
data Any

Data to write to the file

required

Raises:

Type Description
IOError

If file cannot be written

Source code in gigaspatial/core/io/data_store.py
@abstractmethod
def write_file(self, path: str, data: Any) -> None:
    """
    Write data to a file in the data store.

    Args:
        path: Path where to write the file
        data: Data to write to the file

    Raises:
        IOError: If file cannot be written
    """
    pass

database

DBConnection

A unified database connection class supporting both Trino and PostgreSQL.

Source code in gigaspatial/core/io/database.py
class DBConnection:
    """
    A unified database connection class supporting both Trino and PostgreSQL.
    """

    DB_CONFIG = global_config.DB_CONFIG or {}

    def __init__(
        self,
        db_type: Literal["postgresql", "trino"] = DB_CONFIG.get(
            "db_type", "postgresql"
        ),
        host: Optional[str] = DB_CONFIG.get("host", None),
        port: Union[int, str] = DB_CONFIG.get("port", None),  # type: ignore
        user: Optional[str] = DB_CONFIG.get("user", None),
        password: Optional[str] = DB_CONFIG.get("password", None),
        catalog: Optional[str] = DB_CONFIG.get("catalog", None),  # For Trino
        database: Optional[str] = DB_CONFIG.get("database", None),  # For PostgreSQL
        schema: str = DB_CONFIG.get("schema", "public"),  # Default for PostgreSQL
        http_scheme: str = DB_CONFIG.get("http_scheme", "https"),  # For Trino
        sslmode: str = DB_CONFIG.get("sslmode", "require"),  # For PostgreSQL
        **kwargs,
    ):
        """
        Initialize a database connection for either Trino or PostgreSQL.

        Args:
            db_type: Either "trino" or "postgresql"
            host: Database server host
            port: Database server port
            user: Username
            password: Password
            catalog: Trino catalog name
            database: PostgreSQL database name
            schema: Default schema name
            http_scheme: For Trino ("http" or "https")
            sslmode: For PostgreSQL (e.g., "require", "verify-full")
            **kwargs: Additional connection parameters
        """
        self.db_type = db_type.lower()
        self.host = host
        self.port = str(port) if port else None
        self.user = user
        self.password = quote_plus(password) if password else None
        self.default_schema = schema

        if self.db_type == "trino":
            self.catalog = catalog
            self.http_scheme = http_scheme
            self.engine = self._create_trino_engine(**kwargs)
        elif self.db_type == "postgresql":
            self.database = database
            self.sslmode = sslmode
            self.engine = self._create_postgresql_engine(**kwargs)
        else:
            raise ValueError(f"Unsupported database type: {db_type}")

        self._add_event_listener()

    def _create_trino_engine(self, **kwargs) -> Engine:
        """Create a Trino SQLAlchemy engine."""
        self._connection_string = (
            f"trino://{self.user}:{self.password}@{self.host}:{self.port}/"
            f"{self.catalog}/{self.default_schema}"
        )
        return create_engine(
            self._connection_string,
            connect_args={"http_scheme": self.http_scheme},
            **kwargs,
        )

    def _create_postgresql_engine(self, **kwargs) -> Engine:
        """Create a PostgreSQL SQLAlchemy engine."""
        self._connection_string = (
            f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/"
            f"{self.database}?sslmode={self.sslmode}"
        )
        return create_engine(self._connection_string, **kwargs)

    def _add_event_listener(self):
        """Add event listeners for schema setting."""
        if self.db_type == "trino":

            @event.listens_for(self.engine, "connect", insert=True)
            def set_current_schema(dbapi_connection, connection_record):
                cursor_obj = dbapi_connection.cursor()
                try:
                    cursor_obj.execute(f"USE {self.default_schema}")
                except Exception as e:
                    warnings.warn(f"Could not set schema to {self.default_schema}: {e}")
                finally:
                    cursor_obj.close()

    def get_connection_string(self) -> str:
        """
        Returns the connection string used to create the engine.

        Returns:
            str: The connection string.
        """
        return self._connection_string

    def get_schema_names(self) -> List[str]:
        """Get list of all schema names."""
        inspector = inspect(self.engine)
        return inspector.get_schema_names()

    def get_table_names(self, schema: Optional[str] = None) -> List[str]:
        """Get list of table names in a schema."""
        schema = schema or self.default_schema
        inspector = inspect(self.engine)
        return inspector.get_table_names(schema=schema)

    def get_view_names(self, schema: Optional[str] = None) -> List[str]:
        """Get list of view names in a schema."""
        schema = schema or self.default_schema
        inspector = inspect(self.engine)
        return inspector.get_view_names(schema=schema)

    def get_column_names(
        self, table_name: str, schema: Optional[str] = None
    ) -> List[str]:
        """Get column names for a specific table."""
        if "." in table_name:
            schema, table_name = table_name.split(".")
        else:
            schema = schema or self.default_schema

        inspector = inspect(self.engine)
        columns = inspector.get_columns(table_name, schema=schema)
        return [col["name"] for col in columns]

    def get_table_info(
        self, table_name: str, schema: Optional[str] = None
    ) -> List[Dict]:
        """Get detailed column information for a table."""
        if "." in table_name:
            schema, table_name = table_name.split(".")
        else:
            schema = schema or self.default_schema

        inspector = inspect(self.engine)
        return inspector.get_columns(table_name, schema=schema)

    def get_primary_keys(
        self, table_name: str, schema: Optional[str] = None
    ) -> List[str]:
        """Get primary key columns for a table."""
        if "." in table_name:
            schema, table_name = table_name.split(".")
        else:
            schema = schema or self.default_schema

        inspector = inspect(self.engine)
        try:
            return inspector.get_pk_constraint(table_name, schema=schema)[
                "constrained_columns"
            ]
        except:
            return []  # Some databases may not support PK constraints

    def table_exists(self, table_name: str, schema: Optional[str] = None) -> bool:
        """Check if a table exists."""
        if "." in table_name:
            schema, table_name = table_name.split(".")
        else:
            schema = schema or self.default_schema

        return table_name in self.get_table_names(schema=schema)

    # PostgreSQL-specific methods
    def get_extensions(self) -> List[str]:
        """Get list of installed PostgreSQL extensions (PostgreSQL only)."""
        if self.db_type != "postgresql":
            raise NotImplementedError(
                "This method is only available for PostgreSQL connections"
            )

        with self.engine.connect() as conn:
            result = conn.execute("SELECT extname FROM pg_extension")
            return [row[0] for row in result]

    def execute_query(
        self, query: str, fetch_results: bool = True, params: Optional[Dict] = None
    ) -> Union[List[tuple], None]:
        """
        Executes a SQL query (works for both PostgreSQL and Trino).

        Args:
            query: SQL query to execute
            fetch_results: Whether to fetch results
            params: Parameters for parameterized queries

        Returns:
            Results as list of tuples or None
        """
        try:
            with self.engine.connect() as connection:
                stmt = text(query)
                result = (
                    connection.execute(stmt, params)
                    if params
                    else connection.execute(stmt)
                )

                if fetch_results and result.returns_rows:
                    return result.fetchall()
                return None
        except SQLAlchemyError as e:
            print(f"Error executing query: {e}")
            raise

    def test_connection(self) -> bool:
        """
        Tests the database connection (works for both PostgreSQL and Trino).

        Returns:
            True if connection successful, False otherwise
        """
        test_query = (
            "SELECT 1"
            if self.db_type == "postgresql"
            else "SELECT 1 AS connection_test"
        )

        try:
            print(
                f"Attempting to connect to {self.db_type} at {self.host}:{self.port}..."
            )
            with self.engine.connect() as conn:
                conn.execute(text(test_query))
            print(f"Successfully connected to {self.db_type.upper()}.")
            return True
        except Exception as e:
            print(f"Failed to connect to {self.db_type.upper()}: {e}")
            return False

    def read_sql_to_dataframe(
        self, query: str, params: Optional[Dict] = None
    ) -> pd.DataFrame:
        """
        Executes query and returns results as pandas DataFrame (works for both).

        Args:
            query: SQL query to execute
            params: Parameters for parameterized queries

        Returns:
            pandas DataFrame with results
        """
        try:
            with self.engine.connect() as connection:
                return pd.read_sql_query(text(query), connection, params=params)
        except SQLAlchemyError as e:
            print(f"Error reading SQL to DataFrame: {e}")
            raise

    def read_sql_to_dask_dataframe(
        self,
        table_name: str,
        index_col: str,
        columns: Optional[List[str]] = None,
        limit: Optional[int] = None,
        **kwargs,
    ) -> pd.DataFrame:
        """
        Reads data to Dask DataFrame (works for both, but connection string differs).

        Args:
            table_name: Table name (schema.table or just table)
            columns: List of columns to select
            limit: Maximum rows to return
            **kwargs: Additional arguments

        Returns:
            Dask DataFrame with results
        """
        try:
            connection_string = self.get_connection_string()

            # Handle schema.table format
            if "." in table_name:
                schema, table = table_name.split(".")
            else:
                schema = self.default_schema
                table = table_name

            metadata = MetaData()
            table_obj = Table(table, metadata, schema=schema, autoload_with=self.engine)

            # Build query
            query = (
                select(*[table_obj.c[col] for col in columns])
                if columns
                else select(table_obj)
            )
            if limit:
                query = query.limit(limit)

            return dd.read_sql_query(
                sql=query, con=connection_string, index_col=index_col, **kwargs
            )
        except Exception as e:
            print(f"Error reading SQL to Dask DataFrame: {e}")
            raise ValueError(f"Failed to read SQL to Dask DataFrame: {e}") from e
__init__(db_type=DB_CONFIG.get('db_type', 'postgresql'), host=DB_CONFIG.get('host', None), port=DB_CONFIG.get('port', None), user=DB_CONFIG.get('user', None), password=DB_CONFIG.get('password', None), catalog=DB_CONFIG.get('catalog', None), database=DB_CONFIG.get('database', None), schema=DB_CONFIG.get('schema', 'public'), http_scheme=DB_CONFIG.get('http_scheme', 'https'), sslmode=DB_CONFIG.get('sslmode', 'require'), **kwargs)

Initialize a database connection for either Trino or PostgreSQL.

Parameters:

Name Type Description Default
db_type Literal['postgresql', 'trino']

Either "trino" or "postgresql"

get('db_type', 'postgresql')
host Optional[str]

Database server host

get('host', None)
port Union[int, str]

Database server port

get('port', None)
user Optional[str]

Username

get('user', None)
password Optional[str]

Password

get('password', None)
catalog Optional[str]

Trino catalog name

get('catalog', None)
database Optional[str]

PostgreSQL database name

get('database', None)
schema str

Default schema name

get('schema', 'public')
http_scheme str

For Trino ("http" or "https")

get('http_scheme', 'https')
sslmode str

For PostgreSQL (e.g., "require", "verify-full")

get('sslmode', 'require')
**kwargs

Additional connection parameters

{}
Source code in gigaspatial/core/io/database.py
def __init__(
    self,
    db_type: Literal["postgresql", "trino"] = DB_CONFIG.get(
        "db_type", "postgresql"
    ),
    host: Optional[str] = DB_CONFIG.get("host", None),
    port: Union[int, str] = DB_CONFIG.get("port", None),  # type: ignore
    user: Optional[str] = DB_CONFIG.get("user", None),
    password: Optional[str] = DB_CONFIG.get("password", None),
    catalog: Optional[str] = DB_CONFIG.get("catalog", None),  # For Trino
    database: Optional[str] = DB_CONFIG.get("database", None),  # For PostgreSQL
    schema: str = DB_CONFIG.get("schema", "public"),  # Default for PostgreSQL
    http_scheme: str = DB_CONFIG.get("http_scheme", "https"),  # For Trino
    sslmode: str = DB_CONFIG.get("sslmode", "require"),  # For PostgreSQL
    **kwargs,
):
    """
    Initialize a database connection for either Trino or PostgreSQL.

    Args:
        db_type: Either "trino" or "postgresql"
        host: Database server host
        port: Database server port
        user: Username
        password: Password
        catalog: Trino catalog name
        database: PostgreSQL database name
        schema: Default schema name
        http_scheme: For Trino ("http" or "https")
        sslmode: For PostgreSQL (e.g., "require", "verify-full")
        **kwargs: Additional connection parameters
    """
    self.db_type = db_type.lower()
    self.host = host
    self.port = str(port) if port else None
    self.user = user
    self.password = quote_plus(password) if password else None
    self.default_schema = schema

    if self.db_type == "trino":
        self.catalog = catalog
        self.http_scheme = http_scheme
        self.engine = self._create_trino_engine(**kwargs)
    elif self.db_type == "postgresql":
        self.database = database
        self.sslmode = sslmode
        self.engine = self._create_postgresql_engine(**kwargs)
    else:
        raise ValueError(f"Unsupported database type: {db_type}")

    self._add_event_listener()
execute_query(query, fetch_results=True, params=None)

Executes a SQL query (works for both PostgreSQL and Trino).

Parameters:

Name Type Description Default
query str

SQL query to execute

required
fetch_results bool

Whether to fetch results

True
params Optional[Dict]

Parameters for parameterized queries

None

Returns:

Type Description
Union[List[tuple], None]

Results as list of tuples or None

Source code in gigaspatial/core/io/database.py
def execute_query(
    self, query: str, fetch_results: bool = True, params: Optional[Dict] = None
) -> Union[List[tuple], None]:
    """
    Executes a SQL query (works for both PostgreSQL and Trino).

    Args:
        query: SQL query to execute
        fetch_results: Whether to fetch results
        params: Parameters for parameterized queries

    Returns:
        Results as list of tuples or None
    """
    try:
        with self.engine.connect() as connection:
            stmt = text(query)
            result = (
                connection.execute(stmt, params)
                if params
                else connection.execute(stmt)
            )

            if fetch_results and result.returns_rows:
                return result.fetchall()
            return None
    except SQLAlchemyError as e:
        print(f"Error executing query: {e}")
        raise
get_column_names(table_name, schema=None)

Get column names for a specific table.

Source code in gigaspatial/core/io/database.py
def get_column_names(
    self, table_name: str, schema: Optional[str] = None
) -> List[str]:
    """Get column names for a specific table."""
    if "." in table_name:
        schema, table_name = table_name.split(".")
    else:
        schema = schema or self.default_schema

    inspector = inspect(self.engine)
    columns = inspector.get_columns(table_name, schema=schema)
    return [col["name"] for col in columns]
get_connection_string()

Returns the connection string used to create the engine.

Returns:

Name Type Description
str str

The connection string.

Source code in gigaspatial/core/io/database.py
def get_connection_string(self) -> str:
    """
    Returns the connection string used to create the engine.

    Returns:
        str: The connection string.
    """
    return self._connection_string
get_extensions()

Get list of installed PostgreSQL extensions (PostgreSQL only).

Source code in gigaspatial/core/io/database.py
def get_extensions(self) -> List[str]:
    """Get list of installed PostgreSQL extensions (PostgreSQL only)."""
    if self.db_type != "postgresql":
        raise NotImplementedError(
            "This method is only available for PostgreSQL connections"
        )

    with self.engine.connect() as conn:
        result = conn.execute("SELECT extname FROM pg_extension")
        return [row[0] for row in result]
get_primary_keys(table_name, schema=None)

Get primary key columns for a table.

Source code in gigaspatial/core/io/database.py
def get_primary_keys(
    self, table_name: str, schema: Optional[str] = None
) -> List[str]:
    """Get primary key columns for a table."""
    if "." in table_name:
        schema, table_name = table_name.split(".")
    else:
        schema = schema or self.default_schema

    inspector = inspect(self.engine)
    try:
        return inspector.get_pk_constraint(table_name, schema=schema)[
            "constrained_columns"
        ]
    except:
        return []  # Some databases may not support PK constraints
get_schema_names()

Get list of all schema names.

Source code in gigaspatial/core/io/database.py
def get_schema_names(self) -> List[str]:
    """Get list of all schema names."""
    inspector = inspect(self.engine)
    return inspector.get_schema_names()
get_table_info(table_name, schema=None)

Get detailed column information for a table.

Source code in gigaspatial/core/io/database.py
def get_table_info(
    self, table_name: str, schema: Optional[str] = None
) -> List[Dict]:
    """Get detailed column information for a table."""
    if "." in table_name:
        schema, table_name = table_name.split(".")
    else:
        schema = schema or self.default_schema

    inspector = inspect(self.engine)
    return inspector.get_columns(table_name, schema=schema)
get_table_names(schema=None)

Get list of table names in a schema.

Source code in gigaspatial/core/io/database.py
def get_table_names(self, schema: Optional[str] = None) -> List[str]:
    """Get list of table names in a schema."""
    schema = schema or self.default_schema
    inspector = inspect(self.engine)
    return inspector.get_table_names(schema=schema)
get_view_names(schema=None)

Get list of view names in a schema.

Source code in gigaspatial/core/io/database.py
def get_view_names(self, schema: Optional[str] = None) -> List[str]:
    """Get list of view names in a schema."""
    schema = schema or self.default_schema
    inspector = inspect(self.engine)
    return inspector.get_view_names(schema=schema)
read_sql_to_dask_dataframe(table_name, index_col, columns=None, limit=None, **kwargs)

Reads data to Dask DataFrame (works for both, but connection string differs).

Parameters:

Name Type Description Default
table_name str

Table name (schema.table or just table)

required
columns Optional[List[str]]

List of columns to select

None
limit Optional[int]

Maximum rows to return

None
**kwargs

Additional arguments

{}

Returns:

Type Description
DataFrame

Dask DataFrame with results

Source code in gigaspatial/core/io/database.py
def read_sql_to_dask_dataframe(
    self,
    table_name: str,
    index_col: str,
    columns: Optional[List[str]] = None,
    limit: Optional[int] = None,
    **kwargs,
) -> pd.DataFrame:
    """
    Reads data to Dask DataFrame (works for both, but connection string differs).

    Args:
        table_name: Table name (schema.table or just table)
        columns: List of columns to select
        limit: Maximum rows to return
        **kwargs: Additional arguments

    Returns:
        Dask DataFrame with results
    """
    try:
        connection_string = self.get_connection_string()

        # Handle schema.table format
        if "." in table_name:
            schema, table = table_name.split(".")
        else:
            schema = self.default_schema
            table = table_name

        metadata = MetaData()
        table_obj = Table(table, metadata, schema=schema, autoload_with=self.engine)

        # Build query
        query = (
            select(*[table_obj.c[col] for col in columns])
            if columns
            else select(table_obj)
        )
        if limit:
            query = query.limit(limit)

        return dd.read_sql_query(
            sql=query, con=connection_string, index_col=index_col, **kwargs
        )
    except Exception as e:
        print(f"Error reading SQL to Dask DataFrame: {e}")
        raise ValueError(f"Failed to read SQL to Dask DataFrame: {e}") from e
read_sql_to_dataframe(query, params=None)

Executes query and returns results as pandas DataFrame (works for both).

Parameters:

Name Type Description Default
query str

SQL query to execute

required
params Optional[Dict]

Parameters for parameterized queries

None

Returns:

Type Description
DataFrame

pandas DataFrame with results

Source code in gigaspatial/core/io/database.py
def read_sql_to_dataframe(
    self, query: str, params: Optional[Dict] = None
) -> pd.DataFrame:
    """
    Executes query and returns results as pandas DataFrame (works for both).

    Args:
        query: SQL query to execute
        params: Parameters for parameterized queries

    Returns:
        pandas DataFrame with results
    """
    try:
        with self.engine.connect() as connection:
            return pd.read_sql_query(text(query), connection, params=params)
    except SQLAlchemyError as e:
        print(f"Error reading SQL to DataFrame: {e}")
        raise
table_exists(table_name, schema=None)

Check if a table exists.

Source code in gigaspatial/core/io/database.py
def table_exists(self, table_name: str, schema: Optional[str] = None) -> bool:
    """Check if a table exists."""
    if "." in table_name:
        schema, table_name = table_name.split(".")
    else:
        schema = schema or self.default_schema

    return table_name in self.get_table_names(schema=schema)
test_connection()

Tests the database connection (works for both PostgreSQL and Trino).

Returns:

Type Description
bool

True if connection successful, False otherwise

Source code in gigaspatial/core/io/database.py
def test_connection(self) -> bool:
    """
    Tests the database connection (works for both PostgreSQL and Trino).

    Returns:
        True if connection successful, False otherwise
    """
    test_query = (
        "SELECT 1"
        if self.db_type == "postgresql"
        else "SELECT 1 AS connection_test"
    )

    try:
        print(
            f"Attempting to connect to {self.db_type} at {self.host}:{self.port}..."
        )
        with self.engine.connect() as conn:
            conn.execute(text(test_query))
        print(f"Successfully connected to {self.db_type.upper()}.")
        return True
    except Exception as e:
        print(f"Failed to connect to {self.db_type.upper()}: {e}")
        return False

local_data_store

LocalDataStore

Bases: DataStore

Implementation for local filesystem storage.

Source code in gigaspatial/core/io/local_data_store.py
class LocalDataStore(DataStore):
    """Implementation for local filesystem storage."""

    def __init__(self, base_path: Union[str, Path] = ""):
        super().__init__()
        self.base_path = Path(base_path).resolve()

    def _resolve_path(self, path: str) -> Path:
        """Resolve path relative to base directory."""
        return self.base_path / path

    def read_file(self, path: str) -> bytes:
        full_path = self._resolve_path(path)
        with open(full_path, "rb") as f:
            return f.read()

    def write_file(self, path: str, data: Union[bytes, str]) -> None:
        full_path = self._resolve_path(path)
        self.mkdir(str(full_path.parent), exist_ok=True)

        if isinstance(data, str):
            mode = "w"
            encoding = "utf-8"
        else:
            mode = "wb"
            encoding = None

        with open(full_path, mode, encoding=encoding) as f:
            f.write(data)

    def file_exists(self, path: str) -> bool:
        return self._resolve_path(path).is_file()

    def list_files(self, path: str) -> List[str]:
        full_path = self._resolve_path(path)
        return [
            str(f.relative_to(self.base_path))
            for f in full_path.iterdir()
            if f.is_file()
        ]

    def walk(self, top: str) -> Generator[Tuple[str, List[str], List[str]], None, None]:
        full_path = self._resolve_path(top)
        for root, dirs, files in os.walk(full_path):
            rel_root = str(Path(root).relative_to(self.base_path))
            yield rel_root, dirs, files

    def list_directories(self, path: str) -> List[str]:
        full_path = self._resolve_path(path)

        if not full_path.exists():
            return []

        if not full_path.is_dir():
            return []

        return [d.name for d in full_path.iterdir() if d.is_dir()]

    def open(self, path: str, mode: str = "r") -> IO:
        full_path = self._resolve_path(path)
        self.mkdir(str(full_path.parent), exist_ok=True)
        return open(full_path, mode)

    def is_file(self, path: str) -> bool:
        return self._resolve_path(path).is_file()

    def is_dir(self, path: str) -> bool:
        return self._resolve_path(path).is_dir()

    def remove(self, path: str) -> None:
        full_path = self._resolve_path(path)
        if full_path.is_file():
            os.remove(full_path)

    def copy_file(self, src: str, dst: str) -> None:
        """Copy a file from src to dst."""
        src_path = self._resolve_path(src)
        dst_path = self._resolve_path(dst)
        self.mkdir(str(dst_path.parent), exist_ok=True)
        shutil.copy2(src_path, dst_path)

    def rmdir(self, directory: str) -> None:
        full_path = self._resolve_path(directory)
        if full_path.is_dir():
            os.rmdir(full_path)

    def mkdir(self, path: str, exist_ok: bool = False) -> None:
        full_path = self._resolve_path(path)
        full_path.mkdir(parents=True, exist_ok=exist_ok)

    def exists(self, path: str) -> bool:
        return self._resolve_path(path).exists()
copy_file(src, dst)

Copy a file from src to dst.

Source code in gigaspatial/core/io/local_data_store.py
def copy_file(self, src: str, dst: str) -> None:
    """Copy a file from src to dst."""
    src_path = self._resolve_path(src)
    dst_path = self._resolve_path(dst)
    self.mkdir(str(dst_path.parent), exist_ok=True)
    shutil.copy2(src_path, dst_path)

readers

read_dataset(data_store, path, compression=None, **kwargs)

Read data from various file formats stored in both local and cloud-based storage.

Parameters:

data_store : DataStore Instance of DataStore for accessing data storage. path : str, Path Path to the file in data storage. **kwargs : dict Additional arguments passed to the specific reader function.

Returns:

pandas.DataFrame or geopandas.GeoDataFrame The data read from the file.

Raises:

FileNotFoundError If the file doesn't exist in blob storage. ValueError If the file type is unsupported or if there's an error reading the file.

Source code in gigaspatial/core/io/readers.py
def read_dataset(data_store: DataStore, path: str, compression: str = None, **kwargs):
    """
    Read data from various file formats stored in both local and cloud-based storage.

    Parameters:
    ----------
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    path : str, Path
        Path to the file in data storage.
    **kwargs : dict
        Additional arguments passed to the specific reader function.

    Returns:
    -------
    pandas.DataFrame or geopandas.GeoDataFrame
        The data read from the file.

    Raises:
    ------
    FileNotFoundError
        If the file doesn't exist in blob storage.
    ValueError
        If the file type is unsupported or if there's an error reading the file.
    """

    # Define supported file formats and their readers
    BINARY_FORMATS = {
        ".shp",
        ".zip",
        ".parquet",
        ".gpkg",
        ".xlsx",
        ".xls",
        ".kmz",
        ".gz",
    }

    PANDAS_READERS = {
        ".csv": pd.read_csv,
        ".xlsx": lambda f, **kw: pd.read_excel(f, engine="openpyxl", **kw),
        ".xls": lambda f, **kw: pd.read_excel(f, engine="xlrd", **kw),
        ".json": pd.read_json,
        # ".gz": lambda f, **kw: pd.read_csv(f, compression="gzip", **kw),
    }

    GEO_READERS = {
        ".shp": gpd.read_file,
        ".zip": gpd.read_file,
        ".geojson": gpd.read_file,
        ".gpkg": gpd.read_file,
        ".parquet": gpd.read_parquet,
        ".kmz": read_kmz,
    }

    COMPRESSION_FORMATS = {
        ".gz": "gzip",
        ".bz2": "bz2",
        ".zip": "zip",
        ".xz": "xz",
    }

    try:
        # Check if file exists
        if not data_store.file_exists(path):
            raise FileNotFoundError(f"File '{path}' not found in blob storage")

        path_obj = Path(path)
        suffixes = path_obj.suffixes
        file_extension = suffixes[-1].lower() if suffixes else ""

        if compression is None and file_extension in COMPRESSION_FORMATS:
            compression_format = COMPRESSION_FORMATS[file_extension]

            # if file has multiple extensions (e.g., .csv.gz), get the inner format
            if len(suffixes) > 1:
                inner_extension = suffixes[-2].lower()

                if inner_extension == ".tar":
                    raise ValueError(
                        "Tar archives (.tar.gz) are not directly supported"
                    )

                if inner_extension in PANDAS_READERS:
                    try:
                        with data_store.open(path, "rb") as f:
                            return PANDAS_READERS[inner_extension](
                                f, compression=compression_format, **kwargs
                            )
                    except Exception as e:
                        raise ValueError(f"Error reading compressed file: {str(e)}")
                elif inner_extension in GEO_READERS:
                    try:
                        with data_store.open(path, "rb") as f:
                            if compression_format == "gzip":
                                import gzip

                                decompressed_data = gzip.decompress(f.read())
                                import io

                                return GEO_READERS[inner_extension](
                                    io.BytesIO(decompressed_data), **kwargs
                                )
                            else:
                                raise ValueError(
                                    f"Compression format {compression_format} not supported for geo data"
                                )
                    except Exception as e:
                        raise ValueError(f"Error reading compressed geo file: {str(e)}")
            else:
                # if just .gz without clear inner type, assume csv
                try:
                    with data_store.open(path, "rb") as f:
                        return pd.read_csv(f, compression=compression_format, **kwargs)
                except Exception as e:
                    raise ValueError(
                        f"Error reading compressed file as CSV: {str(e)}. "
                        f"If not a CSV, specify the format in the filename (e.g., .json.gz)"
                    )

        # Special handling for compressed files
        if file_extension == ".zip":
            # For zip files, we need to use binary mode
            with data_store.open(path, "rb") as f:
                return gpd.read_file(f)

        # Determine if we need binary mode based on file type
        mode = "rb" if file_extension in BINARY_FORMATS else "r"

        # Try reading with appropriate reader
        if file_extension in PANDAS_READERS:
            try:
                with data_store.open(path, mode) as f:
                    return PANDAS_READERS[file_extension](f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error reading file with pandas: {str(e)}")

        if file_extension in GEO_READERS:
            try:
                with data_store.open(path, "rb") as f:
                    return GEO_READERS[file_extension](f, **kwargs)
            except Exception as e:
                # For parquet files, try pandas reader if geopandas fails
                if file_extension == ".parquet":
                    try:
                        with data_store.open(path, "rb") as f:
                            return pd.read_parquet(f, **kwargs)
                    except Exception as e2:
                        raise ValueError(
                            f"Failed to read parquet with both geopandas ({str(e)}) "
                            f"and pandas ({str(e2)})"
                        )
                raise ValueError(f"Error reading file with geopandas: {str(e)}")

        # If we get here, the file type is unsupported
        supported_formats = sorted(set(PANDAS_READERS.keys()) | set(GEO_READERS.keys()))
        supported_compressions = sorted(COMPRESSION_FORMATS.keys())
        raise ValueError(
            f"Unsupported file type: {file_extension}\n"
            f"Supported formats: {', '.join(supported_formats)}"
            f"Supported compressions: {', '.join(supported_compressions)}"
        )

    except Exception as e:
        if isinstance(e, (FileNotFoundError, ValueError)):
            raise
        raise RuntimeError(f"Unexpected error reading dataset: {str(e)}")
read_datasets(data_store, paths, **kwargs)

Read multiple datasets from data storage at once.

Parameters:

data_store : DataStore Instance of DataStore for accessing data storage. paths : list of str Paths to files in data storage. **kwargs : dict Additional arguments passed to read_dataset.

Returns:

dict Dictionary mapping paths to their corresponding DataFrames/GeoDataFrames.

Source code in gigaspatial/core/io/readers.py
def read_datasets(data_store: DataStore, paths, **kwargs):
    """
    Read multiple datasets from data storage at once.

    Parameters:
    ----------
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    paths : list of str
        Paths to files in data storage.
    **kwargs : dict
        Additional arguments passed to read_dataset.

    Returns:
    -------
    dict
        Dictionary mapping paths to their corresponding DataFrames/GeoDataFrames.
    """
    results = {}
    errors = {}

    for path in paths:
        try:
            results[path] = read_dataset(data_store, path, **kwargs)
        except Exception as e:
            errors[path] = str(e)

    if errors:
        error_msg = "\n".join(f"- {path}: {error}" for path, error in errors.items())
        raise ValueError(f"Errors reading datasets:\n{error_msg}")

    return results
read_gzipped_json_or_csv(file_path, data_store)

Reads a gzipped file, attempting to parse it as JSON (lines=True) or CSV.

Source code in gigaspatial/core/io/readers.py
def read_gzipped_json_or_csv(file_path, data_store):
    """Reads a gzipped file, attempting to parse it as JSON (lines=True) or CSV."""

    with data_store.open(file_path, "rb") as f:
        g = gzip.GzipFile(fileobj=f)
        text = g.read().decode("utf-8")
        try:
            df = pd.read_json(io.StringIO(text), lines=True)
            return df
        except json.JSONDecodeError:
            try:
                df = pd.read_csv(io.StringIO(text))
                return df
            except pd.errors.ParserError:
                print(f"Error: Could not parse {file_path} as JSON or CSV.")
                return None
read_kmz(file_obj, **kwargs)

Helper function to read KMZ files and return a GeoDataFrame.

Source code in gigaspatial/core/io/readers.py
def read_kmz(file_obj, **kwargs):
    """Helper function to read KMZ files and return a GeoDataFrame."""
    try:
        with zipfile.ZipFile(file_obj) as kmz:
            # Find the KML file in the archive (usually doc.kml)
            kml_filename = next(
                name for name in kmz.namelist() if name.endswith(".kml")
            )

            # Read the KML content
            kml_content = io.BytesIO(kmz.read(kml_filename))

            gdf = gpd.read_file(kml_content)

            # Validate the GeoDataFrame
            if gdf.empty:
                raise ValueError(
                    "The KML file is empty or does not contain valid geospatial data."
                )

        return gdf

    except zipfile.BadZipFile:
        raise ValueError("The provided file is not a valid KMZ file.")
    except StopIteration:
        raise ValueError("No KML file found in the KMZ archive.")
    except Exception as e:
        raise RuntimeError(f"An error occurred: {e}")

snowflake_data_store

SnowflakeDataStore

Bases: DataStore

An implementation of DataStore for Snowflake internal stages. Uses Snowflake stages for file storage and retrieval.

Source code in gigaspatial/core/io/snowflake_data_store.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
class SnowflakeDataStore(DataStore):
    """
    An implementation of DataStore for Snowflake internal stages.
    Uses Snowflake stages for file storage and retrieval.
    """

    def __init__(
        self,
        account: str = config.SNOWFLAKE_ACCOUNT,
        user: str = config.SNOWFLAKE_USER,
        password: str = config.SNOWFLAKE_PASSWORD,
        warehouse: str = config.SNOWFLAKE_WAREHOUSE,
        database: str = config.SNOWFLAKE_DATABASE,
        schema: str = config.SNOWFLAKE_SCHEMA,
        stage_name: str = config.SNOWFLAKE_STAGE_NAME,
    ):
        """
        Create a new instance of SnowflakeDataStore.

        :param account: Snowflake account identifier
        :param user: Snowflake username
        :param password: Snowflake password
        :param warehouse: Snowflake warehouse name
        :param database: Snowflake database name
        :param schema: Snowflake schema name
        :param stage_name: Name of the Snowflake stage to use for file storage
        """
        if not all([account, user, password, warehouse, database, schema, stage_name]):
            raise ValueError(
                "Snowflake connection parameters (account, user, password, warehouse, "
                "database, schema, stage_name) must be provided via config or constructor."
            )

        self.account = account
        self.user = user
        self.password = password
        self.warehouse = warehouse
        self.database = database
        self.schema = schema
        self.stage_name = stage_name

        # Create connection
        self.connection = self._create_connection()
        self.logger = config.get_logger(self.__class__.__name__)

        # Temporary directory for file operations
        self._temp_dir = tempfile.mkdtemp()

    def _create_connection(self):
        """Create and return a Snowflake connection."""
        conn_params = {
            "account": self.account,
            "user": self.user,
            "password": self.password,
            "warehouse": self.warehouse,
            "database": self.database,
            "schema": self.schema,
        }

        connection = snowflake.connector.connect(**conn_params)

        # Explicitly set the database and schema context
        # This ensures the session knows which database/schema to use
        cursor = connection.cursor()
        try:
            # Use database first
            cursor.execute(f'USE DATABASE "{self.database}"')
            # Then use schema (don't need to specify database again)
            cursor.execute(f'USE SCHEMA "{self.schema}"')
            cursor.close()
        except Exception as e:
            cursor.close()
            connection.close()
            error_msg = (
                f"Failed to set database/schema context: {e}\n"
                f"Make sure the database '{self.database}' and schema '{self.schema}' exist.\n"
                f"You may need to run the setup_snowflake_test.sql script first.\n"
                f"Current config - Database: {self.database}, Schema: {self.schema}, Stage: {self.stage_name}"
            )
            raise IOError(error_msg)

        return connection

    def _ensure_connection(self):
        """Ensure the connection is active, reconnect if needed."""
        try:
            self.connection.cursor().execute("SELECT 1")
        except Exception:
            self.connection = self._create_connection()

    def _get_stage_path(self, path: str) -> str:
        """Convert a file path to a Snowflake stage path."""
        # Remove leading/trailing slashes and normalize
        path = path.strip("/")
        # Stage paths use forward slashes and @stage_name/ prefix
        return f"@{self.stage_name}/{path}"

    def _normalize_path(self, path: str) -> str:
        """Normalize path for Snowflake stage operations."""
        return path.strip("/").replace("\\", "/")

    def read_file(self, path: str, encoding: Optional[str] = None) -> Union[str, bytes]:
        """
        Read file from Snowflake stage.

        :param path: Path to the file in the stage
        :param encoding: File encoding (optional)
        :return: File contents as string or bytes
        """
        self._ensure_connection()
        cursor = self.connection.cursor(DictCursor)

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            # Create temporary directory for download
            temp_download_dir = os.path.join(self._temp_dir, "downloads")
            os.makedirs(temp_download_dir, exist_ok=True)

            # Download file from stage using GET command
            # GET command: GET <stage_path> file://<local_path>
            temp_dir_normalized = temp_download_dir.replace("\\", "/")
            if not temp_dir_normalized.endswith("/"):
                temp_dir_normalized += "/"

            get_command = f"GET {stage_path} 'file://{temp_dir_normalized}'"
            cursor.execute(get_command)

            # Find the downloaded file (Snowflake may add prefixes/suffixes or preserve structure)
            downloaded_files = []
            for root, dirs, files in os.walk(temp_download_dir):
                for f in files:
                    file_path = os.path.join(root, f)
                    # Check if this file matches our expected filename
                    if os.path.basename(normalized_path) in f or normalized_path.endswith(f):
                        downloaded_files.append(file_path)

            if not downloaded_files:
                raise FileNotFoundError(f"File not found in stage: {path}")

            # Read the first matching file
            downloaded_path = downloaded_files[0]
            with open(downloaded_path, "rb") as f:
                data = f.read()

            # Clean up
            os.remove(downloaded_path)
            # Clean up empty directories
            try:
                if os.path.exists(temp_download_dir) and not os.listdir(temp_download_dir):
                    os.rmdir(temp_download_dir)
            except OSError:
                pass

            # Decode if encoding is specified
            if encoding:
                return data.decode(encoding)
            return data

        except Exception as e:
            raise IOError(f"Error reading file {path} from Snowflake stage: {e}")
        finally:
            cursor.close()

    def write_file(self, path: str, data: Union[bytes, str]) -> None:
        """
        Write file to Snowflake stage.

        :param path: Destination path in the stage
        :param data: File contents
        """
        self._ensure_connection()
        cursor = self.connection.cursor()

        try:
            # Convert to bytes if string
            if isinstance(data, str):
                binary_data = data.encode("utf-8")
            elif isinstance(data, bytes):
                binary_data = data
            else:
                raise ValueError('Unsupported data type. Only "bytes" or "string" accepted')

            normalized_path = self._normalize_path(path)

            # Write to temporary file first
            # Use the full path structure for the temp file to preserve directory structure
            temp_file_path = os.path.join(self._temp_dir, normalized_path)
            os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)

            with open(temp_file_path, "wb") as f:
                f.write(binary_data)

            # Upload to stage using PUT command
            # Snowflake PUT requires the local file path and the target stage path
            # Convert Windows paths to Unix-style for Snowflake
            temp_file_normalized = os.path.abspath(temp_file_path).replace("\\", "/")

            # PUT command: PUT 'file://<absolute_local_path>' @stage_name/<path>
            # The file will be stored at the specified path in the stage
            stage_target = f"@{self.stage_name}/"
            if "/" in normalized_path:
                # Include directory structure in stage path
                dir_path = os.path.dirname(normalized_path)
                stage_target = f"@{self.stage_name}/{dir_path}/"

            # Snowflake PUT syntax: PUT 'file://<path>' @stage/path
            put_command = f"PUT 'file://{temp_file_normalized}' {stage_target} OVERWRITE=TRUE AUTO_COMPRESS=FALSE"
            cursor.execute(put_command)

            # Clean up temp file
            if os.path.exists(temp_file_path):
                os.remove(temp_file_path)
                # Clean up empty directories if they were created
                try:
                    temp_dir = os.path.dirname(temp_file_path)
                    if temp_dir != self._temp_dir and os.path.exists(temp_dir):
                        os.rmdir(temp_dir)
                except OSError:
                    pass  # Directory not empty or other error, ignore

        except Exception as e:
            raise IOError(f"Error writing file {path} to Snowflake stage: {e}")
        finally:
            cursor.close()

    def upload_file(self, file_path: str, stage_path: str):
        """
        Uploads a single file from local filesystem to Snowflake stage.

        :param file_path: Local file path
        :param stage_path: Destination path in the stage
        """
        try:
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Local file not found: {file_path}")

            # Read the file
            with open(file_path, "rb") as f:
                data = f.read()

            # Write to stage using write_file
            self.write_file(stage_path, data)
            self.logger.info(f"Uploaded {file_path} to {stage_path}")
        except Exception as e:
            self.logger.error(f"Failed to upload {file_path}: {e}")
            raise

    def upload_directory(self, dir_path: str, stage_dir_path: str):
        """
        Uploads all files from a local directory to Snowflake stage.

        :param dir_path: Local directory path
        :param stage_dir_path: Destination directory path in the stage
        """
        if not os.path.isdir(dir_path):
            raise NotADirectoryError(f"Local directory not found: {dir_path}")

        for root, dirs, files in os.walk(dir_path):
            for file in files:
                local_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(local_file_path, dir_path)
                # Normalize path separators for stage
                stage_file_path = os.path.join(stage_dir_path, relative_path).replace("\\", "/")

                self.upload_file(local_file_path, stage_file_path)

    def download_directory(self, stage_dir_path: str, local_dir_path: str):
        """
        Downloads all files from a Snowflake stage directory to a local directory.

        :param stage_dir_path: Source directory path in the stage
        :param local_dir_path: Destination local directory path
        """
        try:
            # Ensure the local directory exists
            os.makedirs(local_dir_path, exist_ok=True)

            # List all files in the stage directory
            files = self.list_files(stage_dir_path)

            for file_path in files:
                # Get the relative path from the stage directory
                if stage_dir_path:
                    if file_path.startswith(stage_dir_path):
                        relative_path = file_path[len(stage_dir_path):].lstrip("/")
                    else:
                        # If file_path doesn't start with stage_dir_path, use it as is
                        relative_path = os.path.basename(file_path)
                else:
                    relative_path = file_path

                # Construct the local file path
                local_file_path = os.path.join(local_dir_path, relative_path)
                # Create directories if needed
                os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                # Download the file
                data = self.read_file(file_path)
                with open(local_file_path, "wb") as f:
                    if isinstance(data, str):
                        f.write(data.encode("utf-8"))
                    else:
                        f.write(data)

            self.logger.info(f"Downloaded directory {stage_dir_path} to {local_dir_path}")
        except Exception as e:
            self.logger.error(f"Failed to download directory {stage_dir_path}: {e}")
            raise

    def copy_directory(self, source_dir: str, destination_dir: str):
        """
        Copies all files from a source directory to a destination directory within the stage.

        :param source_dir: Source directory path in the stage
        :param destination_dir: Destination directory path in the stage
        """
        try:
            # Normalize directory paths
            source_dir = source_dir.rstrip("/")
            destination_dir = destination_dir.rstrip("/")

            # List all files in the source directory
            files = self.list_files(source_dir)

            for file_path in files:
                # Get relative path from source directory
                if source_dir:
                    if file_path.startswith(source_dir):
                        relative_path = file_path[len(source_dir):].lstrip("/")
                    else:
                        relative_path = os.path.basename(file_path)
                else:
                    relative_path = file_path

                # Construct the destination file path
                if destination_dir:
                    dest_file_path = f"{destination_dir}/{relative_path}".replace("//", "/")
                else:
                    dest_file_path = relative_path

                # Copy each file
                self.copy_file(file_path, dest_file_path, overwrite=True)

            self.logger.info(f"Copied directory from {source_dir} to {destination_dir}")
        except Exception as e:
            self.logger.error(f"Failed to copy directory {source_dir}: {e}")
            raise

    def copy_file(
            self, source_path: str, destination_path: str, overwrite: bool = False
    ):
        """
        Copies a single file within the Snowflake stage.

        :param source_path: Source file path in the stage
        :param destination_path: Destination file path in the stage
        :param overwrite: If True, overwrite the destination file if it already exists
        """
        try:
            if not self.file_exists(source_path):
                raise FileNotFoundError(f"Source file not found: {source_path}")

            if self.file_exists(destination_path) and not overwrite:
                raise FileExistsError(
                    f"Destination file already exists and overwrite is False: {destination_path}"
                )

            # Read from source and write to destination
            data = self.read_file(source_path)
            self.write_file(destination_path, data)

            self.logger.info(f"Copied file from {source_path} to {destination_path}")
        except Exception as e:
            self.logger.error(f"Failed to copy file {source_path}: {e}")
            raise

    def exists(self, path: str) -> bool:
        """Check if a path exists (file or directory)."""
        return self.file_exists(path) or self.is_dir(path)

    def file_exists(self, path: str) -> bool:
        """
        Check if a file exists in the Snowflake stage.

        :param path: Path to check
        :return: True if file exists, False otherwise
        """
        self._ensure_connection()
        cursor = self.connection.cursor(DictCursor)

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            # List files in stage with the given path pattern
            list_command = f"LIST {stage_path}"
            cursor.execute(list_command)
            results = cursor.fetchall()

            # Check if exact file exists
            for result in results:
                if result["name"].endswith(normalized_path) or result["name"] == stage_path:
                    return True

            return False

        except Exception as e:
            self.logger.warning(f"Error checking file existence {path}: {e}")
            return False
        finally:
            cursor.close()

    def file_size(self, path: str) -> float:
        """
        Get the size of a file in kilobytes.

        :param path: File path in the stage
        :return: File size in kilobytes
        """
        self._ensure_connection()
        cursor = self.connection.cursor(DictCursor)

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            # LIST command returns file metadata including size
            list_command = f"LIST {stage_path}"
            cursor.execute(list_command)
            results = cursor.fetchall()

            # Find the matching file and get its size
            for result in results:
                file_path = result["name"]
                if normalized_path in file_path.lower() or file_path.endswith(normalized_path):
                    # Size is in bytes, convert to kilobytes
                    size_bytes = result.get("size", 0)
                    size_kb = size_bytes / 1024.0
                    return size_kb

            raise FileNotFoundError(f"File not found: {path}")
        except Exception as e:
            self.logger.error(f"Error getting file size for {path}: {e}")
            raise
        finally:
            cursor.close()

    def list_files(self, path: str) -> List[str]:
        """
        List all files in a directory within the Snowflake stage.

        :param path: Directory path to list
        :return: List of file paths
        """
        self._ensure_connection()
        cursor = self.connection.cursor(DictCursor)

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            # List files in stage
            list_command = f"LIST {stage_path}"
            cursor.execute(list_command)
            results = cursor.fetchall()

            # Extract file paths relative to the base stage path
            files = []
            for result in results:
                file_path = result["name"]
                # Snowflake LIST returns names in lowercase without @ symbol
                # Remove stage prefix to get relative path
                # Check both @stage_name/ and lowercase stage_name/ formats
                stage_prefixes = [
                    f"@{self.stage_name}/",
                    f"{self.stage_name.lower()}/",
                    f"@{self.stage_name.lower()}/",
                ]

                for prefix in stage_prefixes:
                    if file_path.startswith(prefix):
                        relative_path = file_path[len(prefix):]
                        files.append(relative_path)
                        break
                else:
                    # If no prefix matches, try to extract path after stage name
                    # Sometimes stage name might be in different case
                    stage_name_lower = self.stage_name.lower()
                    if stage_name_lower in file_path.lower():
                        # Find the position after the stage name
                        idx = file_path.lower().find(stage_name_lower)
                        if idx != -1:
                            # Get everything after stage name and '/'
                            after_stage = file_path[idx + len(stage_name_lower):].lstrip("/")
                            if after_stage.startswith(normalized_path):
                                relative_path = after_stage
                                files.append(relative_path)

            return files

        except Exception as e:
            self.logger.warning(f"Error listing files in {path}: {e}")
            return []
        finally:
            cursor.close()

    def walk(self, top: str) -> Generator[Tuple[str, List[str], List[str]], None, None]:
        """
        Walk through directory tree in Snowflake stage, similar to os.walk().

        :param top: Starting directory for the walk
        :return: Generator yielding tuples of (dirpath, dirnames, filenames)
        """
        try:
            normalized_top = self._normalize_path(top)

            # Use list_files to get all files (it handles path parsing correctly)
            all_files = self.list_files(normalized_top)

            # Organize into directory structure
            dirs = {}

            for file_path in all_files:
                # Ensure we're working with paths relative to the top
                if normalized_top and not file_path.startswith(normalized_top):
                    continue

                # Get relative path from top
                if normalized_top and file_path.startswith(normalized_top):
                    relative_path = file_path[len(normalized_top):].lstrip("/")
                else:
                    relative_path = file_path

                if not relative_path:
                    continue

                # Get directory and filename
                if "/" in relative_path:
                    dir_path, filename = os.path.split(relative_path)
                    full_dir_path = f"{normalized_top}/{dir_path}" if normalized_top else dir_path
                    if full_dir_path not in dirs:
                        dirs[full_dir_path] = []
                    dirs[full_dir_path].append(filename)
                else:
                    # File in root of the top directory
                    if normalized_top not in dirs:
                        dirs[normalized_top] = []
                    dirs[normalized_top].append(relative_path)

            # Yield results in os.walk format
            for dir_path, files in dirs.items():
                # Extract subdirectories (simplified - Snowflake stages are flat)
                subdirs = []
                yield (dir_path, subdirs, files)

        except Exception as e:
            self.logger.warning(f"Error walking directory {top}: {e}")
            yield (top, [], [])

    def list_directories(self, path: str) -> List[str]:
        """
        List only directory names (not files) from a given path in the stage.

        :param path: Directory path to list
        :return: List of directory names
        """
        normalized_path = self._normalize_path(path)
        files = self.list_files(normalized_path)

        directories = set()

        for file_path in files:
            # Get relative path from the search path
            if normalized_path:
                if file_path.startswith(normalized_path):
                    relative_path = file_path[len(normalized_path):].lstrip("/")
                else:
                    continue
            else:
                relative_path = file_path

            # Skip if empty
            if not relative_path:
                continue

            # If there's a "/" in the relative path, it means there's a subdirectory
            if "/" in relative_path:
                # Get the first directory name
                dir_name = relative_path.split("/")[0]
                directories.add(dir_name)

        return sorted(list(directories))

    @contextlib.contextmanager
    def open(self, path: str, mode: str = "r"):
        """
        Context manager for file operations.

        :param path: File path in Snowflake stage
        :param mode: File open mode (r, rb, w, wb)
        """
        if mode == "w":
            file = io.StringIO()
            yield file
            self.write_file(path, file.getvalue())

        elif mode == "wb":
            file = io.BytesIO()
            yield file
            self.write_file(path, file.getvalue())

        elif mode == "r":
            data = self.read_file(path, encoding="UTF-8")
            file = io.StringIO(data)
            yield file

        elif mode == "rb":
            data = self.read_file(path)
            file = io.BytesIO(data)
            yield file

        else:
            raise ValueError(f"Unsupported mode: {mode}")

    def get_file_metadata(self, path: str) -> dict:
        """
        Retrieve comprehensive file metadata from Snowflake stage.

        :param path: File path in the stage
        :return: File metadata dictionary
        """
        self._ensure_connection()
        cursor = self.connection.cursor(DictCursor)

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            # LIST command returns file metadata
            list_command = f"LIST {stage_path}"
            cursor.execute(list_command)
            results = cursor.fetchall()

            # Find the matching file
            for result in results:
                file_path = result["name"]
                if normalized_path in file_path.lower() or file_path.endswith(normalized_path):
                    return {
                        "name": path,
                        "size_bytes": result.get("size", 0),
                        "last_modified": result.get("last_modified"),
                        "md5": result.get("md5"),
                    }

            raise FileNotFoundError(f"File not found: {path}")
        except Exception as e:
            self.logger.error(f"Error getting file metadata for {path}: {e}")
            raise
        finally:
            cursor.close()

    def is_file(self, path: str) -> bool:
        """Check if path points to a file."""
        return self.file_exists(path)

    def is_dir(self, path: str) -> bool:
        """Check if path points to a directory."""
        # First check if it's actually a file (exact match)
        if self.file_exists(path):
            return False

        # In Snowflake stages, directories are conceptual
        # Check if there are files with this path prefix
        normalized_path = self._normalize_path(path)
        files = self.list_files(normalized_path)

        # Filter out files that are exact matches (they're files, not directories)
        exact_match = any(f == normalized_path or f == path for f in files)
        if exact_match:
            return False

        return len(files) > 0

    def rmdir(self, dir: str) -> None:
        """
        Remove a directory and all its contents from the Snowflake stage.

        :param dir: Path to the directory to remove
        """
        self._ensure_connection()
        cursor = self.connection.cursor()

        try:
            normalized_dir = self._normalize_path(dir)
            stage_path = self._get_stage_path(normalized_dir)

            # Remove all files in the directory
            remove_command = f"REMOVE {stage_path}"
            cursor.execute(remove_command)

        except Exception as e:
            raise IOError(f"Error removing directory {dir}: {e}")
        finally:
            cursor.close()

    def mkdir(self, path: str, exist_ok: bool = False) -> None:
        """
        Create a directory in Snowflake stage.

        In Snowflake stages, directories are created implicitly when files are uploaded.
        This method creates a placeholder file if the directory doesn't exist.

        :param path: Path of the directory to create
        :param exist_ok: If False, raise an error if the directory already exists
        """
        # Check if directory already exists
        if self.is_dir(path) and not exist_ok:
            raise FileExistsError(f"Directory {path} already exists")

        # Create a placeholder file to ensure directory exists
        placeholder_path = os.path.join(path, ".placeholder").replace("\\", "/")
        if not self.file_exists(placeholder_path):
            self.write_file(placeholder_path, b"Placeholder file for directory")

    def remove(self, path: str) -> None:
        """
        Remove a file from the Snowflake stage.

        :param path: Path to the file to remove
        """
        self._ensure_connection()
        cursor = self.connection.cursor()

        try:
            normalized_path = self._normalize_path(path)
            stage_path = self._get_stage_path(normalized_path)

            remove_command = f"REMOVE {stage_path}"
            cursor.execute(remove_command)

        except Exception as e:
            raise IOError(f"Error removing file {path}: {e}")
        finally:
            cursor.close()

    def rename(
        self,
        source_path: str,
        destination_path: str,
        overwrite: bool = False,
        delete_source: bool = True,
    ) -> None:
        """
        Rename (move) a single file by copying to the new path and deleting the source.

        :param source_path: Existing file path in the stage
        :param destination_path: Target file path in the stage
        :param overwrite: Overwrite destination if it already exists
        :param delete_source: Delete original after successful copy
        """
        if not self.file_exists(source_path):
            raise FileNotFoundError(f"Source file not found: {source_path}")

        if self.file_exists(destination_path) and not overwrite:
            raise FileExistsError(
                f"Destination already exists and overwrite is False: {destination_path}"
            )

        # Copy file to new location
        self.copy_file(source_path, destination_path, overwrite=overwrite)

        # Delete source if requested
        if delete_source:
            self.remove(source_path)

    def close(self):
        """Close the Snowflake connection."""
        if self.connection:
            self.connection.close()

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.close()
__enter__()

Context manager entry.

Source code in gigaspatial/core/io/snowflake_data_store.py
def __enter__(self):
    """Context manager entry."""
    return self
__exit__(exc_type, exc_val, exc_tb)

Context manager exit.

Source code in gigaspatial/core/io/snowflake_data_store.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Context manager exit."""
    self.close()
__init__(account=config.SNOWFLAKE_ACCOUNT, user=config.SNOWFLAKE_USER, password=config.SNOWFLAKE_PASSWORD, warehouse=config.SNOWFLAKE_WAREHOUSE, database=config.SNOWFLAKE_DATABASE, schema=config.SNOWFLAKE_SCHEMA, stage_name=config.SNOWFLAKE_STAGE_NAME)

Create a new instance of SnowflakeDataStore.

:param account: Snowflake account identifier :param user: Snowflake username :param password: Snowflake password :param warehouse: Snowflake warehouse name :param database: Snowflake database name :param schema: Snowflake schema name :param stage_name: Name of the Snowflake stage to use for file storage

Source code in gigaspatial/core/io/snowflake_data_store.py
def __init__(
    self,
    account: str = config.SNOWFLAKE_ACCOUNT,
    user: str = config.SNOWFLAKE_USER,
    password: str = config.SNOWFLAKE_PASSWORD,
    warehouse: str = config.SNOWFLAKE_WAREHOUSE,
    database: str = config.SNOWFLAKE_DATABASE,
    schema: str = config.SNOWFLAKE_SCHEMA,
    stage_name: str = config.SNOWFLAKE_STAGE_NAME,
):
    """
    Create a new instance of SnowflakeDataStore.

    :param account: Snowflake account identifier
    :param user: Snowflake username
    :param password: Snowflake password
    :param warehouse: Snowflake warehouse name
    :param database: Snowflake database name
    :param schema: Snowflake schema name
    :param stage_name: Name of the Snowflake stage to use for file storage
    """
    if not all([account, user, password, warehouse, database, schema, stage_name]):
        raise ValueError(
            "Snowflake connection parameters (account, user, password, warehouse, "
            "database, schema, stage_name) must be provided via config or constructor."
        )

    self.account = account
    self.user = user
    self.password = password
    self.warehouse = warehouse
    self.database = database
    self.schema = schema
    self.stage_name = stage_name

    # Create connection
    self.connection = self._create_connection()
    self.logger = config.get_logger(self.__class__.__name__)

    # Temporary directory for file operations
    self._temp_dir = tempfile.mkdtemp()
close()

Close the Snowflake connection.

Source code in gigaspatial/core/io/snowflake_data_store.py
def close(self):
    """Close the Snowflake connection."""
    if self.connection:
        self.connection.close()
copy_directory(source_dir, destination_dir)

Copies all files from a source directory to a destination directory within the stage.

:param source_dir: Source directory path in the stage :param destination_dir: Destination directory path in the stage

Source code in gigaspatial/core/io/snowflake_data_store.py
def copy_directory(self, source_dir: str, destination_dir: str):
    """
    Copies all files from a source directory to a destination directory within the stage.

    :param source_dir: Source directory path in the stage
    :param destination_dir: Destination directory path in the stage
    """
    try:
        # Normalize directory paths
        source_dir = source_dir.rstrip("/")
        destination_dir = destination_dir.rstrip("/")

        # List all files in the source directory
        files = self.list_files(source_dir)

        for file_path in files:
            # Get relative path from source directory
            if source_dir:
                if file_path.startswith(source_dir):
                    relative_path = file_path[len(source_dir):].lstrip("/")
                else:
                    relative_path = os.path.basename(file_path)
            else:
                relative_path = file_path

            # Construct the destination file path
            if destination_dir:
                dest_file_path = f"{destination_dir}/{relative_path}".replace("//", "/")
            else:
                dest_file_path = relative_path

            # Copy each file
            self.copy_file(file_path, dest_file_path, overwrite=True)

        self.logger.info(f"Copied directory from {source_dir} to {destination_dir}")
    except Exception as e:
        self.logger.error(f"Failed to copy directory {source_dir}: {e}")
        raise
copy_file(source_path, destination_path, overwrite=False)

Copies a single file within the Snowflake stage.

:param source_path: Source file path in the stage :param destination_path: Destination file path in the stage :param overwrite: If True, overwrite the destination file if it already exists

Source code in gigaspatial/core/io/snowflake_data_store.py
def copy_file(
        self, source_path: str, destination_path: str, overwrite: bool = False
):
    """
    Copies a single file within the Snowflake stage.

    :param source_path: Source file path in the stage
    :param destination_path: Destination file path in the stage
    :param overwrite: If True, overwrite the destination file if it already exists
    """
    try:
        if not self.file_exists(source_path):
            raise FileNotFoundError(f"Source file not found: {source_path}")

        if self.file_exists(destination_path) and not overwrite:
            raise FileExistsError(
                f"Destination file already exists and overwrite is False: {destination_path}"
            )

        # Read from source and write to destination
        data = self.read_file(source_path)
        self.write_file(destination_path, data)

        self.logger.info(f"Copied file from {source_path} to {destination_path}")
    except Exception as e:
        self.logger.error(f"Failed to copy file {source_path}: {e}")
        raise
download_directory(stage_dir_path, local_dir_path)

Downloads all files from a Snowflake stage directory to a local directory.

:param stage_dir_path: Source directory path in the stage :param local_dir_path: Destination local directory path

Source code in gigaspatial/core/io/snowflake_data_store.py
def download_directory(self, stage_dir_path: str, local_dir_path: str):
    """
    Downloads all files from a Snowflake stage directory to a local directory.

    :param stage_dir_path: Source directory path in the stage
    :param local_dir_path: Destination local directory path
    """
    try:
        # Ensure the local directory exists
        os.makedirs(local_dir_path, exist_ok=True)

        # List all files in the stage directory
        files = self.list_files(stage_dir_path)

        for file_path in files:
            # Get the relative path from the stage directory
            if stage_dir_path:
                if file_path.startswith(stage_dir_path):
                    relative_path = file_path[len(stage_dir_path):].lstrip("/")
                else:
                    # If file_path doesn't start with stage_dir_path, use it as is
                    relative_path = os.path.basename(file_path)
            else:
                relative_path = file_path

            # Construct the local file path
            local_file_path = os.path.join(local_dir_path, relative_path)
            # Create directories if needed
            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

            # Download the file
            data = self.read_file(file_path)
            with open(local_file_path, "wb") as f:
                if isinstance(data, str):
                    f.write(data.encode("utf-8"))
                else:
                    f.write(data)

        self.logger.info(f"Downloaded directory {stage_dir_path} to {local_dir_path}")
    except Exception as e:
        self.logger.error(f"Failed to download directory {stage_dir_path}: {e}")
        raise
exists(path)

Check if a path exists (file or directory).

Source code in gigaspatial/core/io/snowflake_data_store.py
def exists(self, path: str) -> bool:
    """Check if a path exists (file or directory)."""
    return self.file_exists(path) or self.is_dir(path)
file_exists(path)

Check if a file exists in the Snowflake stage.

:param path: Path to check :return: True if file exists, False otherwise

Source code in gigaspatial/core/io/snowflake_data_store.py
def file_exists(self, path: str) -> bool:
    """
    Check if a file exists in the Snowflake stage.

    :param path: Path to check
    :return: True if file exists, False otherwise
    """
    self._ensure_connection()
    cursor = self.connection.cursor(DictCursor)

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        # List files in stage with the given path pattern
        list_command = f"LIST {stage_path}"
        cursor.execute(list_command)
        results = cursor.fetchall()

        # Check if exact file exists
        for result in results:
            if result["name"].endswith(normalized_path) or result["name"] == stage_path:
                return True

        return False

    except Exception as e:
        self.logger.warning(f"Error checking file existence {path}: {e}")
        return False
    finally:
        cursor.close()
file_size(path)

Get the size of a file in kilobytes.

:param path: File path in the stage :return: File size in kilobytes

Source code in gigaspatial/core/io/snowflake_data_store.py
def file_size(self, path: str) -> float:
    """
    Get the size of a file in kilobytes.

    :param path: File path in the stage
    :return: File size in kilobytes
    """
    self._ensure_connection()
    cursor = self.connection.cursor(DictCursor)

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        # LIST command returns file metadata including size
        list_command = f"LIST {stage_path}"
        cursor.execute(list_command)
        results = cursor.fetchall()

        # Find the matching file and get its size
        for result in results:
            file_path = result["name"]
            if normalized_path in file_path.lower() or file_path.endswith(normalized_path):
                # Size is in bytes, convert to kilobytes
                size_bytes = result.get("size", 0)
                size_kb = size_bytes / 1024.0
                return size_kb

        raise FileNotFoundError(f"File not found: {path}")
    except Exception as e:
        self.logger.error(f"Error getting file size for {path}: {e}")
        raise
    finally:
        cursor.close()
get_file_metadata(path)

Retrieve comprehensive file metadata from Snowflake stage.

:param path: File path in the stage :return: File metadata dictionary

Source code in gigaspatial/core/io/snowflake_data_store.py
def get_file_metadata(self, path: str) -> dict:
    """
    Retrieve comprehensive file metadata from Snowflake stage.

    :param path: File path in the stage
    :return: File metadata dictionary
    """
    self._ensure_connection()
    cursor = self.connection.cursor(DictCursor)

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        # LIST command returns file metadata
        list_command = f"LIST {stage_path}"
        cursor.execute(list_command)
        results = cursor.fetchall()

        # Find the matching file
        for result in results:
            file_path = result["name"]
            if normalized_path in file_path.lower() or file_path.endswith(normalized_path):
                return {
                    "name": path,
                    "size_bytes": result.get("size", 0),
                    "last_modified": result.get("last_modified"),
                    "md5": result.get("md5"),
                }

        raise FileNotFoundError(f"File not found: {path}")
    except Exception as e:
        self.logger.error(f"Error getting file metadata for {path}: {e}")
        raise
    finally:
        cursor.close()
is_dir(path)

Check if path points to a directory.

Source code in gigaspatial/core/io/snowflake_data_store.py
def is_dir(self, path: str) -> bool:
    """Check if path points to a directory."""
    # First check if it's actually a file (exact match)
    if self.file_exists(path):
        return False

    # In Snowflake stages, directories are conceptual
    # Check if there are files with this path prefix
    normalized_path = self._normalize_path(path)
    files = self.list_files(normalized_path)

    # Filter out files that are exact matches (they're files, not directories)
    exact_match = any(f == normalized_path or f == path for f in files)
    if exact_match:
        return False

    return len(files) > 0
is_file(path)

Check if path points to a file.

Source code in gigaspatial/core/io/snowflake_data_store.py
def is_file(self, path: str) -> bool:
    """Check if path points to a file."""
    return self.file_exists(path)
list_directories(path)

List only directory names (not files) from a given path in the stage.

:param path: Directory path to list :return: List of directory names

Source code in gigaspatial/core/io/snowflake_data_store.py
def list_directories(self, path: str) -> List[str]:
    """
    List only directory names (not files) from a given path in the stage.

    :param path: Directory path to list
    :return: List of directory names
    """
    normalized_path = self._normalize_path(path)
    files = self.list_files(normalized_path)

    directories = set()

    for file_path in files:
        # Get relative path from the search path
        if normalized_path:
            if file_path.startswith(normalized_path):
                relative_path = file_path[len(normalized_path):].lstrip("/")
            else:
                continue
        else:
            relative_path = file_path

        # Skip if empty
        if not relative_path:
            continue

        # If there's a "/" in the relative path, it means there's a subdirectory
        if "/" in relative_path:
            # Get the first directory name
            dir_name = relative_path.split("/")[0]
            directories.add(dir_name)

    return sorted(list(directories))
list_files(path)

List all files in a directory within the Snowflake stage.

:param path: Directory path to list :return: List of file paths

Source code in gigaspatial/core/io/snowflake_data_store.py
def list_files(self, path: str) -> List[str]:
    """
    List all files in a directory within the Snowflake stage.

    :param path: Directory path to list
    :return: List of file paths
    """
    self._ensure_connection()
    cursor = self.connection.cursor(DictCursor)

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        # List files in stage
        list_command = f"LIST {stage_path}"
        cursor.execute(list_command)
        results = cursor.fetchall()

        # Extract file paths relative to the base stage path
        files = []
        for result in results:
            file_path = result["name"]
            # Snowflake LIST returns names in lowercase without @ symbol
            # Remove stage prefix to get relative path
            # Check both @stage_name/ and lowercase stage_name/ formats
            stage_prefixes = [
                f"@{self.stage_name}/",
                f"{self.stage_name.lower()}/",
                f"@{self.stage_name.lower()}/",
            ]

            for prefix in stage_prefixes:
                if file_path.startswith(prefix):
                    relative_path = file_path[len(prefix):]
                    files.append(relative_path)
                    break
            else:
                # If no prefix matches, try to extract path after stage name
                # Sometimes stage name might be in different case
                stage_name_lower = self.stage_name.lower()
                if stage_name_lower in file_path.lower():
                    # Find the position after the stage name
                    idx = file_path.lower().find(stage_name_lower)
                    if idx != -1:
                        # Get everything after stage name and '/'
                        after_stage = file_path[idx + len(stage_name_lower):].lstrip("/")
                        if after_stage.startswith(normalized_path):
                            relative_path = after_stage
                            files.append(relative_path)

        return files

    except Exception as e:
        self.logger.warning(f"Error listing files in {path}: {e}")
        return []
    finally:
        cursor.close()
mkdir(path, exist_ok=False)

Create a directory in Snowflake stage.

In Snowflake stages, directories are created implicitly when files are uploaded. This method creates a placeholder file if the directory doesn't exist.

:param path: Path of the directory to create :param exist_ok: If False, raise an error if the directory already exists

Source code in gigaspatial/core/io/snowflake_data_store.py
def mkdir(self, path: str, exist_ok: bool = False) -> None:
    """
    Create a directory in Snowflake stage.

    In Snowflake stages, directories are created implicitly when files are uploaded.
    This method creates a placeholder file if the directory doesn't exist.

    :param path: Path of the directory to create
    :param exist_ok: If False, raise an error if the directory already exists
    """
    # Check if directory already exists
    if self.is_dir(path) and not exist_ok:
        raise FileExistsError(f"Directory {path} already exists")

    # Create a placeholder file to ensure directory exists
    placeholder_path = os.path.join(path, ".placeholder").replace("\\", "/")
    if not self.file_exists(placeholder_path):
        self.write_file(placeholder_path, b"Placeholder file for directory")
open(path, mode='r')

Context manager for file operations.

:param path: File path in Snowflake stage :param mode: File open mode (r, rb, w, wb)

Source code in gigaspatial/core/io/snowflake_data_store.py
@contextlib.contextmanager
def open(self, path: str, mode: str = "r"):
    """
    Context manager for file operations.

    :param path: File path in Snowflake stage
    :param mode: File open mode (r, rb, w, wb)
    """
    if mode == "w":
        file = io.StringIO()
        yield file
        self.write_file(path, file.getvalue())

    elif mode == "wb":
        file = io.BytesIO()
        yield file
        self.write_file(path, file.getvalue())

    elif mode == "r":
        data = self.read_file(path, encoding="UTF-8")
        file = io.StringIO(data)
        yield file

    elif mode == "rb":
        data = self.read_file(path)
        file = io.BytesIO(data)
        yield file

    else:
        raise ValueError(f"Unsupported mode: {mode}")
read_file(path, encoding=None)

Read file from Snowflake stage.

:param path: Path to the file in the stage :param encoding: File encoding (optional) :return: File contents as string or bytes

Source code in gigaspatial/core/io/snowflake_data_store.py
def read_file(self, path: str, encoding: Optional[str] = None) -> Union[str, bytes]:
    """
    Read file from Snowflake stage.

    :param path: Path to the file in the stage
    :param encoding: File encoding (optional)
    :return: File contents as string or bytes
    """
    self._ensure_connection()
    cursor = self.connection.cursor(DictCursor)

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        # Create temporary directory for download
        temp_download_dir = os.path.join(self._temp_dir, "downloads")
        os.makedirs(temp_download_dir, exist_ok=True)

        # Download file from stage using GET command
        # GET command: GET <stage_path> file://<local_path>
        temp_dir_normalized = temp_download_dir.replace("\\", "/")
        if not temp_dir_normalized.endswith("/"):
            temp_dir_normalized += "/"

        get_command = f"GET {stage_path} 'file://{temp_dir_normalized}'"
        cursor.execute(get_command)

        # Find the downloaded file (Snowflake may add prefixes/suffixes or preserve structure)
        downloaded_files = []
        for root, dirs, files in os.walk(temp_download_dir):
            for f in files:
                file_path = os.path.join(root, f)
                # Check if this file matches our expected filename
                if os.path.basename(normalized_path) in f or normalized_path.endswith(f):
                    downloaded_files.append(file_path)

        if not downloaded_files:
            raise FileNotFoundError(f"File not found in stage: {path}")

        # Read the first matching file
        downloaded_path = downloaded_files[0]
        with open(downloaded_path, "rb") as f:
            data = f.read()

        # Clean up
        os.remove(downloaded_path)
        # Clean up empty directories
        try:
            if os.path.exists(temp_download_dir) and not os.listdir(temp_download_dir):
                os.rmdir(temp_download_dir)
        except OSError:
            pass

        # Decode if encoding is specified
        if encoding:
            return data.decode(encoding)
        return data

    except Exception as e:
        raise IOError(f"Error reading file {path} from Snowflake stage: {e}")
    finally:
        cursor.close()
remove(path)

Remove a file from the Snowflake stage.

:param path: Path to the file to remove

Source code in gigaspatial/core/io/snowflake_data_store.py
def remove(self, path: str) -> None:
    """
    Remove a file from the Snowflake stage.

    :param path: Path to the file to remove
    """
    self._ensure_connection()
    cursor = self.connection.cursor()

    try:
        normalized_path = self._normalize_path(path)
        stage_path = self._get_stage_path(normalized_path)

        remove_command = f"REMOVE {stage_path}"
        cursor.execute(remove_command)

    except Exception as e:
        raise IOError(f"Error removing file {path}: {e}")
    finally:
        cursor.close()
rename(source_path, destination_path, overwrite=False, delete_source=True)

Rename (move) a single file by copying to the new path and deleting the source.

:param source_path: Existing file path in the stage :param destination_path: Target file path in the stage :param overwrite: Overwrite destination if it already exists :param delete_source: Delete original after successful copy

Source code in gigaspatial/core/io/snowflake_data_store.py
def rename(
    self,
    source_path: str,
    destination_path: str,
    overwrite: bool = False,
    delete_source: bool = True,
) -> None:
    """
    Rename (move) a single file by copying to the new path and deleting the source.

    :param source_path: Existing file path in the stage
    :param destination_path: Target file path in the stage
    :param overwrite: Overwrite destination if it already exists
    :param delete_source: Delete original after successful copy
    """
    if not self.file_exists(source_path):
        raise FileNotFoundError(f"Source file not found: {source_path}")

    if self.file_exists(destination_path) and not overwrite:
        raise FileExistsError(
            f"Destination already exists and overwrite is False: {destination_path}"
        )

    # Copy file to new location
    self.copy_file(source_path, destination_path, overwrite=overwrite)

    # Delete source if requested
    if delete_source:
        self.remove(source_path)
rmdir(dir)

Remove a directory and all its contents from the Snowflake stage.

:param dir: Path to the directory to remove

Source code in gigaspatial/core/io/snowflake_data_store.py
def rmdir(self, dir: str) -> None:
    """
    Remove a directory and all its contents from the Snowflake stage.

    :param dir: Path to the directory to remove
    """
    self._ensure_connection()
    cursor = self.connection.cursor()

    try:
        normalized_dir = self._normalize_path(dir)
        stage_path = self._get_stage_path(normalized_dir)

        # Remove all files in the directory
        remove_command = f"REMOVE {stage_path}"
        cursor.execute(remove_command)

    except Exception as e:
        raise IOError(f"Error removing directory {dir}: {e}")
    finally:
        cursor.close()
upload_directory(dir_path, stage_dir_path)

Uploads all files from a local directory to Snowflake stage.

:param dir_path: Local directory path :param stage_dir_path: Destination directory path in the stage

Source code in gigaspatial/core/io/snowflake_data_store.py
def upload_directory(self, dir_path: str, stage_dir_path: str):
    """
    Uploads all files from a local directory to Snowflake stage.

    :param dir_path: Local directory path
    :param stage_dir_path: Destination directory path in the stage
    """
    if not os.path.isdir(dir_path):
        raise NotADirectoryError(f"Local directory not found: {dir_path}")

    for root, dirs, files in os.walk(dir_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            relative_path = os.path.relpath(local_file_path, dir_path)
            # Normalize path separators for stage
            stage_file_path = os.path.join(stage_dir_path, relative_path).replace("\\", "/")

            self.upload_file(local_file_path, stage_file_path)
upload_file(file_path, stage_path)

Uploads a single file from local filesystem to Snowflake stage.

:param file_path: Local file path :param stage_path: Destination path in the stage

Source code in gigaspatial/core/io/snowflake_data_store.py
def upload_file(self, file_path: str, stage_path: str):
    """
    Uploads a single file from local filesystem to Snowflake stage.

    :param file_path: Local file path
    :param stage_path: Destination path in the stage
    """
    try:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Local file not found: {file_path}")

        # Read the file
        with open(file_path, "rb") as f:
            data = f.read()

        # Write to stage using write_file
        self.write_file(stage_path, data)
        self.logger.info(f"Uploaded {file_path} to {stage_path}")
    except Exception as e:
        self.logger.error(f"Failed to upload {file_path}: {e}")
        raise
walk(top)

Walk through directory tree in Snowflake stage, similar to os.walk().

:param top: Starting directory for the walk :return: Generator yielding tuples of (dirpath, dirnames, filenames)

Source code in gigaspatial/core/io/snowflake_data_store.py
def walk(self, top: str) -> Generator[Tuple[str, List[str], List[str]], None, None]:
    """
    Walk through directory tree in Snowflake stage, similar to os.walk().

    :param top: Starting directory for the walk
    :return: Generator yielding tuples of (dirpath, dirnames, filenames)
    """
    try:
        normalized_top = self._normalize_path(top)

        # Use list_files to get all files (it handles path parsing correctly)
        all_files = self.list_files(normalized_top)

        # Organize into directory structure
        dirs = {}

        for file_path in all_files:
            # Ensure we're working with paths relative to the top
            if normalized_top and not file_path.startswith(normalized_top):
                continue

            # Get relative path from top
            if normalized_top and file_path.startswith(normalized_top):
                relative_path = file_path[len(normalized_top):].lstrip("/")
            else:
                relative_path = file_path

            if not relative_path:
                continue

            # Get directory and filename
            if "/" in relative_path:
                dir_path, filename = os.path.split(relative_path)
                full_dir_path = f"{normalized_top}/{dir_path}" if normalized_top else dir_path
                if full_dir_path not in dirs:
                    dirs[full_dir_path] = []
                dirs[full_dir_path].append(filename)
            else:
                # File in root of the top directory
                if normalized_top not in dirs:
                    dirs[normalized_top] = []
                dirs[normalized_top].append(relative_path)

        # Yield results in os.walk format
        for dir_path, files in dirs.items():
            # Extract subdirectories (simplified - Snowflake stages are flat)
            subdirs = []
            yield (dir_path, subdirs, files)

    except Exception as e:
        self.logger.warning(f"Error walking directory {top}: {e}")
        yield (top, [], [])
write_file(path, data)

Write file to Snowflake stage.

:param path: Destination path in the stage :param data: File contents

Source code in gigaspatial/core/io/snowflake_data_store.py
def write_file(self, path: str, data: Union[bytes, str]) -> None:
    """
    Write file to Snowflake stage.

    :param path: Destination path in the stage
    :param data: File contents
    """
    self._ensure_connection()
    cursor = self.connection.cursor()

    try:
        # Convert to bytes if string
        if isinstance(data, str):
            binary_data = data.encode("utf-8")
        elif isinstance(data, bytes):
            binary_data = data
        else:
            raise ValueError('Unsupported data type. Only "bytes" or "string" accepted')

        normalized_path = self._normalize_path(path)

        # Write to temporary file first
        # Use the full path structure for the temp file to preserve directory structure
        temp_file_path = os.path.join(self._temp_dir, normalized_path)
        os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)

        with open(temp_file_path, "wb") as f:
            f.write(binary_data)

        # Upload to stage using PUT command
        # Snowflake PUT requires the local file path and the target stage path
        # Convert Windows paths to Unix-style for Snowflake
        temp_file_normalized = os.path.abspath(temp_file_path).replace("\\", "/")

        # PUT command: PUT 'file://<absolute_local_path>' @stage_name/<path>
        # The file will be stored at the specified path in the stage
        stage_target = f"@{self.stage_name}/"
        if "/" in normalized_path:
            # Include directory structure in stage path
            dir_path = os.path.dirname(normalized_path)
            stage_target = f"@{self.stage_name}/{dir_path}/"

        # Snowflake PUT syntax: PUT 'file://<path>' @stage/path
        put_command = f"PUT 'file://{temp_file_normalized}' {stage_target} OVERWRITE=TRUE AUTO_COMPRESS=FALSE"
        cursor.execute(put_command)

        # Clean up temp file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)
            # Clean up empty directories if they were created
            try:
                temp_dir = os.path.dirname(temp_file_path)
                if temp_dir != self._temp_dir and os.path.exists(temp_dir):
                    os.rmdir(temp_dir)
            except OSError:
                pass  # Directory not empty or other error, ignore

    except Exception as e:
        raise IOError(f"Error writing file {path} to Snowflake stage: {e}")
    finally:
        cursor.close()

writers

write_dataset(data, data_store, path, **kwargs)

Write DataFrame, GeoDataFrame, or a generic object (for JSON) to various file formats in DataStore.

Parameters:

data : pandas.DataFrame, geopandas.GeoDataFrame, or any object The data to write to data storage. data_store : DataStore Instance of DataStore for accessing data storage. path : str Path where the file will be written in data storage. **kwargs : dict Additional arguments passed to the specific writer function.

Raises:

ValueError If the file type is unsupported or if there's an error writing the file. TypeError If input data is not a DataFrame, GeoDataFrame, AND not a generic object intended for a .json file.

Source code in gigaspatial/core/io/writers.py
def write_dataset(data, data_store: DataStore, path, **kwargs):
    """
    Write DataFrame, GeoDataFrame, or a generic object (for JSON)
    to various file formats in DataStore.

    Parameters:
    ----------
    data : pandas.DataFrame, geopandas.GeoDataFrame, or any object
        The data to write to data storage.
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    path : str
        Path where the file will be written in data storage.
    **kwargs : dict
        Additional arguments passed to the specific writer function.

    Raises:
    ------
    ValueError
        If the file type is unsupported or if there's an error writing the file.
    TypeError
            If input data is not a DataFrame, GeoDataFrame, AND not a generic object
            intended for a .json file.
    """

    # Define supported file formats and their writers
    BINARY_FORMATS = {".shp", ".zip", ".parquet", ".gpkg", ".xlsx", ".xls"}

    PANDAS_WRITERS = {
        ".csv": lambda df, buf, **kw: df.to_csv(buf, **kw),
        ".xlsx": lambda df, buf, **kw: df.to_excel(buf, engine="openpyxl", **kw),
        ".json": lambda df, buf, **kw: df.to_json(buf, **kw),
        ".parquet": lambda df, buf, **kw: df.to_parquet(buf, **kw),
    }

    GEO_WRITERS = {
        ".geojson": lambda gdf, buf, **kw: gdf.to_file(buf, driver="GeoJSON", **kw),
        ".gpkg": lambda gdf, buf, **kw: gdf.to_file(buf, driver="GPKG", **kw),
        ".parquet": lambda gdf, buf, **kw: gdf.to_parquet(buf, **kw),
    }

    try:
        # Get file suffix and ensure it's lowercase
        suffix = Path(path).suffix.lower()

        # 1. Handle generic JSON data
        is_dataframe_like = isinstance(data, (pd.DataFrame, gpd.GeoDataFrame))
        if not is_dataframe_like:
            if suffix == ".json":
                try:
                    # Pass generic data directly to the write_json function
                    write_json(data, data_store, path, **kwargs)
                    return  # Successfully wrote JSON, so exit
                except Exception as e:
                    raise ValueError(f"Error writing generic JSON data: {str(e)}")
            else:
                # Raise an error if it's not a DataFrame/GeoDataFrame and not a .json file
                raise TypeError(
                    "Input data must be a pandas DataFrame or GeoDataFrame, "
                    "or a generic object destined for a '.json' file."
                )

        # 2. Handle DataFrame/GeoDataFrame
        # Determine if we need binary mode based on file type
        mode = "wb" if suffix in BINARY_FORMATS else "w"

        # Handle different data types and formats
        if isinstance(data, gpd.GeoDataFrame):
            if suffix not in GEO_WRITERS:
                supported_formats = sorted(GEO_WRITERS.keys())
                raise ValueError(
                    f"Unsupported file type for GeoDataFrame: {suffix}\n"
                    f"Supported formats: {', '.join(supported_formats)}"
                )

            try:
                with data_store.open(path, "wb") as f:
                    GEO_WRITERS[suffix](data, f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error writing GeoDataFrame: {str(e)}")

        else:  # pandas DataFrame
            if suffix not in PANDAS_WRITERS:
                supported_formats = sorted(PANDAS_WRITERS.keys())
                raise ValueError(
                    f"Unsupported file type for DataFrame: {suffix}\n"
                    f"Supported formats: {', '.join(supported_formats)}"
                )

            try:
                with data_store.open(path, mode) as f:
                    PANDAS_WRITERS[suffix](data, f, **kwargs)
            except Exception as e:
                raise ValueError(f"Error writing DataFrame: {str(e)}")

    except Exception as e:
        if isinstance(e, (TypeError, ValueError)):
            raise
        raise RuntimeError(f"Unexpected error writing dataset: {str(e)}")
write_datasets(data_dict, data_store, **kwargs)

Write multiple datasets to data storage at once.

Parameters:

data_dict : dict Dictionary mapping paths to DataFrames/GeoDataFrames. data_store : DataStore Instance of DataStore for accessing data storage. **kwargs : dict Additional arguments passed to write_dataset.

Raises:

ValueError If there are any errors writing the datasets.

Source code in gigaspatial/core/io/writers.py
def write_datasets(data_dict, data_store: DataStore, **kwargs):
    """
    Write multiple datasets to data storage at once.

    Parameters:
    ----------
    data_dict : dict
        Dictionary mapping paths to DataFrames/GeoDataFrames.
    data_store : DataStore
        Instance of DataStore for accessing data storage.
    **kwargs : dict
        Additional arguments passed to write_dataset.

    Raises:
    ------
    ValueError
        If there are any errors writing the datasets.
    """
    errors = {}

    for path, data in data_dict.items():
        try:
            write_dataset(data, data_store, path, **kwargs)
        except Exception as e:
            errors[path] = str(e)

    if errors:
        error_msg = "\n".join(f"- {path}: {error}" for path, error in errors.items())
        raise ValueError(f"Errors writing datasets:\n{error_msg}")

schemas

entity

BaseGigaEntity

Bases: BaseModel

Base class for all Giga entities with common fields.

Source code in gigaspatial/core/schemas/entity.py
class BaseGigaEntity(BaseModel):
    """Base class for all Giga entities with common fields."""

    source: Optional[str] = Field(None, max_length=100, description="Source reference")
    source_detail: Optional[str] = None

    @property
    def id(self) -> str:
        """Abstract property that must be implemented by subclasses."""
        raise NotImplementedError("Subclasses must implement id property")
id: str property

Abstract property that must be implemented by subclasses.

EntityTable

Bases: BaseModel, Generic[E]

Source code in gigaspatial/core/schemas/entity.py
class EntityTable(BaseModel, Generic[E]):
    entities: List[E] = Field(default_factory=list)
    _cached_kdtree: Optional[cKDTree] = PrivateAttr(
        default=None
    )  # Internal cache for the KDTree

    @classmethod
    def from_file(
        cls: Type["EntityTable"],
        file_path: Union[str, Path],
        entity_class: Type[E],
        data_store: Optional[DataStore] = None,
        **kwargs,
    ) -> "EntityTable":
        """
        Create an EntityTable instance from a file.

        Args:
            file_path: Path to the dataset file
            entity_class: The entity class for validation

        Returns:
            EntityTable instance

        Raises:
            ValidationError: If any row fails validation
            FileNotFoundError: If the file doesn't exist
        """
        data_store = data_store or LocalDataStore()
        file_path = Path(file_path)
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")

        df = read_dataset(data_store, file_path, **kwargs)
        try:
            entities = [entity_class(**row) for row in df.to_dict(orient="records")]
            return cls(entities=entities)
        except ValidationError as e:
            raise ValueError(f"Validation error in input data: {e}")
        except Exception as e:
            raise ValueError(f"Error reading or processing the file: {e}")

    def _check_has_location(self, method_name: str) -> bool:
        """Helper method to check if entities have location data."""
        if not self.entities:
            return False
        if not isinstance(self.entities[0], GigaEntity):
            raise ValueError(
                f"Cannot perform {method_name}: entities of type {type(self.entities[0]).__name__} "
                "do not have location data (latitude/longitude)"
            )
        return True

    def to_dataframe(self) -> pd.DataFrame:
        """Convert the entity table to a pandas DataFrame."""
        return pd.DataFrame([e.model_dump() for e in self.entities])

    def to_geodataframe(self) -> gpd.GeoDataFrame:
        """Convert the entity table to a GeoDataFrame."""
        if not self._check_has_location("to_geodataframe"):
            raise ValueError("Cannot create GeoDataFrame: no entities available")
        df = self.to_dataframe()
        return gpd.GeoDataFrame(
            df,
            geometry=gpd.points_from_xy(df["longitude"], df["latitude"]),
            crs="EPSG:4326",
        )

    def to_coordinate_vector(self) -> np.ndarray:
        """Transforms the entity table into a numpy vector of coordinates"""
        if not self.entities:
            return np.zeros((0, 2))

        if not self._check_has_location("to_coordinate_vector"):
            return np.zeros((0, 2))

        return np.array([[e.latitude, e.longitude] for e in self.entities])

    def get_lat_array(self) -> np.ndarray:
        """Get an array of latitude values."""
        if not self._check_has_location("get_lat_array"):
            return np.array([])
        return np.array([e.latitude for e in self.entities])

    def get_lon_array(self) -> np.ndarray:
        """Get an array of longitude values."""
        if not self._check_has_location("get_lon_array"):
            return np.array([])
        return np.array([e.longitude for e in self.entities])

    def filter_by_admin1(self, admin1_id_giga: str) -> "EntityTable[E]":
        """Filter entities by primary administrative division."""
        return self.__class__(
            entities=[e for e in self.entities if e.admin1_id_giga == admin1_id_giga]
        )

    def filter_by_admin2(self, admin2_id_giga: str) -> "EntityTable[E]":
        """Filter entities by secondary administrative division."""
        return self.__class__(
            entities=[e for e in self.entities if e.admin2_id_giga == admin2_id_giga]
        )

    def filter_by_polygon(self, polygon: Polygon) -> "EntityTable[E]":
        """Filter entities within a polygon"""
        if not self._check_has_location("filter_by_polygon"):
            return self.__class__(entities=[])

        filtered = [
            e for e in self.entities if polygon.contains(Point(e.longitude, e.latitude))
        ]
        return self.__class__(entities=filtered)

    def filter_by_bounds(
        self, min_lat: float, max_lat: float, min_lon: float, max_lon: float
    ) -> "EntityTable[E]":
        """Filter entities whose coordinates fall within the given bounds."""
        if not self._check_has_location("filter_by_bounds"):
            return self.__class__(entities=[])

        filtered = [
            e
            for e in self.entities
            if min_lat <= e.latitude <= max_lat and min_lon <= e.longitude <= max_lon
        ]
        return self.__class__(entities=filtered)

    def get_nearest_neighbors(
        self, lat: float, lon: float, k: int = 5
    ) -> "EntityTable[E]":
        """Find k nearest neighbors to a point using a cached KDTree."""
        if not self._check_has_location("get_nearest_neighbors"):
            return self.__class__(entities=[])

        if not self._cached_kdtree:
            self._build_kdtree()  # Build the KDTree if not already cached

        if not self._cached_kdtree:  # If still None after building
            return self.__class__(entities=[])

        _, indices = self._cached_kdtree.query([[lat, lon]], k=k)
        return self.__class__(entities=[self.entities[i] for i in indices[0]])

    def _build_kdtree(self):
        """Builds and caches the KDTree."""
        if not self._check_has_location("_build_kdtree"):
            self._cached_kdtree = None
            return
        coords = self.to_coordinate_vector()
        if coords:
            self._cached_kdtree = cKDTree(coords)

    def clear_cache(self):
        """Clears the KDTree cache."""
        self._cached_kdtree = None

    def to_file(
        self,
        file_path: Union[str, Path],
        data_store: Optional[DataStore] = None,
        **kwargs,
    ) -> None:
        """
        Save the entity data to a file.

        Args:
            file_path: Path to save the file
        """
        if not self.entities:
            raise ValueError("Cannot write to a file: no entities available.")

        data_store = data_store or LocalDataStore()

        write_dataset(self.to_dataframe(), data_store, file_path, **kwargs)

    def __len__(self) -> int:
        return len(self.entities)

    def __iter__(self):
        return iter(self.entities)
clear_cache()

Clears the KDTree cache.

Source code in gigaspatial/core/schemas/entity.py
def clear_cache(self):
    """Clears the KDTree cache."""
    self._cached_kdtree = None
filter_by_admin1(admin1_id_giga)

Filter entities by primary administrative division.

Source code in gigaspatial/core/schemas/entity.py
def filter_by_admin1(self, admin1_id_giga: str) -> "EntityTable[E]":
    """Filter entities by primary administrative division."""
    return self.__class__(
        entities=[e for e in self.entities if e.admin1_id_giga == admin1_id_giga]
    )
filter_by_admin2(admin2_id_giga)

Filter entities by secondary administrative division.

Source code in gigaspatial/core/schemas/entity.py
def filter_by_admin2(self, admin2_id_giga: str) -> "EntityTable[E]":
    """Filter entities by secondary administrative division."""
    return self.__class__(
        entities=[e for e in self.entities if e.admin2_id_giga == admin2_id_giga]
    )
filter_by_bounds(min_lat, max_lat, min_lon, max_lon)

Filter entities whose coordinates fall within the given bounds.

Source code in gigaspatial/core/schemas/entity.py
def filter_by_bounds(
    self, min_lat: float, max_lat: float, min_lon: float, max_lon: float
) -> "EntityTable[E]":
    """Filter entities whose coordinates fall within the given bounds."""
    if not self._check_has_location("filter_by_bounds"):
        return self.__class__(entities=[])

    filtered = [
        e
        for e in self.entities
        if min_lat <= e.latitude <= max_lat and min_lon <= e.longitude <= max_lon
    ]
    return self.__class__(entities=filtered)
filter_by_polygon(polygon)

Filter entities within a polygon

Source code in gigaspatial/core/schemas/entity.py
def filter_by_polygon(self, polygon: Polygon) -> "EntityTable[E]":
    """Filter entities within a polygon"""
    if not self._check_has_location("filter_by_polygon"):
        return self.__class__(entities=[])

    filtered = [
        e for e in self.entities if polygon.contains(Point(e.longitude, e.latitude))
    ]
    return self.__class__(entities=filtered)
from_file(file_path, entity_class, data_store=None, **kwargs) classmethod

Create an EntityTable instance from a file.

Parameters:

Name Type Description Default
file_path Union[str, Path]

Path to the dataset file

required
entity_class Type[E]

The entity class for validation

required

Returns:

Type Description
EntityTable

EntityTable instance

Raises:

Type Description
ValidationError

If any row fails validation

FileNotFoundError

If the file doesn't exist

Source code in gigaspatial/core/schemas/entity.py
@classmethod
def from_file(
    cls: Type["EntityTable"],
    file_path: Union[str, Path],
    entity_class: Type[E],
    data_store: Optional[DataStore] = None,
    **kwargs,
) -> "EntityTable":
    """
    Create an EntityTable instance from a file.

    Args:
        file_path: Path to the dataset file
        entity_class: The entity class for validation

    Returns:
        EntityTable instance

    Raises:
        ValidationError: If any row fails validation
        FileNotFoundError: If the file doesn't exist
    """
    data_store = data_store or LocalDataStore()
    file_path = Path(file_path)
    if not file_path.exists():
        raise FileNotFoundError(f"File not found: {file_path}")

    df = read_dataset(data_store, file_path, **kwargs)
    try:
        entities = [entity_class(**row) for row in df.to_dict(orient="records")]
        return cls(entities=entities)
    except ValidationError as e:
        raise ValueError(f"Validation error in input data: {e}")
    except Exception as e:
        raise ValueError(f"Error reading or processing the file: {e}")
get_lat_array()

Get an array of latitude values.

Source code in gigaspatial/core/schemas/entity.py
def get_lat_array(self) -> np.ndarray:
    """Get an array of latitude values."""
    if not self._check_has_location("get_lat_array"):
        return np.array([])
    return np.array([e.latitude for e in self.entities])
get_lon_array()

Get an array of longitude values.

Source code in gigaspatial/core/schemas/entity.py
def get_lon_array(self) -> np.ndarray:
    """Get an array of longitude values."""
    if not self._check_has_location("get_lon_array"):
        return np.array([])
    return np.array([e.longitude for e in self.entities])
get_nearest_neighbors(lat, lon, k=5)

Find k nearest neighbors to a point using a cached KDTree.

Source code in gigaspatial/core/schemas/entity.py
def get_nearest_neighbors(
    self, lat: float, lon: float, k: int = 5
) -> "EntityTable[E]":
    """Find k nearest neighbors to a point using a cached KDTree."""
    if not self._check_has_location("get_nearest_neighbors"):
        return self.__class__(entities=[])

    if not self._cached_kdtree:
        self._build_kdtree()  # Build the KDTree if not already cached

    if not self._cached_kdtree:  # If still None after building
        return self.__class__(entities=[])

    _, indices = self._cached_kdtree.query([[lat, lon]], k=k)
    return self.__class__(entities=[self.entities[i] for i in indices[0]])
to_coordinate_vector()

Transforms the entity table into a numpy vector of coordinates

Source code in gigaspatial/core/schemas/entity.py
def to_coordinate_vector(self) -> np.ndarray:
    """Transforms the entity table into a numpy vector of coordinates"""
    if not self.entities:
        return np.zeros((0, 2))

    if not self._check_has_location("to_coordinate_vector"):
        return np.zeros((0, 2))

    return np.array([[e.latitude, e.longitude] for e in self.entities])
to_dataframe()

Convert the entity table to a pandas DataFrame.

Source code in gigaspatial/core/schemas/entity.py
def to_dataframe(self) -> pd.DataFrame:
    """Convert the entity table to a pandas DataFrame."""
    return pd.DataFrame([e.model_dump() for e in self.entities])
to_file(file_path, data_store=None, **kwargs)

Save the entity data to a file.

Parameters:

Name Type Description Default
file_path Union[str, Path]

Path to save the file

required
Source code in gigaspatial/core/schemas/entity.py
def to_file(
    self,
    file_path: Union[str, Path],
    data_store: Optional[DataStore] = None,
    **kwargs,
) -> None:
    """
    Save the entity data to a file.

    Args:
        file_path: Path to save the file
    """
    if not self.entities:
        raise ValueError("Cannot write to a file: no entities available.")

    data_store = data_store or LocalDataStore()

    write_dataset(self.to_dataframe(), data_store, file_path, **kwargs)
to_geodataframe()

Convert the entity table to a GeoDataFrame.

Source code in gigaspatial/core/schemas/entity.py
def to_geodataframe(self) -> gpd.GeoDataFrame:
    """Convert the entity table to a GeoDataFrame."""
    if not self._check_has_location("to_geodataframe"):
        raise ValueError("Cannot create GeoDataFrame: no entities available")
    df = self.to_dataframe()
    return gpd.GeoDataFrame(
        df,
        geometry=gpd.points_from_xy(df["longitude"], df["latitude"]),
        crs="EPSG:4326",
    )
GigaEntity

Bases: BaseGigaEntity

Entity with location data.

Source code in gigaspatial/core/schemas/entity.py
class GigaEntity(BaseGigaEntity):
    """Entity with location data."""

    latitude: float = Field(
        ..., ge=-90, le=90, description="Latitude coordinate of the entity"
    )
    longitude: float = Field(
        ..., ge=-180, le=180, description="Longitude coordinate of the entity"
    )
    admin1: Optional[str] = Field(
        "Unknown", max_length=100, description="Primary administrative division"
    )
    admin1_id_giga: Optional[str] = Field(
        None,
        max_length=50,
        description="Unique identifier for the primary administrative division",
    )
    admin2: Optional[str] = Field(
        "Unknown", max_length=100, description="Secondary administrative division"
    )
    admin2_id_giga: Optional[str] = Field(
        None,
        max_length=50,
        description="Unique identifier for the secondary administrative division",
    )
GigaEntityNoLocation

Bases: BaseGigaEntity

Entity without location data.

Source code in gigaspatial/core/schemas/entity.py
class GigaEntityNoLocation(BaseGigaEntity):
    """Entity without location data."""

    pass