Source code for geo_sampling.visualization

"""Visualization utilities for road segments."""

from typing import List, Dict
import matplotlib.pyplot as plt
from matplotlib import colors

from ._types import RoadSegment


[docs] class RoadPlotter: """Handles plotting of road segments with automatic styling.""" def __init__(self, figsize: tuple = (12, 10)): """Initialize plotter. Args: figsize: Figure size for plots """ self.figsize = figsize self.color_map: Dict[str, str] = {} self._color_values = list(colors.CSS4_COLORS.keys()) self._color_index = 0
[docs] def plot( self, segments: List[RoadSegment], title: str = "Road Segments", show_legend: bool = True, grid: bool = True, ) -> None: """Plot road segments with different colors by road type. Args: segments: List of road segments to plot title: Plot title show_legend: Whether to show legend grid: Whether to show grid """ if not segments: print("No segments to plot") return fig, ax = plt.subplots(figsize=self.figsize) # Track which road types we've already added to legend legend_labels = set() for segment in segments: x_coords = [segment.start_long, segment.end_long] y_coords = [segment.start_lat, segment.end_lat] # Get color for this road type color = self._get_road_color(segment.osm_type) # Only add label if this road type hasn't been added yet label = segment.osm_type if segment.osm_type not in legend_labels else "" if label: legend_labels.add(segment.osm_type) ax.plot( x_coords, y_coords, color=color, linewidth=1.2, label=label, alpha=0.8 ) # Format axes ax.get_yaxis().get_major_formatter().set_useOffset(False) ax.get_yaxis().get_major_formatter().set_scientific(False) ax.get_xaxis().get_major_formatter().set_useOffset(False) ax.get_xaxis().get_major_formatter().set_scientific(False) ax.set_xlabel("Longitude") ax.set_ylabel("Latitude") ax.set_title(title) if show_legend and legend_labels: ax.legend(loc="best", fancybox=True, framealpha=0.7) if grid: ax.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] def plot_sample_comparison( self, all_segments: List[RoadSegment], sampled_segments: List[RoadSegment] ) -> None: """Plot comparison between all segments and sample. Args: all_segments: Complete set of road segments sampled_segments: Sampled subset """ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) # Plot all segments self._plot_on_axis(ax1, all_segments, f"All Segments (N = {len(all_segments)})") # Plot sampled segments self._plot_on_axis( ax2, sampled_segments, f"Sample (N = {len(sampled_segments)})" ) plt.tight_layout() plt.show()
def _plot_on_axis(self, ax, segments: List[RoadSegment], title: str) -> None: """Plot segments on a specific axis.""" legend_labels = set() for segment in segments: x_coords = [segment.start_long, segment.end_long] y_coords = [segment.start_lat, segment.end_lat] color = self._get_road_color(segment.osm_type) label = segment.osm_type if segment.osm_type not in legend_labels else "" if label: legend_labels.add(segment.osm_type) ax.plot( x_coords, y_coords, color=color, linewidth=1.2, label=label, alpha=0.8 ) # Format axis ax.get_yaxis().get_major_formatter().set_useOffset(False) ax.get_yaxis().get_major_formatter().set_scientific(False) ax.get_xaxis().get_major_formatter().set_useOffset(False) ax.get_xaxis().get_major_formatter().set_scientific(False) ax.set_xlabel("Longitude") ax.set_ylabel("Latitude") ax.set_title(title) ax.grid(True, alpha=0.3) if legend_labels: ax.legend(loc="best", fancybox=True, framealpha=0.7) def _get_road_color(self, road_type: str) -> str: """Get consistent color for a road type.""" if road_type not in self.color_map: # Assign next color self.color_map[road_type] = self._color_values[ self._color_index % len(self._color_values) ] self._color_index += 1 return self.color_map[road_type]
[docs] def plot_road_segments( segments: List[RoadSegment], title: str = "Road Segments", figsize: tuple = (12, 10), show_legend: bool = True, ) -> None: """Convenience function to plot road segments. Args: segments: List of road segments to plot title: Plot title figsize: Figure size show_legend: Whether to show legend """ plotter = RoadPlotter(figsize=figsize) plotter.plot(segments, title=title, show_legend=show_legend)